Efficient support for any channel_dimension for quantized kernels on ARM32.

PiperOrigin-RevId: 320400753
diff --git a/ruy/create_trmul_params.cc b/ruy/create_trmul_params.cc
index 7410856..5ae910b 100644
--- a/ruy/create_trmul_params.cc
+++ b/ruy/create_trmul_params.cc
@@ -47,9 +47,7 @@
 #endif
 
 #if RUY_PLATFORM_NEON_32
-  if (src[Side::kLhs].data_type == Type::Create<float>()) {
-    return false;
-  }
+  return false;
 #endif
 
   // Ruy's optimized kernels currently only support the channel_dimension==kRow
diff --git a/ruy/kernel_arm32.cc b/ruy/kernel_arm32.cc
index 7a56be5..3136a7c 100644
--- a/ruy/kernel_arm32.cc
+++ b/ruy/kernel_arm32.cc
@@ -894,16 +894,24 @@
         "ldrb r4, [%[params], #" RUY_STR(RUY_OFFSET_FLAGS) "]\n"
         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_BIAS) "]\n"
 
-        // Offset these base pointers as needed given the current row, col.
-        "ldr r8, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
-        "add r5, r1, r8, lsl #2\n"
-
+        // Let r8 be stack offset of the row or column variable, whichever
+        // is the channel index.
+        "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "ite eq\n"
+        "moveq r8, #" RUY_STR(RUY_STACK_OFFSET_ROW) "\n"
+        "movne r8, #" RUY_STR(RUY_STACK_OFFSET_COL) "\n"
+        // Let r8 be the channel index.
+        "ldr r8, [sp, r8]\n"
+        // Compute the bias pointer, by conditionally using the channel index
+        // (r8) as offset into bias buffer (r1).
         "tst r4, #" RUY_STR(RUY_ASM_FLAG_HAS_BIAS) "\n"
         "it ne\n"
-        "movne r1, r5\n"
+        "addne r1, r1, r8, lsl #2\n"
 
-        // Load 4 bias values.
-        "vld1.32 {d24, d25}, [r1]\n"
+        // Load 2 bias values. When the channel dimension is rows, we will load
+        // another 2 bias values just before performing the bias addition below,
+        // as this kernel has a 4x2 rectangular shape.
+        "vld1.32 {d24}, [r1]!\n"
 
         // Now that we know what LHS and RHS data the next iteration of the
         // main loop will need to load, we start loading the first 32 bytes of
@@ -920,12 +928,29 @@
         // https://arxiv.org/pdf/1712.05877.pdf
         "ldr r3, [%[params], #" RUY_STR(RUY_OFFSET_PROD_ZP_DEPTH) "]\n"
         "vdup.32 q9, r3\n"
-        "vadd.i32 q12, q12, q9\n"
+        "vadd.i32 d24, d24, d18\n"
 
         // Perform the bias-addition (per the above, we have just folded into
         // the bias the (depth * lhs_zero_point * rhs_zero_point) term.)
+        // Jump based on channel dimension.
+        "tst r4, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
+        "bne 6f\n"
+        // Case where channels are rows.
+        // Load the remaining 2 bias values, since we're on the width-4 side
+        // of this 4x2 kernel.
+        "vld1.32 {d25}, [r1]\n"
+        "vadd.i32 d25, d25, d19\n"
         "vadd.i32 q14, q14, q12\n"
         "vadd.i32 q15, q15, q12\n"
+        "b 7f\n"
+
+        "6:\n"
+        // Case where channels are columns.
+        "vdup.32 q10, d24[0]\n"
+        "vdup.32 q11, d24[1]\n"
+        "vadd.i32 q14, q14, q10\n"
+        "vadd.i32 q15, q15, q11\n"
+        "7:\n"
 
         // LHS/RHS zero points
         // Has RHS sums
@@ -981,41 +1006,70 @@
         // multiplied by a multiplier that has a fixed-point component and an
         // exponent component.
 
-        //Load the exponent part of the multiplier.
+        // Compute the data pointers for the multiplier data
+        //   r1 = exponent part
+        //   r2 = fixedpoint part
         "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_EXPONENT) "]\n"
