Merge pull request #1864 from terrelln/dict-fix
Fix 2 bugs in dictionary loading
diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c
index a8856cd..35346b9 100644
--- a/lib/compress/zstd_compress.c
+++ b/lib/compress/zstd_compress.c
@@ -2771,7 +2771,7 @@
/*! ZSTD_loadZstdDictionary() :
* @return : dictID, or an error code
* assumptions : magic number supposed already checked
- * dictSize supposed > 8
+ * dictSize supposed >= 8
*/
static size_t ZSTD_loadZstdDictionary(ZSTD_compressedBlockState_t* bs,
ZSTD_matchState_t* ms,
@@ -2788,7 +2788,7 @@
size_t dictID;
ZSTD_STATIC_ASSERT(HUF_WORKSPACE_SIZE >= (1<<MAX(MLFSELog,LLFSELog)));
- assert(dictSize > 8);
+ assert(dictSize >= 8);
assert(MEM_readLE32(dictPtr) == ZSTD_MAGIC_DICTIONARY);
dictPtr += 4; /* skip magic number */
@@ -2890,7 +2890,10 @@
void* workspace)
{
DEBUGLOG(4, "ZSTD_compress_insertDictionary (dictSize=%u)", (U32)dictSize);
- if ((dict==NULL) || (dictSize<=8)) return 0;
+ if ((dict==NULL) || (dictSize<8)) {
+ RETURN_ERROR_IF(dictContentType == ZSTD_dct_fullDict, dictionary_wrong);
+ return 0;
+ }
ZSTD_reset_compressedBlockState(bs);
@@ -2942,7 +2945,7 @@
FORWARD_IF_ERROR( ZSTD_resetCCtx_internal(cctx, *params, pledgedSrcSize,
ZSTDcrp_makeClean, zbuff) );
- { size_t const dictID = cdict ?
+ { size_t const dictID = cdict ?
ZSTD_compress_insertDictionary(
cctx->blockState.prevCBlock, &cctx->blockState.matchState,
&cctx->workspace, params, cdict->dictContent, cdict->dictContentSize,
@@ -3219,7 +3222,7 @@
ZSTDirp_reset,
ZSTD_resetTarget_CDict));
/* (Maybe) load the dictionary
- * Skips loading the dictionary if it is <= 8 bytes.
+ * Skips loading the dictionary if it is < 8 bytes.
*/
{ ZSTD_CCtx_params params;
memset(¶ms, 0, sizeof(params));
diff --git a/lib/decompress/zstd_decompress.c b/lib/decompress/zstd_decompress.c
index ca47a66..dd4591b 100644
--- a/lib/decompress/zstd_decompress.c
+++ b/lib/decompress/zstd_decompress.c
@@ -1096,7 +1096,7 @@
size_t const dictContentSize = (size_t)(dictEnd - (dictPtr+12));
for (i=0; i<3; i++) {
U32 const rep = MEM_readLE32(dictPtr); dictPtr += 4;
- RETURN_ERROR_IF(rep==0 || rep >= dictContentSize,
+ RETURN_ERROR_IF(rep==0 || rep > dictContentSize,
dictionary_corrupted);
entropy->rep[i] = rep;
} }
@@ -1265,7 +1265,7 @@
{
RETURN_ERROR_IF(dctx->streamStage != zdss_init, stage_wrong);
ZSTD_clearDict(dctx);
- if (dict && dictSize >= 8) {
+ if (dict && dictSize != 0) {
dctx->ddictLocal = ZSTD_createDDict_advanced(dict, dictSize, dictLoadMethod, dictContentType, dctx->customMem);
RETURN_ERROR_IF(dctx->ddictLocal == NULL, memory_allocation);
dctx->ddict = dctx->ddictLocal;
diff --git a/tests/fuzz/Makefile b/tests/fuzz/Makefile
index 83837e6..f66dade 100644
--- a/tests/fuzz/Makefile
+++ b/tests/fuzz/Makefile
@@ -73,7 +73,8 @@
dictionary_round_trip \
dictionary_decompress \
zstd_frame_info \
- simple_compress
+ simple_compress \
+ dictionary_loader
all: $(FUZZ_TARGETS)
@@ -110,6 +111,9 @@
zstd_frame_info: $(FUZZ_HEADERS) $(FUZZ_OBJ) zstd_frame_info.o
$(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) zstd_frame_info.o $(LIB_FUZZING_ENGINE) -o $@
+dictionary_loader: $(FUZZ_HEADERS) $(FUZZ_OBJ) dictionary_loader.o
+ $(CXX) $(FUZZ_TARGET_FLAGS) $(FUZZ_OBJ) dictionary_loader.o $(LIB_FUZZING_ENGINE) -o $@
+
libregression.a: $(FUZZ_HEADERS) $(PRGDIR)/util.h $(PRGDIR)/util.c regression_driver.o
$(AR) $(FUZZ_ARFLAGS) $@ regression_driver.o
diff --git a/tests/fuzz/dictionary_loader.c b/tests/fuzz/dictionary_loader.c
new file mode 100644
index 0000000..cb34f5d
--- /dev/null
+++ b/tests/fuzz/dictionary_loader.c
@@ -0,0 +1,93 @@
+/*
+ * Copyright (c) 2016-present, Facebook, Inc.
+ * All rights reserved.
+ *
+ * This source code is licensed under both the BSD-style license (found in the
+ * LICENSE file in the root directory of this source tree) and the GPLv2 (found
+ * in the COPYING file in the root directory of this source tree).
+ */
+
+/**
+ * This fuzz target makes sure that whenever a compression dictionary can be
+ * loaded, the data can be round tripped.
+ */
+
+#include <stddef.h>
+#include <stdlib.h>
+#include <stdio.h>
+#include <string.h>
+#include "fuzz_helpers.h"
+#include "zstd_helpers.h"
+#include "fuzz_data_producer.h"
+
+/**
+ * Compresses the data and returns the compressed size or an error.
+ */
+static size_t compress(void* compressed, size_t compressedCapacity,
+ void const* source, size_t sourceSize,
+ void const* dict, size_t dictSize,
+ ZSTD_dictLoadMethod_e dictLoadMethod,
+ ZSTD_dictContentType_e dictContentType)
+{
+ ZSTD_CCtx* cctx = ZSTD_createCCtx();
+ FUZZ_ZASSERT(ZSTD_CCtx_loadDictionary_advanced(
+ cctx, dict, dictSize, dictLoadMethod, dictContentType));
+ size_t const compressedSize = ZSTD_compress2(
+ cctx, compressed, compressedCapacity, source, sourceSize);
+ ZSTD_freeCCtx(cctx);
+ return compressedSize;
+}
+
+static size_t decompress(void* result, size_t resultCapacity,
+ void const* compressed, size_t compressedSize,
+ void const* dict, size_t dictSize,
+ ZSTD_dictLoadMethod_e dictLoadMethod,
+ ZSTD_dictContentType_e dictContentType)
+{
+ ZSTD_DCtx* dctx = ZSTD_createDCtx();
+ FUZZ_ZASSERT(ZSTD_DCtx_loadDictionary_advanced(
+ dctx, dict, dictSize, dictLoadMethod, dictContentType));
+ size_t const resultSize = ZSTD_decompressDCtx(
+ dctx, result, resultCapacity, compressed, compressedSize);
+ FUZZ_ZASSERT(resultSize);
+ ZSTD_freeDCtx(dctx);
+ return resultSize;
+}
+
+int LLVMFuzzerTestOneInput(const uint8_t *src, size_t size)
+{
+ FUZZ_dataProducer_t *producer = FUZZ_dataProducer_create(src, size);
+ ZSTD_dictLoadMethod_e const dlm =
+ size = FUZZ_dataProducer_uint32Range(producer, 0, 1);
+ ZSTD_dictContentType_e const dct =
+ FUZZ_dataProducer_uint32Range(producer, 0, 2);
+ size = FUZZ_dataProducer_remainingBytes(producer);
+
+ DEBUGLOG(2, "Dict load method %d", dlm);
+ DEBUGLOG(2, "Dict content type %d", dct);
+ DEBUGLOG(2, "Dict size %u", (unsigned)size);
+
+ void* const rBuf = malloc(size);
+ FUZZ_ASSERT(rBuf);
+ size_t const cBufSize = ZSTD_compressBound(size);
+ void* const cBuf = malloc(cBufSize);
+ FUZZ_ASSERT(cBuf);
+
+ size_t const cSize =
+ compress(cBuf, cBufSize, src, size, src, size, dlm, dct);
+ /* compression failing is okay */
+ if (ZSTD_isError(cSize)) {
+ FUZZ_ASSERT_MSG(dct != ZSTD_dct_rawContent, "Raw must always succeed!");
+ goto out;
+ }
+ size_t const rSize =
+ decompress(rBuf, size, cBuf, cSize, src, size, dlm, dct);
+ FUZZ_ASSERT_MSG(rSize == size, "Incorrect regenerated size");
+ FUZZ_ASSERT_MSG(!memcmp(src, rBuf, size), "Corruption!");
+
+out:
+ free(cBuf);
+ free(rBuf);
+ FUZZ_dataProducer_free(producer);
+ return 0;
+}
diff --git a/tests/fuzz/fuzz.py b/tests/fuzz/fuzz.py
index 9df68df..87f115a 100755
--- a/tests/fuzz/fuzz.py
+++ b/tests/fuzz/fuzz.py
@@ -27,6 +27,7 @@
class InputType(object):
RAW_DATA = 1
COMPRESSED_DATA = 2
+ DICTIONARY_DATA = 3
class FrameType(object):
@@ -54,6 +55,7 @@
'dictionary_decompress': TargetInfo(InputType.COMPRESSED_DATA),
'zstd_frame_info': TargetInfo(InputType.COMPRESSED_DATA),
'simple_compress': TargetInfo(InputType.RAW_DATA),
+ 'dictionary_loader': TargetInfo(InputType.DICTIONARY_DATA),
}
TARGETS = list(TARGET_INFO.keys())
ALL_TARGETS = TARGETS + ['all']
@@ -73,6 +75,7 @@
AFL_FUZZ = os.environ.get('AFL_FUZZ', 'afl-fuzz')
DECODECORPUS = os.environ.get('DECODECORPUS',
abs_join(FUZZ_DIR, '..', 'decodecorpus'))
+ZSTD = os.environ.get('ZSTD', abs_join(FUZZ_DIR, '..', '..', 'zstd'))
# Sanitizer environment variables
MSAN_EXTRA_CPPFLAGS = os.environ.get('MSAN_EXTRA_CPPFLAGS', '')
@@ -674,6 +677,11 @@
help="decodecorpus binary (default: $DECODECORPUS='{}')".format(
DECODECORPUS))
parser.add_argument(
+ '--zstd',
+ type=str,
+ default=ZSTD,
+ help="zstd binary (default: $ZSTD='{}')".format(ZSTD))
+ parser.add_argument(
'--fuzz-rng-seed-size',
type=int,
default=4,
@@ -707,46 +715,66 @@
return 1
seed = create(args.seed)
- with tmpdir() as compressed:
- with tmpdir() as decompressed:
- cmd = [
- args.decodecorpus,
- '-n{}'.format(args.number),
- '-p{}/'.format(compressed),
- '-o{}'.format(decompressed),
+ with tmpdir() as compressed, tmpdir() as decompressed, tmpdir() as dict:
+ info = TARGET_INFO[args.TARGET]
+
+ if info.input_type == InputType.DICTIONARY_DATA:
+ number = max(args.number, 1000)
+ else:
+ number = args.number
+ cmd = [
+ args.decodecorpus,
+ '-n{}'.format(args.number),
+ '-p{}/'.format(compressed),
+ '-o{}'.format(decompressed),
+ ]
+
+ if info.frame_type == FrameType.BLOCK:
+ cmd += [
+ '--gen-blocks',
+ '--max-block-size-log={}'.format(min(args.max_size_log, 17))
]
+ else:
+ cmd += ['--max-content-size-log={}'.format(args.max_size_log)]
- info = TARGET_INFO[args.TARGET]
- if info.frame_type == FrameType.BLOCK:
- cmd += [
- '--gen-blocks',
- '--max-block-size-log={}'.format(min(args.max_size_log, 17))
+ print(' '.join(cmd))
+ subprocess.check_call(cmd)
+
+ if info.input_type == InputType.RAW_DATA:
+ print('using decompressed data in {}'.format(decompressed))
+ samples = decompressed
+ elif info.input_type == InputType.COMPRESSED_DATA:
+ print('using compressed data in {}'.format(compressed))
+ samples = compressed
+ else:
+ assert info.input_type == InputType.DICTIONARY_DATA
+ print('making dictionary data from {}'.format(decompressed))
+ samples = dict
+ min_dict_size_log = 9
+ max_dict_size_log = max(min_dict_size_log + 1, args.max_size_log)
+ for dict_size_log in range(min_dict_size_log, max_dict_size_log):
+ dict_size = 1 << dict_size_log
+ cmd = [
+ args.zstd,
+ '--train',
+ '-r', decompressed,
+ '--maxdict={}'.format(dict_size),
+ '-o', abs_join(dict, '{}.zstd-dict'.format(dict_size))
]
- else:
- cmd += ['--max-content-size-log={}'.format(args.max_size_log)]
+ print(' '.join(cmd))
+ subprocess.check_call(cmd)
- print(' '.join(cmd))
- subprocess.check_call(cmd)
-
- if info.input_type == InputType.RAW_DATA:
- print('using decompressed data in {}'.format(decompressed))
- samples = decompressed
- else:
- assert info.input_type == InputType.COMPRESSED_DATA
- print('using compressed data in {}'.format(compressed))
- samples = compressed
-
- # Copy the samples over and prepend the RNG seeds
- for name in os.listdir(samples):
- samplename = abs_join(samples, name)
- outname = abs_join(seed, name)
- with open(samplename, 'rb') as sample:
- with open(outname, 'wb') as out:
- CHUNK_SIZE = 131072
+ # Copy the samples over and prepend the RNG seeds
+ for name in os.listdir(samples):
+ samplename = abs_join(samples, name)
+ outname = abs_join(seed, name)
+ with open(samplename, 'rb') as sample:
+ with open(outname, 'wb') as out:
+ CHUNK_SIZE = 131072
+ chunk = sample.read(CHUNK_SIZE)
+ while len(chunk) > 0:
+ out.write(chunk)
chunk = sample.read(CHUNK_SIZE)
- while len(chunk) > 0:
- out.write(chunk)
- chunk = sample.read(CHUNK_SIZE)
return 0