Optimize FP32 requantization in WAsm SIMD QS8/QC8/QU8 GEMM/IGEMM/DWCONV
PiperOrigin-RevId: 414178544
diff --git a/src/qc8-gemm/gen/3x4c8-minmax-fp32-wasmsimd-dot16x2-ld64.c b/src/qc8-gemm/gen/3x4c8-minmax-fp32-wasmsimd-dot16x2-ld64.c
index 9ec32b1..182e2f7 100644
--- a/src/qc8-gemm/gen/3x4c8-minmax-fp32-wasmsimd-dot16x2-ld64.c
+++ b/src/qc8-gemm/gen/3x4c8-minmax-fp32-wasmsimd-dot16x2-ld64.c
@@ -122,21 +122,16 @@
vacc1x0123 = wasm_f32x4_mul(vacc1x0123, vscale0123);
vacc2x0123 = wasm_f32x4_mul(vacc2x0123, vscale0123);
- const v128_t voutput_min_less_zero_point = wasm_v128_load(params->wasmsimd.output_min_less_zero_point);
- vacc0x0123 = wasm_f32x4_pmax(voutput_min_less_zero_point, vacc0x0123);
- vacc1x0123 = wasm_f32x4_pmax(voutput_min_less_zero_point, vacc1x0123);
- vacc2x0123 = wasm_f32x4_pmax(voutput_min_less_zero_point, vacc2x0123);
-
- const v128_t voutput_max_less_zero_point = wasm_v128_load(params->wasmsimd.output_max_less_zero_point);
- vacc0x0123 = wasm_f32x4_pmin(voutput_max_less_zero_point, vacc0x0123);
- vacc1x0123 = wasm_f32x4_pmin(voutput_max_less_zero_point, vacc1x0123);
- vacc2x0123 = wasm_f32x4_pmin(voutput_max_less_zero_point, vacc2x0123);
-
const v128_t vmagic_bias = wasm_v128_load(params->wasmsimd.magic_bias);
vacc0x0123 = wasm_f32x4_add(vacc0x0123, vmagic_bias);
vacc1x0123 = wasm_f32x4_add(vacc1x0123, vmagic_bias);
vacc2x0123 = wasm_f32x4_add(vacc2x0123, vmagic_bias);
+ const v128_t vmagic_min = wasm_v128_load(params->wasmsimd.magic_min);
+ vacc0x0123 = wasm_i32x4_max(vacc0x0123, vmagic_min);
+ vacc1x0123 = wasm_i32x4_max(vacc1x0123, vmagic_min);
+ vacc2x0123 = wasm_i32x4_max(vacc2x0123, vmagic_min);
+
const v128_t vmagic_bias_less_output_zero_point = wasm_v128_load(params->wasmsimd.magic_bias_less_output_zero_point);
vacc0x0123 = wasm_i32x4_sub(vacc0x0123, vmagic_bias_less_output_zero_point);
vacc1x0123 = wasm_i32x4_sub(vacc1x0123, vmagic_bias_less_output_zero_point);
@@ -147,6 +142,9 @@
v128_t vout = wasm_i8x16_narrow_i16x8(vacc01x0123, vacc22x0123);
+ const v128_t voutput_max = wasm_v128_load(params->wasmsimd.output_max);
+ vout = wasm_i8x16_min(vout, voutput_max);
+
if (nc >= 4) {
*((float*) c0) = (float) wasm_f32x4_extract_lane(vout, 0);
*((float*) c1) = (float) wasm_f32x4_extract_lane(vout, 1);