Switch from C2 to S4C2 for qs8 microkernels on 32 bit ARM

PiperOrigin-RevId: 407217899
diff --git a/BUILD.bazel b/BUILD.bazel
index 8ad5841..b4e12cd 100644
--- a/BUILD.bazel
+++ b/BUILD.bazel
@@ -133,8 +133,8 @@
     "src/f32-conv-hwc/3x3s2p1c3x4-scalar-1x1.c",
     "src/f32-conv-hwc2chw/3x3s2p1c3x4-scalar-1x1.c",
     "src/f32-dwconv/gen/up1x3-minmax-scalar-acc2.c",
-    "src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c",
     "src/f32-dwconv/gen/up1x3-scalar-acc2.c",
+    "src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c",
     "src/f32-dwconv/gen/up1x4-scalar-acc2.c",
     "src/f32-dwconv/gen/up1x9-minmax-scalar-acc2.c",
     "src/f32-dwconv/gen/up1x9-scalar-acc2.c",
@@ -1153,11 +1153,11 @@
     "src/f32-dwconv/gen/up4x3-minmax-wasmsimd-arm.c",
     "src/f32-dwconv/gen/up4x3-minmax-wasmsimd-x86-acc2.c",
     "src/f32-dwconv/gen/up4x3-minmax-wasmsimd-x86.c",
+    "src/f32-dwconv/gen/up4x3-wasmsimd.c",
     "src/f32-dwconv/gen/up4x4-minmax-wasmsimd-arm-acc2.c",
     "src/f32-dwconv/gen/up4x4-minmax-wasmsimd-arm.c",
     "src/f32-dwconv/gen/up4x4-minmax-wasmsimd-x86-acc2.c",
     "src/f32-dwconv/gen/up4x4-minmax-wasmsimd-x86.c",
-    "src/f32-dwconv/gen/up4x3-wasmsimd.c",
     "src/f32-dwconv/gen/up4x4-wasmsimd.c",
     "src/f32-dwconv/gen/up4x9-minmax-wasmsimd-arm-acc2.c",
     "src/f32-dwconv/gen/up4x9-minmax-wasmsimd-arm.c",
@@ -1173,11 +1173,11 @@
     "src/f32-dwconv/gen/up8x3-minmax-wasmsimd-arm.c",
     "src/f32-dwconv/gen/up8x3-minmax-wasmsimd-x86-acc2.c",
     "src/f32-dwconv/gen/up8x3-minmax-wasmsimd-x86.c",
+    "src/f32-dwconv/gen/up8x3-wasmsimd.c",
     "src/f32-dwconv/gen/up8x4-minmax-wasmsimd-arm-acc2.c",
     "src/f32-dwconv/gen/up8x4-minmax-wasmsimd-arm.c",
     "src/f32-dwconv/gen/up8x4-minmax-wasmsimd-x86-acc2.c",
     "src/f32-dwconv/gen/up8x4-minmax-wasmsimd-x86.c",
-    "src/f32-dwconv/gen/up8x3-wasmsimd.c",
     "src/f32-dwconv/gen/up8x4-wasmsimd.c",
     "src/f32-dwconv/gen/up8x9-minmax-wasmsimd-arm-acc2.c",
     "src/f32-dwconv/gen/up8x9-minmax-wasmsimd-arm.c",
@@ -2100,12 +2100,12 @@
     "src/qs8-dwconv/gen/up16x25-minmax-rndnu-neon-mla8-ld64.c",
     "src/qs8-gavgpool/gen/7p7x-minmax-neon-c8-acc2.c",
     "src/qs8-gavgpool/gen/7x-minmax-neon-c8-acc2.c",
-    "src/qs8-gemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
+    "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c",
-    "src/qs8-gemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
-    "src/qs8-igemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
+    "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+    "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c",
-    "src/qs8-igemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c",
+    "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-vadd/gen/minmax-neon-ld64-x16.c",
     "src/qs8-vadd/gen/minmax-neon-ld64-x32.c",
     "src/qs8-vaddc/gen/minmax-neon-ld64-x16.c",
@@ -2523,8 +2523,8 @@
     "src/qs8-gemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c",
     "src/qs8-gemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-gemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c",
-    "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+    "src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-gemm/gen/1x16-minmax-fp32-neon-mlal-lane.c",
     "src/qs8-gemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c",
     "src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c",
