Merge pull request #1144 from terrelln/fse-entropy

Approximate FSE encoding costs for selection
diff --git a/lib/common/fse.h b/lib/common/fse.h
index 5a23444..e88a5ef 100644
--- a/lib/common/fse.h
+++ b/lib/common/fse.h
@@ -581,12 +581,13 @@
     return (symbolTT[symbolValue].deltaNbBits + ((1<<16)-1)) >> 16;
 }
 
-/* FSE_bitCost_b256() :
+/* FSE_bitCost() :
  * Approximate symbol cost,
  * provide fractional value, using fixed-point format (accuracyLog fractional bits)
  * note: assume symbolValue is valid */
-MEM_STATIC U32 FSE_bitCost(const FSE_symbolCompressionTransform* symbolTT, U32 tableLog, U32 symbolValue, U32 accuracyLog)
+MEM_STATIC U32 FSE_bitCost(const void* symbolTTPtr, U32 tableLog, U32 symbolValue, U32 accuracyLog)
 {
+    const FSE_symbolCompressionTransform* symbolTT = (const FSE_symbolCompressionTransform*) symbolTTPtr;
     U32 const minNbBits = symbolTT[symbolValue].deltaNbBits >> 16;
     U32 const threshold = (minNbBits+1) << 16;
     assert(tableLog < 16);
diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c
index d8420a8..22c704f 100644
--- a/lib/compress/zstd_compress.c
+++ b/lib/compress/zstd_compress.c
@@ -946,10 +946,10 @@
     int i;
     for (i = 0; i < ZSTD_REP_NUM; ++i)
         bs->rep[i] = repStartValue[i];
-    bs->entropy.hufCTable_repeatMode = HUF_repeat_none;
-    bs->entropy.offcode_repeatMode = FSE_repeat_none;
-    bs->entropy.matchlength_repeatMode = FSE_repeat_none;
-    bs->entropy.litlength_repeatMode = FSE_repeat_none;
+    bs->entropy.huf.repeatMode = HUF_repeat_none;
+    bs->entropy.fse.offcode_repeatMode = FSE_repeat_none;
+    bs->entropy.fse.matchlength_repeatMode = FSE_repeat_none;
+    bs->entropy.fse.litlength_repeatMode = FSE_repeat_none;
 }
 
 /*! ZSTD_invalidateMatchState()
@@ -1455,8 +1455,8 @@
 
 static size_t ZSTD_minGain(size_t srcSize) { return (srcSize >> 6) + 2; }
 
-static size_t ZSTD_compressLiterals (ZSTD_entropyCTables_t const* prevEntropy,
-                                     ZSTD_entropyCTables_t* nextEntropy,
+static size_t ZSTD_compressLiterals (ZSTD_hufCTables_t const* prevHuf,
+                                     ZSTD_hufCTables_t* nextHuf,
                                      ZSTD_strategy strategy, int disableLiteralCompression,
                                      void* dst, size_t dstCapacity,
                                const void* src, size_t srcSize,
@@ -1473,27 +1473,25 @@
                 disableLiteralCompression);
 
     /* Prepare nextEntropy assuming reusing the existing table */
-    nextEntropy->hufCTable_repeatMode = prevEntropy->hufCTable_repeatMode;
-    memcpy(nextEntropy->hufCTable, prevEntropy->hufCTable,
-           sizeof(prevEntropy->hufCTable));
+    memcpy(nextHuf, prevHuf, sizeof(*prevHuf));
 
     if (disableLiteralCompression)
         return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize);
 
     /* small ? don't even attempt compression (speed opt) */
 #   define COMPRESS_LITERALS_SIZE_MIN 63
