blob: d6a9fbf8aa6048f0352cc96c0cae4bef6573fd22 [file] [log] [blame]
Frank Barchard671d1b02021-03-10 15:42:08 -08001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <xnnpack/assembly.h>
7
8# void xnn_qs8_igemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64(
9# size_t mr, x0
10# size_t nc, x1
11# size_t kc, x2 / x0
12# size_t ks, x3 / x9
13# const int8_t**restrict a, x4
14# const int8_t* restrict w, x5
15# int8_t* restrict c, x6
16# size_t cm_stride, x7
17# size_t cn_stride, [sp] -> x10
18# size_t a_offset, [sp + 8] -> x11
19# const float* zero, [sp + 16] -> x12
20# const xnn_f32_minmax_params params [sp + 24] -> x8
21
22# d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
23
24# Register usage
25# A0 x20 v0
26# A1 x15 v1
27# A2 x13 v2
28# A3 x21 v3
29# B x5 v4 v5 v6 v7
30# C0 x6 v16 v20 v24 v28
31# C1 x16 v17 v21 v25 v29
32# C2 x17 v18 v22 v26 v30
33# C3 x7 v19 v23 v27 v31
34# unused v8 v9 v10 v11 v12 v13 v14 v15
35
36BEGIN_FUNCTION xnn_qs8_igemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64
37
38 # Clamp C pointers
39 CMP x0, 2 // if mr < 2
40 LDP x10, x11, [sp] // Load cn_stride, a_offset
41 ADD x16, x6, x7 // c1 = c0 + cm_stride
42 CSEL x16, x6, x16, LO // c1 = c0
43 ADD x2, x2, 3 // kc = (kc + 3) & ~3
44
45 ADD x17, x16, x7 // c2 = c1 + cm_stride
46 LDP x12, x8, [sp, 16] // Load zero, params pointer
47 // if mr <= 2
48 CSEL x17, x16, x17, LS // c2 = c1
49 BIC x2, x2, 3
50
51 CMP x0, 4 // if mr < 4
52 STP x20, x21, [sp, -16]! // Save x20-x21 on stack
53 ADD x7, x17, x7 // c3 = c2 + cm_stride
54 CSEL x7, x17, x7, LO // c3 = c2
55
56 .p2align 3
570:
58 # Load initial bias from w into accumulators
59 LDP q16, q20, [x5], 32
60 MOV v17.16b, v16.16b
61 MOV v18.16b, v16.16b
62 LDP q24, q28, [x5], 32
63 MOV v19.16b, v16.16b
64 MOV v21.16b, v20.16b
65 MOV v22.16b, v20.16b
66 MOV v23.16b, v20.16b
67 MOV v25.16b, v24.16b
68 MOV v26.16b, v24.16b
69 MOV v27.16b, v24.16b
70 MOV v29.16b, v28.16b
71 MOV v30.16b, v28.16b
72 MOV v31.16b, v28.16b
73 MOV x9, x3 // p = ks
74
75 .p2align 3
761:
77 # Load next 4 A pointers
78 LDP x20, x15, [x4], 16
79 LDP x13, x21, [x4], 16
80
81 CMP x20, x12 // if a0 == zero
82 ADD x20, x20, x11 // a0 += a_offset
83 CSEL x20, x12, x20, EQ // a0 = zero, else += a0 + a_offset
84 CMP x15, x12 // if a1 == zero
85 ADD x15, x15, x11 // a1 += a_offset
86 CSEL x15, x12, x15, EQ // a1 = zero, else += a1 + a_offset
87 CMP x13, x12 // if a2 == zero
88 ADD x13, x13, x11 // a2 += a_offset
89 CSEL x13, x12, x13, EQ // a2 = zero, else += a2 + a_offset
90 CMP x21, x12 // if a3 == zero
91 ADD x21, x21, x11 // a3 += a_offset
92 CSEL x21, x12, x21, EQ // a3 = zero, else += a3 + a_offset
93
94 # Is there at least 8 bytes for main loop?
95 SUBS x0, x2, 8 // k = kc - 8
96 B.LO 4f
97
98 # Main loop - 8 bytes of A
99 .p2align 3
1002:
101 LDR d0, [x20], 8
102 LDR q4, [x5], 16
103 LDR d1, [x15], 8
104 LDR d2, [x13], 8
105 LDR d3, [x21], 8
106 LDR q5, [x5], 16
107 SDOT v16.4s, v4.16b, v0.4b[0]
108 SDOT v17.4s, v4.16b, v1.4b[0]
109 LDP q6, q7, [x5], 32
110 SDOT v18.4s, v4.16b, v2.4b[0]
111 SDOT v19.4s, v4.16b, v3.4b[0]
112 SDOT v20.4s, v5.16b, v0.4b[0]
113 SDOT v21.4s, v5.16b, v1.4b[0]
114 SDOT v22.4s, v5.16b, v2.4b[0]
115 SDOT v23.4s, v5.16b, v3.4b[0]
116 SDOT v24.4s, v6.16b, v0.4b[0]
117 SDOT v25.4s, v6.16b, v1.4b[0]
118 LDP q4, q5, [x5], 32
119 SDOT v26.4s, v6.16b, v2.4b[0]
120 SDOT v27.4s, v6.16b, v3.4b[0]
121 SDOT v28.4s, v7.16b, v0.4b[0]
122 SDOT v29.4s, v7.16b, v1.4b[0]
123 SDOT v30.4s, v7.16b, v2.4b[0]
124 SDOT v31.4s, v7.16b, v3.4b[0]
125 SDOT v16.4s, v4.16b, v0.4b[1]
126 SDOT v17.4s, v4.16b, v1.4b[1]
127 LDP q6, q7, [x5], 32
128 SDOT v18.4s, v4.16b, v2.4b[1]
129 SDOT v19.4s, v4.16b, v3.4b[1]
130 SDOT v20.4s, v5.16b, v0.4b[1]
131 SDOT v21.4s, v5.16b, v1.4b[1]
132 SDOT v22.4s, v5.16b, v2.4b[1]
133 SDOT v23.4s, v5.16b, v3.4b[1]
134 SDOT v24.4s, v6.16b, v0.4b[1]
135 SDOT v25.4s, v6.16b, v1.4b[1]
136 SDOT v26.4s, v6.16b, v2.4b[1]
137 SDOT v27.4s, v6.16b, v3.4b[1]
138 SDOT v28.4s, v7.16b, v0.4b[1]
139 SDOT v29.4s, v7.16b, v1.4b[1]
140 SDOT v30.4s, v7.16b, v2.4b[1]
141 SUBS x0, x0, 8
142 SDOT v31.4s, v7.16b, v3.4b[1]
143 B.HS 2b
144
145 # Is there a remainder?- 4 bytes of A
146 TBNZ x0, 2, 4f
147
1483:
149 # ks loop
150 SUBS x9, x9, 32 // ks -= MR * sizeof(int8_t*)
151 B.HI 1b
152
153 # Apply params - scale, shift, bias and clamp
154 LD2R {v0.4s, v1.4s}, [x8], 8
155 CMEQ v2.4s, v1.4s, 0
156
157 BIC v4.16b, v16.16b, v2.16b
158 BIC v5.16b, v17.16b, v2.16b
159 BIC v6.16b, v18.16b, v2.16b
160 BIC v7.16b, v19.16b, v2.16b
161
162 SQRDMULH v16.4s, v16.4s, v0.4s
163 SQRDMULH v17.4s, v17.4s, v0.4s
164 SQRDMULH v18.4s, v18.4s, v0.4s
165 SQRDMULH v19.4s, v19.4s, v0.4s
166
167 SSRA v16.4s, v4.4s, 31 // signed shift right accumulate
168 SSRA v17.4s, v5.4s, 31
169 SSRA v18.4s, v6.4s, 31
170 SSRA v19.4s, v7.4s, 31
171
172 BIC v4.16b, v20.16b, v2.16b
173 BIC v5.16b, v21.16b, v2.16b
174 BIC v6.16b, v22.16b, v2.16b
175 BIC v7.16b, v23.16b, v2.16b
176
177 SQRDMULH v20.4s, v20.4s, v0.4s
178 SQRDMULH v21.4s, v21.4s, v0.4s
179 SQRDMULH v22.4s, v22.4s, v0.4s
180 SQRDMULH v23.4s, v23.4s, v0.4s
181
182 SSRA v20.4s, v4.4s, 31
183 SSRA v21.4s, v5.4s, 31
184 SSRA v22.4s, v6.4s, 31
185 SSRA v23.4s, v7.4s, 31
186
187 BIC v4.16b, v24.16b, v2.16b
188 BIC v5.16b, v25.16b, v2.16b
189 BIC v6.16b, v26.16b, v2.16b
190 BIC v7.16b, v27.16b, v2.16b
191
192 SQRDMULH v24.4s, v24.4s, v0.4s
193 SQRDMULH v25.4s, v25.4s, v0.4s
194 SQRDMULH v26.4s, v26.4s, v0.4s
195 SQRDMULH v27.4s, v27.4s, v0.4s
196
197 SSRA v24.4s, v4.4s, 31
198 SSRA v25.4s, v5.4s, 31
199 SSRA v26.4s, v6.4s, 31
200 SSRA v27.4s, v7.4s, 31
201
202 BIC v4.16b, v28.16b, v2.16b
203 BIC v5.16b, v29.16b, v2.16b
204 BIC v6.16b, v30.16b, v2.16b
205 BIC v7.16b, v31.16b, v2.16b
206
207 SQRDMULH v28.4s, v28.4s, v0.4s
208 SQRDMULH v29.4s, v29.4s, v0.4s
209 SQRDMULH v30.4s, v30.4s, v0.4s
210 SQRDMULH v31.4s, v31.4s, v0.4s
211
212 SSRA v28.4s, v4.4s, 31
213 SSRA v29.4s, v5.4s, 31
214 SSRA v30.4s, v6.4s, 31
215 SSRA v31.4s, v7.4s, 31
216
217 SRSHL v16.4s, v16.4s, v1.4s // signed rounding shift left
218 SRSHL v17.4s, v17.4s, v1.4s
219 SRSHL v18.4s, v18.4s, v1.4s
220 SRSHL v19.4s, v19.4s, v1.4s
221 SRSHL v20.4s, v20.4s, v1.4s
222 SRSHL v21.4s, v21.4s, v1.4s
223 SRSHL v22.4s, v22.4s, v1.4s
224 SRSHL v23.4s, v23.4s, v1.4s
225 SRSHL v24.4s, v24.4s, v1.4s
226 SRSHL v25.4s, v25.4s, v1.4s
227 SRSHL v26.4s, v26.4s, v1.4s
228 SRSHL v27.4s, v27.4s, v1.4s
229 SRSHL v28.4s, v28.4s, v1.4s
230 SRSHL v29.4s, v29.4s, v1.4s
231 SRSHL v30.4s, v30.4s, v1.4s
232 SRSHL v31.4s, v31.4s, v1.4s
233
234 SQXTN v16.4h, v16.4s
235 SQXTN v17.4h, v17.4s
236 SQXTN v18.4h, v18.4s
237 SQXTN v19.4h, v19.4s
238 SQXTN v24.4h, v24.4s
239 SQXTN v25.4h, v25.4s
240 SQXTN v26.4h, v26.4s
241 SQXTN v27.4h, v27.4s
242 LD1R {v2.8h}, [x8], 2 // add bias
243
244 SQXTN2 v16.8h, v20.4s
245 SQXTN2 v17.8h, v21.4s
246 SQXTN2 v18.8h, v22.4s
247 SQXTN2 v19.8h, v23.4s
248 SQXTN2 v24.8h, v28.4s
249 SQXTN2 v25.8h, v29.4s
250 SQXTN2 v26.8h, v30.4s
251 SQXTN2 v27.8h, v31.4s
252
253 SQADD v16.8h, v16.8h, v2.8h
254 SQADD v17.8h, v17.8h, v2.8h
255 SQADD v18.8h, v18.8h, v2.8h
256 SQADD v19.8h, v19.8h, v2.8h
257 SQADD v24.8h, v24.8h, v2.8h
258 SQADD v25.8h, v25.8h, v2.8h
259 SQADD v26.8h, v26.8h, v2.8h
260 SQADD v27.8h, v27.8h, v2.8h
261 LD1R {v0.16b}, [x8], 1 // clamp min value
262
263 SQXTN v4.8b, v16.8h
264 SQXTN v5.8b, v17.8h
265 SQXTN v6.8b, v18.8h
266 SQXTN v7.8b, v19.8h
267 LD1R {v1.16b}, [x8] // clamp max value
268 SQXTN2 v4.16b, v24.8h
269 SQXTN2 v5.16b, v25.8h
270 SQXTN2 v6.16b, v26.8h
271 SQXTN2 v7.16b, v27.8h
272 SUB x8, x8, 11 // rewind params pointer
273
274 SMAX v4.16b, v4.16b, v0.16b
275 SMAX v5.16b, v5.16b, v0.16b
276 SMAX v6.16b, v6.16b, v0.16b
277 SMAX v7.16b, v7.16b, v0.16b
278 SUBS x1, x1, 16
279 SMIN v4.16b, v4.16b, v1.16b
280 SMIN v5.16b, v5.16b, v1.16b
281 SMIN v6.16b, v6.16b, v1.16b
282 SMIN v7.16b, v7.16b, v1.16b
283 B.LO 5f
284
285 # Store full 4 x 16
286 ST1 {v7.16b}, [x7], x10
287 ST1 {v6.16b}, [x17], x10
288 ST1 {v5.16b}, [x16], x10
289 ST1 {v4.16b}, [x6], x10
290
291 SUB x4, x4, x3 // a -= ks
292
293 # nc loop
294 B.HI 0b
295
296 # Restore x20-x21 from stack
297 LDP x20, x21, [sp], 16
298 RET
299
300 # Remainder- 4 bytes of A
301 .p2align 3
3024:
303 LDR s0, [x20], 4
304 LDR q4, [x5], 16
305 LDR s1, [x15], 4
306 LDR s2, [x13], 4
307 LDR s3, [x21], 4
308 LDR q5, [x5], 16
309 SDOT v16.4s, v4.16b, v0.4b[0]
310 SDOT v17.4s, v4.16b, v1.4b[0]
311 LDP q6, q7, [x5], 32
312 SDOT v18.4s, v4.16b, v2.4b[0]
313 SDOT v19.4s, v4.16b, v3.4b[0]
314 SDOT v20.4s, v5.16b, v0.4b[0]
315 SDOT v21.4s, v5.16b, v1.4b[0]
316 SDOT v22.4s, v5.16b, v2.4b[0]
317 SDOT v23.4s, v5.16b, v3.4b[0]
318 SDOT v24.4s, v6.16b, v0.4b[0]
319 SDOT v25.4s, v6.16b, v1.4b[0]
320 SDOT v26.4s, v6.16b, v2.4b[0]
321 SDOT v27.4s, v6.16b, v3.4b[0]
322 SDOT v28.4s, v7.16b, v0.4b[0]
323 SDOT v29.4s, v7.16b, v1.4b[0]
324 SDOT v30.4s, v7.16b, v2.4b[0]
325 SDOT v31.4s, v7.16b, v3.4b[0]
326 B 3b
327
328 # Store odd width
329 .p2align 3
3305:
331 TBZ x1, 3, 6f
332 STR d7, [x7], 8
333 DUP d7, v7.d[1]
334 STR d6, [x17], 8
335 DUP d6, v6.d[1]
336 STR d5, [x16], 8
337 DUP d5, v5.d[1]
338 STR d4, [x6], 8
339 DUP d4, v4.d[1]
3406:
341 TBZ x1, 2, 7f
342 STR s7, [x7], 4
343 DUP s7, v7.s[1]
344 STR s6, [x17], 4
345 DUP s6, v6.s[1]
346 STR s5, [x16], 4
347 DUP s5, v5.s[1]
348 STR s4, [x6], 4
349 DUP s4, v4.s[1]
3507:
351 TBZ x1, 1, 8f
352 ST1 {v7.h}[0], [x7], 2
353 DUP h7, v7.h[1]
354 ST1 {v6.h}[0], [x17], 2
355 DUP h6, v6.h[1]
356 ST1 {v5.h}[0], [x16], 2
357 DUP h5, v5.h[1]
358 ST1 {v4.h}[0], [x6], 2
359 DUP h4, v4.h[1]
3608:
361 TBZ x1, 0, 9f
362 ST1 {v7.b}[0], [x7]
363 ST1 {v6.b}[0], [x17]
364 ST1 {v5.b}[0], [x16]
365 ST1 {v4.b}[0], [x6]
3669:
367 # Restore x20-x21 from stack
368 LDP x20, x21, [sp], 16
369 RET
370
371END_FUNCTION xnn_qs8_igemm_minmax_ukernel_4x16c4__aarch64_neondot_ld64
372
373#ifdef __ELF__
374.section ".note.GNU-stack","",%progbits
375#endif