@@ -2546,8 +2546,8 @@
     "src/qs8-gemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c",
     "src/qs8-gemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-gemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c",
-    "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+    "src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-gemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c",
     "src/qs8-gemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c",
     "src/qs8-gemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c",
@@ -2600,8 +2600,8 @@
     "src/qs8-igemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c",
     "src/qs8-igemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-igemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c",
-    "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+    "src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-igemm/gen/1x16-minmax-fp32-neon-mlal-lane.c",
     "src/qs8-igemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c",
     "src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c",
@@ -2623,8 +2623,8 @@
     "src/qs8-igemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c",
     "src/qs8-igemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c",
     "src/qs8-igemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c",
-    "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c",
+    "src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c",
     "src/qs8-igemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c",
     "src/qs8-igemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c",
     "src/qs8-igemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c",
diff --git a/CMakeLists.txt b/CMakeLists.txt
index 628e959..7d7ab0e8 100755
--- a/CMakeLists.txt
+++ b/CMakeLists.txt
@@ -261,8 +261,8 @@
   src/f32-conv-hwc/3x3s2p1c3x4-scalar-1x1.c
   src/f32-conv-hwc2chw/3x3s2p1c3x4-scalar-1x1.c
   src/f32-dwconv/gen/up1x3-minmax-scalar-acc2.c
-  src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c
   src/f32-dwconv/gen/up1x3-scalar-acc2.c
+  src/f32-dwconv/gen/up1x4-minmax-scalar-acc2.c
   src/f32-dwconv/gen/up1x4-scalar-acc2.c
   src/f32-dwconv/gen/up1x9-minmax-scalar-acc2.c
   src/f32-dwconv/gen/up1x9-scalar-acc2.c
@@ -1135,12 +1135,12 @@
   src/qs8-dwconv/gen/up16x25-minmax-rndnu-neon-mla8-ld64.c
   src/qs8-gavgpool/gen/7p7x-minmax-neon-c8-acc2.c
   src/qs8-gavgpool/gen/7x-minmax-neon-c8-acc2.c
-  src/qs8-gemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c
+  src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
   src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c
-  src/qs8-gemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c
-  src/qs8-igemm/gen/1x8c2-minmax-rndnu-neon-mlal-padal-dup.c
+  src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
+  src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
   src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane.c
-  src/qs8-igemm/gen/2x8c2-minmax-rndnu-neon-mlal-padal-dup.c
+  src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
   src/qs8-vadd/gen/minmax-neon-ld64-x16.c
   src/qs8-vadd/gen/minmax-neon-ld64-x32.c
   src/qs8-vaddc/gen/minmax-neon-ld64-x16.c
@@ -1557,8 +1557,8 @@
   src/qs8-gemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c
   src/qs8-gemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c
   src/qs8-gemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c
-  src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
+  src/qs8-gemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-gemm/gen/1x16-minmax-fp32-neon-mlal-lane.c
   src/qs8-gemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c
   src/qs8-gemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c
@@ -1580,8 +1580,8 @@
   src/qs8-gemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c
   src/qs8-gemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c
   src/qs8-gemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c
-  src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
+  src/qs8-gemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-gemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c
   src/qs8-gemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c
   src/qs8-gemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c
@@ -1634,8 +1634,8 @@
   src/qs8-igemm/gen/1x8c8-minmax-gemmlowp-neon-mull-padal.c
   src/qs8-igemm/gen/1x8c8-minmax-rndnu-neon-mlal-padal.c
   src/qs8-igemm/gen/1x8c16-minmax-gemmlowp-neon-mlal-padal.c
-  src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mlal-padal.c
+  src/qs8-igemm/gen/1x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-igemm/gen/1x16-minmax-fp32-neon-mlal-lane.c
   src/qs8-igemm/gen/1x16-minmax-gemmlowp-neon-mlal-lane.c
   src/qs8-igemm/gen/1x16-minmax-rndnu-neon-mlal-lane-prfm.c
@@ -1657,8 +1657,8 @@
   src/qs8-igemm/gen/2x8c8-minmax-gemmlowp-neon-mull-padal.c
   src/qs8-igemm/gen/2x8c8-minmax-rndnu-neon-mlal-padal.c
   src/qs8-igemm/gen/2x8c16-minmax-gemmlowp-neon-mlal-padal.c