+        "ldr r2, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
+        // r6 has flags, r8 has channel index
         "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "ldr r4, [sp, #" RUY_STR(RUY_STACK_OFFSET_ROW) "]\n"
-        "add r5, r1, r4, lsl #2\n"
         "it ne\n"
-        "movne r1, r5\n"
+        "addne r1, r1, r8, lsl #2\n"
+        "it ne\n"
+        "addne r2, r2, r8, lsl #2\n"
 
-        "vld1.32 {q10}, [r1]\n"
+        // Load the first 2 values of multiplier exponent and fixedpoint data
+        // Since this kernel is rectangular 4x2, we will only conditionally load
+        // 2 more values below.
+        "vld1.32 {d20}, [r1]!\n"  // 2 values of multiplier_exponent
+        "vld1.32 {d12}, [r2]!\n"  // 2 values of multiplier_fixedpoint
 
+        "tst r6, #" RUY_STR(RUY_ASM_FLAG_CHANNEL_DIMENSION_IS_COL) "\n"
         RUY_MAKE_ZERO(q8)
-        "vmax.s32 q12, q10, q8\n"
+        "bne 8f\n"
+        // Case where channels are rows.
+        // Load the remaining 2 bias values, since we're on the width-4 side
+        // of this 4x2 kernel.
+        "vld1.32 {d21}, [r1]\n"  // 2 more values of multiplier_exponent
+        "vld1.32 {d13}, [r2]\n"  // 2 more values of multiplier_fixedpoint
+        "vmax.s32 q11, q10, q8\n"
+        "vmin.s32 q10, q10, q8\n"
 
         // Apply the positive exponent part of the multiplier.
-        "vshl.s32 q14, q14, q12\n"
-        "vshl.s32 q15, q15, q12\n"
-
-        "vmin.s32 q12, q10, q8\n"
-
-        // Load fixed point part of the multiplier
-        "ldr r1, [%[params], #" RUY_STR(RUY_OFFSET_MULTIPLIER_FIXEDPOINT) "]\n"
-        // r6 has flags, r4 has row
-        "add r5, r1, r4, lsl #2\n"
-        "tst r6, #" RUY_STR(RUY_ASM_FLAG_HAS_PERCHANNEL) "\n"
-        "it ne\n"
-        "movne r1, r5\n"
-        "vld1.32 {q10}, [r1]\n" // multiplier_fixedpoint
+        "vshl.s32 q14, q14, q11\n"
+        "vshl.s32 q15, q15, q11\n"
 
         // Apply the fixed-point part of the multiplier.
-        "vqrdmulh.s32 q14, q14, q10\n"
-        "vqrdmulh.s32 q15, q15, q10\n"
+        "vqrdmulh.s32 q14, q14, q6\n"
+        "vqrdmulh.s32 q15, q15, q6\n"
 
         // Apply the negative exponent part of the multiplier.
+        "vrshl.s32 q14, q14, q10\n"
+        "vrshl.s32 q15, q15, q10\n"
+        "b 9f\n"
+
+        "8:\n"
+        // Case where channels are columns.
+        "vmax.s32 d22, d20, d16\n"
+        "vmin.s32 d20, d20, d16\n"
+
+        // Apply the positive exponent part of the multiplier.
+        "vdup.32  q12, d22[0]\n"
+        "vdup.32  q13, d22[1]\n"
+        "vshl.s32 q14, q14, q12\n"
+        "vshl.s32 q15, q15, q13\n"
+
+        // Apply the fixed-point part of the multiplier.
+        "vqrdmulh.s32 q14, q14, d12[0]\n"
+        "vqrdmulh.s32 q15, q15, d12[1]\n"
+
+        // Apply the negative exponent part of the multiplier.
+        "vdup.32  q12, d20[0]\n"
+        "vdup.32  q13, d20[1]\n"
         "vrshl.s32 q14, q14, q12\n"
-        "vrshl.s32 q15, q15, q12\n"
+        "vrshl.s32 q15, q15, q13\n"
+
+        "9:\n"
 
         "ldrb r10, [%[params], #" RUY_STR(RUY_OFFSET_DST_TYPE_ID) "]\n"
         "cmp r10, #" RUY_STR(RUY_ASM_TYPE_ID_INT16) "\n"