blob: f28ec100f3b9cfee6c3ea05894c285af6bc51c28 [file] [log] [blame]
Zhi An Ngc2e2da82022-01-25 16:51:58 -08001// Copyright 2022 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
Zhi An Ng5ebe6862022-02-02 09:54:27 -08006#include <cassert>
7#include <cstddef>
8
Zhi An Ngc2e2da82022-01-25 16:51:58 -08009#include <xnnpack/aarch64-assembler.h>
10#include <xnnpack/allocator.h>
11#include <xnnpack/gemm.h>
Zhi An Ng5ebe6862022-02-02 09:54:27 -080012#include <xnnpack/params.h>
Zhi An Ngc2e2da82022-01-25 16:51:58 -080013
14namespace xnnpack {
15namespace aarch64 {
16namespace {
17class Generator : public Assembler {
18 using Assembler::Assembler;
19 public:
Zhi An Ng5ebe6862022-02-02 09:54:27 -080020 void generate(bool prefetch, size_t nc, size_t kc, float min, float max);
Zhi An Ngc2e2da82022-01-25 16:51:58 -080021};
22
23// void xnn_f32_gemm_minmax_ukernel_6x8__aarch64_neonfma_prfm_cortex_a75(
24// size_t mr, x0
25// size_t nc, x1
26// size_t kc, x2 / x0
27// const uint8_t*restrict a, x3
28// size_t a_stride, x4
29// const void*restrict w, x5
30// uint8_t*restrict c, x6
31// size_t cm_stride, x7
32// size_t cn_stride, [sp] -> (x0)
33// const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) [sp + 8] -> x8
34
35// d8-d15, x19-x30 need to be preserved if used. x18 is reserved by the OS.
36
37// A pointers
38// x3 a0
39// x9 a1
40// x10 a2
41// x11 a3
42// x12 a4
43// x4 a5
44
45// C pointers
46// x6 c0
47// x16 c1
48// x17 c2
49// x14 c3
50// x13 c4
51// x7 c5
52
53// Vector register usage
54// A0 v0 v6
55// A1 v1 v7
56// A2 v2 v8
57// A3 v3 v9
58// A4 v4 v10
59// A5 v5 v11
60// B v12 v13 v14 v15
61// B v16 v17 v18 v19
62// C v20 v21
63// C v22 v23
64// C v24 v25
65// C v26 v27
66// C v28 v29
67// C v30 v31
68// Clamp v6 v7
69
70// Converted from: src/f32-gemm/gen/6x8-minmax-aarch64-neonfma-prfm-cortex-a75.S
Zhi An Ng5ebe6862022-02-02 09:54:27 -080071void Generator::generate(bool prefetch, size_t nc, size_t kc, float min, float max) {
Zhi An Ngc2e2da82022-01-25 16:51:58 -080072 Label l0, l1, l2, l3, l4, l5, l6, l7, l8, l9, l10;
Zhi An Ngc92034d2022-02-03 16:14:50 -080073 const bool clamp_min = min != -std::numeric_limits<float>::infinity();
74 const bool clamp_max = max != +std::numeric_limits<float>::infinity();
Zhi An Ngc2e2da82022-01-25 16:51:58 -080075
76 // Load params pointer
77 ldr(x8, mem[sp, 8]);
78
79 // Clamp A and C pointers / Save d8-d15 on stack
80 stp(d8, d9, mem[sp, -64]++);
81 cmp(x0, 2); // if mr < 2
82 add(x9, x3, x4); // a1 = a0 + a_stride
83 add(x16, x6, x7); // c1 = c0 + cm_stride
84 csel(x9, x3, x9, kLO); // a1 = a0
85 csel(x16, x6, x16, kLO); // c1 = c0
86
87 stp(d10, d11, mem[sp, 16]);
88 add(x10, x9, x4); // a2 = a1 + a_stride
89 add(x17, x16, x7); // c2 = c1 + cm_stride
90 // if mr <= 2
91 csel(x10, x9, x10, kLS); // a2 = a1
92 csel(x17, x16, x17, kLS); // c2 = c1
93
94 stp(d12, d13, mem[sp, 32]);
95 cmp(x0, 4); // if mr < 4
96 add(x11, x10, x4); // a3 = a2 + a_stride
97 add(x14, x17, x7); // c3 = c2 + cm_stride
98 csel(x11, x10, x11, kLO); // a3 = a2
99 csel(x14, x17, x14, kLO); // c3 = c2
100
101 stp(d14, d15, mem[sp, 48]);
102 add(x12, x11, x4); // a4 = a3 + a_stride
103 add(x13, x14, x7); // c4 = c3 + cm_stride
104 // if mr <= 4
105 csel(x12, x11, x12, kLS); // a4 = a3
106 csel(x13, x14, x13, kLS); // c4 = c3
107
108 cmp(x0, 6); // if mr < 6
109 add(x4, x12, x4); // a5 = a4 + a_stride
110 add(x7, x13, x7); // c5 = c4 + cm_stride
111 csel(x4, x12, x4, kLO); // a5 = a4
112 csel(x7, x13, x7, kLO); // c5 = c4
113
114 bind(l0);
115 // Load initial bias from w into accumulators
116 ldp(q20, q21, mem[x5], 32);
117 mov(v22.v16b(), v20.v16b());
118 if (prefetch) {
119 prfm(kPLDL1KEEP, mem[x5, 0]); // Prefetch B
120 }
121 mov(v23.v16b(), v21.v16b());
122 if (prefetch) {
123 prfm(kPLDL1KEEP, mem[x5, 64]);
124 }
125 mov(v24.v16b(), v20.v16b());
126 if (prefetch) {
127 prfm(kPLDL1KEEP, mem[x5, 128]);
128 }
129 mov(v25.v16b(), v21.v16b());
130 if (prefetch) {
131 prfm(kPLDL1KEEP, mem[x5, 192]);
132 }
133 mov(v26.v16b(), v20.v16b());
134 if (prefetch) {
135 prfm(kPLDL1KEEP, mem[x3]); // Prefetch A
136 }
137 mov(v27.v16b(), v21.v16b());
138 if (prefetch) {
139 prfm(kPLDL1KEEP, mem[x9]);
140 }
141 mov(v28.v16b(), v20.v16b());
142 if (prefetch) {
143 prfm(kPLDL1KEEP, mem[x10]);
144 }
145 mov(v29.v16b(), v21.v16b());
146 if (prefetch) {
147 prfm(kPLDL1KEEP, mem[x11]);
148 }
149 mov(v30.v16b(), v20.v16b());
150 if (prefetch) {
151 prfm(kPLDL1KEEP, mem[x12]);
152 }
153 mov(v31.v16b(), v21.v16b());
154 if (prefetch) {
155 prfm(kPLDL1KEEP, mem[x4]);
156 }
157
158 // Is there at least 8 floats (32 bytes) for prologue + epilogue?
159 subs(x0, x2, 32); // k = kc - 32
160 b_lo(l4);
161
162 // Prologue - loads for main loop of 96 FMA
163 ldr(q0, mem[x3], 16);
164 ldr(q1, mem[x9], 16);
165 ldr(q2, mem[x10], 16);
166 ldr(q3, mem[x11], 16);
167 ldr(q4, mem[x12], 16);
168 ldr(q5, mem[x4], 16);
169 ldp(q12, q13, mem[x5], 32); // Fetch 3 B (4th deferred)
170 ldp(q14, q15, mem[x5], 32);
171 ldp(q16, q17, mem[x5], 32);
172
173 // Is there at least 8 floats (32 bytes) for main loop?
174 subs(x0, x0, 32);
175 b_lo(l2);
176
177 // Main loop - 8 floats of A (32 bytes)
178 // 96 FMA + 6 LDP A + 8 LDP B
179 bind(l1);
180 // First group of 4 A. 48 FMA.
181 fmla(v20.v4s(), v12.v4s(), v0.s()[0]);
182 ldp(q18, q19, mem[x5], 32); // Load last B
183 fmla(v22.v4s(), v12.v4s(), v1.s()[0]);
184 fmla(v24.v4s(), v12.v4s(), v2.s()[0]);
185 fmla(v26.v4s(), v12.v4s(), v3.s()[0]);
186 fmla(v28.v4s(), v12.v4s(), v4.s()[0]);
187 fmla(v30.v4s(), v12.v4s(), v5.s()[0]);
188 fmla(v21.v4s(), v13.v4s(), v0.s()[0]);
189 fmla(v23.v4s(), v13.v4s(), v1.s()[0]);
190 fmla(v25.v4s(), v13.v4s(), v2.s()[0]);
191 fmla(v27.v4s(), v13.v4s(), v3.s()[0]);
192 fmla(v29.v4s(), v13.v4s(), v4.s()[0]);
193
194 fmla(v31.v4s(), v13.v4s(), v5.s()[0]);
195 fmla(v20.v4s(), v14.v4s(), v0.s()[1]);
196 if (prefetch) {
197 prfm(kPLDL1KEEP, mem[x5, 128]); // Prefetch B
198 }
199 fmla(v22.v4s(), v14.v4s(), v1.s()[1]);
200 fmla(v24.v4s(), v14.v4s(), v2.s()[1]);
201 fmla(v26.v4s(), v14.v4s(), v3.s()[1]);
202 fmla(v28.v4s(), v14.v4s(), v4.s()[1]);
203 if (prefetch) {
204 prfm(kPLDL1KEEP, mem[x5, 256]);
205 }
206 fmla(v30.v4s(), v14.v4s(), v5.s()[1]);
207 fmla(v21.v4s(), v15.v4s(), v0.s()[1]);
208 fmla(v23.v4s(), v15.v4s(), v1.s()[1]);
209 fmla(v25.v4s(), v15.v4s(), v2.s()[1]);
210 ldr(q6, mem[x3], 16); // Load next 6 A
211 fmla(v27.v4s(), v15.v4s(), v3.s()[1]);
212 fmla(v29.v4s(), v15.v4s(), v4.s()[1]);
213 fmla(v31.v4s(), v15.v4s(), v5.s()[1]);
214 ldr(q7, mem[x9], 16);
215
216 fmla(v20.v4s(), v16.v4s(), v0.s()[2]);
217 fmla(v22.v4s(), v16.v4s(), v1.s()[2]);
218 fmla(v24.v4s(), v16.v4s(), v2.s()[2]);
219 ldr(q8, mem[x10], 16);
220 fmla(v26.v4s(), v16.v4s(), v3.s()[2]);
221 fmla(v28.v4s(), v16.v4s(), v4.s()[2]);
222 fmla(v30.v4s(), v16.v4s(), v5.s()[2]);
223 ldr(q9, mem[x11], 16);
224 fmla(v21.v4s(), v17.v4s(), v0.s()[2]);
225 fmla(v23.v4s(), v17.v4s(), v1.s()[2]);
226 fmla(v25.v4s(), v17.v4s(), v2.s()[2]);
227 ldr(q10, mem[x12], 16);
228 fmla(v27.v4s(), v17.v4s(), v3.s()[2]);
229 fmla(v29.v4s(), v17.v4s(), v4.s()[2]);
230 fmla(v31.v4s(), v17.v4s(), v5.s()[2]);
231 ldr(q11, mem[x4], 16);
232
233 fmla(v20.v4s(), v18.v4s(), v0.s()[3]);
234 fmla(v22.v4s(), v18.v4s(), v1.s()[3]);
235 fmla(v24.v4s(), v18.v4s(), v2.s()[3]);
236 ldp(q12, q13, mem[x5], 32); // Load 4 B
237 fmla(v26.v4s(), v18.v4s(), v3.s()[3]);
238 fmla(v28.v4s(), v18.v4s(), v4.s()[3]);
239 fmla(v30.v4s(), v18.v4s(), v5.s()[3]);
240 ldp(q14, q15, mem[x5], 32);
241 fmla(v21.v4s(), v19.v4s(), v0.s()[3]);
242 fmla(v23.v4s(), v19.v4s(), v1.s()[3]);
243 fmla(v25.v4s(), v19.v4s(), v2.s()[3]);
244 ldp(q16, q17, mem[x5], 32);
245 fmla(v27.v4s(), v19.v4s(), v3.s()[3]);
246 fmla(v29.v4s(), v19.v4s(), v4.s()[3]);
247 fmla(v31.v4s(), v19.v4s(), v5.s()[3]);
248 ldp(q18, q19, mem[x5], 32);
249
250 // Second group of 4 A. 48 FMA.
251 fmla(v20.v4s(), v12.v4s(), v6.s()[0]);
252 fmla(v22.v4s(), v12.v4s(), v7.s()[0]);
253 fmla(v24.v4s(), v12.v4s(), v8.s()[0]);
254 ldr(q0, mem[x3], 16); // Load next 6 A
255 fmla(v26.v4s(), v12.v4s(), v9.s()[0]);
256 fmla(v28.v4s(), v12.v4s(), v10.s()[0]);
257 fmla(v30.v4s(), v12.v4s(), v11.s()[0]);
258 ldr(q1, mem[x9], 16);
259 fmla(v21.v4s(), v13.v4s(), v6.s()[0]);
260 fmla(v23.v4s(), v13.v4s(), v7.s()[0]);
261 fmla(v25.v4s(), v13.v4s(), v8.s()[0]);
262 ldr(q2, mem[x10], 16);
263 fmla(v27.v4s(), v13.v4s(), v9.s()[0]);
264 fmla(v29.v4s(), v13.v4s(), v10.s()[0]);
265 fmla(v31.v4s(), v13.v4s(), v11.s()[0]);
266 ldr(q3, mem[x11], 16);
267
268 fmla(v20.v4s(), v14.v4s(), v6.s()[1]);
269 fmla(v22.v4s(), v14.v4s(), v7.s()[1]);
270 fmla(v24.v4s(), v14.v4s(), v8.s()[1]);
271 ldr(q4, mem[x12], 16);
272 fmla(v26.v4s(), v14.v4s(), v9.s()[1]);
273 fmla(v28.v4s(), v14.v4s(), v10.s()[1]);
274 fmla(v30.v4s(), v14.v4s(), v11.s()[1]);
275 ldr(q5, mem[x4], 16);
276 fmla(v21.v4s(), v15.v4s(), v6.s()[1]);
277 fmla(v23.v4s(), v15.v4s(), v7.s()[1]);
278 fmla(v25.v4s(), v15.v4s(), v8.s()[1]);
279 ldp(q12, q13, mem[x5], 32); // Load next 3 B (not last)
280 fmla(v27.v4s(), v15.v4s(), v9.s()[1]);
281 fmla(v29.v4s(), v15.v4s(), v10.s()[1]);
282 fmla(v31.v4s(), v15.v4s(), v11.s()[1]);
283 ldp(q14, q15, mem[x5], 32);
284
285 fmla(v20.v4s(), v16.v4s(), v6.s()[2]);
286 fmla(v22.v4s(), v16.v4s(), v7.s()[2]);
287 fmla(v24.v4s(), v16.v4s(), v8.s()[2]);
288 fmla(v26.v4s(), v16.v4s(), v9.s()[2]);
289 fmla(v28.v4s(), v16.v4s(), v10.s()[2]);
290 fmla(v30.v4s(), v16.v4s(), v11.s()[2]);
291 fmla(v21.v4s(), v17.v4s(), v6.s()[2]);
292 fmla(v23.v4s(), v17.v4s(), v7.s()[2]);
293 fmla(v25.v4s(), v17.v4s(), v8.s()[2]);
294 fmla(v27.v4s(), v17.v4s(), v9.s()[2]);
295 fmla(v29.v4s(), v17.v4s(), v10.s()[2]);
296 fmla(v31.v4s(), v17.v4s(), v11.s()[2]);
297 ldp(q16, q17, mem[x5], 32);
298
299 fmla(v20.v4s(), v18.v4s(), v6.s()[3]);
300 fmla(v22.v4s(), v18.v4s(), v7.s()[3]);
301 subs(x0, x0, 32);
302 fmla(v24.v4s(), v18.v4s(), v8.s()[3]);
303 fmla(v26.v4s(), v18.v4s(), v9.s()[3]);
304 fmla(v28.v4s(), v18.v4s(), v10.s()[3]);
305 fmla(v30.v4s(), v18.v4s(), v11.s()[3]);
306 fmla(v21.v4s(), v19.v4s(), v6.s()[3]);
307 fmla(v23.v4s(), v19.v4s(), v7.s()[3]);
308 fmla(v25.v4s(), v19.v4s(), v8.s()[3]);
309 fmla(v27.v4s(), v19.v4s(), v9.s()[3]);
310 fmla(v29.v4s(), v19.v4s(), v10.s()[3]);
311 fmla(v31.v4s(), v19.v4s(), v11.s()[3]);
312 b_hs(l1);
313
314 // Epilogue - 8 floats of A (32 bytes)
315 // 96 FMA + 6 LDP A + 8 LDP B
316 // First block same as main loop. Second block has no preloads.
317 bind(l2);
318 // First group of 4 A. 48 FMA.
319 fmla(v20.v4s(), v12.v4s(), v0.s()[0]);
320 ldp(q18, q19, mem[x5], 32); // Load last B
321 fmla(v22.v4s(), v12.v4s(), v1.s()[0]);
322 fmla(v24.v4s(), v12.v4s(), v2.s()[0]);
323 fmla(v26.v4s(), v12.v4s(), v3.s()[0]);
324 fmla(v28.v4s(), v12.v4s(), v4.s()[0]);
325 fmla(v30.v4s(), v12.v4s(), v5.s()[0]);
326 fmla(v21.v4s(), v13.v4s(), v0.s()[0]);
327 fmla(v23.v4s(), v13.v4s(), v1.s()[0]);
328 fmla(v25.v4s(), v13.v4s(), v2.s()[0]);
329 fmla(v27.v4s(), v13.v4s(), v3.s()[0]);
330 fmla(v29.v4s(), v13.v4s(), v4.s()[0]);
331
332 fmla(v31.v4s(), v13.v4s(), v5.s()[0]);
333 fmla(v20.v4s(), v14.v4s(), v0.s()[1]);
334 if (prefetch) {
335 prfm(kPLDL1KEEP, mem[x5, 128]); // Prefetch B
336 }
337 fmla(v22.v4s(), v14.v4s(), v1.s()[1]);
338 fmla(v24.v4s(), v14.v4s(), v2.s()[1]);
339 fmla(v26.v4s(), v14.v4s(), v3.s()[1]);
340 fmla(v28.v4s(), v14.v4s(), v4.s()[1]);
341 if (prefetch) {
342 prfm(kPLDL1KEEP, mem[x5, 256]);
343 }
344 fmla(v30.v4s(), v14.v4s(), v5.s()[1]);
345 fmla(v21.v4s(), v15.v4s(), v0.s()[1]);
346 fmla(v23.v4s(), v15.v4s(), v1.s()[1]);
347 fmla(v25.v4s(), v15.v4s(), v2.s()[1]);
348 ldr(q6, mem[x3], 16); // Load next 6 A
349 fmla(v27.v4s(), v15.v4s(), v3.s()[1]);
350 fmla(v29.v4s(), v15.v4s(), v4.s()[1]);
351 fmla(v31.v4s(), v15.v4s(), v5.s()[1]);
352 ldr(q7, mem[x9], 16);
353
354 fmla(v20.v4s(), v16.v4s(), v0.s()[2]);
355 fmla(v22.v4s(), v16.v4s(), v1.s()[2]);
356 fmla(v24.v4s(), v16.v4s(), v2.s()[2]);
357 ldr(q8, mem[x10], 16);
358 fmla(v26.v4s(), v16.v4s(), v3.s()[2]);
359 fmla(v28.v4s(), v16.v4s(), v4.s()[2]);
360 fmla(v30.v4s(), v16.v4s(), v5.s()[2]);
361 ldr(q9, mem[x11], 16);
362 fmla(v21.v4s(), v17.v4s(), v0.s()[2]);
363 fmla(v23.v4s(), v17.v4s(), v1.s()[2]);
364 fmla(v25.v4s(), v17.v4s(), v2.s()[2]);
365 ldr(q10, mem[x12], 16);
366 fmla(v27.v4s(), v17.v4s(), v3.s()[2]);
367 fmla(v29.v4s(), v17.v4s(), v4.s()[2]);
368 fmla(v31.v4s(), v17.v4s(), v5.s()[2]);
369 ldr(q11, mem[x4], 16);
370
371 fmla(v20.v4s(), v18.v4s(), v0.s()[3]);
372 fmla(v22.v4s(), v18.v4s(), v1.s()[3]);
373 fmla(v24.v4s(), v18.v4s(), v2.s()[3]);
374 ldp(q12, q13, mem[x5], 32); // Load 4 B
375 fmla(v26.v4s(), v18.v4s(), v3.s()[3]);
376 fmla(v28.v4s(), v18.v4s(), v4.s()[3]);
377 fmla(v30.v4s(), v18.v4s(), v5.s()[3]);
378 ldp(q14, q15, mem[x5], 32);
379 fmla(v21.v4s(), v19.v4s(), v0.s()[3]);
380 fmla(v23.v4s(), v19.v4s(), v1.s()[3]);
381 fmla(v25.v4s(), v19.v4s(), v2.s()[3]);
382 ldp(q16, q17, mem[x5], 32);
383 fmla(v27.v4s(), v19.v4s(), v3.s()[3]);
384 fmla(v29.v4s(), v19.v4s(), v4.s()[3]);
385 fmla(v31.v4s(), v19.v4s(), v5.s()[3]);
386 ldp(q18, q19, mem[x5], 32);
387
388 // Second group of 4 A. 48 FMA.
389 fmla(v20.v4s(), v12.v4s(), v6.s()[0]);
390 fmla(v22.v4s(), v12.v4s(), v7.s()[0]);
391 fmla(v24.v4s(), v12.v4s(), v8.s()[0]);
392 fmla(v26.v4s(), v12.v4s(), v9.s()[0]);
393 fmla(v28.v4s(), v12.v4s(), v10.s()[0]);
394 fmla(v30.v4s(), v12.v4s(), v11.s()[0]);
395 fmla(v21.v4s(), v13.v4s(), v6.s()[0]);
396 fmla(v23.v4s(), v13.v4s(), v7.s()[0]);
397 fmla(v25.v4s(), v13.v4s(), v8.s()[0]);
398 fmla(v27.v4s(), v13.v4s(), v9.s()[0]);
399 fmla(v29.v4s(), v13.v4s(), v10.s()[0]);
400 fmla(v31.v4s(), v13.v4s(), v11.s()[0]);
401
402 fmla(v20.v4s(), v14.v4s(), v6.s()[1]);
403 fmla(v22.v4s(), v14.v4s(), v7.s()[1]);
404 fmla(v24.v4s(), v14.v4s(), v8.s()[1]);
405 fmla(v26.v4s(), v14.v4s(), v9.s()[1]);
406 fmla(v28.v4s(), v14.v4s(), v10.s()[1]);
407 fmla(v30.v4s(), v14.v4s(), v11.s()[1]);
408 fmla(v21.v4s(), v15.v4s(), v6.s()[1]);
409 fmla(v23.v4s(), v15.v4s(), v7.s()[1]);
410 fmla(v25.v4s(), v15.v4s(), v8.s()[1]);
411 fmla(v27.v4s(), v15.v4s(), v9.s()[1]);
412 fmla(v29.v4s(), v15.v4s(), v10.s()[1]);
413 fmla(v31.v4s(), v15.v4s(), v11.s()[1]);
414
415 fmla(v20.v4s(), v16.v4s(), v6.s()[2]);
416 fmla(v22.v4s(), v16.v4s(), v7.s()[2]);
417 fmla(v24.v4s(), v16.v4s(), v8.s()[2]);
418 fmla(v26.v4s(), v16.v4s(), v9.s()[2]);
419 fmla(v28.v4s(), v16.v4s(), v10.s()[2]);
420 fmla(v30.v4s(), v16.v4s(), v11.s()[2]);
421 fmla(v21.v4s(), v17.v4s(), v6.s()[2]);
422 fmla(v23.v4s(), v17.v4s(), v7.s()[2]);
423 fmla(v25.v4s(), v17.v4s(), v8.s()[2]);
424 fmla(v27.v4s(), v17.v4s(), v9.s()[2]);
425 fmla(v29.v4s(), v17.v4s(), v10.s()[2]);
426 fmla(v31.v4s(), v17.v4s(), v11.s()[2]);
427
428 fmla(v20.v4s(), v18.v4s(), v6.s()[3]);
429 fmla(v22.v4s(), v18.v4s(), v7.s()[3]);
430 fmla(v24.v4s(), v18.v4s(), v8.s()[3]);
431 fmla(v26.v4s(), v18.v4s(), v9.s()[3]);
432 fmla(v28.v4s(), v18.v4s(), v10.s()[3]);
433 fmla(v30.v4s(), v18.v4s(), v11.s()[3]);
434 fmla(v21.v4s(), v19.v4s(), v6.s()[3]);
435 fmla(v23.v4s(), v19.v4s(), v7.s()[3]);
436
437 // Load min/max values
Zhi An Ngc92034d2022-02-03 16:14:50 -0800438 if (clamp_min || clamp_max) {
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800439 ld2r({v6.v4s(), v7.v4s()}, mem[x8]);
440 }
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800441
442 fmla(v25.v4s(), v19.v4s(), v8.s()[3]);
443 fmla(v27.v4s(), v19.v4s(), v9.s()[3]);
444 // Is there a remainder?- 4 floats of A (16 bytes) or less
445 tst(x0, 31);
446 fmla(v29.v4s(), v19.v4s(), v10.s()[3]);
447 fmla(v31.v4s(), v19.v4s(), v11.s()[3]);
448 b_ne(l4);
449
450 // Clamp
451 bind(l3);
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800452 // Load cn_stride
453 ldr(x0, mem[sp, 64]);
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800454 subs(x1, x1, 8);
Zhi An Ngc92034d2022-02-03 16:14:50 -0800455 if (clamp_min) {
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800456 fmax(v20.v4s(), v20.v4s(), v6.v4s());
457 fmax(v21.v4s(), v21.v4s(), v6.v4s());
458 fmax(v22.v4s(), v22.v4s(), v6.v4s());
459 fmax(v23.v4s(), v23.v4s(), v6.v4s());
460 fmax(v24.v4s(), v24.v4s(), v6.v4s());
461 fmax(v25.v4s(), v25.v4s(), v6.v4s());
462 fmax(v26.v4s(), v26.v4s(), v6.v4s());
463 fmax(v27.v4s(), v27.v4s(), v6.v4s());
464 fmax(v28.v4s(), v28.v4s(), v6.v4s());
465 fmax(v29.v4s(), v29.v4s(), v6.v4s());
466 fmax(v30.v4s(), v30.v4s(), v6.v4s());
467 fmax(v31.v4s(), v31.v4s(), v6.v4s());
468 }
Zhi An Ngc92034d2022-02-03 16:14:50 -0800469 if (clamp_max) {
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800470 fmin(v20.v4s(), v20.v4s(), v7.v4s());
471 fmin(v21.v4s(), v21.v4s(), v7.v4s());
472 fmin(v22.v4s(), v22.v4s(), v7.v4s());
473 fmin(v23.v4s(), v23.v4s(), v7.v4s());
474 fmin(v24.v4s(), v24.v4s(), v7.v4s());
475 fmin(v25.v4s(), v25.v4s(), v7.v4s());
476 fmin(v26.v4s(), v26.v4s(), v7.v4s());
477 fmin(v27.v4s(), v27.v4s(), v7.v4s());
478 fmin(v28.v4s(), v28.v4s(), v7.v4s());
479 fmin(v29.v4s(), v29.v4s(), v7.v4s());
480 fmin(v30.v4s(), v30.v4s(), v7.v4s());
481 fmin(v31.v4s(), v31.v4s(), v7.v4s());
482 }
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800483
484 // Store full 6 x 8
485 b_lo(l7);
486
487 stp(q20, q21, mem[x6]);
488 add(x6, x6, x0);
489 sub(x3, x3, x2); // a0 -= kc
490 stp(q22, q23, mem[x16]);
491 add(x16, x16, x0);
492 sub(x9, x9, x2); // a1 -= kc
493 stp(q24, q25, mem[x17]);
494 add(x17, x17, x0);
495 sub(x10, x10, x2); // a2 -= kc
496 stp(q26, q27, mem[x14]);
497 add(x14, x14, x0);
498 sub(x11, x11, x2); // a3 -= kc
499 stp(q28, q29, mem[x13]);
500 add(x13, x13, x0);
501 sub(x12, x12, x2); // a4 -= kc
502 stp(q30, q31, mem[x7]);
503 add(x7, x7, x0);
504 sub(x4, x4, x2); // a5 -= kc
505
506 b_hi(l0);
507
508 // Restore d8-d15 from stack
509 ldp(d14, d15, mem[sp, 48]);
510 ldp(d12, d13, mem[sp, 32]);
511 ldp(d10, d11, mem[sp, 16]);
512 ldp(d8, d9, mem[sp], 64);
513 ret();
514
515 bind(l4);
516 // Load min/max values
517 ld2r({v6.v4s(), v7.v4s()}, mem[x8]);
518
519 // Is there a remainder?- 4 floats of A (16 bytes)
520 tbz(x0, 4, l5);
521
522 // Remainder- 4 floats of A (16 bytes)
523 // Load A
524 ldr(q0, mem[x3], 16);
525 ldr(q1, mem[x9], 16);
526 ldr(q2, mem[x10], 16);
527 ldr(q3, mem[x11], 16);
528 ldr(q4, mem[x12], 16);
529 ldr(q5, mem[x4], 16);
530 // Load B
531 ldp(q12, q13, mem[x5], 32);
532 ldp(q14, q15, mem[x5], 32);
533 ldp(q16, q17, mem[x5], 32);
534 ldp(q18, q19, mem[x5], 32);
535
536 fmla(v20.v4s(), v12.v4s(), v0.s()[0]);
537 fmla(v22.v4s(), v12.v4s(), v1.s()[0]);
538 fmla(v24.v4s(), v12.v4s(), v2.s()[0]);
539 fmla(v26.v4s(), v12.v4s(), v3.s()[0]);
540 fmla(v28.v4s(), v12.v4s(), v4.s()[0]);
541 fmla(v30.v4s(), v12.v4s(), v5.s()[0]);
542 fmla(v21.v4s(), v13.v4s(), v0.s()[0]);
543 fmla(v23.v4s(), v13.v4s(), v1.s()[0]);
544 fmla(v25.v4s(), v13.v4s(), v2.s()[0]);
545 fmla(v27.v4s(), v13.v4s(), v3.s()[0]);
546 fmla(v29.v4s(), v13.v4s(), v4.s()[0]);
547 fmla(v31.v4s(), v13.v4s(), v5.s()[0]);
548
549 fmla(v20.v4s(), v14.v4s(), v0.s()[1]);
550 fmla(v22.v4s(), v14.v4s(), v1.s()[1]);
551 fmla(v24.v4s(), v14.v4s(), v2.s()[1]);
552 fmla(v26.v4s(), v14.v4s(), v3.s()[1]);
553 fmla(v28.v4s(), v14.v4s(), v4.s()[1]);
554 fmla(v30.v4s(), v14.v4s(), v5.s()[1]);
555 fmla(v21.v4s(), v15.v4s(), v0.s()[1]);
556 fmla(v23.v4s(), v15.v4s(), v1.s()[1]);
557 fmla(v25.v4s(), v15.v4s(), v2.s()[1]);
558 fmla(v27.v4s(), v15.v4s(), v3.s()[1]);
559 fmla(v29.v4s(), v15.v4s(), v4.s()[1]);
560 fmla(v31.v4s(), v15.v4s(), v5.s()[1]);
561
562 fmla(v20.v4s(), v16.v4s(), v0.s()[2]);
563 fmla(v22.v4s(), v16.v4s(), v1.s()[2]);
564 fmla(v24.v4s(), v16.v4s(), v2.s()[2]);
565 fmla(v26.v4s(), v16.v4s(), v3.s()[2]);
566 fmla(v28.v4s(), v16.v4s(), v4.s()[2]);
567 fmla(v30.v4s(), v16.v4s(), v5.s()[2]);
568 fmla(v21.v4s(), v17.v4s(), v0.s()[2]);
569 fmla(v23.v4s(), v17.v4s(), v1.s()[2]);
570 fmla(v25.v4s(), v17.v4s(), v2.s()[2]);
571 fmla(v27.v4s(), v17.v4s(), v3.s()[2]);
572 fmla(v29.v4s(), v17.v4s(), v4.s()[2]);
573 fmla(v31.v4s(), v17.v4s(), v5.s()[2]);
574
575 fmla(v20.v4s(), v18.v4s(), v0.s()[3]);
576 fmla(v22.v4s(), v18.v4s(), v1.s()[3]);
577 fmla(v24.v4s(), v18.v4s(), v2.s()[3]);
578 fmla(v26.v4s(), v18.v4s(), v3.s()[3]);
579 fmla(v28.v4s(), v18.v4s(), v4.s()[3]);
580 fmla(v30.v4s(), v18.v4s(), v5.s()[3]);
581 fmla(v21.v4s(), v19.v4s(), v0.s()[3]);
582 fmla(v23.v4s(), v19.v4s(), v1.s()[3]);
583 fmla(v25.v4s(), v19.v4s(), v2.s()[3]);
584 fmla(v27.v4s(), v19.v4s(), v3.s()[3]);
585 fmla(v29.v4s(), v19.v4s(), v4.s()[3]);
586 fmla(v31.v4s(), v19.v4s(), v5.s()[3]);
587
588 // Is there a remainder?- 2 floats of A (8 bytes)
589 bind(l5);
590 tbz(x0, 3, l6);
591
592 // Remainder- 2 floats of A (8 bytes)
593 // Load A
594 ldr(d0, mem[x3], 8);
595 ldr(d1, mem[x9], 8);
596 ldr(d2, mem[x10], 8);
597 ldr(d3, mem[x11], 8);
598 ldr(d4, mem[x12], 8);
599 ldr(d5, mem[x4], 8);
600 // Load B
601 ldp(q12, q13, mem[x5], 32);
602 ldp(q14, q15, mem[x5], 32);
603
604 fmla(v20.v4s(), v12.v4s(), v0.s()[0]);
605 fmla(v22.v4s(), v12.v4s(), v1.s()[0]);
606 fmla(v24.v4s(), v12.v4s(), v2.s()[0]);
607 fmla(v26.v4s(), v12.v4s(), v3.s()[0]);
608 fmla(v28.v4s(), v12.v4s(), v4.s()[0]);
609 fmla(v30.v4s(), v12.v4s(), v5.s()[0]);
610 fmla(v21.v4s(), v13.v4s(), v0.s()[0]);
611 fmla(v23.v4s(), v13.v4s(), v1.s()[0]);
612 fmla(v25.v4s(), v13.v4s(), v2.s()[0]);
613 fmla(v27.v4s(), v13.v4s(), v3.s()[0]);
614 fmla(v29.v4s(), v13.v4s(), v4.s()[0]);
615 fmla(v31.v4s(), v13.v4s(), v5.s()[0]);
616
617 fmla(v20.v4s(), v14.v4s(), v0.s()[1]);
618 fmla(v22.v4s(), v14.v4s(), v1.s()[1]);
619 fmla(v24.v4s(), v14.v4s(), v2.s()[1]);
620 fmla(v26.v4s(), v14.v4s(), v3.s()[1]);
621 fmla(v28.v4s(), v14.v4s(), v4.s()[1]);
622 fmla(v30.v4s(), v14.v4s(), v5.s()[1]);
623 fmla(v21.v4s(), v15.v4s(), v0.s()[1]);
624 fmla(v23.v4s(), v15.v4s(), v1.s()[1]);
625 fmla(v25.v4s(), v15.v4s(), v2.s()[1]);
626 fmla(v27.v4s(), v15.v4s(), v3.s()[1]);
627 fmla(v29.v4s(), v15.v4s(), v4.s()[1]);
628 fmla(v31.v4s(), v15.v4s(), v5.s()[1]);
629
630 // Is there a remainder?- 1 float of A (4 bytes)
631 bind(l6);
632 tbz(x0, 2, l3);
633
634 // Remainder- 1 float of A (4 bytes)
635 // Load A
636 ldr(s0, mem[x3], 4);
637 ldr(s1, mem[x9], 4);
638 ldr(s2, mem[x10], 4);
639 ldr(s3, mem[x11], 4);
640 ldr(s4, mem[x12], 4);
641 ldr(s5, mem[x4], 4);
642 // Load B
643 ldp(q12, q13, mem[x5], 32);
644
645 fmla(v20.v4s(), v12.v4s(), v0.s()[0]);
646 fmla(v22.v4s(), v12.v4s(), v1.s()[0]);
647 fmla(v24.v4s(), v12.v4s(), v2.s()[0]);
648 fmla(v26.v4s(), v12.v4s(), v3.s()[0]);
649 fmla(v28.v4s(), v12.v4s(), v4.s()[0]);
650 fmla(v30.v4s(), v12.v4s(), v5.s()[0]);
651 fmla(v21.v4s(), v13.v4s(), v0.s()[0]);
652 fmla(v23.v4s(), v13.v4s(), v1.s()[0]);
653 fmla(v25.v4s(), v13.v4s(), v2.s()[0]);
654 fmla(v27.v4s(), v13.v4s(), v3.s()[0]);
655 fmla(v29.v4s(), v13.v4s(), v4.s()[0]);
656 fmla(v31.v4s(), v13.v4s(), v5.s()[0]);
657 b(l3);
658
659 // Store odd width
660 bind(l7);
661 tbz(x1, 2, l8);
662 str(q20, mem[x6], 16);
663 mov(v20.v16b(), v21.v16b());
664 str(q22, mem[x16], 16);
665 mov(v22.v16b(), v23.v16b());
666 str(q24, mem[x17], 16);
667 mov(v24.v16b(), v25.v16b());
668 str(q26, mem[x14], 16);
669 mov(v26.v16b(), v27.v16b());
670 str(q28, mem[x13], 16);
671 mov(v28.v16b(), v29.v16b());
672 str(q30, mem[x7], 16);
673 mov(v30.v16b(), v31.v16b());
674 bind(l8);
675 tbz(x1, 1, l9);
676 str(d20, mem[x6], 8);
677 str(d22, mem[x16], 8);
678 dup(d20, v20.d()[1]);
679 dup(d22, v22.d()[1]);
680 str(d24, mem[x17], 8);
681 str(d26, mem[x14], 8);
682 dup(d24, v24.d()[1]);
683 dup(d26, v26.d()[1]);
684 str(d28, mem[x13], 8);
685 str(d30, mem[x7], 8);
686 dup(d28, v28.d()[1]);
687 dup(d30, v30.d()[1]);
688
689 bind(l9);
690 tbz(x1, 0, l10);
691 str(s20, mem[x6]);
692 str(s22, mem[x16]);
693 str(s24, mem[x17]);
694 str(s26, mem[x14]);
695 str(s28, mem[x13]);
696 str(s30, mem[x7]);
697 bind(l10);
698 // Restore d8-d15 from stack
699 ldp(d14, d15, mem[sp, 48]);
700 ldp(d12, d13, mem[sp, 32]);
701 ldp(d10, d11, mem[sp, 16]);
702 ldp(d8, d9, mem[sp], 64);
703 ret();
704
705
706}
707} // namespace
708} // aarch64
709} // xnnpack
710
Zhi An Ng3e3124e2022-02-02 12:46:34 -0800711xnn_status xnn_generate_f32_gemm_ukernel_6x8__aarch64_neonfma_cortex_a75(xnn_code_buffer* code, size_t nc, size_t kc, const void* params) {
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800712 using namespace xnnpack::aarch64;
713 Generator g(code);
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800714 assert(params != nullptr);
Zhi An Ng3e3124e2022-02-02 12:46:34 -0800715 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800716 g.generate(false, nc, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800717 g.finalize();
718 if (g.error() != xnnpack::Error::kNoError) {
719 return xnn_status_invalid_state;
720 }
721 return xnn_status_success;
722}
723
Zhi An Ng3e3124e2022-02-02 12:46:34 -0800724xnn_status xnn_generate_f32_gemm_ukernel_6x8__aarch64_neonfma_prfm_cortex_a75(xnn_code_buffer* code, size_t nc, size_t kc, const void* params) {
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800725 using namespace xnnpack::aarch64;
726 Generator g(code);
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800727 assert(params != nullptr);
Zhi An Ng3e3124e2022-02-02 12:46:34 -0800728 const jit_gemm_params* gemm_params = static_cast<const jit_gemm_params*>(params);
Zhi An Ng5ebe6862022-02-02 09:54:27 -0800729 g.generate(true, nc, kc, gemm_params->f32_minmax.min, gemm_params->f32_minmax.max);
Zhi An Ngc2e2da82022-01-25 16:51:58 -0800730 g.finalize();
731 if (g.error() != xnnpack::Error::kNoError) {
732 return xnn_status_invalid_state;
733 }
734 return xnn_status_success;
735}