[lib] Ensure that multithreaded compression always makes some progress
diff --git a/lib/compress/zstd_compress.c b/lib/compress/zstd_compress.c
index 6830020..ed142b2 100644
--- a/lib/compress/zstd_compress.c
+++ b/lib/compress/zstd_compress.c
@@ -4441,26 +4441,42 @@
/* compression stage */
#ifdef ZSTD_MULTITHREAD
if (cctx->appliedParams.nbWorkers > 0) {
- int const forceMaxProgress = (endOp == ZSTD_e_flush || endOp == ZSTD_e_end);
size_t flushMin;
- assert(forceMaxProgress || endOp == ZSTD_e_continue /* Protection for a new flush type */);
if (cctx->cParamsChanged) {
ZSTDMT_updateCParams_whileCompressing(cctx->mtctx, &cctx->requestedParams);
cctx->cParamsChanged = 0;
}
- do {
+ for (;;) {
+ size_t const ipos = input->pos;
+ size_t const opos = output->pos;
flushMin = ZSTDMT_compressStream_generic(cctx->mtctx, output, input, endOp);
if ( ZSTD_isError(flushMin)
|| (endOp == ZSTD_e_end && flushMin == 0) ) { /* compression completed */
ZSTD_CCtx_reset(cctx, ZSTD_reset_session_only);
}
FORWARD_IF_ERROR(flushMin, "ZSTDMT_compressStream_generic failed");
- } while (forceMaxProgress && flushMin != 0 && output->pos < output->size);
+
+ if (endOp == ZSTD_e_continue) {
+ /* We only require some progress with ZSTD_e_continue, not maximal progress.
+ * We're done if we've consumed or produced any bytes, or either buffer is
+ * full.
+ */
+ if (input->pos != ipos || output->pos != opos || input->pos == input->size || output->pos == output->size)
+ break;
+ } else {
+ assert(endOp == ZSTD_e_flush || endOp == ZSTD_e_end);
+ /* We require maximal progress. We're done when the flush is complete or the
+ * output buffer is full.
+ */
+ if (flushMin == 0 || output->pos == output->size)
+ break;
+ }
+ }
DEBUGLOG(5, "completed ZSTD_compressStream2 delegating to ZSTDMT_compressStream_generic");
/* Either we don't require maximum forward progress, we've finished the
* flush, or we are out of output space.
*/
- assert(!forceMaxProgress || flushMin == 0 || output->pos == output->size);
+ assert(endOp == ZSTD_e_continue || flushMin == 0 || output->pos == output->size);
ZSTD_setBufferExpectations(cctx, output, input);
return flushMin;
}
diff --git a/tests/zstreamtest.c b/tests/zstreamtest.c
index 1855b4d..fa18ea4 100644
--- a/tests/zstreamtest.c
+++ b/tests/zstreamtest.c
@@ -2299,6 +2299,7 @@
/* Ensure maximal forward progress for determinism */
forwardProgress = (inBuff.pos != ipos) || (outBuff.pos != opos);
} while (forwardProgress);
+ assert(inBuff.pos == inBuff.size);
XXH64_update(&xxhState, srcBuffer+srcStart, inBuff.pos);
memcpy(copyBuffer+totalTestSize, srcBuffer+srcStart, inBuff.pos);