QU8 GEMM/IGEMM microkernels for AVX512
PiperOrigin-RevId: 383679359
diff --git a/BUILD.bazel b/BUILD.bazel
index 8d4eec5..00177a1 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -4275,6 +4275,14 @@
"src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c",
"src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c",
"src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c",
+ "src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c",
+ "src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c",
]
WASM32_ASM_UKERNELS = [
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 7dcfde3..dce5c92 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -3468,7 +3468,15 @@
src/qs8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
- src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c)
+ src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
+ src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
+ src/qu8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
+ src/qu8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
+ src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
+ src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
+ src/qu8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
+ src/qu8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
+ src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c)
SET(XNNPACK_AARCH32_ASM_MICROKERNEL_SRCS
src/f32-gemm/4x4-aarch32-vfp-ld64.S
diff --git a/scripts/generate-qs8-gemm.sh b/scripts/generate-qs8-gemm.sh
index ac2a709..5b13372 100755
--- a/scripts/generate-qs8-gemm.sh
+++ b/scripts/generate-qs8-gemm.sh
@@ -625,15 +625,20 @@
tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=GEMMLOWP -o src/qs8-gemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=GEMMLOWP -o src/qs8-gemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
+
tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-gemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
################################## Unit tests #################################
tools/generate-gemm-test.py --spec test/qs8-gemm-minmax-gemmlowp.yaml --output test/qs8-gemm-minmax-gemmlowp.cc
diff --git a/scripts/generate-qs8-igemm.sh b/scripts/generate-qs8-igemm.sh
index 4cba94e..bf63ee9 100755
--- a/scripts/generate-qs8-igemm.sh
+++ b/scripts/generate-qs8-igemm.sh
@@ -553,15 +553,20 @@
tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=GEMMLOWP -o src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=GEMMLOWP -o src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
+
tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QS8 -D REQUANTIZATION=FP32 -o src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
-tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QC8 -D REQUANTIZATION=FP32 -o src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=1 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=2 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=3 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
+tools/xngen src/qs8-igemm/MRx16c8-avx512skx.c.in -D MR=4 -D VARIANT=LD256 -D DATATYPE=QU8 -D REQUANTIZATION=FP32 -o src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
################################## Unit tests #################################
tools/generate-gemm-test.py --spec test/qs8-igemm-minmax-gemmlowp.yaml --output test/qs8-igemm-minmax-gemmlowp.cc
diff --git a/src/init.c b/src/init.c
index 3519c1f..a174c2e 100644
--- a/src/init.c
+++ b/src/init.c
@@ -2107,7 +2107,16 @@
#ifndef XNN_NO_QU8_OPERATORS
init_flags |= XNN_INIT_FLAG_QU8;
- if (cpuinfo_has_x86_xop()) {
+ if (cpuinfo_has_x86_avx512f() && cpuinfo_has_x86_avx512bw() && cpuinfo_has_x86_avx512dq() && cpuinfo_has_x86_avx512vl()) {
+ xnn_params.qu8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx);
+ xnn_params.qu8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx);
+ xnn_params.qu8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx);
+ xnn_params.qu8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx);
+ xnn_params.qu8.gemm.init.qu8 = xnn_init_qu8_conv_minmax_fp32_avx512_params;
+ xnn_params.qu8.gemm.mr = 4;
+ xnn_params.qu8.gemm.nr = 16;
+ xnn_params.qu8.gemm.log2_kr = 3;
+ } else if (cpuinfo_has_x86_xop()) {
// XOP should be checked before AVX2: AMD Excavator supports both, but performs better with XOP microkernels
xnn_params.qu8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qu8_gemm_minmax_fp32_ukernel_2x4c8__xop_ld64);
xnn_params.qu8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qu8_igemm_minmax_fp32_ukernel_2x4c8__xop_ld64);
diff --git a/src/params-init.c b/src/params-init.c
index 9a8bc9b..a490eae 100644
--- a/src/params-init.c
+++ b/src/params-init.c
@@ -143,6 +143,27 @@
params->fp32_avx2.output_max[i] = output_max;
}
}
+
+void xnn_init_qu8_conv_minmax_fp32_avx512_params(
+ union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
+ uint8_t kernel_zero_point,
+ float scale,
+ uint8_t output_zero_point,
+ uint8_t output_min,
+ uint8_t output_max)
+{
+ for (uint32_t i = 0; i < 16; i++) {
+ params->fp32_avx512.scale[i] = scale;
+ }
+ for (uint32_t i = 0; i < 32; i++) {
+ params->fp32_avx512.kernel_zero_point[i] = (int16_t) (uint16_t) kernel_zero_point;
+ params->fp32_avx512.output_zero_point[i] = (int16_t) (uint16_t) output_zero_point;
+ }
+ for (uint32_t i = 0; i < 64; i++) {
+ params->fp32_avx512.output_min[i] = output_min;
+ params->fp32_avx512.output_max[i] = output_max;
+ }
+}
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
diff --git a/src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c b/src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
index 459da1f..72a74af 100644
--- a/src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
@@ -47,10 +47,10 @@
const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -60,17 +60,17 @@
const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
@@ -82,7 +82,7 @@
__m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c b/src/qc8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
index d09d5a4..62bd1d2 100644
--- a/src/qc8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
@@ -53,14 +53,14 @@
const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
__m512i vacc1xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -73,20 +73,20 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
@@ -102,7 +102,7 @@
__m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c b/src/qc8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
index c7cbe61..a330288 100644
--- a/src/qc8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
@@ -59,9 +59,9 @@
const __m512i voutput_max = _mm512_load_si512(params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -70,7 +70,7 @@
__m512i vacc2x4567 = vacc0x4567;
__m512i vacc2x89AB = vacc0x89AB;
__m512i vacc2xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -86,23 +86,23 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
@@ -122,7 +122,7 @@
__m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c b/src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
index e863817..6d51602 100644
--- a/src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
@@ -65,9 +65,9 @@
const __m512i voutput_max = _mm512_load_si512(params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -80,7 +80,7 @@
__m512i vacc3x4567 = vacc0x4567;
__m512i vacc3x89AB = vacc0x89AB;
__m512i vacc3xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -99,26 +99,26 @@
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
@@ -142,7 +142,7 @@
__m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c b/src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
index 3cbbe01..412b2ee 100644
--- a/src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
@@ -48,10 +48,10 @@
const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -69,17 +69,17 @@
const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 1 * sizeof(void*);
@@ -93,7 +93,7 @@
__m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c b/src/qc8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
index 4157a88..725d4d4 100644
--- a/src/qc8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
@@ -52,14 +52,14 @@
const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
__m512i vacc1xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -84,20 +84,20 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 2 * sizeof(void*);
@@ -115,7 +115,7 @@
__m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c b/src/qc8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
index 0dae177..6bf576c 100644
--- a/src/qc8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
@@ -56,9 +56,9 @@
const __m512i voutput_max = _mm512_load_si512(params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -67,7 +67,7 @@
__m512i vacc2x4567 = vacc0x4567;
__m512i vacc2x89AB = vacc0x89AB;
__m512i vacc2xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -99,23 +99,23 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 3 * sizeof(void*);
@@ -137,7 +137,7 @@
__m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c b/src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
index ad2d565..8ef0305 100644
--- a/src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
+++ b/src/qc8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
@@ -60,9 +60,9 @@
const __m512i voutput_max = _mm512_load_si512(params->avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -75,7 +75,7 @@
__m512i vacc3x4567 = vacc0x4567;
__m512i vacc3x89AB = vacc0x89AB;
__m512i vacc3xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -114,26 +114,26 @@
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 4 * sizeof(void*);
@@ -159,7 +159,7 @@
__m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
diff --git a/src/qs8-gemm/MRx16c8-avx512skx.c.in b/src/qs8-gemm/MRx16c8-avx512skx.c.in
index 9331214..6c8b799 100644
--- a/src/qs8-gemm/MRx16c8-avx512skx.c.in
+++ b/src/qs8-gemm/MRx16c8-avx512skx.c.in
@@ -5,7 +5,7 @@
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert REQUANTIZATION in ["GEMMLOWP", "FP32"]
-$assert DATATYPE in ["QC8", "QS8"]
+$assert DATATYPE in ["QC8", "QS8", "QU8"]
$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32"
$assert VARIANT in ["LD256", "EXTENDED"]
$assert MR <= 4
@@ -20,15 +20,16 @@
$PARAMS_STRUCT = "avx512" if DATATYPE == "QC8" else REQUANTIZATION.lower() + "_avx512"
$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
-$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_qs8_conv_minmax_params"
+$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_%s_conv_minmax_params" % DATATYPE.lower()
+$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
void xnn_${DATATYPE.lower()}_gemm${GEMM_SUFFIX}_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x16c8__avx512skx(
size_t mr,
size_t nc,
size_t kc,
- const int8_t* restrict a,
+ const ${XINT8_T}* restrict a,
size_t a_stride,
const void* restrict w,
- int8_t* restrict c,
+ ${XINT8_T}* restrict c,
size_t cm_stride,
size_t cn_stride,
const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
@@ -37,17 +38,17 @@
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);
- assert(kc % sizeof(int8_t) == 0);
+ assert(kc % sizeof(${XINT8_T}) == 0);
assert(a != NULL);
assert(w != NULL);
assert(c != NULL);
kc = round_up_po2(kc, 8);
- const int8_t* a0 = a;
- int8_t* c0 = c;
+ const ${XINT8_T}* a0 = a;
+ ${XINT8_T}* c0 = c;
$for M in range(1, MR):
- const int8_t* a${M} = (const int8_t*) ((uintptr_t) a${M-1} + a_stride);
- int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
+ const ${XINT8_T}* a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M-1} + a_stride);
+ ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(mr <= ${M}) {
a${M} = a${M-1};
@@ -91,38 +92,49 @@
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
$for N in range(4, 16, 4):
- __m512i vacc0x${ABC[N:N+4]} = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + ${N} * sizeof(int32_t)));
+ __m512i vacc0x${ABC[N:N+4]} = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + ${N}));
$for M in range(1, MR):
$for N in range(0, 16, 4):
__m512i vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
+ $if DATATYPE == "QU8":
+ const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point);
while (k < kc) {
$for M in range(MR):
- const __m512i va${M} = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a${M})));
+ $if DATATYPE == "QU8":
+ const __m512i va${M} = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a${M})));
+ $else:
+ const __m512i va${M} = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a${M})));
a${M} += 8;
$for N in range(0, 16, 4):
$if VARIANT == "EXTENDED":
$if N == 0:
- const __m512i vb${ABC[N:N+4]} = _mm512_load_si512((const __m512i*) w);
+ const __m512i vb${ABC[N:N+4]} = _mm512_load_si512(w);
$else:
- const __m512i vb${ABC[N:N+4]} = _mm512_load_si512((const __m512i*) ((uintptr_t) w + ${N * 8} * sizeof(int16_t)));
+ const __m512i vb${ABC[N:N+4]} = _mm512_load_si512((const void*) ((const int16_t*) w + ${N * 8}));
$else:
- $if N == 0:
- const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
+ $if DATATYPE == "QU8":
+ $if N == 0:
+ const __m512i vb${ABC[N:N+4]} = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+ $else:
+ const __m512i vb${ABC[N:N+4]} = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const ${XINT8_T}*) w + ${N * 8}))), vb_zero_point);
$else:
- const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + ${N * 8} * sizeof(int8_t))));
+ $if N == 0:
+ const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
+ $else:
+ const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const ${XINT8_T}*) w + ${N * 8})));
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = _mm512_add_epi32(vacc${M}x${ABC[N:N+4]}, _mm512_madd_epi16(va${M}, vb${ABC[N:N+4]}));
$if VARIANT == "EXTENDED":
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int16_t));
+ w = (const void*) ((const int16_t*) w + 128);
$else:
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
- k += 8 * sizeof(int8_t);
+ w = (const void*) ((const ${XINT8_T}*) w + 128);
+ k += 8 * sizeof(${XINT8_T});
}
$for M in range(MR):
@@ -166,7 +178,7 @@
$if DATATYPE == "QC8":
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
$for M in range(MR):
vscaled${M}x084C195D2A6E3B7F = _mm512_mul_ps(vscaled${M}x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
@@ -184,22 +196,43 @@
const __m512i vacc${M}${min(M+1, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc${M}x084C195D2A6E3B7F, vacc${min(M+1, MR-1)}x084C195D2A6E3B7F), voutput_zero_point);
$if MR > 2:
- __m512i vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc2${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
+ $if DATATYPE == "QU8":
+ __m512i vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc2${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
+ $else:
+ __m512i vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc2${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
vout012${min(M+3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
__m512i vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_shuffle_epi8(vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
- vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_max_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_min);
- vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_min_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_max);
+ $if DATATYPE == "QU8":
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_max_epu8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_min);
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_min_epu8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_max);
+ $else:
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_max_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_min);
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_min_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_max);
$elif MR == 2:
- const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packs_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
+ $if DATATYPE == "QU8":
+ const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packus_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
+ $else:
+ const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packs_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
const __m256i vout01x084C2A6E195D3B7F = _mm256_permutevar8x32_epi32(vout01x084Cx2A6Ex195Dx3B7F, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0));
__m256i vout01x0123456789ABCDEF = _mm256_shuffle_epi8(vout01x084C2A6E195D3B7F, _mm256_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0, 15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
- vout01x0123456789ABCDEF = _mm256_max_epi8(vout01x0123456789ABCDEF, voutput_min);
- vout01x0123456789ABCDEF = _mm256_min_epi8(vout01x0123456789ABCDEF, voutput_max);
+ $if DATATYPE == "QU8":
+ vout01x0123456789ABCDEF = _mm256_max_epu8(vout01x0123456789ABCDEF, voutput_min);
+ vout01x0123456789ABCDEF = _mm256_min_epu8(vout01x0123456789ABCDEF, voutput_max);
+ $else:
+ vout01x0123456789ABCDEF = _mm256_max_epi8(vout01x0123456789ABCDEF, voutput_min);
+ vout01x0123456789ABCDEF = _mm256_min_epi8(vout01x0123456789ABCDEF, voutput_max);
$elif MR == 1:
- const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
+ $if DATATYPE == "QU8":
+ const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
+ $else:
+ const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
__m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
- vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
- vout0x0123456789ABCDEF = _mm_min_epi8(vout0x0123456789ABCDEF, voutput_max);
+ $if DATATYPE == "QU8":
+ vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = _mm_min_epu8(vout0x0123456789ABCDEF, voutput_max);
+ $else:
+ vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = _mm_min_epi8(vout0x0123456789ABCDEF, voutput_max);
$if MR > 2:
if (nc >= 16) {
@@ -208,10 +241,10 @@
_mm_storeu_si128((__m128i*) c${M}, _mm512_extracti32x4_epi32(vout012${min(M+3, MR-1)}x0123456789ABCDEF, ${M}));
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - k);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
nc -= 16;
} else {
@@ -231,10 +264,10 @@
_mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1));
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - kc);
+ a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - kc);
nc -= 16;
} else {
@@ -252,10 +285,10 @@
_mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
$for M in range(MR):
- a${M} = (const int8_t*) ((uintptr_t) a${M} - k);
+ a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} - k);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
nc -= 16;
} else {
diff --git a/src/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
index 5273ed5..89154ee 100644
--- a/src/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
@@ -48,10 +48,10 @@
const __m128i voutput_max = _mm_load_si128((const __m128i*) params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -61,17 +61,17 @@
const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-gemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c
index da22fae..57fc687 100644
--- a/src/qs8-gemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-gemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c
@@ -53,10 +53,10 @@
const __m128i voutput_max = _mm_load_si128((const __m128i*) params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -66,17 +66,17 @@
const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
index 528c215..8a3f860 100644
--- a/src/qs8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
@@ -54,14 +54,14 @@
const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
__m512i vacc1xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -74,20 +74,20 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-gemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c
index aecefec..4a85591 100644
--- a/src/qs8-gemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-gemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c
@@ -59,14 +59,14 @@
const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
__m512i vacc1xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -79,20 +79,20 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
index 29861c2..7b1054f 100644
--- a/src/qs8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
@@ -60,9 +60,9 @@
const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -71,7 +71,7 @@
__m512i vacc2x4567 = vacc0x4567;
__m512i vacc2x89AB = vacc0x89AB;
__m512i vacc2xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -87,23 +87,23 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-gemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
index ec7e3bd..b5b1567 100644
--- a/src/qs8-gemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-gemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
@@ -65,9 +65,9 @@
const __m512i voutput_max = _mm512_load_si512(params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -76,7 +76,7 @@
__m512i vacc2x4567 = vacc0x4567;
__m512i vacc2x89AB = vacc0x89AB;
__m512i vacc2xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -92,23 +92,23 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c b/src/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
index 9af948e..764d8cf 100644
--- a/src/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
@@ -66,9 +66,9 @@
const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -81,7 +81,7 @@
__m512i vacc3x4567 = vacc0x4567;
__m512i vacc3x89AB = vacc0x89AB;
__m512i vacc3xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -100,26 +100,26 @@
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-gemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-gemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
index af29bdd..d6ffaf4 100644
--- a/src/qs8-gemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-gemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
@@ -71,9 +71,9 @@
const __m512i voutput_max = _mm512_load_si512(params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -86,7 +86,7 @@
__m512i vacc3x4567 = vacc0x4567;
__m512i vacc3x89AB = vacc0x89AB;
__m512i vacc3xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t k = 0;
while (k < kc) {
@@ -105,26 +105,26 @@
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
diff --git a/src/qs8-igemm/MRx16c8-avx512skx.c.in b/src/qs8-igemm/MRx16c8-avx512skx.c.in
index 6237ab0..80f6c60 100644
--- a/src/qs8-igemm/MRx16c8-avx512skx.c.in
+++ b/src/qs8-igemm/MRx16c8-avx512skx.c.in
@@ -5,7 +5,7 @@
$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
$assert REQUANTIZATION in ["GEMMLOWP", "FP32"]
-$assert DATATYPE in ["QC8", "QS8"]
+$assert DATATYPE in ["QC8", "QS8", "QU8"]
$assert DATATYPE != "QC8" or REQUANTIZATION == "FP32"
$assert VARIANT in ["LD256", "EXTENDED"]
$assert MR <= 4
@@ -20,34 +20,35 @@
$PARAMS_STRUCT = "avx512" if DATATYPE == "QC8" else REQUANTIZATION.lower() + "_avx512"
$GEMM_SUFFIX = "_xw" if VARIANT == "EXTENDED" else ""
-$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_qs8_conv_minmax_params"
+$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_%s_conv_minmax_params" % DATATYPE.lower()
+$XINT8_T = "uint8_t" if DATATYPE == "QU8" else "int8_t"
void xnn_${DATATYPE.lower()}_igemm${GEMM_SUFFIX}_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x16c8__avx512skx(
size_t mr,
size_t nc,
size_t kc,
size_t ks,
- const int8_t** restrict a,
+ const ${XINT8_T}** restrict a,
const void* restrict w,
- int8_t* restrict c,
+ ${XINT8_T}* restrict c,
size_t cm_stride,
size_t cn_stride,
size_t a_offset,
- const int8_t* zero,
+ const ${XINT8_T}* zero,
const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
{
assert(mr != 0);
assert(mr <= ${MR});
assert(nc != 0);
assert(kc != 0);
- assert(kc % sizeof(int8_t) == 0);
+ assert(kc % sizeof(${XINT8_T}) == 0);
assert(a != NULL);
assert(w != NULL);
assert(c != NULL);
kc = round_up_po2(kc, 8);
- int8_t* c0 = c;
+ ${XINT8_T}* c0 = c;
$for M in range(1, MR):
- int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
+ ${XINT8_T}* c${M} = (${XINT8_T}*) ((uintptr_t) c${M-1} + cm_stride);
$if M % 2 == 0:
if XNN_UNPREDICTABLE(mr <= ${M}) {
c${M} = c${M-1};
@@ -88,25 +89,30 @@
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
$for N in range(4, 16, 4):
- __m512i vacc0x${ABC[N:N+4]} = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + ${N} * sizeof(int32_t)));
+ __m512i vacc0x${ABC[N:N+4]} = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + ${N}));
$for M in range(1, MR):
$for N in range(0, 16, 4):
__m512i vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
$for M in range(MR):
- const int8_t* restrict a${M} = a[${M}];
+ const ${XINT8_T}* restrict a${M} = a[${M}];
if XNN_UNPREDICTABLE(a${M} != zero) {
- a${M} = (const int8_t*) ((uintptr_t) a${M} + a_offset);
+ a${M} = (const ${XINT8_T}*) ((uintptr_t) a${M} + a_offset);
}
a += ${MR};
size_t k = 0;
+ $if DATATYPE == "QU8":
+ const __m512i vb_zero_point = _mm512_load_si512(params->${PARAMS_STRUCT}.kernel_zero_point);
while (k < kc) {
$for M in range(MR):
- const __m512i va${M} = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a${M})));
+ $if DATATYPE == "QU8":
+ const __m512i va${M} = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a${M})));
+ $else:
+ const __m512i va${M} = _mm512_broadcast_i32x4(_mm_cvtepi8_epi16(_mm_loadl_epi64((const __m128i*) a${M})));
a${M} += 8;
$for N in range(0, 16, 4):
@@ -114,21 +120,27 @@
$if N == 0:
const __m512i vb${ABC[N:N+4]} = _mm512_load_si512((const __m512i*) w);
$else:
- const __m512i vb${ABC[N:N+4]} = _mm512_load_si512((const __m512i*) ((uintptr_t) w + ${N * 8} * sizeof(int16_t)));
+ const __m512i vb${ABC[N:N+4]} = _mm512_load_si512((const __m512i*) ((const int16_t*) w + ${N * 8}));
$else:
- $if N == 0:
- const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
+ $if DATATYPE == "QU8":
+ $if N == 0:
+ const __m512i vb${ABC[N:N+4]} = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+ $else:
+ const __m512i vb${ABC[N:N+4]} = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const ${XINT8_T}*) w + ${N * 8}))), vb_zero_point);
$else:
- const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + ${N * 8} * sizeof(int8_t))));
+ $if N == 0:
+ const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
+ $else:
+ const __m512i vb${ABC[N:N+4]} = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const ${XINT8_T}*) w + ${N * 8})));
$for M in range(MR):
vacc${M}x${ABC[N:N+4]} = _mm512_add_epi32(vacc${M}x${ABC[N:N+4]}, _mm512_madd_epi16(va${M}, vb${ABC[N:N+4]}));
$if VARIANT == "EXTENDED":
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int16_t));
+ w = (const void*) ((const int16_t*) w + 128);
$else:
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
- k += 8 * sizeof(int8_t);
+ w = (const void*) ((const ${XINT8_T}*) w + 128);
+ k += 8 * sizeof(${XINT8_T});
}
p -= ${MR} * sizeof(void*);
} while (p != 0);
@@ -174,7 +186,7 @@
$if DATATYPE == "QC8":
const __m512 vscale012345678ABCDEF = _mm512_load_ps(w);
- w = (const void*) ((uintptr_t) w + 16 * sizeof(float));
+ w = (const void*) ((const float*) w + 16);
const __m512 vscale084C195D2A6E3B7F = _mm512_permutexvar_ps(_mm512_set_epi32(15, 7, 11, 3, 14, 6, 10, 2, 13, 5, 9, 1, 12, 4, 8, 0), vscale012345678ABCDEF);
$for M in range(MR):
vscaled${M}x084C195D2A6E3B7F = _mm512_mul_ps(vscaled${M}x084C195D2A6E3B7F, vscale084C195D2A6E3B7F);
@@ -192,22 +204,43 @@
const __m512i vacc${M}${min(M+1, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc${M}x084C195D2A6E3B7F, vacc${min(M+1, MR-1)}x084C195D2A6E3B7F), voutput_zero_point);
$if MR > 2:
- __m512i vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc2${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
+ $if DATATYPE == "QU8":
+ __m512i vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc2${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
+ $else:
+ __m512i vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_packs_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc2${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
vout012${min(M+3, MR-1)}x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F);
__m512i vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_shuffle_epi8(vout012${min(3, MR-1)}x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
- vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_max_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_min);
- vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_min_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_max);
+ $if DATATYPE == "QU8":
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_max_epu8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_min);
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_min_epu8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_max);
+ $else:
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_max_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_min);
+ vout012${min(3, MR-1)}x0123456789ABCDEF = _mm512_min_epi8(vout012${min(3, MR-1)}x0123456789ABCDEF, voutput_max);
$elif MR == 2:
- const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packs_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
+ $if DATATYPE == "QU8":
+ const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packus_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
+ $else:
+ const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packs_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
const __m256i vout01x084C2A6E195D3B7F = _mm256_permutevar8x32_epi32(vout01x084Cx2A6Ex195Dx3B7F, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0));
__m256i vout01x0123456789ABCDEF = _mm256_shuffle_epi8(vout01x084C2A6E195D3B7F, _mm256_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0, 15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
- vout01x0123456789ABCDEF = _mm256_max_epi8(vout01x0123456789ABCDEF, voutput_min);
- vout01x0123456789ABCDEF = _mm256_min_epi8(vout01x0123456789ABCDEF, voutput_max);
+ $if DATATYPE == "QU8":
+ vout01x0123456789ABCDEF = _mm256_max_epu8(vout01x0123456789ABCDEF, voutput_min);
+ vout01x0123456789ABCDEF = _mm256_min_epu8(vout01x0123456789ABCDEF, voutput_max);
+ $else:
+ vout01x0123456789ABCDEF = _mm256_max_epi8(vout01x0123456789ABCDEF, voutput_min);
+ vout01x0123456789ABCDEF = _mm256_min_epi8(vout01x0123456789ABCDEF, voutput_max);
$elif MR == 1:
- const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
+ $if DATATYPE == "QU8":
+ const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
+ $else:
+ const __m128i vout0x084C2A6E195D3B7F = _mm_packs_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
__m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
- vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
- vout0x0123456789ABCDEF = _mm_min_epi8(vout0x0123456789ABCDEF, voutput_max);
+ $if DATATYPE == "QU8":
+ vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = _mm_min_epu8(vout0x0123456789ABCDEF, voutput_max);
+ $else:
+ vout0x0123456789ABCDEF = _mm_max_epi8(vout0x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = _mm_min_epi8(vout0x0123456789ABCDEF, voutput_max);
$if MR > 2:
if (nc >= 16) {
@@ -216,9 +249,9 @@
_mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout012${min(M+3, MR-1)}x0123456789ABCDEF));
$for M in reversed(range(MR)):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
- a = (const int8_t**restrict) ((uintptr_t) a - ks);
+ a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks);
nc -= 16;
} else {
@@ -238,9 +271,9 @@
_mm_storeu_si128((__m128i*) c0, _mm256_castsi256_si128(vout01x0123456789ABCDEF));
$for M in reversed(range(MR)):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
- a = (const int8_t**restrict) ((uintptr_t) a - ks);
+ a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks);
nc -= 16;
} else {
@@ -258,9 +291,9 @@
_mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
$for M in range(MR):
- c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
+ c${M} = (${XINT8_T}*) ((uintptr_t) c${M} + cn_stride);
- a = (const int8_t**restrict) ((uintptr_t) a - ks);
+ a = (const ${XINT8_T}**restrict) ((uintptr_t) a - ks);
nc -= 16;
} else {
diff --git a/src/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
index 5005b48..755a1ba 100644
--- a/src/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
@@ -49,10 +49,10 @@
const __m128i voutput_max = _mm_load_si128((const __m128i*) params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -70,17 +70,17 @@
const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 1 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-igemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c
index d13cbea..08f5db0 100644
--- a/src/qs8-igemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-igemm/gen/1x16c8-minmax-gemmlowp-avx512skx.c
@@ -54,10 +54,10 @@
const __m128i voutput_max = _mm_load_si128((const __m128i*) params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -75,17 +75,17 @@
const __m512i vb0123 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) w));
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 1 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
index 270a685..a60a077 100644
--- a/src/qs8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
@@ -53,14 +53,14 @@
const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
__m512i vacc1xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -85,20 +85,20 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 2 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-igemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c
index 5156446..6d9a99a 100644
--- a/src/qs8-igemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-igemm/gen/2x16c8-minmax-gemmlowp-avx512skx.c
@@ -58,14 +58,14 @@
const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
__m512i vacc1xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -90,20 +90,20 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 2 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
index 92f4894..443da58 100644
--- a/src/qs8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
@@ -57,9 +57,9 @@
const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -68,7 +68,7 @@
__m512i vacc2x4567 = vacc0x4567;
__m512i vacc2x89AB = vacc0x89AB;
__m512i vacc2xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -100,23 +100,23 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 3 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
index c0aa798..1aaea12 100644
--- a/src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-igemm/gen/3x16c8-minmax-gemmlowp-avx512skx.c
@@ -62,9 +62,9 @@
const __m512i voutput_max = _mm512_load_si512(params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -73,7 +73,7 @@
__m512i vacc2x4567 = vacc0x4567;
__m512i vacc2x89AB = vacc0x89AB;
__m512i vacc2xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -105,23 +105,23 @@
vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 3 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c b/src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
index 7c63429..05a94b1 100644
--- a/src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
+++ b/src/qs8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
@@ -61,9 +61,9 @@
const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -76,7 +76,7 @@
__m512i vacc3x4567 = vacc0x4567;
__m512i vacc3x89AB = vacc0x89AB;
__m512i vacc3xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -115,26 +115,26 @@
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 4 * sizeof(void*);
diff --git a/src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c b/src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
index 8452156..6ee2e1f 100644
--- a/src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
+++ b/src/qs8-igemm/gen/4x16c8-minmax-gemmlowp-avx512skx.c
@@ -66,9 +66,9 @@
const __m512i voutput_max = _mm512_load_si512(params->gemmlowp_avx512.output_max);
do {
__m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
- __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 4 * sizeof(int32_t)));
- __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 8 * sizeof(int32_t)));
- __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((uintptr_t) w + 12 * sizeof(int32_t)));
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
__m512i vacc1x0123 = vacc0x0123;
__m512i vacc1x4567 = vacc0x4567;
__m512i vacc1x89AB = vacc0x89AB;
@@ -81,7 +81,7 @@
__m512i vacc3x4567 = vacc0x4567;
__m512i vacc3x89AB = vacc0x89AB;
__m512i vacc3xCDEF = vacc0xCDEF;
- w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
+ w = (const void*) ((const int32_t*) w + 16);
size_t p = ks;
do {
@@ -120,26 +120,26 @@
vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
- const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 32 * sizeof(int8_t))));
+ const __m512i vb4567 = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 32)));
vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
- const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 64 * sizeof(int8_t))));
+ const __m512i vb89AB = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 64)));
vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
- const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((uintptr_t) w + 96 * sizeof(int8_t))));
+ const __m512i vbCDEF = _mm512_cvtepi8_epi16(_mm256_load_si256((const __m256i*) ((const int8_t*) w + 96)));
vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
- w = (const void*) ((uintptr_t) w + 128 * sizeof(int8_t));
+ w = (const void*) ((const int8_t*) w + 128);
k += 8 * sizeof(int8_t);
}
p -= 4 * sizeof(void*);
diff --git a/src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..cf5f497
--- /dev/null
+++ b/src/qu8-gemm/gen/1x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,114 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
+ const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
+ const __m128i voutput_max = _mm_load_si128((const __m128i*) params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+
+ const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
+
+ const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
+ __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
+ vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = _mm_min_epu8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - k);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
+
+ _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..32912e1
--- /dev/null
+++ b/src/qu8-gemm/gen/2x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,142 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
+ const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
+ const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ __m512i vacc1x0123 = vacc0x0123;
+ __m512i vacc1x4567 = vacc0x4567;
+ __m512i vacc1x89AB = vacc0x89AB;
+ __m512i vacc1xCDEF = vacc0xCDEF;
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+ const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
+ a1 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+ vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+ const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
+ const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+ __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+ __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+ vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+ vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
+
+ const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
+
+ const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packus_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
+ const __m256i vout01x084C2A6E195D3B7F = _mm256_permutevar8x32_epi32(vout01x084Cx2A6Ex195Dx3B7F, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0));
+ __m256i vout01x0123456789ABCDEF = _mm256_shuffle_epi8(vout01x084C2A6E195D3B7F, _mm256_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0, 15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
+ vout01x0123456789ABCDEF = _mm256_max_epu8(vout01x0123456789ABCDEF, voutput_min);
+ vout01x0123456789ABCDEF = _mm256_min_epu8(vout01x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c0, _mm256_castsi256_si128(vout01x0123456789ABCDEF));
+ _mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1));
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
+
+ _mm256_mask_storeu_epi8(c0, vmask, vout01x0123456789ABCDEF);
+ vmask = _kshiftli_mask64(vmask, 16);
+ _mm256_mask_storeu_epi8(c1 - 16, vmask, vout01x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..306fee6
--- /dev/null
+++ b/src/qu8-gemm/gen/3x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,170 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
+ const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
+ const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ __m512i vacc1x0123 = vacc0x0123;
+ __m512i vacc1x4567 = vacc0x4567;
+ __m512i vacc1x89AB = vacc0x89AB;
+ __m512i vacc1xCDEF = vacc0xCDEF;
+ __m512i vacc2x0123 = vacc0x0123;
+ __m512i vacc2x4567 = vacc0x4567;
+ __m512i vacc2x89AB = vacc0x89AB;
+ __m512i vacc2xCDEF = vacc0xCDEF;
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+ const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
+ a1 += 8;
+ const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
+ a2 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
+ vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
+ vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
+ vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+ vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
+ vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+ const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
+ const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
+ const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
+ const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+ __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
+ __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+ __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
+ __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+ vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
+ vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+ vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
+ vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
+
+ const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
+ const __m512i vacc22x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc2x084C195D2A6E3B7F), voutput_zero_point);
+
+ __m512i vout0122x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc22x084Cx195Dx2A6Ex3B7F);
+ vout0122x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0122x084Cx195Dx2A6Ex3B7F);
+ __m512i vout0122x0123456789ABCDEF = _mm512_shuffle_epi8(vout0122x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
+ vout0122x0123456789ABCDEF = _mm512_max_epu8(vout0122x0123456789ABCDEF, voutput_min);
+ vout0122x0123456789ABCDEF = _mm512_min_epu8(vout0122x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0122x0123456789ABCDEF));
+ _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0122x0123456789ABCDEF, 1));
+ _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0122x0123456789ABCDEF, 2));
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - k);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - k);
+ a2 = (const uint8_t*) ((uintptr_t) a2 - k);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
+
+ _mm512_mask_storeu_epi8(c0, vmask, vout0122x0123456789ABCDEF);
+ vmask = _kshiftli_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0122x0123456789ABCDEF);
+ vmask = _kshiftli_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0122x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c b/src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..1a8c017
--- /dev/null
+++ b/src/qu8-gemm/gen/4x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,197 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-gemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/gemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ const uint8_t* restrict a,
+ size_t a_stride,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ const uint8_t* a0 = a;
+ uint8_t* c0 = c;
+ const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ a1 = a0;
+ c1 = c0;
+ }
+ const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ a2 = a1;
+ c2 = c1;
+ }
+ const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ a3 = a2;
+ c3 = c2;
+ }
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
+ const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
+ const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ __m512i vacc1x0123 = vacc0x0123;
+ __m512i vacc1x4567 = vacc0x4567;
+ __m512i vacc1x89AB = vacc0x89AB;
+ __m512i vacc1xCDEF = vacc0xCDEF;
+ __m512i vacc2x0123 = vacc0x0123;
+ __m512i vacc2x4567 = vacc0x4567;
+ __m512i vacc2x89AB = vacc0x89AB;
+ __m512i vacc2xCDEF = vacc0xCDEF;
+ __m512i vacc3x0123 = vacc0x0123;
+ __m512i vacc3x4567 = vacc0x4567;
+ __m512i vacc3x89AB = vacc0x89AB;
+ __m512i vacc3xCDEF = vacc0xCDEF;
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+ const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
+ a1 += 8;
+ const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
+ a2 += 8;
+ const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
+ a3 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
+ vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
+ vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
+ vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
+ vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
+ vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
+ vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+ vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
+ vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
+ vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+ const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
+ const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
+ const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
+ const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
+ const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
+ const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+ __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
+ __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
+ __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+ __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
+ __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
+ __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+ vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
+ vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
+ vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+ vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
+ vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
+ vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
+
+ const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
+ const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
+
+ __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
+ vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
+ __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
+ vout0123x0123456789ABCDEF = _mm512_max_epu8(vout0123x0123456789ABCDEF, voutput_min);
+ vout0123x0123456789ABCDEF = _mm512_min_epu8(vout0123x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
+ _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
+ _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
+ _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
+
+ a0 = (const uint8_t*) ((uintptr_t) a0 - k);
+ a1 = (const uint8_t*) ((uintptr_t) a1 - k);
+ a2 = (const uint8_t*) ((uintptr_t) a2 - k);
+ a3 = (const uint8_t*) ((uintptr_t) a3 - k);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+ c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
+
+ _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
+ vmask = _kshiftli_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
+ vmask = _kshiftli_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
+ vmask = _kshiftli_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..38fbbbc
--- /dev/null
+++ b/src/qu8-igemm/gen/1x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,125 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 1);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ uint8_t* c0 = c;
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_zero_point);
+ const __m128i voutput_min = _mm_load_si128((const __m128i*) params->fp32_avx512.output_min);
+ const __m128i voutput_max = _mm_load_si128((const __m128i*) params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ a += 1;
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+ p -= 1 * sizeof(void*);
+ } while (p != 0);
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+
+ const __m256i vacc0x084C2A6E195D3B7F = _mm256_adds_epi16(_mm256_packs_epi32(_mm512_castsi512_si256(vacc0x084C195D2A6E3B7F), _mm512_extracti32x8_epi32(vacc0x084C195D2A6E3B7F, 1)), voutput_zero_point);
+
+ const __m128i vout0x084C2A6E195D3B7F = _mm_packus_epi16(_mm256_castsi256_si128(vacc0x084C2A6E195D3B7F), _mm256_extracti128_si256(vacc0x084C2A6E195D3B7F, 1));
+ __m128i vout0x0123456789ABCDEF = _mm_shuffle_epi8(vout0x084C2A6E195D3B7F, _mm_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
+ vout0x0123456789ABCDEF = _mm_max_epu8(vout0x0123456789ABCDEF, voutput_min);
+ vout0x0123456789ABCDEF = _mm_min_epu8(vout0x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c0, vout0x0123456789ABCDEF);
+
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ const __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << nc) - UINT32_C(1)));
+
+ _mm_mask_storeu_epi8(c0, vmask, vout0x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..e3321bb
--- /dev/null
+++ b/src/qu8-igemm/gen/2x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,154 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 2);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 2) {
+ c1 = c0;
+ }
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
+ const __m256i voutput_min = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_min);
+ const __m256i voutput_max = _mm256_load_si256((const __m256i*) params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ __m512i vacc1x0123 = vacc0x0123;
+ __m512i vacc1x4567 = vacc0x4567;
+ __m512i vacc1x89AB = vacc0x89AB;
+ __m512i vacc1xCDEF = vacc0xCDEF;
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ a += 2;
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+ const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
+ a1 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+ vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+ p -= 2 * sizeof(void*);
+ } while (p != 0);
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+ const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
+ const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+ __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+ __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+ vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+ vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
+
+ const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
+
+ const __m256i vout01x084Cx2A6Ex195Dx3B7F = _mm256_packus_epi16(_mm512_castsi512_si256(vacc01x084Cx195Dx2A6Ex3B7F), _mm512_extracti32x8_epi32(vacc01x084Cx195Dx2A6Ex3B7F, 1));
+ const __m256i vout01x084C2A6E195D3B7F = _mm256_permutevar8x32_epi32(vout01x084Cx2A6Ex195Dx3B7F, _mm256_set_epi32(7, 5, 3, 1, 6, 4, 2, 0));
+ __m256i vout01x0123456789ABCDEF = _mm256_shuffle_epi8(vout01x084C2A6E195D3B7F, _mm256_set_epi8(15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0, 15, 7, 11, 3, 13, 5, 9, 1, 14, 6, 10, 2, 12, 4, 8, 0));
+ vout01x0123456789ABCDEF = _mm256_max_epu8(vout01x0123456789ABCDEF, voutput_min);
+ vout01x0123456789ABCDEF = _mm256_min_epu8(vout01x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c1, _mm256_extracti128_si256(vout01x0123456789ABCDEF, 1));
+ _mm_storeu_si128((__m128i*) c0, _mm256_castsi256_si128(vout01x0123456789ABCDEF));
+
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT32_C(1) << (nc + 16)) - (UINT32_C(1) << 16)));
+
+ _mm256_mask_storeu_epi8(c1 - 16, vmask, vout01x0123456789ABCDEF);
+ vmask = _kshiftri_mask64(vmask, 16);
+ _mm256_mask_storeu_epi8(c0, vmask, vout01x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..d1fa2c3
--- /dev/null
+++ b/src/qu8-igemm/gen/3x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,183 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 3);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
+ const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
+ const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ __m512i vacc1x0123 = vacc0x0123;
+ __m512i vacc1x4567 = vacc0x4567;
+ __m512i vacc1x89AB = vacc0x89AB;
+ __m512i vacc1xCDEF = vacc0xCDEF;
+ __m512i vacc2x0123 = vacc0x0123;
+ __m512i vacc2x4567 = vacc0x4567;
+ __m512i vacc2x89AB = vacc0x89AB;
+ __m512i vacc2xCDEF = vacc0xCDEF;
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const uint8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ a += 3;
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+ const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
+ a1 += 8;
+ const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
+ a2 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
+ vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
+ vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
+ vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+ vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
+ vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+ p -= 3 * sizeof(void*);
+ } while (p != 0);
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+ const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
+ const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
+ const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
+ const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+ __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
+ __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+ __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
+ __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+ vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
+ vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+ vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
+ vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
+
+ const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
+ const __m512i vacc22x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc2x084C195D2A6E3B7F), voutput_zero_point);
+
+ __m512i vout0122x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc22x084Cx195Dx2A6Ex3B7F);
+ vout0122x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0122x084Cx195Dx2A6Ex3B7F);
+ __m512i vout0122x0123456789ABCDEF = _mm512_shuffle_epi8(vout0122x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
+ vout0122x0123456789ABCDEF = _mm512_max_epu8(vout0122x0123456789ABCDEF, voutput_min);
+ vout0122x0123456789ABCDEF = _mm512_min_epu8(vout0122x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0122x0123456789ABCDEF, 2));
+ _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0122x0123456789ABCDEF, 1));
+ _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0122x0123456789ABCDEF));
+
+ c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 32)) - (UINT64_C(1) << 32)));
+
+ _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0122x0123456789ABCDEF);
+ vmask = _kshiftri_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0122x0123456789ABCDEF);
+ vmask = _kshiftri_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c0, vmask, vout0122x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c b/src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
new file mode 100644
index 0000000..e076fef
--- /dev/null
+++ b/src/qu8-igemm/gen/4x16c8-minmax-fp32-avx512skx.c
@@ -0,0 +1,211 @@
+// Auto-generated file. Do not edit!
+// Template: src/qs8-igemm/MRx16c8-avx512skx.c.in
+// Generator: tools/xngen
+//
+// Copyright 2020 Google LLC
+//
+// This source code is licensed under the BSD-style license found in the
+// LICENSE file in the root directory of this source tree.
+
+#include <assert.h>
+
+#include <immintrin.h>
+
+#include <xnnpack/igemm.h>
+#include <xnnpack/intrinsics-polyfill.h>
+#include <xnnpack/math.h>
+
+
+void xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx(
+ size_t mr,
+ size_t nc,
+ size_t kc,
+ size_t ks,
+ const uint8_t** restrict a,
+ const void* restrict w,
+ uint8_t* restrict c,
+ size_t cm_stride,
+ size_t cn_stride,
+ size_t a_offset,
+ const uint8_t* zero,
+ const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
+{
+ assert(mr != 0);
+ assert(mr <= 4);
+ assert(nc != 0);
+ assert(kc != 0);
+ assert(kc % sizeof(uint8_t) == 0);
+ assert(a != NULL);
+ assert(w != NULL);
+ assert(c != NULL);
+
+ kc = round_up_po2(kc, 8);
+ uint8_t* c0 = c;
+ uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
+ if XNN_UNPREDICTABLE(mr < 2) {
+ c1 = c0;
+ }
+ uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
+ if XNN_UNPREDICTABLE(mr <= 2) {
+ c2 = c1;
+ }
+ uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
+ if XNN_UNPREDICTABLE(mr != 4) {
+ c3 = c2;
+ }
+
+ const __mmask16 vbias_mask = _cvtu32_mask16(0x1111);
+ const __m512 vscale = _mm512_load_ps(params->fp32_avx512.scale);
+ const __m512i voutput_zero_point = _mm512_load_si512(params->fp32_avx512.output_zero_point);
+ const __m512i voutput_min = _mm512_load_si512(params->fp32_avx512.output_min);
+ const __m512i voutput_max = _mm512_load_si512(params->fp32_avx512.output_max);
+ do {
+ __m512i vacc0x0123 = _mm512_maskz_expandloadu_epi32(vbias_mask, w);
+ __m512i vacc0x4567 = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 4));
+ __m512i vacc0x89AB = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 8));
+ __m512i vacc0xCDEF = _mm512_maskz_expandloadu_epi32(vbias_mask, (const void*) ((const int32_t*) w + 12));
+ __m512i vacc1x0123 = vacc0x0123;
+ __m512i vacc1x4567 = vacc0x4567;
+ __m512i vacc1x89AB = vacc0x89AB;
+ __m512i vacc1xCDEF = vacc0xCDEF;
+ __m512i vacc2x0123 = vacc0x0123;
+ __m512i vacc2x4567 = vacc0x4567;
+ __m512i vacc2x89AB = vacc0x89AB;
+ __m512i vacc2xCDEF = vacc0xCDEF;
+ __m512i vacc3x0123 = vacc0x0123;
+ __m512i vacc3x4567 = vacc0x4567;
+ __m512i vacc3x89AB = vacc0x89AB;
+ __m512i vacc3xCDEF = vacc0xCDEF;
+ w = (const void*) ((const int32_t*) w + 16);
+
+ size_t p = ks;
+ do {
+ const uint8_t* restrict a0 = a[0];
+ if XNN_UNPREDICTABLE(a0 != zero) {
+ a0 = (const uint8_t*) ((uintptr_t) a0 + a_offset);
+ }
+ const uint8_t* restrict a1 = a[1];
+ if XNN_UNPREDICTABLE(a1 != zero) {
+ a1 = (const uint8_t*) ((uintptr_t) a1 + a_offset);
+ }
+ const uint8_t* restrict a2 = a[2];
+ if XNN_UNPREDICTABLE(a2 != zero) {
+ a2 = (const uint8_t*) ((uintptr_t) a2 + a_offset);
+ }
+ const uint8_t* restrict a3 = a[3];
+ if XNN_UNPREDICTABLE(a3 != zero) {
+ a3 = (const uint8_t*) ((uintptr_t) a3 + a_offset);
+ }
+ a += 4;
+
+ size_t k = 0;
+ const __m512i vb_zero_point = _mm512_load_si512(params->fp32_avx512.kernel_zero_point);
+ while (k < kc) {
+ const __m512i va0 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a0)));
+ a0 += 8;
+ const __m512i va1 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a1)));
+ a1 += 8;
+ const __m512i va2 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a2)));
+ a2 += 8;
+ const __m512i va3 = _mm512_broadcast_i32x4(_mm_cvtepu8_epi16(_mm_loadl_epi64((const __m128i*) a3)));
+ a3 += 8;
+
+ const __m512i vb0123 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) w)), vb_zero_point);
+
+ vacc0x0123 = _mm512_add_epi32(vacc0x0123, _mm512_madd_epi16(va0, vb0123));
+ vacc1x0123 = _mm512_add_epi32(vacc1x0123, _mm512_madd_epi16(va1, vb0123));
+ vacc2x0123 = _mm512_add_epi32(vacc2x0123, _mm512_madd_epi16(va2, vb0123));
+ vacc3x0123 = _mm512_add_epi32(vacc3x0123, _mm512_madd_epi16(va3, vb0123));
+ const __m512i vb4567 = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 32))), vb_zero_point);
+
+ vacc0x4567 = _mm512_add_epi32(vacc0x4567, _mm512_madd_epi16(va0, vb4567));
+ vacc1x4567 = _mm512_add_epi32(vacc1x4567, _mm512_madd_epi16(va1, vb4567));
+ vacc2x4567 = _mm512_add_epi32(vacc2x4567, _mm512_madd_epi16(va2, vb4567));
+ vacc3x4567 = _mm512_add_epi32(vacc3x4567, _mm512_madd_epi16(va3, vb4567));
+ const __m512i vb89AB = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 64))), vb_zero_point);
+
+ vacc0x89AB = _mm512_add_epi32(vacc0x89AB, _mm512_madd_epi16(va0, vb89AB));
+ vacc1x89AB = _mm512_add_epi32(vacc1x89AB, _mm512_madd_epi16(va1, vb89AB));
+ vacc2x89AB = _mm512_add_epi32(vacc2x89AB, _mm512_madd_epi16(va2, vb89AB));
+ vacc3x89AB = _mm512_add_epi32(vacc3x89AB, _mm512_madd_epi16(va3, vb89AB));
+ const __m512i vbCDEF = _mm512_sub_epi16(_mm512_cvtepu8_epi16(_mm256_load_si256((const __m256i*) ((const uint8_t*) w + 96))), vb_zero_point);
+
+ vacc0xCDEF = _mm512_add_epi32(vacc0xCDEF, _mm512_madd_epi16(va0, vbCDEF));
+ vacc1xCDEF = _mm512_add_epi32(vacc1xCDEF, _mm512_madd_epi16(va1, vbCDEF));
+ vacc2xCDEF = _mm512_add_epi32(vacc2xCDEF, _mm512_madd_epi16(va2, vbCDEF));
+ vacc3xCDEF = _mm512_add_epi32(vacc3xCDEF, _mm512_madd_epi16(va3, vbCDEF));
+
+ w = (const void*) ((const uint8_t*) w + 128);
+ k += 8 * sizeof(uint8_t);
+ }
+ p -= 4 * sizeof(void*);
+ } while (p != 0);
+
+ const __m512i vacc0x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x0123, vacc0x4567), _mm512_unpackhi_epi32(vacc0x0123, vacc0x4567));
+ const __m512i vacc0x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x89AB, vacc0xCDEF), _mm512_unpackhi_epi32(vacc0x89AB, vacc0xCDEF));
+ const __m512i vacc1x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x0123, vacc1x4567), _mm512_unpackhi_epi32(vacc1x0123, vacc1x4567));
+ const __m512i vacc1x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x89AB, vacc1xCDEF), _mm512_unpackhi_epi32(vacc1x89AB, vacc1xCDEF));
+ const __m512i vacc2x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x0123, vacc2x4567), _mm512_unpackhi_epi32(vacc2x0123, vacc2x4567));
+ const __m512i vacc2x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x89AB, vacc2xCDEF), _mm512_unpackhi_epi32(vacc2x89AB, vacc2xCDEF));
+ const __m512i vacc3x04152637 = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x0123, vacc3x4567), _mm512_unpackhi_epi32(vacc3x0123, vacc3x4567));
+ const __m512i vacc3x8C9DAEBF = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x89AB, vacc3xCDEF), _mm512_unpackhi_epi32(vacc3x89AB, vacc3xCDEF));
+
+ __m512i vacc0x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc0x04152637, vacc0x8C9DAEBF), _mm512_unpackhi_epi32(vacc0x04152637, vacc0x8C9DAEBF));
+ __m512i vacc1x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc1x04152637, vacc1x8C9DAEBF), _mm512_unpackhi_epi32(vacc1x04152637, vacc1x8C9DAEBF));
+ __m512i vacc2x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc2x04152637, vacc2x8C9DAEBF), _mm512_unpackhi_epi32(vacc2x04152637, vacc2x8C9DAEBF));
+ __m512i vacc3x084C195D2A6E3B7F = _mm512_add_epi32(_mm512_unpacklo_epi32(vacc3x04152637, vacc3x8C9DAEBF), _mm512_unpackhi_epi32(vacc3x04152637, vacc3x8C9DAEBF));
+
+ __m512 vscaled0x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc0x084C195D2A6E3B7F);
+ __m512 vscaled1x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc1x084C195D2A6E3B7F);
+ __m512 vscaled2x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc2x084C195D2A6E3B7F);
+ __m512 vscaled3x084C195D2A6E3B7F = _mm512_cvtepi32_ps(vacc3x084C195D2A6E3B7F);
+
+ vscaled0x084C195D2A6E3B7F = _mm512_mul_ps(vscaled0x084C195D2A6E3B7F, vscale);
+ vscaled1x084C195D2A6E3B7F = _mm512_mul_ps(vscaled1x084C195D2A6E3B7F, vscale);
+ vscaled2x084C195D2A6E3B7F = _mm512_mul_ps(vscaled2x084C195D2A6E3B7F, vscale);
+ vscaled3x084C195D2A6E3B7F = _mm512_mul_ps(vscaled3x084C195D2A6E3B7F, vscale);
+
+ vacc0x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled0x084C195D2A6E3B7F);
+ vacc1x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled1x084C195D2A6E3B7F);
+ vacc2x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled2x084C195D2A6E3B7F);
+ vacc3x084C195D2A6E3B7F = _mm512_cvtps_epi32(vscaled3x084C195D2A6E3B7F);
+
+ const __m512i vacc01x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc0x084C195D2A6E3B7F, vacc1x084C195D2A6E3B7F), voutput_zero_point);
+ const __m512i vacc23x084Cx195Dx2A6Ex3B7F = _mm512_adds_epi16(_mm512_packs_epi32(vacc2x084C195D2A6E3B7F, vacc3x084C195D2A6E3B7F), voutput_zero_point);
+
+ __m512i vout0123x084Cx195Dx2A6Ex3B7F = _mm512_packus_epi16(vacc01x084Cx195Dx2A6Ex3B7F, vacc23x084Cx195Dx2A6Ex3B7F);
+ vout0123x084Cx195Dx2A6Ex3B7F = _mm512_permutexvar_epi32(_mm512_set_epi32(15, 11, 7, 3, 14, 10, 6, 2, 13, 9, 5, 1, 12, 8, 4, 0), vout0123x084Cx195Dx2A6Ex3B7F);
+ __m512i vout0123x0123456789ABCDEF = _mm512_shuffle_epi8(vout0123x084Cx195Dx2A6Ex3B7F, _mm512_set_epi8(15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0, 15, 11, 7, 3, 13, 9, 5, 1, 14, 10, 6, 2, 12, 8, 4, 0));
+ vout0123x0123456789ABCDEF = _mm512_max_epu8(vout0123x0123456789ABCDEF, voutput_min);
+ vout0123x0123456789ABCDEF = _mm512_min_epu8(vout0123x0123456789ABCDEF, voutput_max);
+
+ if (nc >= 16) {
+ _mm_storeu_si128((__m128i*) c3, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 3));
+ _mm_storeu_si128((__m128i*) c2, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 2));
+ _mm_storeu_si128((__m128i*) c1, _mm512_extracti32x4_epi32(vout0123x0123456789ABCDEF, 1));
+ _mm_storeu_si128((__m128i*) c0, _mm512_castsi512_si128(vout0123x0123456789ABCDEF));
+
+ c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
+ c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
+ c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
+ c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
+
+ a = (const uint8_t**restrict) ((uintptr_t) a - ks);
+
+ nc -= 16;
+ } else {
+ // Prepare mask for valid 8-bit elements (depends on nc).
+ __mmask64 vmask = _cvtu64_mask64((uint64_t) ((UINT64_C(1) << (nc + 48)) - (UINT64_C(1) << 48)));
+
+ _mm512_mask_storeu_epi8(c3 - 48, vmask, vout0123x0123456789ABCDEF);
+ vmask = _kshiftri_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c2 - 32, vmask, vout0123x0123456789ABCDEF);
+ vmask = _kshiftri_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c1 - 16, vmask, vout0123x0123456789ABCDEF);
+ vmask = _kshiftri_mask64(vmask, 16);
+ _mm512_mask_storeu_epi8(c0, vmask, vout0123x0123456789ABCDEF);
+
+ nc = 0;
+ }
+ } while (nc != 0);
+}
diff --git a/src/xnnpack/gemm.h b/src/xnnpack/gemm.h
index f46aee6..77136e8 100644
--- a/src/xnnpack/gemm.h
+++ b/src/xnnpack/gemm.h
@@ -586,6 +586,11 @@
DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_fp32_ukernel_2x8c8__avx2)
DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_fp32_ukernel_3x8c8__avx2)
+DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx)
+DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx)
+DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx)
+DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx)
+
DECLARE_QU8_GEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_gemm_minmax_gemmlowp_ukernel_2x2__scalar)
diff --git a/src/xnnpack/igemm.h b/src/xnnpack/igemm.h
index 39afb40..f4483e1 100644
--- a/src/xnnpack/igemm.h
+++ b/src/xnnpack/igemm.h
@@ -395,6 +395,11 @@
DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_fp32_ukernel_2x8c8__avx2)
DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_fp32_ukernel_3x8c8__avx2)
+DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx)
+DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx)
+DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx)
+DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx)
+
DECLARE_QU8_IGEMM_MINMAX_UKERNEL_FUNCTION(xnn_qu8_igemm_minmax_gemmlowp_ukernel_2x2__scalar)
diff --git a/src/xnnpack/params-init.h b/src/xnnpack/params-init.h
index 3f87fef..f61b55a 100644
--- a/src/xnnpack/params-init.h
+++ b/src/xnnpack/params-init.h
@@ -49,6 +49,14 @@
uint8_t output_zero_point,
uint8_t output_min,
uint8_t output_max);
+
+XNN_INTERNAL void xnn_init_qu8_conv_minmax_fp32_avx512_params(
+ union xnn_qu8_conv_minmax_params params[XNN_MIN_ELEMENTS(1)],
+ uint8_t kernel_zero_point,
+ float scale,
+ uint8_t output_zero_point,
+ uint8_t output_min,
+ uint8_t output_max);
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
#if XNN_ARCH_ARM || XNN_ARCH_ARM64
diff --git a/src/xnnpack/params.h b/src/xnnpack/params.h
index 14860da..e7c3fd9 100644
--- a/src/xnnpack/params.h
+++ b/src/xnnpack/params.h
@@ -288,6 +288,13 @@
XNN_ALIGN(32) uint8_t output_min[32];
XNN_ALIGN(32) uint8_t output_max[32];
} fp32_avx2;
+ struct {
+ XNN_ALIGN(64) int16_t kernel_zero_point[32];
+ XNN_ALIGN(64) float scale[16];
+ XNN_ALIGN(64) int16_t output_zero_point[32];
+ XNN_ALIGN(64) uint8_t output_min[64];
+ XNN_ALIGN(64) uint8_t output_max[64];
+ } fp32_avx512;
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
};
diff --git a/test/qu8-gemm-minmax-fp32.cc b/test/qu8-gemm-minmax-fp32.cc
index 9507af1..35eca5e 100644
--- a/test/qu8-gemm-minmax-fp32.cc
+++ b/test/qu8-gemm-minmax-fp32.cc
@@ -29815,3 +29815,2023 @@
}
}
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_lt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_gt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_div_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_1X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_lt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_gt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_div_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_2X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_lt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_gt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_div_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_3X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_lt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_stride(11)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_gt_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_div_8_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_stride(83)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16_strided_a) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .a_stride(43)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_GEMM_MINMAX_FP32_4X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
diff --git a/test/qu8-gemm-minmax-fp32.yaml b/test/qu8-gemm-minmax-fp32.yaml
index 64a9400..ad6bd1f 100644
--- a/test/qu8-gemm-minmax-fp32.yaml
+++ b/test/qu8-gemm-minmax-fp32.yaml
@@ -180,3 +180,15 @@
- name: xnn_qu8_gemm_minmax_fp32_ukernel_3x8c8__avx2
init: xnn_init_qu8_conv_minmax_fp32_avx2_params
k-block: 8
+- name: xnn_qu8_gemm_minmax_fp32_ukernel_1x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
+- name: xnn_qu8_gemm_minmax_fp32_ukernel_2x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
+- name: xnn_qu8_gemm_minmax_fp32_ukernel_3x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
+- name: xnn_qu8_gemm_minmax_fp32_ukernel_4x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
diff --git a/test/qu8-igemm-minmax-fp32.cc b/test/qu8-igemm-minmax-fp32.cc
index 8d773a4..7f3f919 100644
--- a/test/qu8-igemm-minmax-fp32.cc
+++ b/test/qu8-igemm-minmax-fp32.cc
@@ -30523,3 +30523,2071 @@
}
}
#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, small_kernel_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_gt_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, n_div_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 1; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, a_offset) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(43)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, zero) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t mz = 0; mz < 1; mz++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(43)
+ .zero_index(mz)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_1X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(1)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(1)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, small_kernel_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_gt_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, n_div_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 2; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, a_offset) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, zero) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t mz = 0; mz < 2; mz++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(83)
+ .zero_index(mz)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_2X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(2)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(2)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, small_kernel_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_gt_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, n_div_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 3; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, a_offset) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(127)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, zero) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t mz = 0; mz < 3; mz++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(127)
+ .zero_index(mz)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_3X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(3)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(3)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
+
+
+#if XNN_ARCH_X86 || XNN_ARCH_X86_64
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_subtile_m) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(16)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_eq_8_subtile_n) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(8)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_lt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_lt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k < 8; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_gt_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_gt_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 9; k < 16; k++) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_div_8) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, k_div_8_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 16; k <= 80; k += 8) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16_strided_cn) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(n)
+ .k(k)
+ .cn_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, small_kernel_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .ks(3)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_gt_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 17; n < 32; n++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, n_div_16_small_kernel) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t n = 32; n <= 48; n += 16) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, strided_cm_subtile) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ for (uint32_t m = 1; m <= 4; m++) {
+ for (uint32_t n = 1; n <= 16; n++) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(m)
+ .n(n)
+ .k(k)
+ .cm_stride(19)
+ .iterations(1)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, a_offset) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(163)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, zero) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (uint32_t mz = 0; mz < 4; mz++) {
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .ks(3)
+ .a_offset(163)
+ .zero_index(mz)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, qmin) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .qmin(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, qmax) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .qmax(128)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, strided_cm) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(8)
+ .cm_stride(19)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, no_a_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, no_b_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+
+ TEST(QU8_IGEMM_MINMAX_FP32_4X16C8__AVX512SKX, no_zero_point) {
+ TEST_REQUIRES_X86_AVX512SKX;
+ for (size_t k = 1; k <= 40; k += 9) {
+ GemmMicrokernelTester()
+ .mr(4)
+ .nr(16)
+ .kr(8)
+ .sr(1)
+ .m(4)
+ .n(16)
+ .k(k)
+ .a_zero_point(0)
+ .b_zero_point(0)
+ .Test(xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx, xnn_init_qu8_conv_minmax_fp32_avx512_params, xnn_init_qu8_requantization_fp32_params, xnn_qu8_requantize_fp32);
+ }
+ }
+#endif // XNN_ARCH_X86 || XNN_ARCH_X86_64
diff --git a/test/qu8-igemm-minmax-fp32.yaml b/test/qu8-igemm-minmax-fp32.yaml
index da8f89c..bd0ec13 100644
--- a/test/qu8-igemm-minmax-fp32.yaml
+++ b/test/qu8-igemm-minmax-fp32.yaml
@@ -180,3 +180,15 @@
- name: xnn_qu8_igemm_minmax_fp32_ukernel_3x8c8__avx2
init: xnn_init_qu8_conv_minmax_fp32_avx2_params
k-block: 8
+- name: xnn_qu8_igemm_minmax_fp32_ukernel_1x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
+- name: xnn_qu8_igemm_minmax_fp32_ukernel_2x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
+- name: xnn_qu8_igemm_minmax_fp32_ukernel_3x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8
+- name: xnn_qu8_igemm_minmax_fp32_ukernel_4x16c8__avx512skx
+ init: xnn_init_qu8_conv_minmax_fp32_avx512_params
+ k-block: 8