Merge pull request #1864 from terrelln/dict-fix

Fix 2 bugs in dictionary loading
diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c
index a8856cd..35346b9 100644
--- a/lib/compress/zstd_compress.c
+++ b/lib/compress/zstd_compress.c
@@ -2771,7 +2771,7 @@
 /*! ZSTD_loadZstdDictionary() :
  * @return : dictID, or an error code
  *  assumptions : magic number supposed already checked
- *                dictSize supposed > 8
+ *                dictSize supposed >= 8
  */
 static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
                                       ZSTD_matchState_t* ms,
@@ -2788,7 +2788,7 @@
     size_t dictID;
 
     ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<<MAX(MLFSELog,LLFSELog)));
-    assert(dictSize > 8);
+    assert(dictSize >= 8);
     assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY);
 
     dictPtr += 4;   /* skip magic number */
@@ -2890,7 +2890,10 @@
                                void* workspace)
 {
     DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize);
-    if ((dict==NULL) || (dictSize<=8)) return 0;
+    if ((dict==NULL) || (dictSize<8)) {
+        RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong);
+        return 0;
+    }
 
     ZSTD_reset_compressedBlockState(bs);
 
@@ -2942,7 +2945,7 @@
 
     FORWARD_IF_ERROR( ZSTD_resetCCtx_internal(cctx, *params, pledgedSrcSize,
                                      ZSTDcrp_makeClean, zbuff) );
