Switch from C2 to S4C2 for qs8 microkernels on 32 bit ARM
PiperOrigin-RevId: 407217899
diff --git a/BUILD.bazel b/BUILD.bazel
index 8ad5841..b4e12cd 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -133,8 +133,8 @@
"src/f32-conv-hwc/3x3s2p1c3x4-scalar-1x1.c",
"src/f32-conv-hwc2chw/3x3s2p1c3x4-scalar-1x1.c",
"src/f32-dwconv/gen/up1x3-minmax-scalar-acc2.c",
- "src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c",
"src/f32-dwconv/gen/up1x3-scalar-acc2.c",
+ "src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c",
"src/f32-dwconv/gen/up1x4-scalar-acc2.c",
"src/f32-dwconv/gen/up1x9-minmax-scalar-acc2.c",
"src/f32-dwconv/gen/up1x9-scalar-acc2.c",
@@ -1153,11 +1153,11 @@
"src/f32-dwconv/gen/up4x3-minmax-wasmsimd-arm.c",
"src/f32-dwconv/gen/up4x3-minmax-wasmsimd-x86-acc2.c",
"src/f32-dwconv/gen/up4x3-minmax-wasmsimd-x86.c",
+ "src/f32-dwconv/gen/up4x3-wasmsimd.c",
"src/f32-dwconv/gen/up4x4-minmax-wasmsimd-arm-acc2.c",
"src/f32-dwconv/gen/up4x4-minmax-wasmsimd-arm.c",
"src/f32-dwconv/gen/up4x4-minmax-wasmsimd-x86-acc2.c",
"src/f32-dwconv/gen/up4x4-minmax-wasmsimd-x86.c",
- "src/f32-dwconv/gen/up4x3-wasmsimd.c",
"src/f32-dwconv/gen/up4x4-wasmsimd.c",
"src/f32-dwconv/gen/up4x9-minmax-wasmsimd-arm-acc2.c",
"src/f32-dwconv/gen/up4x9-minmax-wasmsimd-arm.c",
@@ -1173,11 +1173,11 @@
"src/f32-dwconv/gen/up8x3-minmax-wasmsimd-arm.c",
"src/f32-dwconv/gen/up8x3-minmax-wasmsimd-x86-acc2.c",
"src/f32-dwconv/gen/up8x3-minmax-wasmsimd-x86.c",
+ "src/f32-dwconv/gen/up8x3-wasmsimd.c",
"src/f32-dwconv/gen/up8x4-minmax-wasmsimd-arm-acc2.c",
"src/f32-dwconv/gen/up8x4-minmax-wasmsimd-arm.c",
"src/f32-dwconv/gen/up8x4-minmax-wasmsimd-x86-acc2.c",
"src/f32-dwconv/gen/up8x4-minmax-wasmsimd-x86.c",
- "src/f32-dwconv/gen/up8x3-wasmsimd.c",
"src/f32-dwconv/gen/up8x4-wasmsimd.c",
"src/f32-dwconv/gen/up8x9-minmax-wasmsimd-arm-acc2.c",
"src/f32-dwconv/gen/up8x9-minmax-wasmsimd-arm.c",
@@ -2100,12 +2100,12 @@
"src/qs8-dwconv/gen/up16x25-minmax-rndnu-neon-mla8-ld64.c",
"src/qs8-gavgpool/gen/7p7x-minmax-neon-c8-acc2.c",
"src/qs8-gavgpool/gen/7x-minmax-neon-c8-acc2.c",
- "src/qs8-gemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
+ "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c",
- "src/qs8-gemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
- "src/qs8-igemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
+ "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c",
- "src/qs8-igemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
+ "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-vadd/gen/minmax-neon-ld64-x16.c",
"src/qs8-vadd/gen/minmax-neon-ld64-x32.c",
"src/qs8-vaddc/gen/minmax-neon-ld64-x16.c",
@@ -2523,8 +2523,8 @@
"src/qs8-gemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c",
"src/qs8-gemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-gemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c",
- "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-gemm/gen/1x16-minmax-fp32-neon-mlal-lane.c",
"src/qs8-gemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c",
"src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c",
@@ -2546,8 +2546,8 @@
"src/qs8-gemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c",
"src/qs8-gemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-gemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c",
- "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+ "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-gemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c",
"src/qs8-gemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c",
"src/qs8-gemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c",
@@ -2600,8 +2600,8 @@
"src/qs8-igemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c",
"src/qs8-igemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-igemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c",
- "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-igemm/gen/1x16-minmax-fp32-neon-mlal-lane.c",
"src/qs8-igemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c",
"src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c",
@@ -2623,8 +2623,8 @@
"src/qs8-igemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c",
"src/qs8-igemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c",
"src/qs8-igemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c",
- "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+ "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
"src/qs8-igemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c",
"src/qs8-igemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c",
"src/qs8-igemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 628e959..7d7ab0e8 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -261,8 +261,8 @@
src/f32-conv-hwc/3x3s2p1c3x4-scalar-1x1.c
src/f32-conv-hwc2chw/3x3s2p1c3x4-scalar-1x1.c
src/f32-dwconv/gen/up1x3-minmax-scalar-acc2.c
- src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c
src/f32-dwconv/gen/up1x3-scalar-acc2.c
+ src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c
src/f32-dwconv/gen/up1x4-scalar-acc2.c
src/f32-dwconv/gen/up1x9-minmax-scalar-acc2.c
src/f32-dwconv/gen/up1x9-scalar-acc2.c
@@ -1135,12 +1135,12 @@
src/qs8-dwconv/gen/up16x25-minmax-rndnu-neon-mla8-ld64.c
src/qs8-gavgpool/gen/7p7x-minmax-neon-c8-acc2.c
src/qs8-gavgpool/gen/7x-minmax-neon-c8-acc2.c
- src/qs8-gemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c
+ src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c
- src/qs8-gemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c
- src/qs8-igemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c
+ src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
+ src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c
- src/qs8-igemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c
+ src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
src/qs8-vadd/gen/minmax-neon-ld64-x16.c
src/qs8-vadd/gen/minmax-neon-ld64-x32.c
src/qs8-vaddc/gen/minmax-neon-ld64-x16.c
@@ -1557,8 +1557,8 @@
src/qs8-gemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c
src/qs8-gemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c
src/qs8-gemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c
- src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
+ src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-gemm/gen/1x16-minmax-fp32-neon-mlal-lane.c
src/qs8-gemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c
src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c
@@ -1580,8 +1580,8 @@
src/qs8-gemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c
src/qs8-gemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c
src/qs8-gemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c
- src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
+ src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-gemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c
src/qs8-gemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c
src/qs8-gemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c
@@ -1634,8 +1634,8 @@
src/qs8-igemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c
src/qs8-igemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c
src/qs8-igemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c
- src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
+ src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-igemm/gen/1x16-minmax-fp32-neon-mlal-lane.c
src/qs8-igemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c
src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c
@@ -1657,8 +1657,8 @@
src/qs8-igemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c
src/qs8-igemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c
src/qs8-igemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c
- src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
+ src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
src/qs8-igemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c
src/qs8-igemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c
src/qs8-igemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c
diff --git a/src/init.c b/src/init.c
index a99e318..8fd2538 100644
--- a/src/init.c
+++ b/src/init.c
@@ -148,14 +148,15 @@
xnn_params.qs8.gemm.nr = 8;
xnn_params.qs8.gemm.log2_kr = 2;
} else {
- xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
+ xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
xnn_params.qs8.gemm.mr = 2;
xnn_params.qs8.gemm.nr = 8;
xnn_params.qs8.gemm.log2_kr = 1;
+ xnn_params.qs8.gemm.log2_sr = 2;
}
xnn_params.qs8.dwconv[0].minmax.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_qs8_dwconv_minmax_rndnu_ukernel_up16x9__neon_mla8_ld64;
@@ -1289,14 +1290,15 @@
xnn_params.qs8.gemm.nr = 16;
xnn_params.qs8.gemm.log2_kr = 2;
} else {
- xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
+ xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
xnn_params.qs8.gemm.mr = 2;
xnn_params.qs8.gemm.nr = 8;
xnn_params.qs8.gemm.log2_kr = 1;
+ xnn_params.qs8.gemm.log2_sr = 2;
}
#endif // XNN_ENABLE_ASSEMBLY
#else // !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC
@@ -1410,10 +1412,10 @@
xnn_params.qs8.gemm.nr = 16;
xnn_params.qs8.gemm.log2_kr = 2;
} else {
- xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
- xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
+ xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
+ xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
xnn_params.qs8.gemm.mr = 2;
xnn_params.qs8.gemm.nr = 8;
diff --git a/src/packing.c b/src/packing.c
index 3aefbdd..63619dc 100644
--- a/src/packing.c
+++ b/src/packing.c
@@ -271,7 +271,9 @@
size_t extra_bytes,
const struct xnn_qs8_packing_params* params)
{
- assert(sr == 1);
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
const int32_t izp = (int32_t) params->input_zero_point;
do {
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
@@ -290,7 +292,24 @@
} while (--n != 0);
}
packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
- for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const int8_t kv = k[(nr_block_start + nr_block_offset) * kc + (round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset)];
+ ksum += (int32_t) kv;
+ *((int16_t*) packed_w) = (int16_t) kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int16_t));
+ }
+ packed_b[nr_block_offset] -= ksum * izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int16_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int16_t));
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
const size_t kr_block_size = min(kc - kr_block_start, kr);
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
int32_t ksum = 0;
@@ -467,7 +486,9 @@
void* packed_w,
const struct xnn_qs8_packing_params* params)
{
- assert(sr == 1);
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
const int32_t izp = (int32_t) params->input_zero_point;
for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
const size_t nr_block_size = min(nc - nr_block_start, nr);
@@ -485,7 +506,24 @@
} while (--n != 0);
}
packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
- for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const int8_t kv = k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+ ksum += (int32_t) kv;
+ *((int8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
const size_t kr_block_size = min(kc - kr_block_start, kr);
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
int32_t ksum = 0;
@@ -1043,6 +1081,7 @@
packed_w += nr;
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
+
for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
@@ -1093,7 +1132,9 @@
struct subconvolution_params* subconv_params,
const struct xnn_qs8_packing_params* params)
{
- assert(sr == 1);
+ const size_t skr = sr * kr;
+ const size_t skc = round_down_po2(kc, skr);
+ const size_t sr_mask = (sr - 1) * kr;
const int32_t izp = (int32_t) params->input_zero_point;
for (size_t i = 0; i < g; i++) {
for (size_t oy = 0; oy < sh; oy++) {
@@ -1119,7 +1160,25 @@
packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
for (size_t ky = oy; ky < kh; ky += sh) {
for (size_t kx = ox; kx < kw; kx += sw) {
- for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+
+ for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+ const size_t kr_block_size = min(kc - kr_block_start, kr);
+ for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+ int32_t ksum = 0;
+ for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+ const int8_t kv =
+ k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset)];
+ ksum += (int32_t) kv;
+ *((int8_t*) packed_w) = kv;
+ packed_w = (void*) ((uintptr_t) packed_w + sizeof(int8_t));
+ }
+ packed_b[nr_block_offset] -= ksum * izp;
+ packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int8_t));
+ }
+ packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
+ }
+
+ for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
const size_t kr_block_size = min(kc - kr_block_start, kr);
for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
int32_t ksum = 0;
@@ -1135,6 +1194,7 @@
}
packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
}
+
}
}
}