-  src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mlal-padal.c
+  src/qs8-igemm/gen/2x8s4c2-minmax-rndnu-neon-mull-padal.c
   src/qs8-igemm/gen/2x16-minmax-gemmlowp-neon-mlal-lane.c
   src/qs8-igemm/gen/2x16-minmax-rndnu-neon-mull-addw-dup.c
   src/qs8-igemm/gen/2x16c2-minmax-rndnu-neon-mlal-padal-dup.c
diff --git a/src/init.c b/src/init.c
index a99e318..8fd2538 100644
--- a/src/init.c
+++ b/src/init.c
@@ -148,14 +148,15 @@
         xnn_params.qs8.gemm.nr = 8;
         xnn_params.qs8.gemm.log2_kr = 2;
       } else {
-        xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
-        xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
-        xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
-        xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
+        xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+        xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+        xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
+        xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
         xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
         xnn_params.qs8.gemm.mr = 2;
         xnn_params.qs8.gemm.nr = 8;
         xnn_params.qs8.gemm.log2_kr = 1;
+        xnn_params.qs8.gemm.log2_sr = 2;
       }
 
       xnn_params.qs8.dwconv[0].minmax.unipass = (xnn_dwconv_unipass_ukernel_function) xnn_qs8_dwconv_minmax_rndnu_ukernel_up16x9__neon_mla8_ld64;
@@ -1289,14 +1290,15 @@
           xnn_params.qs8.gemm.nr = 16;
           xnn_params.qs8.gemm.log2_kr = 2;
         } else {
-          xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
-          xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
-          xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
-          xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
+          xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+          xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+          xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
+          xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
           xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
           xnn_params.qs8.gemm.mr = 2;
           xnn_params.qs8.gemm.nr = 8;
           xnn_params.qs8.gemm.log2_kr = 1;
+          xnn_params.qs8.gemm.log2_sr = 2;
         }
       #endif  // XNN_ENABLE_ASSEMBLY
     #else  // !XNN_PLATFORM_IOS && !XNN_PLATFORM_MAC
@@ -1410,10 +1412,10 @@
           xnn_params.qs8.gemm.nr = 16;
           xnn_params.qs8.gemm.log2_kr = 2;
         } else {
-          xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
-          xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8c2__neon_mlal_padal_dup);
-          xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
-          xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8c2__neon_mlal_padal_dup);
+          xnn_params.qs8.gemm.minmax.gemm = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+          xnn_params.qs8.gemm.minmax.igemm = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_2x8s4c2__neon_mlal_padal);
+          xnn_params.qs8.gemm.minmax.gemm1 = xnn_init_hmp_gemm_ukernel((xnn_gemm_ukernel_function) xnn_qs8_gemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
+          xnn_params.qs8.gemm.minmax.igemm1 = xnn_init_hmp_igemm_ukernel((xnn_igemm_ukernel_function) xnn_qs8_igemm_minmax_rndnu_ukernel_1x8s4c2__neon_mlal_padal);
           xnn_params.qs8.gemm.init.qs8 = xnn_init_qs8_conv_minmax_rndnu_neon_params;
           xnn_params.qs8.gemm.mr = 2;
           xnn_params.qs8.gemm.nr = 8;
diff --git a/src/packing.c b/src/packing.c
index 3aefbdd..63619dc 100644
--- a/src/packing.c
+++ b/src/packing.c
@@ -271,7 +271,9 @@
   size_t extra_bytes,
   const struct xnn_qs8_packing_params* params)
 {
-  assert(sr == 1);
+  const size_t skr = sr * kr;
+  const size_t skc = round_down_po2(kc, skr);
+  const size_t sr_mask = (sr - 1) * kr;
   const int32_t izp = (int32_t) params->input_zero_point;
   do {
     for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
@@ -290,7 +292,24 @@
         } while (--n != 0);
       }
       packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
-      for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+
+      for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+        const size_t kr_block_size = min(kc - kr_block_start, kr);
+        for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+          int32_t ksum = 0;
+          for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+            const int8_t kv = k[(nr_block_start + nr_block_offset) * kc + (round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset)];
+            ksum += (int32_t) kv;
+            *((int16_t*) packed_w) = (int16_t) kv;
+            packed_w = (void*) ((uintptr_t) packed_w + sizeof(int16_t));
+          }
+          packed_b[nr_block_offset] -= ksum * izp;
+          packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int16_t));
+        }
+        packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int16_t));
+      }
+
+      for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
         const size_t kr_block_size = min(kc - kr_block_start, kr);
         for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
           int32_t ksum = 0;