-    {   size_t const minLitSize = (prevEntropy->hufCTable_repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN;
+    {   size_t const minLitSize = (prevHuf->repeatMode == HUF_repeat_valid) ? 6 : COMPRESS_LITERALS_SIZE_MIN;
         if (srcSize <= minLitSize) return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize);
     }
 
     if (dstCapacity < lhSize+1) return ERROR(dstSize_tooSmall);   /* not enough space for compression */
-    {   HUF_repeat repeat = prevEntropy->hufCTable_repeatMode;
+    {   HUF_repeat repeat = prevHuf->repeatMode;
         int const preferRepeat = strategy < ZSTD_lazy ? srcSize <= 1024 : 0;
         if (repeat == HUF_repeat_valid && lhSize == 3) singleStream = 1;
         cLitSize = singleStream ? HUF_compress1X_repeat(ostart+lhSize, dstCapacity-lhSize, src, srcSize, 255, 11,
-                                      workspace, HUF_WORKSPACE_SIZE, (HUF_CElt*)nextEntropy->hufCTable, &repeat, preferRepeat, bmi2)
+                                      workspace, HUF_WORKSPACE_SIZE, (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2)
                                 : HUF_compress4X_repeat(ostart+lhSize, dstCapacity-lhSize, src, srcSize, 255, 11,
-                                      workspace, HUF_WORKSPACE_SIZE, (HUF_CElt*)nextEntropy->hufCTable, &repeat, preferRepeat, bmi2);
+                                      workspace, HUF_WORKSPACE_SIZE, (HUF_CElt*)nextHuf->CTable, &repeat, preferRepeat, bmi2);
         if (repeat != HUF_repeat_none) {
             /* reused the existing table */
             hType = set_repeat;
@@ -1501,17 +1499,17 @@
     }
 
     if ((cLitSize==0) | (cLitSize >= srcSize - minGain) | ERR_isError(cLitSize)) {
-        memcpy(nextEntropy->hufCTable, prevEntropy->hufCTable, sizeof(prevEntropy->hufCTable));
+        memcpy(nextHuf, prevHuf, sizeof(*prevHuf));
         return ZSTD_noCompressLiterals(dst, dstCapacity, src, srcSize);
     }
     if (cLitSize==1) {
-        memcpy(nextEntropy->hufCTable, prevEntropy->hufCTable, sizeof(prevEntropy->hufCTable));
+        memcpy(nextHuf, prevHuf, sizeof(*prevHuf));
         return ZSTD_compressRleLiteralsBlock(dst, dstCapacity, src, srcSize);
     }
 
     if (hType == set_compressed) {
         /* using a newly constructed table */
-        nextEntropy->hufCTable_repeatMode = HUF_repeat_check;
+        nextHuf->repeatMode = HUF_repeat_check;
     }
 
     /* Build header */
@@ -1561,6 +1559,137 @@
         mlCodeTable[seqStorePtr->longLengthPos] = MaxML;
 }
 
+
+/**
+ * -log2(x / 256) lookup table for x in [0, 256).
+ * If x == 0: Return 0
+ * Else: Return floor(-log2(x / 256) * 256)
+ */
+static unsigned const kInverseProbabiltyLog256[256] = {
+    0,    2048, 1792, 1642, 1536, 1453, 1386, 1329, 1280, 1236, 1197, 1162,
+    1130, 1100, 1073, 1047, 1024, 1001, 980,  960,  941,  923,  906,  889,
+    874,  859,  844,  830,  817,  804,  791,  779,  768,  756,  745,  734,
+    724,  714,  704,  694,  685,  676,  667,  658,  650,  642,  633,  626,
+    618,  610,  603,  595,  588,  581,  574,  567,  561,  554,  548,  542,
+    535,  529,  523,  517,  512,  506,  500,  495,  489,  484,  478,  473,
+    468,  463,  458,  453,  448,  443,  438,  434,  429,  424,  420,  415,
+    411,  407,  402,  398,  394,  390,  386,  382,  377,  373,  370,  366,
+    362,  358,  354,  350,  347,  343,  339,  336,  332,  329,  325,  322,
+    318,  315,  311,  308,  305,  302,  298,  295,  292,  289,  286,  282,
+    279,  276,  273,  270,  267,  264,  261,  258,  256,  253,  250,  247,
+    244,  241,  239,  236,  233,  230,  228,  225,  222,  220,  217,  215,
+    212,  209,  207,  204,  202,  199,  197,  194,  192,  190,  187,  185,
+    182,  180,  178,  175,  173,  171,  168,  166,  164,  162,  159,  157,
+    155,  153,  151,  149,  146,  144,  142,  140,  138,  136,  134,  132,
+    130,  128,  126,  123,  121,  119,  117,  115,  114,  112,  110,  108,
+    106,  104,  102,  100,  98,   96,   94,   93,   91,   89,   87,   85,
+    83,   82,   80,   78,   76,   74,   73,   71,   69,   67,   66,   64,
+    62,   61,   59,   57,   55,   54,   52,   50,   49,   47,   46,   44,
+    42,   41,   39,   37,   36,   34,   33,   31,   30,   28,   26,   25,
+    23,   22,   20,   19,   17,   16,   14,   13,   11,   10,   8,    7,
+    5,    4,    2,    1,
+};
+
+
+/**
+ * Returns the cost in bits of encoding the distribution described by count
+ * using the entropy bound.
+ */
+static size_t ZSTD_entropyCost(unsigned const* count, unsigned const max, size_t const total)
+{
+    unsigned cost = 0;
+    unsigned s;
+    for (s = 0; s <= max; ++s) {
+        unsigned norm = (unsigned)((256 * count[s]) / total);
+        if (count[s] != 0 && norm == 0)
+            norm = 1;
+        assert(count[s] < total);
+        cost += count[s] * kInverseProbabiltyLog256[norm];
+    }
+    return cost >> 8;
+}
+
+
+/**
+ * Returns the cost in bits of encoding the distribution in count using the
+ * table described by norm. The max symbol support by norm is assumed >= max.
+ * norm must be valid for every symbol with non-zero probability in count.
+ */
+static size_t ZSTD_crossEntropyCost(short const* norm, unsigned accuracyLog,
+                                    unsigned const* count, unsigned const max)
+{
+    unsigned const shift = 8 - accuracyLog;
+    size_t cost = 0;
+    unsigned s;
+    assert(accuracyLog <= 8);
+    for (s = 0; s <= max; ++s) {
+        unsigned const normAcc = norm[s] != -1 ? norm[s] : 1;
+        unsigned const norm256 = normAcc << shift;
+        assert(norm256 > 0);
+        assert(norm256 < 256);
+        cost += count[s] * kInverseProbabiltyLog256[norm256];
+    }
+    return cost >> 8;
+}
+
+
+static unsigned ZSTD_getFSEMaxSymbolValue(FSE_CTable const* ctable) {
+  void const* ptr = ctable;
+  U16 const* u16ptr = (U16 const*)ptr;
+  U32 const maxSymbolValue = MEM_read16(u16ptr + 1);
+  return maxSymbolValue;
+}
+
+
+/**
+ * Returns the cost in bits of encoding the distribution in count using ctable.
+ * Returns an error if ctable cannot represent all the symbols in count.
+ */
+static size_t ZSTD_fseBitCost(
+    FSE_CTable const* ctable,
+    unsigned const* count,
+    unsigned const max)
+{
+    unsigned const kAccuracyLog = 8;
+    size_t cost = 0;
+    unsigned s;
+    FSE_CState_t cstate;
+    FSE_initCState(&cstate, ctable);
+    if (ZSTD_getFSEMaxSymbolValue(ctable) < max) {
+        DEBUGLOG(5, "Repeat FSE_CTable has maxSymbolValue %u < %u",
+                    ZSTD_getFSEMaxSymbolValue(ctable), max);
+        return ERROR(GENERIC);
+    }
+    for (s = 0; s <= max; ++s) {
+        unsigned const tableLog = cstate.stateLog;
+        unsigned const badCost = (tableLog + 1) << kAccuracyLog;
+        unsigned const bitCost = FSE_bitCost(cstate.symbolTT, tableLog, s, kAccuracyLog);
+        if (count[s] == 0)
+            continue;
+        if (bitCost >= badCost) {
+            DEBUGLOG(5, "Repeat FSE_CTable has Prob[%u] == 0", s);
+            return ERROR(GENERIC);
+        }
+        cost += count[s] * bitCost;
+    }
+    return cost >> kAccuracyLog;
+}
+
+/**
+ * Returns the cost in bytes of encoding the normalized count header.
+ * Returns an error if any of the helper functions return an error.
+ */
+static size_t ZSTD_NCountCost(unsigned const* count, unsigned const max,
+                              size_t const nbSeq, unsigned const FSELog)
+{
+    BYTE wksp[FSE_NCOUNTBOUND];
+    S16 norm[MaxSeq + 1];
+    const U32 tableLog = FSE_optimalTableLog(FSELog, nbSeq, max);
+    CHECK_F(FSE_normalizeCount(norm, tableLog, count, nbSeq, max));
+    return FSE_writeNCount(wksp, sizeof(wksp), norm, max, tableLog);
+}
+
+
 typedef enum {
     ZSTD_defaultDisallowed = 0,
     ZSTD_defaultAllowed = 1
@@ -1568,37 +1697,73 @@
 
 MEM_STATIC
 symbolEncodingType_e ZSTD_selectEncodingType(
-        FSE_repeat* repeatMode, size_t const mostFrequent, size_t nbSeq,
-        U32 defaultNormLog, ZSTD_defaultPolicy_e const isDefaultAllowed)
+        FSE_repeat* repeatMode, unsigned const* count, unsigned const max,
+        size_t const mostFrequent, size_t nbSeq, unsigned const FSELog,
+        FSE_CTable const* prevCTable,
+        short const* defaultNorm, U32 defaultNormLog,
+        ZSTD_defaultPolicy_e const isDefaultAllowed,
+        ZSTD_strategy const strategy)
 {
 #define MIN_SEQ_FOR_DYNAMIC_FSE   64
 #define MAX_SEQ_FOR_STATIC_FSE  1000
     ZSTD_STATIC_ASSERT(ZSTD_defaultDisallowed == 0 && ZSTD_defaultAllowed != 0);
-    if ((mostFrequent == nbSeq) && (!isDefaultAllowed || nbSeq > 2)) {
+    if (mostFrequent == nbSeq) {
+        *repeatMode = FSE_repeat_none;
+        if (isDefaultAllowed && nbSeq <= 2) {
+            /* Prefer set_basic over set_rle when there are 2 or less symbols,
+             * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
+             * If basic encoding isn't possible, always choose RLE.
+             */
+            DEBUGLOG(5, "Selected set_basic");
+            return set_basic;
+        }
         DEBUGLOG(5, "Selected set_rle");
-        /* Prefer set_basic over set_rle when there are 2 or less symbols,
-         * since RLE uses 1 byte, but set_basic uses 5-6 bits per symbol.
-         * If basic encoding isn't possible, always choose RLE.
-         */
-        *repeatMode = FSE_repeat_check;
         return set_rle;
     }
-    if ( isDefaultAllowed
-      && (*repeatMode == FSE_repeat_valid) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) {
-        DEBUGLOG(5, "Selected set_repeat");
-        return set_repeat;
-    }
-    if ( isDefaultAllowed
-      && ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (defaultNormLog-1)))) ) {
-        DEBUGLOG(5, "Selected set_basic");
-        /* The format allows default tables to be repeated, but it isn't useful.
-         * When using simple heuristics to select encoding type, we don't want
-         * to confuse these tables with dictionaries. When running more careful
-         * analysis, we don't need to waste time checking both repeating tables
-         * and default tables.
-         */
-        *repeatMode = FSE_repeat_none;
-        return set_basic;
+    if (strategy < ZSTD_lazy) {
+        if (isDefaultAllowed) {
+            if ((*repeatMode == FSE_repeat_valid) && (nbSeq < MAX_SEQ_FOR_STATIC_FSE)) {
+                DEBUGLOG(5, "Selected set_repeat");
+                return set_repeat;
+            }
+            if ((nbSeq < MIN_SEQ_FOR_DYNAMIC_FSE) || (mostFrequent < (nbSeq >> (defaultNormLog-1)))) {
+                DEBUGLOG(5, "Selected set_basic");
+                /* The format allows default tables to be repeated, but it isn't useful.
+                 * When using simple heuristics to select encoding type, we don't want
+                 * to confuse these tables with dictionaries. When running more careful
+                 * analysis, we don't need to waste time checking both repeating tables
+                 * and default tables.
+                 */
+                *repeatMode = FSE_repeat_none;
+                return set_basic;
+            }
+        }
+    } else {
+        size_t const basicCost = isDefaultAllowed ? ZSTD_crossEntropyCost(defaultNorm, defaultNormLog, count, max) : ERROR(GENERIC);
+        size_t const repeatCost = *repeatMode != FSE_repeat_none ? ZSTD_fseBitCost(prevCTable, count, max) : ERROR(GENERIC);
+        size_t const NCountCost = ZSTD_NCountCost(count, max, nbSeq, FSELog);
+        size_t const compressedCost = (NCountCost << 3) + ZSTD_entropyCost(count, max, nbSeq);
+
+        if (isDefaultAllowed) {
+            assert(!ZSTD_isError(basicCost));
+            assert(!(*repeatMode == FSE_repeat_valid && ZSTD_isError(repeatCost)));
+        }
+        assert(!ZSTD_isError(NCountCost));
+        assert(compressedCost < ERROR(maxCode));
+        DEBUGLOG(5, "Estimated bit costs: basic=%u\trepeat=%u\tcompressed=%u",
+                    (U32)basicCost, (U32)repeatCost, (U32)compressedCost);
+        if (basicCost <= repeatCost && basicCost <= compressedCost) {
+            DEBUGLOG(5, "Selected set_basic");
+            assert(isDefaultAllowed);
+            *repeatMode = FSE_repeat_none;
+            return set_basic;
+        }
+        if (repeatCost <= compressedCost) {
+            DEBUGLOG(5, "Selected set_repeat");
+            assert(!ZSTD_isError(repeatCost));
+            return set_repeat;
+        }
+        assert(compressedCost < basicCost && compressedCost < repeatCost);
     }
     DEBUGLOG(5, "Selected set_compressed");
     *repeatMode = FSE_repeat_check;
@@ -1803,10 +1968,11 @@
                               const int bmi2)
 {
     const int longOffsets = cctxParams->cParams.windowLog > STREAM_ACCUMULATOR_MIN;
+    ZSTD_strategy const strategy = cctxParams->cParams.strategy;
     U32 count[MaxSeq+1];
-    FSE_CTable* CTable_LitLength = nextEntropy->litlengthCTable;
-    FSE_CTable* CTable_OffsetBits = nextEntropy->offcodeCTable;
-    FSE_CTable* CTable_MatchLength = nextEntropy->matchlengthCTable;
+    FSE_CTable* CTable_LitLength = nextEntropy->fse.litlengthCTable;
+    FSE_CTable* CTable_OffsetBits = nextEntropy->fse.offcodeCTable;
+    FSE_CTable* CTable_MatchLength = nextEntropy->fse.matchlengthCTable;
     U32 LLtype, Offtype, MLtype;   /* compressed, raw or rle */
     const seqDef* const sequences = seqStorePtr->sequencesStart;
     const BYTE* const ofCodeTable = seqStorePtr->ofCode;
@@ -1824,7 +1990,7 @@
     {   const BYTE* const literals = seqStorePtr->litStart;
         size_t const litSize = seqStorePtr->lit - literals;
         size_t const cSize = ZSTD_compressLiterals(
-                                    prevEntropy, nextEntropy,
+                                    &prevEntropy->huf, &nextEntropy->huf,
                                     cctxParams->cParams.strategy, cctxParams->disableLiteralCompression,
                                     op, dstCapacity,
                                     literals, litSize,
@@ -1844,13 +2010,9 @@
     else
         op[0]=0xFF, MEM_writeLE16(op+1, (U16)(nbSeq - LONGNBSEQ)), op+=3;
     if (nbSeq==0) {
-      memcpy(nextEntropy->litlengthCTable, prevEntropy->litlengthCTable, sizeof(prevEntropy->litlengthCTable));
-      nextEntropy->litlength_repeatMode = prevEntropy->litlength_repeatMode;
-      memcpy(nextEntropy->offcodeCTable, prevEntropy->offcodeCTable, sizeof(prevEntropy->offcodeCTable));
-      nextEntropy->offcode_repeatMode = prevEntropy->offcode_repeatMode;
-      memcpy(nextEntropy->matchlengthCTable, prevEntropy->matchlengthCTable, sizeof(prevEntropy->matchlengthCTable));
-      nextEntropy->matchlength_repeatMode = prevEntropy->matchlength_repeatMode;
-      return op - ostart;
+        /* Copy the old tables over as if we repeated them */
+        memcpy(&nextEntropy->fse, &prevEntropy->fse, sizeof(prevEntropy->fse));
+        return op - ostart;
     }
 
     /* seqHead : flags for FSE encoding type */
@@ -1862,11 +2024,13 @@
     {   U32 max = MaxLL;
         size_t const mostFrequent = FSE_countFast_wksp(count, &max, llCodeTable, nbSeq, workspace);
         DEBUGLOG(5, "Building LL table");
-        nextEntropy->litlength_repeatMode = prevEntropy->litlength_repeatMode;
-        LLtype = ZSTD_selectEncodingType(&nextEntropy->litlength_repeatMode, mostFrequent, nbSeq, LL_defaultNormLog, ZSTD_defaultAllowed);
+        nextEntropy->fse.litlength_repeatMode = prevEntropy->fse.litlength_repeatMode;
+        LLtype = ZSTD_selectEncodingType(&nextEntropy->fse.litlength_repeatMode, count, max, mostFrequent, nbSeq, LLFSELog, prevEntropy->fse.litlengthCTable, LL_defaultNorm, LL_defaultNormLog, ZSTD_defaultAllowed, strategy);
+        assert(set_basic < set_compressed && set_rle < set_compressed);
+        assert(!(LLtype < set_compressed && nextEntropy->fse.litlength_repeatMode != FSE_repeat_none)); /* We don't copy tables */
         {   size_t const countSize = ZSTD_buildCTable(op, oend - op, CTable_LitLength, LLFSELog, (symbolEncodingType_e)LLtype,
                     count, max, llCodeTable, nbSeq, LL_defaultNorm, LL_defaultNormLog, MaxLL,
-                    prevEntropy->litlengthCTable, sizeof(prevEntropy->litlengthCTable),
+                    prevEntropy->fse.litlengthCTable, sizeof(prevEntropy->fse.litlengthCTable),
                     workspace, HUF_WORKSPACE_SIZE);
             if (ZSTD_isError(countSize)) return countSize;
             op += countSize;
@@ -1877,11 +2041,12 @@
         /* We can only use the basic table if max <= DefaultMaxOff, otherwise the offsets are too large */
         ZSTD_defaultPolicy_e const defaultPolicy = (max <= DefaultMaxOff) ? ZSTD_defaultAllowed : ZSTD_defaultDisallowed;
         DEBUGLOG(5, "Building OF table");
-        nextEntropy->offcode_repeatMode = prevEntropy->offcode_repeatMode;
-        Offtype = ZSTD_selectEncodingType(&nextEntropy->offcode_repeatMode, mostFrequent, nbSeq, OF_defaultNormLog, defaultPolicy);
+        nextEntropy->fse.offcode_repeatMode = prevEntropy->fse.offcode_repeatMode;
+        Offtype = ZSTD_selectEncodingType(&nextEntropy->fse.offcode_repeatMode, count, max, mostFrequent, nbSeq, OffFSELog, prevEntropy->fse.offcodeCTable, OF_defaultNorm, OF_defaultNormLog, defaultPolicy, strategy);
+        assert(!(Offtype < set_compressed && nextEntropy->fse.offcode_repeatMode != FSE_repeat_none)); /* We don't copy tables */
         {   size_t const countSize = ZSTD_buildCTable(op, oend - op, CTable_OffsetBits, OffFSELog, (symbolEncodingType_e)Offtype,
                     count, max, ofCodeTable, nbSeq, OF_defaultNorm, OF_defaultNormLog, DefaultMaxOff,
-                    prevEntropy->offcodeCTable, sizeof(prevEntropy->offcodeCTable),
+                    prevEntropy->fse.offcodeCTable, sizeof(prevEntropy->fse.offcodeCTable),
                     workspace, HUF_WORKSPACE_SIZE);
             if (ZSTD_isError(countSize)) return countSize;
             op += countSize;
@@ -1890,11 +2055,12 @@
     {   U32 max = MaxML;
         size_t const mostFrequent = FSE_countFast_wksp(count, &max, mlCodeTable, nbSeq, workspace);
         DEBUGLOG(5, "Building ML table");
-        nextEntropy->matchlength_repeatMode = prevEntropy->matchlength_repeatMode;
-        MLtype = ZSTD_selectEncodingType(&nextEntropy->matchlength_repeatMode, mostFrequent, nbSeq, ML_defaultNormLog, ZSTD_defaultAllowed);
+        nextEntropy->fse.matchlength_repeatMode = prevEntropy->fse.matchlength_repeatMode;
+        MLtype = ZSTD_selectEncodingType(&nextEntropy->fse.matchlength_repeatMode, count, max, mostFrequent, nbSeq, MLFSELog, prevEntropy->fse.matchlengthCTable, ML_defaultNorm, ML_defaultNormLog, ZSTD_defaultAllowed, strategy);
+        assert(!(MLtype < set_compressed && nextEntropy->fse.matchlength_repeatMode != FSE_repeat_none)); /* We don't copy tables */
         {   size_t const countSize = ZSTD_buildCTable(op, oend - op, CTable_MatchLength, MLFSELog, (symbolEncodingType_e)MLtype,
                     count, max, mlCodeTable, nbSeq, ML_defaultNorm, ML_defaultNormLog, MaxML,
-                    prevEntropy->matchlengthCTable, sizeof(prevEntropy->matchlengthCTable),
+                    prevEntropy->fse.matchlengthCTable, sizeof(prevEntropy->fse.matchlengthCTable),
                     workspace, HUF_WORKSPACE_SIZE);
             if (ZSTD_isError(countSize)) return countSize;
             op += countSize;
@@ -1942,8 +2108,8 @@
      * block. After the first block, the offcode table might not have large
      * enough codes to represent the offsets in the data.
      */
-    if (nextEntropy->offcode_repeatMode == FSE_repeat_valid)
-        nextEntropy->offcode_repeatMode = FSE_repeat_check;
+    if (nextEntropy->fse.offcode_repeatMode == FSE_repeat_valid)
+        nextEntropy->fse.offcode_repeatMode = FSE_repeat_check;
 
     return cSize;
 }
@@ -2384,7 +2550,7 @@
     dictPtr += 4;
 
     {   unsigned maxSymbolValue = 255;
-        size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.hufCTable, &maxSymbolValue, dictPtr, dictEnd-dictPtr);
+        size_t const hufHeaderSize = HUF_readCTable((HUF_CElt*)bs->entropy.huf.CTable, &maxSymbolValue, dictPtr, dictEnd-dictPtr);
         if (HUF_isError(hufHeaderSize)) return ERROR(dictionary_corrupted);
         if (maxSymbolValue < 255) return ERROR(dictionary_corrupted);
         dictPtr += hufHeaderSize;
@@ -2396,7 +2562,7 @@
         if (offcodeLog > OffFSELog) return ERROR(dictionary_corrupted);
         /* Defer checking offcodeMaxValue because we need to know the size of the dictionary content */
         /* fill all offset symbols to avoid garbage at end of table */
-        CHECK_E( FSE_buildCTable_wksp(bs->entropy.offcodeCTable, offcodeNCount, MaxOff, offcodeLog, workspace, HUF_WORKSPACE_SIZE),
+        CHECK_E( FSE_buildCTable_wksp(bs->entropy.fse.offcodeCTable, offcodeNCount, MaxOff, offcodeLog, workspace, HUF_WORKSPACE_SIZE),
                  dictionary_corrupted);
         dictPtr += offcodeHeaderSize;
     }
@@ -2408,7 +2574,7 @@
         if (matchlengthLog > MLFSELog) return ERROR(dictionary_corrupted);
         /* Every match length code must have non-zero probability */
         CHECK_F( ZSTD_checkDictNCount(matchlengthNCount, matchlengthMaxValue, MaxML));
-        CHECK_E( FSE_buildCTable_wksp(bs->entropy.matchlengthCTable, matchlengthNCount, matchlengthMaxValue, matchlengthLog, workspace, HUF_WORKSPACE_SIZE),
+        CHECK_E( FSE_buildCTable_wksp(bs->entropy.fse.matchlengthCTable, matchlengthNCount, matchlengthMaxValue, matchlengthLog, workspace, HUF_WORKSPACE_SIZE),
                  dictionary_corrupted);
         dictPtr += matchlengthHeaderSize;
     }
@@ -2420,7 +2586,7 @@
         if (litlengthLog > LLFSELog) return ERROR(dictionary_corrupted);
         /* Every literal length code must have non-zero probability */
         CHECK_F( ZSTD_checkDictNCount(litlengthNCount, litlengthMaxValue, MaxLL));
-        CHECK_E( FSE_buildCTable_wksp(bs->entropy.litlengthCTable, litlengthNCount, litlengthMaxValue, litlengthLog, workspace, HUF_WORKSPACE_SIZE),
+        CHECK_E( FSE_buildCTable_wksp(bs->entropy.fse.litlengthCTable, litlengthNCount, litlengthMaxValue, litlengthLog, workspace, HUF_WORKSPACE_SIZE),
                  dictionary_corrupted);
         dictPtr += litlengthHeaderSize;
     }
@@ -2446,10 +2612,10 @@
                 if (bs->rep[u] > dictContentSize) return ERROR(dictionary_corrupted);
         }   }
 
-        bs->entropy.hufCTable_repeatMode = HUF_repeat_valid;
-        bs->entropy.offcode_repeatMode = FSE_repeat_valid;
-        bs->entropy.matchlength_repeatMode = FSE_repeat_valid;
-        bs->entropy.litlength_repeatMode = FSE_repeat_valid;
+        bs->entropy.huf.repeatMode = HUF_repeat_valid;
+        bs->entropy.fse.offcode_repeatMode = FSE_repeat_valid;
+        bs->entropy.fse.matchlength_repeatMode = FSE_repeat_valid;
+        bs->entropy.fse.litlength_repeatMode = FSE_repeat_valid;
         CHECK_F(ZSTD_loadDictionaryContent(ms, params, dictPtr, dictContentSize, dtlm));
         return dictID;
     }
diff --git a/lib/compress/zstd_compress_internal.h b/lib/compress/zstd_compress_internal.h
index 0f1830a..937234c 100644
--- a/lib/compress/zstd_compress_internal.h
+++ b/lib/compress/zstd_compress_internal.h
@@ -53,14 +53,22 @@
 } ZSTD_prefixDict;
 
 typedef struct {
-    U32 hufCTable[HUF_CTABLE_SIZE_U32(255)];
+    U32 CTable[HUF_CTABLE_SIZE_U32(255)];
+    HUF_repeat repeatMode;
+} ZSTD_hufCTables_t;
+
+typedef struct {
     FSE_CTable offcodeCTable[FSE_CTABLE_SIZE_U32(OffFSELog, MaxOff)];
     FSE_CTable matchlengthCTable[FSE_CTABLE_SIZE_U32(MLFSELog, MaxML)];
     FSE_CTable litlengthCTable[FSE_CTABLE_SIZE_U32(LLFSELog, MaxLL)];
-    HUF_repeat hufCTable_repeatMode;
     FSE_repeat offcode_repeatMode;
     FSE_repeat matchlength_repeatMode;
     FSE_repeat litlength_repeatMode;
+} ZSTD_fseCTables_t;
+
+typedef struct {
+    ZSTD_hufCTables_t huf;
+    ZSTD_fseCTables_t fse;
 } ZSTD_entropyCTables_t;
 
 typedef struct {
diff --git a/lib/compress/zstd_opt.c b/lib/compress/zstd_opt.c
index 3a48187..521fbbf 100644
--- a/lib/compress/zstd_opt.c
+++ b/lib/compress/zstd_opt.c
@@ -39,7 +39,7 @@
             optPtr->priceType = zop_predef;
 
         assert(optPtr->symbolCosts != NULL);
-        if (optPtr->symbolCosts->hufCTable_repeatMode == HUF_repeat_valid) { /* huffman table presumed generated by dictionary */
+        if (optPtr->symbolCosts->huf.repeatMode == HUF_repeat_valid) { /* huffman table presumed generated by dictionary */
             if (srcSize <= 8192)   /* heuristic */
                 optPtr->priceType = zop_static;
             else {
@@ -52,7 +52,7 @@
             {   unsigned lit;
                 for (lit=0; lit<=MaxLit; lit++) {
                     U32 const scaleLog = 11;   /* scale to 2K */
-                    U32 const bitCost = HUF_getNbBits(optPtr->symbolCosts->hufCTable, lit);
+                    U32 const bitCost = HUF_getNbBits(optPtr->symbolCosts->huf.CTable, lit);
                     assert(bitCost <= scaleLog);
                     optPtr->litFreq[lit] = bitCost ? 1 << (scaleLog-bitCost) : 1 /*minimum to calculate cost*/;
                     optPtr->litSum += optPtr->litFreq[lit];
@@ -60,7 +60,7 @@
 
             {   unsigned ll;
                 FSE_CState_t llstate;
-                FSE_initCState(&llstate, optPtr->symbolCosts->litlengthCTable);
+                FSE_initCState(&llstate, optPtr->symbolCosts->fse.litlengthCTable);
                 optPtr->litLengthSum = 0;
                 for (ll=0; ll<=MaxLL; ll++) {
                     U32 const scaleLog = 10;   /* scale to 1K */
@@ -72,7 +72,7 @@
 
             {   unsigned ml;
                 FSE_CState_t mlstate;
-                FSE_initCState(&mlstate, optPtr->symbolCosts->matchlengthCTable);
+                FSE_initCState(&mlstate, optPtr->symbolCosts->fse.matchlengthCTable);
                 optPtr->matchLengthSum = 0;
                 for (ml=0; ml<=MaxML; ml++) {
                     U32 const scaleLog = 10;
@@ -84,7 +84,7 @@
 
             {   unsigned of;
                 FSE_CState_t ofstate;
-                FSE_initCState(&ofstate, optPtr->symbolCosts->offcodeCTable);
+                FSE_initCState(&ofstate, optPtr->symbolCosts->fse.offcodeCTable);
                 optPtr->offCodeSum = 0;
                 for (of=0; of<=MaxOff; of++) {
                     U32 const scaleLog = 10;
@@ -180,9 +180,9 @@
     if (optPtr->priceType == zop_static) {
         U32 u, cost;
         assert(optPtr->symbolCosts != NULL);
-        assert(optPtr->symbolCosts->hufCTable_repeatMode == HUF_repeat_valid);
+        assert(optPtr->symbolCosts->huf.repeatMode == HUF_repeat_valid);
         for (u=0, cost=0; u < litLength; u++)
-            cost += HUF_getNbBits(optPtr->symbolCosts->hufCTable, literals[u]);
+            cost += HUF_getNbBits(optPtr->symbolCosts->huf.CTable, literals[u]);
         return cost * BITCOST_MULTIPLIER;
     }
 
@@ -202,7 +202,7 @@
     if (optPtr->priceType == zop_static) {
         U32 const llCode = ZSTD_LLcode(litLength);
         FSE_CState_t cstate;
-        FSE_initCState(&cstate, optPtr->symbolCosts->litlengthCTable);
+        FSE_initCState(&cstate, optPtr->symbolCosts->fse.litlengthCTable);
         {   U32 const price = LL_bits[llCode]*BITCOST_MULTIPLIER + BITCOST_SYMBOL(cstate.symbolTT, cstate.stateLog, llCode);
             DEBUGLOG(8, "ZSTD_litLengthPrice: ll=%u, bitCost=%.2f", litLength, (double)price / BITCOST_MULTIPLIER);
             return price;
@@ -234,7 +234,7 @@
     if (optPtr->priceType == zop_static) {
         U32 const llCode = ZSTD_LLcode(litLength);
         FSE_CState_t cstate;
-        FSE_initCState(&cstate, optPtr->symbolCosts->litlengthCTable);
+        FSE_initCState(&cstate, optPtr->symbolCosts->fse.litlengthCTable);
         return (int)(LL_bits[llCode] * BITCOST_MULTIPLIER)
              + BITCOST_SYMBOL(cstate.symbolTT, cstate.stateLog, llCode)
              - BITCOST_SYMBOL(cstate.symbolTT, cstate.stateLog, 0);
@@ -284,8 +284,8 @@
     if (optPtr->priceType == zop_static) {
         U32 const mlCode = ZSTD_MLcode(mlBase);
         FSE_CState_t mlstate, offstate;
-        FSE_initCState(&mlstate, optPtr->symbolCosts->matchlengthCTable);
-        FSE_initCState(&offstate, optPtr->symbolCosts->offcodeCTable);
+        FSE_initCState(&mlstate, optPtr->symbolCosts->fse.matchlengthCTable);
+        FSE_initCState(&offstate, optPtr->symbolCosts->fse.offcodeCTable);
         return BITCOST_SYMBOL(offstate.symbolTT, offstate.stateLog, offCode) + offCode*BITCOST_MULTIPLIER
              + BITCOST_SYMBOL(mlstate.symbolTT, mlstate.stateLog, mlCode) + ML_bits[mlCode]*BITCOST_MULTIPLIER;
     }