-    {   size_t const dictID = cdict ? 
+    {   size_t const dictID = cdict ?
                 ZSTD_compress_insertDictionary(
                         cctx->blockState.prevCBlock, &cctx->blockState.matchState,
                         &cctx->workspace, params, cdict->dictContent, cdict->dictContentSize,
@@ -3219,7 +3222,7 @@
         ZSTDirp_reset,
         ZSTD_resetTarget_CDict));
     /* (Maybe) load the dictionary
-     * Skips loading the dictionary if it is <= 8 bytes.
+     * Skips loading the dictionary if it is < 8 bytes.
      */
     {   ZSTD_CCtx_params params;
         memset(&params, 0, sizeof(params));
diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c
index ca47a66..dd4591b 100644
--- a/lib/decompress/zstd_decompress.c
+++ b/lib/decompress/zstd_decompress.c
@@ -1096,7 +1096,7 @@
         size_t const dictContentSize = (size_t)(dictEnd - (dictPtr+12));
         for (i=0; i<3; i++) {
             U32 const rep = MEM_readLE32(dictPtr); dictPtr += 4;
-            RETURN_ERROR_IF(rep==0 || rep >= dictContentSize,
+            RETURN_ERROR_IF(rep==0 || rep > dictContentSize,
                             dictionary_corrupted);
             entropy->rep[i] = rep;
     }   }
@@ -1265,7 +1265,7 @@
 {
     RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong);
     ZSTD_clearDict(dctx);
-    if (dict && dictSize >= 8) {
+    if (dict && dictSize != 0) {
         dctx->ddictLocal = ZSTD_createDDict_advanced(dict, dictSize, dictLoadMethod, dictContentType, dctx->customMem);
         RETURN_ERROR_IF(dctx->ddictLocal == NULL, memory_allocation);
         dctx->ddict = dctx->ddictLocal;
diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile
index 83837e6..f66dade 100644
--- a/tests/fuzz/Makefile
+++ b/tests/fuzz/Makefile
@@ -73,7 +73,8 @@
 	dictionary_round_trip \
 	dictionary_decompress \
 	zstd_frame_info \
-	simple_compress
+	simple_compress \
+	dictionary_loader
 
 all: $(FUZZ_TARGETS)
 
@@ -110,6 +111,9 @@
 zstd_frame_info: $(FUZZ_HEADERS) $(FUZZ_OBJ) zstd_frame_info.o
 	$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) zstd_frame_info.o $(LIB_FUZZING_ENGINE) -o $@
 
+dictionary_loader: $(FUZZ_HEADERS) $(FUZZ_OBJ) dictionary_loader.o
+	$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) dictionary_loader.o $(LIB_FUZZING_ENGINE) -o $@
+
 libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c regression_driver.o
 	$(AR) $(FUZZ_ARFLAGS) $@ regression_driver.o
 
diff --git a/tests/fuzz/dictionary_loader.c b/tests/fuzz/dictionary_loader.c
new file mode 100644
index 0000000..cb34f5d
--- /dev/null
+++ b/tests/fuzz/dictionary_loader.c
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+
+/**
+ * This fuzz target makes sure that whenever a compression dictionary can be
+ * loaded, the data can be round tripped.
+ */
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+#include "fuzz_helpers.h"
+#include "zstd_helpers.h"
+#include "fuzz_data_producer.h"
+
+/**
+ * Compresses the data and returns the compressed size or an error.
+ */
+static size_t compress(void* compressed, size_t compressedCapacity,
+                       void const* source, size_t sourceSize,
+                       void const* dict, size_t dictSize,
+                       ZSTD_dictLoadMethod_e dictLoadMethod,
+                       ZSTD_dictContentType_e dictContentType)
+{
+    ZSTD_CCtx* cctx = ZSTD_createCCtx();
+    FUZZ_ZASSERT(ZSTD_CCtx_loadDictionary_advanced(
+            cctx, dict, dictSize, dictLoadMethod, dictContentType));
+    size_t const compressedSize = ZSTD_compress2(
+            cctx, compressed, compressedCapacity, source, sourceSize);
+    ZSTD_freeCCtx(cctx);
+    return compressedSize;
+}
+
+static size_t decompress(void* result, size_t resultCapacity,
+                         void const* compressed, size_t compressedSize,
+                         void const* dict, size_t dictSize,
+                       ZSTD_dictLoadMethod_e dictLoadMethod,
+                         ZSTD_dictContentType_e dictContentType)
+{
+    ZSTD_DCtx* dctx = ZSTD_createDCtx();
+    FUZZ_ZASSERT(ZSTD_DCtx_loadDictionary_advanced(
+            dctx, dict, dictSize, dictLoadMethod, dictContentType));
+    size_t const resultSize = ZSTD_decompressDCtx(
+            dctx, result, resultCapacity, compressed, compressedSize);
+    FUZZ_ZASSERT(resultSize);
+    ZSTD_freeDCtx(dctx);
+    return resultSize;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
+{
+    FUZZ_dataProducer_t *producer = FUZZ_dataProducer_create(src, size);
+    ZSTD_dictLoadMethod_e const dlm =
+    size = FUZZ_dataProducer_uint32Range(producer, 0, 1);
+    ZSTD_dictContentType_e const dct =
+            FUZZ_dataProducer_uint32Range(producer, 0, 2);
+    size = FUZZ_dataProducer_remainingBytes(producer);
+
+    DEBUGLOG(2, "Dict load method %d", dlm);
+    DEBUGLOG(2, "Dict content type %d", dct);
+    DEBUGLOG(2, "Dict size %u", (unsigned)size);
+
+    void* const rBuf = malloc(size);
+    FUZZ_ASSERT(rBuf);
+    size_t const cBufSize = ZSTD_compressBound(size);
+    void* const cBuf = malloc(cBufSize);
+    FUZZ_ASSERT(cBuf);
+
+    size_t const cSize =
+            compress(cBuf, cBufSize, src, size, src, size, dlm, dct);
+    /* compression failing is okay */
+    if (ZSTD_isError(cSize)) {
+      FUZZ_ASSERT_MSG(dct != ZSTD_dct_rawContent, "Raw must always succeed!");
+      goto out;
+    }
+    size_t const rSize =
+            decompress(rBuf, size, cBuf, cSize, src, size, dlm, dct);
+    FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size");
+    FUZZ_ASSERT_MSG(!memcmp(src, rBuf, size), "Corruption!");
+
+out:
+    free(cBuf);
+    free(rBuf);
+    FUZZ_dataProducer_free(producer);
+    return 0;
+}
diff --git a/tests/fuzz/fuzz.py b/tests/fuzz/fuzz.py
index 9df68df..87f115a 100755
--- a/tests/fuzz/fuzz.py
+++ b/tests/fuzz/fuzz.py
@@ -27,6 +27,7 @@
 class InputType(object):
     RAW_DATA = 1
     COMPRESSED_DATA = 2
+    DICTIONARY_DATA = 3
 
 
 class FrameType(object):
@@ -54,6 +55,7 @@
     'dictionary_decompress': TargetInfo(InputType.COMPRESSED_DATA),
     'zstd_frame_info': TargetInfo(InputType.COMPRESSED_DATA),
     'simple_compress': TargetInfo(InputType.RAW_DATA),
+    'dictionary_loader': TargetInfo(InputType.DICTIONARY_DATA),
 }
 TARGETS = list(TARGET_INFO.keys())
 ALL_TARGETS = TARGETS + ['all']
@@ -73,6 +75,7 @@
 AFL_FUZZ = os.environ.get('AFL_FUZZ', 'afl-fuzz')
 DECODECORPUS = os.environ.get('DECODECORPUS',
                               abs_join(FUZZ_DIR, '..', 'decodecorpus'))
+ZSTD = os.environ.get('ZSTD', abs_join(FUZZ_DIR, '..', '..', 'zstd'))
 
 # Sanitizer environment variables
 MSAN_EXTRA_CPPFLAGS = os.environ.get('MSAN_EXTRA_CPPFLAGS', '')
@@ -674,6 +677,11 @@
         help="decodecorpus binary (default: $DECODECORPUS='{}')".format(
             DECODECORPUS))
     parser.add_argument(
+        '--zstd',
+        type=str,
+        default=ZSTD,
+        help="zstd binary (default: $ZSTD='{}')".format(ZSTD))
+    parser.add_argument(
         '--fuzz-rng-seed-size',
         type=int,
         default=4,
@@ -707,46 +715,66 @@
         return 1
 
     seed = create(args.seed)
-    with tmpdir() as compressed:
-        with tmpdir() as decompressed:
-            cmd = [
-                args.decodecorpus,
-                '-n{}'.format(args.number),
-                '-p{}/'.format(compressed),
-                '-o{}'.format(decompressed),
+    with tmpdir() as compressed, tmpdir() as decompressed, tmpdir() as dict:
+        info = TARGET_INFO[args.TARGET]
+
+        if info.input_type == InputType.DICTIONARY_DATA:
+            number = max(args.number, 1000)
+        else:
+            number = args.number
+        cmd = [
+            args.decodecorpus,
+            '-n{}'.format(args.number),
+            '-p{}/'.format(compressed),
+            '-o{}'.format(decompressed),
+        ]
+
+        if info.frame_type == FrameType.BLOCK:
+            cmd += [
+                '--gen-blocks',
+                '--max-block-size-log={}'.format(min(args.max_size_log, 17))
             ]
+        else:
+            cmd += ['--max-content-size-log={}'.format(args.max_size_log)]
 
-            info = TARGET_INFO[args.TARGET]
-            if info.frame_type == FrameType.BLOCK:
-                cmd += [
-                    '--gen-blocks',
-                    '--max-block-size-log={}'.format(min(args.max_size_log, 17))
+        print(' '.join(cmd))
+        subprocess.check_call(cmd)
+
+        if info.input_type == InputType.RAW_DATA:
+            print('using decompressed data in {}'.format(decompressed))
+            samples = decompressed
+        elif info.input_type == InputType.COMPRESSED_DATA:
+            print('using compressed data in {}'.format(compressed))
+            samples = compressed
+        else:
+            assert info.input_type == InputType.DICTIONARY_DATA
+            print('making dictionary data from {}'.format(decompressed))
+            samples = dict
+            min_dict_size_log = 9
+            max_dict_size_log = max(min_dict_size_log + 1, args.max_size_log)
+            for dict_size_log in range(min_dict_size_log, max_dict_size_log):
+                dict_size = 1 << dict_size_log
+                cmd = [
+                    args.zstd,
+                    '--train',
+                    '-r', decompressed,
+                    '--maxdict={}'.format(dict_size),
+                    '-o', abs_join(dict, '{}.zstd-dict'.format(dict_size))
                 ]
-            else:
-                cmd += ['--max-content-size-log={}'.format(args.max_size_log)]
+                print(' '.join(cmd))
+                subprocess.check_call(cmd)
 
-            print(' '.join(cmd))
-            subprocess.check_call(cmd)
-
-            if info.input_type == InputType.RAW_DATA:
-                print('using decompressed data in {}'.format(decompressed))
-                samples = decompressed
-            else:
-                assert info.input_type == InputType.COMPRESSED_DATA
-                print('using compressed data in {}'.format(compressed))
-                samples = compressed
-
-            # Copy the samples over and prepend the RNG seeds
-            for name in os.listdir(samples):
-                samplename = abs_join(samples, name)
-                outname = abs_join(seed, name)
-                with open(samplename, 'rb') as sample:
-                    with open(outname, 'wb') as out:
-                        CHUNK_SIZE = 131072
+        # Copy the samples over and prepend the RNG seeds
+        for name in os.listdir(samples):
+            samplename = abs_join(samples, name)
+            outname = abs_join(seed, name)
+            with open(samplename, 'rb') as sample:
+                with open(outname, 'wb') as out:
+                    CHUNK_SIZE = 131072
+                    chunk = sample.read(CHUNK_SIZE)
+                    while len(chunk) > 0:
+                        out.write(chunk)
                         chunk = sample.read(CHUNK_SIZE)
-                        while len(chunk) > 0:
-                            out.write(chunk)
-                            chunk = sample.read(CHUNK_SIZE)
     return 0