Unify implementation of NWC Global Average Pooling across data types
PiperOrigin-RevId: 325386782
diff --git a/src/xnnpack/operator.h b/src/xnnpack/operator.h
index 6412693..5562fc9 100644
--- a/src/xnnpack/operator.h
+++ b/src/xnnpack/operator.h
@@ -242,8 +242,7 @@
float input_scale;
float output_scale;
- uint8_t input_zero_point;
- uint8_t kernel_zero_point;
+ int32_t input_zero_point;
uint8_t output_zero_point;
uint8_t output_min;
uint8_t output_max;
diff --git a/src/xnnpack/params-init.h b/src/xnnpack/params-init.h
index e83bc8d..6df7225 100644
--- a/src/xnnpack/params-init.h
+++ b/src/xnnpack/params-init.h
@@ -392,6 +392,53 @@
return params;
}
+static inline void xnn_update_qu8_avgpool_params(
+ union xnn_qu8_avgpool_params* params,
+ int32_t bias,
+ float scale)
+{
+ // Compute requantization parameters.
+ assert(scale >= 0x1.0p-32f);
+ assert(scale < 256.0f);
+ const uint32_t scale_bits = fp32_to_bits(scale);
+
+ // Multiplier is in [0x00800000, 0x00FFFFFF] range.
+ const int32_t multiplier = ((int32_t) scale_bits & INT32_C(0x007FFFFF)) | INT32_C(0x00800000);
+ assert(multiplier >= INT32_C(0x00800000));
+ assert(multiplier <= INT32_C(0x00FFFFFF));
+
+ // Shift is in [16, 55] range.
+ const int32_t shift = 127 + 23 - (scale_bits >> 23);
+ assert(shift >= 16);
+ assert(shift < 64);
+
+ #if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ const uint64_t rounding = UINT64_C(1) << ((uint32_t) shift - 1);
+ params->sse2.bias[0] = bias;
+ params->sse2.bias[1] = bias;
+ params->sse2.bias[2] = bias;
+ params->sse2.bias[3] = bias;
+ params->sse2.multiplier[0] = (uint32_t) multiplier;
+ params->sse2.multiplier[1] = (uint32_t) multiplier;
+ params->sse2.multiplier[2] = (uint32_t) multiplier;
+ params->sse2.multiplier[3] = (uint32_t) multiplier;
+ params->sse2.rounding[0] = rounding;
+ params->sse2.rounding[1] = rounding;
+ params->sse2.right_shift[0] = (uint64_t) (uint32_t) shift;
+ params->sse2.right_shift[1] = (uint64_t) (uint32_t) shift;
+ #elif XNN_ARCH_ARM || XNN_ARCH_ARM64
+ params->neon.bias = bias;
+ params->neon.multiplier = multiplier;
+ params->neon.left_shift = (int64_t) -shift;
+ #else
+ const int64_t rounding = INT64_C(1) << ((uint32_t) shift - 1);
+ params->scalar.bias = bias;
+ params->scalar.multiplier = multiplier;
+ params->scalar.rounding = rounding;
+ params->scalar.right_shift = (uint32_t) shift;
+ #endif
+}
+
static inline union xnn_qs8_avgpool_params xnn_init_qs8_avgpool_params(
int32_t bias,
float scale,