blob: 45f3ff4f4cd679abde9b1e4bdd08e54d1b628174 [file] [log] [blame]
Frank Barchard02121ca2021-02-26 15:32:28 -08001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert NR % 8 == 0
8$assert 8 <= NR <= 16
9#include <assert.h>
10
11#include <arm_neon.h>
12
13#include <xnnpack/common.h>
14#include <xnnpack/gemm.h>
15
16
17void xnn_qs8_igemm_minmax_ukernel_${MR}x${NR}__neon_mull_addw_dup(
18 size_t mr,
19 size_t nc,
20 size_t kc,
21 size_t ks,
22 const int8_t** restrict a,
23 const void* restrict w,
24 int8_t* restrict c,
25 size_t cm_stride,
26 size_t cn_stride,
27 size_t a_offset,
28 const int8_t* zero,
29 const union xnn_qs8_gemm_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN
30{
31 assert(mr != 0);
32 assert(mr <= ${MR});
33 assert(nc != 0);
34 assert(kc != 0);
35 assert(ks != 0);
36 assert(ks % (${MR} * sizeof(void*)) == 0);
37 assert(a_offset % sizeof(int8_t) == 0);
38 assert(a != NULL);
39 assert(w != NULL);
40 assert(c != NULL);
41
42 int8_t* c0 = c;
43 $for M in range(1, MR):
44 int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
45 $if M % 2 == 0:
46 if XNN_UNPREDICTABLE(mr <= ${M}) {
47 c${M} = c${M-1};
48 }
49 $elif M + 1 == MR:
50 if XNN_UNPREDICTABLE(mr != ${M+1}) {
51 c${M} = c${M-1};
52 }
53 $else:
54 if XNN_UNPREDICTABLE(mr < ${M+1}) {
55 c${M} = c${M-1};
56 }
57
58 do {
59 $for N in range(0, NR, 4):
60 int32x4_t vacc0x${ABC[N:N+4]} = vld1q_s32(w); w = (const void*) ((uintptr_t) w + 4 * sizeof(int32_t));
61 $for M in range(1, MR):
62 $for N in range(0, NR, 4):
63 int32x4_t vacc${M}x${ABC[N:N+4]} = vacc0x${ABC[N:N+4]};
64
65 size_t p = ks;
66 do {
67 $for M in range(MR):
68 const int8_t* restrict a${M} = a[${M}];
69 if XNN_UNPREDICTABLE(a${M} != zero) {
70 a${M} = (const int8_t*) ((uintptr_t) a${M} + a_offset);
71 }
72 a += ${MR};
73
74 size_t k = kc;
75 while (k >= 8 * sizeof(int8_t)) {
76 $for M in range(MR):
77 const int8x8_t va${M} = vld1_s8(a${M}); a${M} += 8;
78
79 $for K in range(8):
80 $for N in range(0, NR, 8):
81 const int8x8_t vb${ABC[N:N+8]}c${K} = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
82
83 $for M in range(MR):
84 const int16x8_t vprod${M}x${ABC[N:N+8]}c${K} = vmull_s8(vb${ABC[N:N+8]}c${K}, vdup_lane_s8(va${M}, ${K}));
85 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c${K}));
86 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c${K}));
87
88 k -= 8 * sizeof(int8_t);
89 }
90 if XNN_UNLIKELY(k != 0) {
91 $for M in range(MR):
92 const int8x8_t va${M} = vld1_s8(a${M}); a${M} = (const int8_t*) ((uintptr_t) a${M} + k);
93
94 $for N in range(0, NR, 8):
95 const int8x8_t vb${ABC[N:N+8]}c0 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
96
97 $for M in range(MR):
98 $for N in range(0, NR, 8):
99 const int16x8_t vprod${M}x${ABC[N:N+8]}c0 = vmull_s8(vb${ABC[N:N+8]}c0, vdup_lane_s8(va${M}, 0));
100 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c0));
101 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c0));
102
103 if (k >= 2 * sizeof(int8_t)) {
104 $for N in range(0, NR, 8):
105 const int8x8_t vb${ABC[N:N+8]}c1 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
106
107 $for M in range(MR):
108 $for N in range(0, NR, 8):
109 const int16x8_t vprod${M}x${ABC[N:N+8]}c1 = vmull_s8(vb${ABC[N:N+8]}c1, vdup_lane_s8(va${M}, 1));
110 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c1));
111 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c1));
112
113 if (k > 2 * sizeof(int8_t)) {
114 $for N in range(0, NR, 8):
115 const int8x8_t vb${ABC[N:N+8]}c2 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
116
117 $for M in range(MR):
118 $for N in range(0, NR, 8):
119 const int16x8_t vprod${M}x${ABC[N:N+8]}c2 = vmull_s8(vb${ABC[N:N+8]}c2, vdup_lane_s8(va${M}, 2));
120 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c2));
121 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c2));
122
123 if (k >= 4 * sizeof(int8_t)) {
124 $for N in range(0, NR, 8):
125 const int8x8_t vb${ABC[N:N+8]}c3 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
126
127 $for M in range(MR):
128 $for N in range(0, NR, 8):
129 const int16x8_t vprod${M}x${ABC[N:N+8]}c3 = vmull_s8(vb${ABC[N:N+8]}c3, vdup_lane_s8(va${M}, 3));
130 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c3));
131 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c3));
132
133 if (k > 4 * sizeof(int8_t)) {
134 $for N in range(0, NR, 8):
135 const int8x8_t vb${ABC[N:N+8]}c4 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
136
137 $for M in range(MR):
138 $for N in range(0, NR, 8):
139 const int16x8_t vprod${M}x${ABC[N:N+8]}c4 = vmull_s8(vb${ABC[N:N+8]}c4, vdup_lane_s8(va${M}, 4));
140 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c4));
141 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c4));
142
143 if (k >= 6 * sizeof(int8_t)) {
144 $for N in range(0, NR, 8):
145 const int8x8_t vb${ABC[N:N+8]}c5 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
146
147 $for M in range(MR):
148 $for N in range(0, NR, 8):
149 const int16x8_t vprod${M}x${ABC[N:N+8]}c5 = vmull_s8(vb${ABC[N:N+8]}c5, vdup_lane_s8(va${M}, 5));
150 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c5));
151 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c5));
152
153 if (k > 6 * sizeof(int8_t)) {
154 $for N in range(0, NR, 8):
155 const int8x8_t vb${ABC[N:N+8]}c6 = vld1_s8(w); w = (const void*) ((uintptr_t) w + 8 * sizeof(int8_t));
156
157 $for M in range(MR):
158 $for N in range(0, NR, 8):
159 const int16x8_t vprod${M}x${ABC[N:N+8]}c6 = vmull_s8(vb${ABC[N:N+8]}c6, vdup_lane_s8(va${M}, 6));
160 vacc${M}x${ABC[N:N+4]} = vaddw_s16(vacc${M}x${ABC[N:N+4]}, vget_low_s16(vprod${M}x${ABC[N:N+8]}c6));
161 vacc${M}x${ABC[N+4:N+8]} = vaddw_s16(vacc${M}x${ABC[N+4:N+8]}, vget_high_s16(vprod${M}x${ABC[N:N+8]}c6));
162 }
163 }
164 }
165 }
166 }
167 }
168 }
169 p -= ${MR} * sizeof(void*);
170 } while (p != 0);
171
172 const int32x4_t vmultiplier = vld1q_dup_s32(&params->neon.multiplier);
173 $for M in range(MR):
174 $for N in range(0, NR, 4):
175 vacc${M}x${ABC[N:N+4]} = vqrdmulhq_s32(vacc${M}x${ABC[N:N+4]}, vmultiplier);
176
177 const int32x4_t vright_shift = vld1q_dup_s32(&params->neon.right_shift);
178 const int32x4_t vzero_shift_mask = vreinterpretq_s32_u32(vceqq_s32(vright_shift, vmovq_n_s32(0)));
179 $for M in range(MR):
180 $for N in range(0, NR, 4):
181 vacc${M}x${ABC[N:N+4]} = vsraq_n_s32(vacc${M}x${ABC[N:N+4]}, vbicq_s32(vacc${M}x${ABC[N:N+4]}, vzero_shift_mask), 31);
182
183 $for M in range(MR):
184 $for N in range(0, NR, 4):
185 vacc${M}x${ABC[N:N+4]} = vrshlq_s32(vacc${M}x${ABC[N:N+4]}, vright_shift);
186
187 const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->neon.output_zero_point);
188#if XNN_ARCH_ARM64
189 $for M in range(MR):
190 $for N in range(0, NR, 8):
191 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vacc${M}x${ABC[N+4:N+8]}), voutput_zero_point);
192
193 $for M in range(MR):
194 $for N in range(0, NR, 16):
195 $if N + 8 < NR:
196 int8x16_t vout${M}x${ABC[N:N+16]} = vqmovn_high_s16(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vacc${M}x${ABC[N+8:N+16]});
197 $elif M % 2 == 1:
198 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vqmovn_high_s16(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vacc${M}x${ABC[N:N+8]});
199 $elif M + 1 == MR:
200 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
201#else
202 $for M in range(MR):
203 $for N in range(0, NR, 8):
204 const int16x8_t vacc${M}x${ABC[N:N+8]} = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc${M}x${ABC[N:N+4]}), vqmovn_s32(vacc${M}x${ABC[N+4:N+8]})), voutput_zero_point);
205
206 $for M in range(MR):
207 $for N in range(0, NR, 16):
208 $if N + 8 < NR:
209 int8x16_t vout${M}x${ABC[N:N+16]} = vcombine_s8(vqmovn_s16(vacc${M}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N+8:N+16]}));
210 $elif M % 2 == 1:
211 int8x16_t vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vcombine_s8(vqmovn_s16(vacc${M-1}x${ABC[N:N+8]}), vqmovn_s16(vacc${M}x${ABC[N:N+8]}));
212 $elif M + 1 == MR:
213 int8x8_t vout${M}x${ABC[N:N+8]} = vqmovn_s16(vacc${M}x${ABC[N:N+8]});
214#endif
215 $if NR == 8 and MR == 1:
216 const int8x8_t voutput_min = vld1_dup_s8(&params->neon.output_min);
217 const int8x8_t voutput_max = vld1_dup_s8(&params->neon.output_max);
218 $else:
219 const int8x16_t voutput_min = vld1q_dup_s8(&params->neon.output_min);
220 const int8x16_t voutput_max = vld1q_dup_s8(&params->neon.output_max);
221
222 $for M in reversed(range(MR)):
223 $for N in range(0, NR, 16):
224 $if N + 8 < NR:
225 vout${M}x${ABC[N:N+16]} = vmaxq_s8(vout${M}x${ABC[N:N+16]}, voutput_min);
226 $elif M % 2 == 1:
227 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vmaxq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_min);
228 $elif M + 1 == MR:
229 $if NR == 8 and MR == 1:
230 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, voutput_min);
231 $else:
232 vout${M}x${ABC[N:N+8]} = vmax_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_min));
233
234 $for M in reversed(range(MR)):
235 $for N in range(0, NR, 16):
236 $if N + 8 < NR:
237 vout${M}x${ABC[N:N+16]} = vminq_s8(vout${M}x${ABC[N:N+16]}, voutput_max);
238 $elif M % 2 == 1:
239 vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]} = vminq_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}, voutput_max);
240 $elif M + 1 == MR:
241 $if NR == 8 and MR == 1:
242 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, voutput_max);
243 $else:
244 vout${M}x${ABC[N:N+8]} = vmin_s8(vout${M}x${ABC[N:N+8]}, vget_low_s8(voutput_max));
245
246 if (nc >= ${NR}) {
247 $for M in reversed(range(MR)):
248 $for N in range(0, NR, 16):
249 $if N + 8 < NR:
250 vst1q_s8(c${M} + ${N}, vout${M}x${ABC[N:N+16]});
251 $elif M % 2 == 1:
252 vst1_s8(c${M} + ${N}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
253 vst1_s8(c${M-1} + ${N}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]}));
254 $elif M + 1 == MR:
255 vst1_s8(c${M} + ${N}, vout${M}x${ABC[N:N+8]});
256
257 $for M in reversed(range(MR)):
258 c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
259
260 a = (const int8_t**restrict) ((uintptr_t) a - ks);
261
262 nc -= ${NR};
263 } else {
264 $if NR == 16:
265 $for M in range(MR):
266 $if M % 2 == 1:
267 int8x16_t vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_low_s8(vout${M-1}x0123456789ABCDEF), vget_low_s8(vout${M}x0123456789ABCDEF));
268 $elif M + 1 == MR:
269 int8x8_t vout${M}x01234567 = vget_low_s8(vout${M}x0123456789ABCDEF);
270 if (nc & 8) {
271 $for M in reversed(range(MR)):
272 $if M % 2 == 1:
273 vst1_s8(c${M}, vget_high_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M} += 8;
274 vst1_s8(c${M-1}, vget_low_s8(vout${M-1}x${ABC[N:N+8]}_${M}x${ABC[N:N+8]})); c${M-1} += 8;
275 $elif M + 1 == MR:
276 vst1_s8(c${M}, vout${M}x${ABC[N:N+8]}); c${M} += 8;
277 $for M in reversed(range(MR)):
278 $if M % 2 == 1:
279 vout${M-1}x01234567_${M}x01234567 = vcombine_s8(vget_high_s8(vout${M-1}x0123456789ABCDEF), vget_high_s8(vout${M}x0123456789ABCDEF));
280 $elif M + 1 == MR:
281 vout${M}x01234567 = vget_high_s8(vout${M}x0123456789ABCDEF);
282 }
283 if (nc & 4) {
284 $for M in reversed(range(MR)):
285 $if M % 2 == 1:
286 vst1q_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 2); c${M} += 4;
287 vst1q_lane_u32(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u32_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 4;
288 $elif M + 1 == MR:
289 vst1_lane_u32(__builtin_assume_aligned(c${M}, 1), vreinterpret_u32_s8(vout${M}x01234567), 0); c${M} += 4;
290 $for M in reversed(range(MR)):
291 $if M % 2 == 1:
292 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 4);
293 $elif M + 1 == MR:
294 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 4);
295 }
296 if (nc & 2) {
297 $for M in reversed(range(MR)):
298 $if M % 2 == 1:
299 vst1q_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 4); c${M} += 2;
300 vst1q_lane_u16(__builtin_assume_aligned(c${M-1}, 1), vreinterpretq_u16_s8(vout${M-1}x01234567_${M}x01234567), 0); c${M-1} += 2;
301 $elif M + 1 == MR:
302 vst1_lane_u16(__builtin_assume_aligned(c${M}, 1), vreinterpret_u16_s8(vout${M}x01234567), 0); c${M} += 2;
303 $for M in reversed(range(MR)):
304 $if M % 2 == 1:
305 vout${M-1}x01234567_${M}x01234567 = vextq_s8(vout${M-1}x01234567_${M}x01234567, vout${M-1}x01234567_${M}x01234567, 2);
306 $elif M + 1 == MR:
307 vout${M}x01234567 = vext_s8(vout${M}x01234567, vout${M}x01234567, 2);
308 }
309 if (nc & 1) {
310 $for M in reversed(range(MR)):
311 $if M % 2 == 1:
312 vst1q_lane_s8(c${M}, vout${M-1}x01234567_${M}x01234567, 8);
313 vst1q_lane_s8(c${M-1}, vout${M-1}x01234567_${M}x01234567, 0);
314 $elif M + 1 == MR:
315 vst1_lane_s8(c${M}, vout${M}x01234567, 0);
316 }
317
318 nc = 0;
319 }
320 } while (nc != 0);
321}