@@ -467,7 +486,9 @@
   void* packed_w,
   const struct xnn_qs8_packing_params* params)
 {
-  assert(sr == 1);
+  const size_t skr = sr * kr;
+  const size_t skc = round_down_po2(kc, skr);
+  const size_t sr_mask = (sr - 1) * kr;
   const int32_t izp = (int32_t) params->input_zero_point;
   for (size_t nr_block_start = 0; nr_block_start < nc; nr_block_start += nr) {
     const size_t nr_block_size = min(nc - nr_block_start, nr);
@@ -485,7 +506,24 @@
       } while (--n != 0);
     }
     packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
-    for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+
+    for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+      const size_t kr_block_size = min(kc - kr_block_start, kr);
+      for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+        int32_t ksum = 0;
+        for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+          const int8_t kv = k[(round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset) * nc + (nr_block_start + nr_block_offset)];
+          ksum += (int32_t) kv;
+          *((int8_t*) packed_w) = kv;
+          packed_w = (void*) ((uintptr_t) packed_w + sizeof(int8_t));
+        }
+        packed_b[nr_block_offset] -= ksum * izp;
+        packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int8_t));
+      }
+      packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
+    }
+
+    for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
       const size_t kr_block_size = min(kc - kr_block_start, kr);
       for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
         int32_t ksum = 0;
@@ -1043,6 +1081,7 @@
           packed_w += nr;
           for (size_t ky = oy; ky < kh; ky += sh) {
             for (size_t kx = ox; kx < kw; kx += sw) {
+
               for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
                 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
                   for (size_t kr_block_offset = 0; kr_block_offset < kr; kr_block_offset++) {
@@ -1093,7 +1132,9 @@
   struct subconvolution_params* subconv_params,
   const struct xnn_qs8_packing_params* params)
 {
-  assert(sr == 1);
+  const size_t skr = sr * kr;
+  const size_t skc = round_down_po2(kc, skr);
+  const size_t sr_mask = (sr - 1) * kr;
   const int32_t izp = (int32_t) params->input_zero_point;
   for (size_t i = 0; i < g; i++) {
     for (size_t oy = 0; oy < sh; oy++) {
@@ -1119,7 +1160,25 @@
           packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * sizeof(int32_t));
           for (size_t ky = oy; ky < kh; ky += sh) {
             for (size_t kx = ox; kx < kw; kx += sw) {
-              for (size_t kr_block_start = 0; kr_block_start < kc; kr_block_start += kr) {
+
+              for (size_t kr_block_start = 0; kr_block_start < skc; kr_block_start += kr) {
+                const size_t kr_block_size = min(kc - kr_block_start, kr);
+                for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
+                  int32_t ksum = 0;
+                  for (size_t kr_block_offset = 0; kr_block_offset < kr_block_size; kr_block_offset++) {
+                    const int8_t kv =
+                      k[(((nr_block_start + nr_block_offset) * kh + ky) * kw + kx) * kc + (round_down_po2(kr_block_start, skr) + ((kr_block_start + nr_block_offset * kr) & sr_mask) + kr_block_offset)];
+                    ksum += (int32_t) kv;
+                    *((int8_t*) packed_w) = kv;
+                    packed_w = (void*) ((uintptr_t) packed_w + sizeof(int8_t));
+                  }
+                  packed_b[nr_block_offset] -= ksum * izp;
+                  packed_w = (void*) ((uintptr_t) packed_w + (kr - kr_block_size) * sizeof(int8_t));
+                }
+                packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
+              }
+
+              for (size_t kr_block_start = skc; kr_block_start < kc; kr_block_start += kr) {
                 const size_t kr_block_size = min(kc - kr_block_start, kr);
                 for (size_t nr_block_offset = 0; nr_block_offset < nr_block_size; nr_block_offset++) {
                   int32_t ksum = 0;
@@ -1135,6 +1194,7 @@
                 }
                 packed_w = (void*) ((uintptr_t) packed_w + (nr - nr_block_size) * kr * sizeof(int8_t));
               }
+
             }
           }
         }