AARCH32 4x8 LD64 stores simplified
PiperOrigin-RevId: 283910622
diff --git a/src/f32-gemm/4x8-aarch32-neon-ld64.S b/src/f32-gemm/4x8-aarch32-neon-ld64.S
index e2cd979..0ae2d1c 100644
--- a/src/f32-gemm/4x8-aarch32-neon-ld64.S
+++ b/src/f32-gemm/4x8-aarch32-neon-ld64.S
@@ -10,14 +10,14 @@
// void xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64(
// size_t mr, r0
// size_t nc, r1
-// size_t kc, r2 -> r7
+// size_t kc, r2 -> r5
// const uint8_t*restrict a, r3
-// size_t a_stride, sp + 96 -> (r5)
+// size_t a_stride, sp + 96 -> (r7)
// const void*restrict w, sp + 100 -> r9
-// uint8_t*restrict c, sp + 104 -> r14
+// uint8_t*restrict c, sp + 104 -> r11
// size_t cm_stride, sp + 108 -> (r6)
-// size_t cn_stride, sp + 112 -> r5
-// const union xnn_f32_output_params params[restrict static 1]) sp + 116 -> (r5)
+// size_t cn_stride, sp + 112 -> r7
+// const union xnn_f32_output_params params[restrict static 1]) sp + 116 -> (r7)
// inner loop registers
@@ -30,13 +30,12 @@
// B r9 d8, d9, d10, d11
// B d12, d13, d14, d15
-// C0 r14 d16-d17 q8 d18-d19 q9
+// C0 r11 d16-d17 q8 d18-d19 q9
// C1 r4 d20-d21 q10 d22-d23 q11
// C2 r8 d24-d25 q12 d26-d27 q13
// C3 r6 d28-d29 q14 d30-d31 q15
// Clamp (r5) d4 d5 d6 d7
-// Unused r11
BEGIN_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64
.arm
@@ -45,46 +44,42 @@
.fpu neon
#endif
// Push 96 bytes
- PUSH {r4, r5, r6, r7, r8, r9, r10, r14} // 32
+ PUSH {r4, r5, r6, r7, r8, r9, r10, r11} // 32
VPUSH {d8-d15} // +64 = 96
+ LDR r7, [sp, 96] // a_stride
+ LDR r9, [sp, 100] // w
+ LDR r11, [sp, 104] // c
+ LDR r6, [sp, 108] // cm_stride
+ LDR r5, [sp, 116] // clamping_params
+
// Clamp A and C pointers
-
- LDR r5, [sp, 96] // a_stride
- LDR r14, [sp, 104] // c
- LDR r6, [sp, 108] // cm_stride
-
CMP r0, 2 // if mr >= 2
- ADDHS r12, r3, r5 // a1 = a0 + a_stride
- ADDHS r4, r14, r6 // c1 = c0 + cm_stride
+ ADD r12, r3, r7 // a1 = a0 + a_stride
+ ADD r4, r11, r6 // c1 = c0 + cm_stride
MOVLO r12, r3 // a1
- MOVLO r4, r14 // c1
+ MOVLO r4, r11 // c1
// if mr > 2
- ADDHI r10, r12, r5 // a2 = a1 + a_stride
- ADDHI r8, r4, r6 // c2 = c1 + cm_stride
+ ADD r10, r12, r7 // a2 = a1 + a_stride
+ ADD r8, r4, r6 // c2 = c1 + cm_stride
MOVLS r10, r12 // a2
MOVLS r8, r4 // c2
CMP r0, 4 // if mr >=4
- ADDHS r0, r10, r5 // a3 = a2 + a_stride
- ADDHS r6, r8, r6 // c3 = c2 + cm_stride
+ ADD r0, r10, r7 // a3 = a2 + a_stride
+ ADD r6, r8, r6 // c3 = c2 + cm_stride
MOVLO r0, r10 // a3
MOVLO r6, r8 // c3
- // Load params pointer
- LDR r5, [sp, 116] // clamping_params
- LDR r9, [sp, 100] // W
-
// Load clamping_params values
VLD1.32 {d4[]-d5[]}, [r5]!
+ LDR r7, [sp, 112] // cn_stride
VLD1.32 {d6[]-d7[]}, [r5]
- LDR r5, [sp, 112] // cn_stride
-
1:
# Load initial bias from w into accumulators
VLDM r9!, {d16-d19} // Bias
- SUBS r7, r2, 8
+ SUBS r5, r2, 8
VMOV q10, q8
VMOV q11, q9
VMOV q12, q8
@@ -116,15 +111,15 @@
VMLA.F32 q11, q7, d1[1]
VMLA.F32 q12, q6, d2[1]
VMLA.F32 q13, q7, d2[1]
- SUBS r7, r7, 8
+ SUBS r5, r5, 8
VMLA.F32 q14, q6, d3[1]
VMLA.F32 q15, q7, d3[1]
BHS 2b
3:
// Is there a remainder?- 1 floats of A (4 bytes)
- TST r7, 4
- BNE 7f
+ TST r5, 4
+ BNE 8f
4:
// Clamp
@@ -146,26 +141,24 @@
VMAX.F32 q15, q15, q3
// Store full 4 x 8
-
- CMP r1, 8
- BLO 8f
-
- SUBS r1, r1, 8 // Loop counter
- VST1.32 {d16-d19}, [r14], r5
+ SUBS r1, r1, 8
+ BLO 10f
+ VST1.32 {d16-d19}, [r11], r7
SUB r0, r0, r2
- VST1.32 {d20-d23}, [r4], r5
+ VST1.32 {d20-d23}, [r4], r7
SUB r10, r10, r2
- VST1.32 {d24-d27}, [r8], r5
+ VST1.32 {d24-d27}, [r8], r7
SUB r12, r12, r2
- VST1.32 {d28-d31}, [r6], r5
+ VST1.32 {d28-d31}, [r6], r7
SUB r3, r3, r2
- BNE 1b
+ BHI 1b
6:
VPOP {d8-d15}
- POP {r4, r5, r6, r7, r8, r9, r10, pc}
+ POP {r4, r5, r6, r7, r8, r9, r10, r11}
+ BX lr
-7:
+8:
// Remainder- 1 floats of A (4 bytes)
VLDM r3!, {s0} // A0
VLDM r9!, {d8-d11} // B0
@@ -183,48 +176,42 @@
B 4b
// Store odd width
-
-9:
- VST1.32 {d16-d17}, [r14]!
- VST1.32 {d20-d21}, [r4]!
- VST1.32 {d24-d25}, [r8]!
- VST1.32 {d28-d29}, [r6]!
- TST r1, 2
- BNE 11f
-
10:
- VMOV d19, d18
- VMOV d23, d22
- VMOV d27, d26
- VMOV d31, d30
- TST r1, 1
- BEQ 6b
- B 12f
-
-8:
TST r1, 4
- BNE 9b
- VMOV q9, q8
- VMOV q11, q10
- VMOV q13, q12
- VMOV q15, q14
- TST r1, 2
- BEQ 10b
+ BEQ 11f
+ VST1.32 {d16-d17}, [r11]!
+ VMOV q8, q9
+ VST1.32 {d20-d21}, [r4]!
+ VMOV q10, q11
+ VST1.32 {d24-d25}, [r8]!
+ VMOV q12, q13
+ VST1.32 {d28-d29}, [r6]!
+ VMOV q14, q15
11:
- VST1.32 {d18}, [r14]!
- VST1.32 {d22}, [r4]!
- VST1.32 {d26}, [r8]!
- VST1.32 {d30}, [r6]!
- TST r1, 1
- BEQ 6b
+ TST r1, 2
+ BEQ 12f
+ VST1.32 {d16}, [r11]!
+ VMOV d16, d17
+ VST1.32 {d20}, [r4]!
+ VMOV d20, d21
+ VST1.32 {d24}, [r8]!
+ VMOV d24, d25
+ VST1.32 {d28}, [r6]!
+ VMOV d28, d29
12:
- VST1.32 {d19[0]}, [r14]
- VST1.32 {d23[0]}, [r4]
- VST1.32 {d27[0]}, [r8]
- VST1.32 {d31[0]}, [r6]
- B 6b
+ TST r1, 1
+ BEQ 13f
+ VST1.32 {d16[0]}, [r11]
+ VST1.32 {d20[0]}, [r4]
+ VST1.32 {d24[0]}, [r8]
+ VST1.32 {d28[0]}, [r6]
+
+13:
+ VPOP {d8-d15}
+ POP {r4, r5, r6, r7, r8, r9, r10, r11}
+ BX lr
END_FUNCTION xnn_f32_gemm_ukernel_4x8__aarch32_neon_ld64