blob: 7381d8a7ba7271e30a343d877a05dae38a3f46f0 [file] [log] [blame]
Frank Barcharde0331262021-08-11 23:18:59 -07001// Auto-generated file. Do not edit!
2// Template: src/qu8-gemm/c4-neondot.c.in
3// Generator: tools/xngen
4//
5// Copyright 2020 Google LLC
6//
7// This source code is licensed under the BSD-style license found in the
8// LICENSE file in the root directory of this source tree.
9
10#include <assert.h>
11
12#include <arm_neon.h>
13
14#include <xnnpack/gemm.h>
15#include <xnnpack/math.h>
16
17
18void xnn_qu8_gemm_minmax_rndnu_ukernel_5x16c4__neondot(
19 size_t mr,
20 size_t nc,
21 size_t kc,
22 const uint8_t* restrict a,
23 size_t a_stride,
24 const void* restrict w,
25 uint8_t* restrict c,
26 size_t cm_stride,
27 size_t cn_stride,
28 const union xnn_qu8_conv_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
29{
30 assert(mr != 0);
31 assert(mr <= 5);
32 assert(nc != 0);
33 assert(kc != 0);
34 assert(kc % sizeof(uint8_t) == 0);
35 assert(a != NULL);
36 assert(w != NULL);
37 assert(c != NULL);
38
39 kc = round_up_po2(kc, 4 * sizeof(uint8_t));
40 const uint8_t* a0 = a;
41 uint8_t* c0 = c;
42 const uint8_t* a1 = (const uint8_t*) ((uintptr_t) a0 + a_stride);
43 uint8_t* c1 = (uint8_t*) ((uintptr_t) c0 + cm_stride);
44 if XNN_UNPREDICTABLE(mr < 2) {
45 a1 = a0;
46 c1 = c0;
47 }
48 const uint8_t* a2 = (const uint8_t*) ((uintptr_t) a1 + a_stride);
49 uint8_t* c2 = (uint8_t*) ((uintptr_t) c1 + cm_stride);
50 if XNN_UNPREDICTABLE(mr <= 2) {
51 a2 = a1;
52 c2 = c1;
53 }
54 const uint8_t* a3 = (const uint8_t*) ((uintptr_t) a2 + a_stride);
55 uint8_t* c3 = (uint8_t*) ((uintptr_t) c2 + cm_stride);
56 if XNN_UNPREDICTABLE(mr < 4) {
57 a3 = a2;
58 c3 = c2;
59 }
60 const uint8_t* a4 = (const uint8_t*) ((uintptr_t) a3 + a_stride);
61 uint8_t* c4 = (uint8_t*) ((uintptr_t) c3 + cm_stride);
62 if XNN_UNPREDICTABLE(mr <= 4) {
63 a4 = a3;
64 c4 = c3;
65 }
66
67 const uint8x16_t vb_zero_point = vld1q_dup_u8(&params->rndnu_neon.kernel_zero_point[0]);
68
69 // Loop over groups of 16 columns.
70 do {
71 // Initialize accumulators with bias. 16 bias values are loaded from the
72 // weight matrix, at the start of the group of 16 columns.
73 uint32x4_t vpacc0x0123 = vld1q_u32(w); w = (const void*) ((const uint32_t*) w + 4);
74 uint32x4_t vpacc0x4567 = vld1q_u32(w); w = (const void*) ((const uint32_t*) w + 4);
75 uint32x4_t vpacc0x89AB = vld1q_u32(w); w = (const void*) ((const uint32_t*) w + 4);
76 uint32x4_t vpacc0xCDEF = vld1q_u32(w); w = (const void*) ((const uint32_t*) w + 4);
77 uint32x4_t vpacc1x0123 = vpacc0x0123;
78 uint32x4_t vpacc1x4567 = vpacc0x4567;
79 uint32x4_t vpacc1x89AB = vpacc0x89AB;
80 uint32x4_t vpacc1xCDEF = vpacc0xCDEF;
81 uint32x4_t vpacc2x0123 = vpacc0x0123;
82 uint32x4_t vpacc2x4567 = vpacc0x4567;
83 uint32x4_t vpacc2x89AB = vpacc0x89AB;
84 uint32x4_t vpacc2xCDEF = vpacc0xCDEF;
85 uint32x4_t vpacc3x0123 = vpacc0x0123;
86 uint32x4_t vpacc3x4567 = vpacc0x4567;
87 uint32x4_t vpacc3x89AB = vpacc0x89AB;
88 uint32x4_t vpacc3xCDEF = vpacc0xCDEF;
89 uint32x4_t vpacc4x0123 = vpacc0x0123;
90 uint32x4_t vpacc4x4567 = vpacc0x4567;
91 uint32x4_t vpacc4x89AB = vpacc0x89AB;
92 uint32x4_t vpacc4xCDEF = vpacc0xCDEF;
93 uint32x4_t vnacc0x0123 = vmovq_n_u32(0);
94 uint32x4_t vnacc0x4567 = vmovq_n_u32(0);
95 uint32x4_t vnacc0x89AB = vmovq_n_u32(0);
96 uint32x4_t vnacc0xCDEF = vmovq_n_u32(0);
97 uint32x4_t vnacc1x0123 = vmovq_n_u32(0);
98 uint32x4_t vnacc1x4567 = vmovq_n_u32(0);
99 uint32x4_t vnacc1x89AB = vmovq_n_u32(0);
100 uint32x4_t vnacc1xCDEF = vmovq_n_u32(0);
101 uint32x4_t vnacc2x0123 = vmovq_n_u32(0);
102 uint32x4_t vnacc2x4567 = vmovq_n_u32(0);
103 uint32x4_t vnacc2x89AB = vmovq_n_u32(0);
104 uint32x4_t vnacc2xCDEF = vmovq_n_u32(0);
105 uint32x4_t vnacc3x0123 = vmovq_n_u32(0);
106 uint32x4_t vnacc3x4567 = vmovq_n_u32(0);
107 uint32x4_t vnacc3x89AB = vmovq_n_u32(0);
108 uint32x4_t vnacc3xCDEF = vmovq_n_u32(0);
109 uint32x4_t vnacc4x0123 = vmovq_n_u32(0);
110 uint32x4_t vnacc4x4567 = vmovq_n_u32(0);
111 uint32x4_t vnacc4x89AB = vmovq_n_u32(0);
112 uint32x4_t vnacc4xCDEF = vmovq_n_u32(0);
113
114 // Inner accumulation loop along the 16 columns.
115 size_t k = kc;
116 // 2x partial unrolled loop to load 8 bytes at a time.
117 while (k >= 8 * sizeof(uint8_t)) {
118 // Load a 5x8 block of activations.
119 const uint8x8_t va0x01234567 = vld1_u8(a0); a0 += 8;
120 const uint8x8_t va1x01234567 = vld1_u8(a1); a1 += 8;
121 const uint8x8_t va2x01234567 = vld1_u8(a2); a2 += 8;
122 const uint8x8_t va3x01234567 = vld1_u8(a3); a3 += 8;
123 const uint8x8_t va4x01234567 = vld1_u8(a4); a4 += 8;
124
125 // Load a 8x16 block of weights.
126 const uint8x16_t vb0123x0123 = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
127 const uint8x16_t vb0123x4567 = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
128 const uint8x16_t vb0123x89AB = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
129 const uint8x16_t vb0123xCDEF = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
130 const uint8x16_t vb4567x0123 = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
131 const uint8x16_t vb4567x4567 = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
132 const uint8x16_t vb4567x89AB = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
133 const uint8x16_t vb4567xCDEF = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
134
135 // Multiply-accumulate: 5x8 * 8x16 --> 5x16.
136 vpacc0x0123 = vdotq_lane_u32(vpacc0x0123, vb0123x0123, va0x01234567, 0);
137 vnacc0x0123 = vdotq_lane_u32(vnacc0x0123, vb_zero_point, va0x01234567, 0);
138 vpacc0x4567 = vdotq_lane_u32(vpacc0x4567, vb0123x4567, va0x01234567, 0);
139 vnacc0x4567 = vdotq_lane_u32(vnacc0x4567, vb_zero_point, va0x01234567, 0);
140 vpacc0x89AB = vdotq_lane_u32(vpacc0x89AB, vb0123x89AB, va0x01234567, 0);
141 vnacc0x89AB = vdotq_lane_u32(vnacc0x89AB, vb_zero_point, va0x01234567, 0);
142 vpacc0xCDEF = vdotq_lane_u32(vpacc0xCDEF, vb0123xCDEF, va0x01234567, 0);
143 vnacc0xCDEF = vdotq_lane_u32(vnacc0xCDEF, vb_zero_point, va0x01234567, 0);
144 vpacc1x0123 = vdotq_lane_u32(vpacc1x0123, vb0123x0123, va1x01234567, 0);
145 vnacc1x0123 = vdotq_lane_u32(vnacc1x0123, vb_zero_point, va1x01234567, 0);
146 vpacc1x4567 = vdotq_lane_u32(vpacc1x4567, vb0123x4567, va1x01234567, 0);
147 vnacc1x4567 = vdotq_lane_u32(vnacc1x4567, vb_zero_point, va1x01234567, 0);
148 vpacc1x89AB = vdotq_lane_u32(vpacc1x89AB, vb0123x89AB, va1x01234567, 0);
149 vnacc1x89AB = vdotq_lane_u32(vnacc1x89AB, vb_zero_point, va1x01234567, 0);
150 vpacc1xCDEF = vdotq_lane_u32(vpacc1xCDEF, vb0123xCDEF, va1x01234567, 0);
151 vnacc1xCDEF = vdotq_lane_u32(vnacc1xCDEF, vb_zero_point, va1x01234567, 0);
152 vpacc2x0123 = vdotq_lane_u32(vpacc2x0123, vb0123x0123, va2x01234567, 0);
153 vnacc2x0123 = vdotq_lane_u32(vnacc2x0123, vb_zero_point, va2x01234567, 0);
154 vpacc2x4567 = vdotq_lane_u32(vpacc2x4567, vb0123x4567, va2x01234567, 0);
155 vnacc2x4567 = vdotq_lane_u32(vnacc2x4567, vb_zero_point, va2x01234567, 0);
156 vpacc2x89AB = vdotq_lane_u32(vpacc2x89AB, vb0123x89AB, va2x01234567, 0);
157 vnacc2x89AB = vdotq_lane_u32(vnacc2x89AB, vb_zero_point, va2x01234567, 0);
158 vpacc2xCDEF = vdotq_lane_u32(vpacc2xCDEF, vb0123xCDEF, va2x01234567, 0);
159 vnacc2xCDEF = vdotq_lane_u32(vnacc2xCDEF, vb_zero_point, va2x01234567, 0);
160 vpacc3x0123 = vdotq_lane_u32(vpacc3x0123, vb0123x0123, va3x01234567, 0);
161 vnacc3x0123 = vdotq_lane_u32(vnacc3x0123, vb_zero_point, va3x01234567, 0);
162 vpacc3x4567 = vdotq_lane_u32(vpacc3x4567, vb0123x4567, va3x01234567, 0);
163 vnacc3x4567 = vdotq_lane_u32(vnacc3x4567, vb_zero_point, va3x01234567, 0);
164 vpacc3x89AB = vdotq_lane_u32(vpacc3x89AB, vb0123x89AB, va3x01234567, 0);
165 vnacc3x89AB = vdotq_lane_u32(vnacc3x89AB, vb_zero_point, va3x01234567, 0);
166 vpacc3xCDEF = vdotq_lane_u32(vpacc3xCDEF, vb0123xCDEF, va3x01234567, 0);
167 vnacc3xCDEF = vdotq_lane_u32(vnacc3xCDEF, vb_zero_point, va3x01234567, 0);
168 vpacc4x0123 = vdotq_lane_u32(vpacc4x0123, vb0123x0123, va4x01234567, 0);
169 vnacc4x0123 = vdotq_lane_u32(vnacc4x0123, vb_zero_point, va4x01234567, 0);
170 vpacc4x4567 = vdotq_lane_u32(vpacc4x4567, vb0123x4567, va4x01234567, 0);
171 vnacc4x4567 = vdotq_lane_u32(vnacc4x4567, vb_zero_point, va4x01234567, 0);
172 vpacc4x89AB = vdotq_lane_u32(vpacc4x89AB, vb0123x89AB, va4x01234567, 0);
173 vnacc4x89AB = vdotq_lane_u32(vnacc4x89AB, vb_zero_point, va4x01234567, 0);
174 vpacc4xCDEF = vdotq_lane_u32(vpacc4xCDEF, vb0123xCDEF, va4x01234567, 0);
175 vnacc4xCDEF = vdotq_lane_u32(vnacc4xCDEF, vb_zero_point, va4x01234567, 0);
176 vpacc0x0123 = vdotq_lane_u32(vpacc0x0123, vb4567x0123, va0x01234567, 1);
177 vnacc0x0123 = vdotq_lane_u32(vnacc0x0123, vb_zero_point, va0x01234567, 1);
178 vpacc0x4567 = vdotq_lane_u32(vpacc0x4567, vb4567x4567, va0x01234567, 1);
179 vnacc0x4567 = vdotq_lane_u32(vnacc0x4567, vb_zero_point, va0x01234567, 1);
180 vpacc0x89AB = vdotq_lane_u32(vpacc0x89AB, vb4567x89AB, va0x01234567, 1);
181 vnacc0x89AB = vdotq_lane_u32(vnacc0x89AB, vb_zero_point, va0x01234567, 1);
182 vpacc0xCDEF = vdotq_lane_u32(vpacc0xCDEF, vb4567xCDEF, va0x01234567, 1);
183 vnacc0xCDEF = vdotq_lane_u32(vnacc0xCDEF, vb_zero_point, va0x01234567, 1);
184 vpacc1x0123 = vdotq_lane_u32(vpacc1x0123, vb4567x0123, va1x01234567, 1);
185 vnacc1x0123 = vdotq_lane_u32(vnacc1x0123, vb_zero_point, va1x01234567, 1);
186 vpacc1x4567 = vdotq_lane_u32(vpacc1x4567, vb4567x4567, va1x01234567, 1);
187 vnacc1x4567 = vdotq_lane_u32(vnacc1x4567, vb_zero_point, va1x01234567, 1);
188 vpacc1x89AB = vdotq_lane_u32(vpacc1x89AB, vb4567x89AB, va1x01234567, 1);
189 vnacc1x89AB = vdotq_lane_u32(vnacc1x89AB, vb_zero_point, va1x01234567, 1);
190 vpacc1xCDEF = vdotq_lane_u32(vpacc1xCDEF, vb4567xCDEF, va1x01234567, 1);
191 vnacc1xCDEF = vdotq_lane_u32(vnacc1xCDEF, vb_zero_point, va1x01234567, 1);
192 vpacc2x0123 = vdotq_lane_u32(vpacc2x0123, vb4567x0123, va2x01234567, 1);
193 vnacc2x0123 = vdotq_lane_u32(vnacc2x0123, vb_zero_point, va2x01234567, 1);
194 vpacc2x4567 = vdotq_lane_u32(vpacc2x4567, vb4567x4567, va2x01234567, 1);
195 vnacc2x4567 = vdotq_lane_u32(vnacc2x4567, vb_zero_point, va2x01234567, 1);
196 vpacc2x89AB = vdotq_lane_u32(vpacc2x89AB, vb4567x89AB, va2x01234567, 1);
197 vnacc2x89AB = vdotq_lane_u32(vnacc2x89AB, vb_zero_point, va2x01234567, 1);
198 vpacc2xCDEF = vdotq_lane_u32(vpacc2xCDEF, vb4567xCDEF, va2x01234567, 1);
199 vnacc2xCDEF = vdotq_lane_u32(vnacc2xCDEF, vb_zero_point, va2x01234567, 1);
200 vpacc3x0123 = vdotq_lane_u32(vpacc3x0123, vb4567x0123, va3x01234567, 1);
201 vnacc3x0123 = vdotq_lane_u32(vnacc3x0123, vb_zero_point, va3x01234567, 1);
202 vpacc3x4567 = vdotq_lane_u32(vpacc3x4567, vb4567x4567, va3x01234567, 1);
203 vnacc3x4567 = vdotq_lane_u32(vnacc3x4567, vb_zero_point, va3x01234567, 1);
204 vpacc3x89AB = vdotq_lane_u32(vpacc3x89AB, vb4567x89AB, va3x01234567, 1);
205 vnacc3x89AB = vdotq_lane_u32(vnacc3x89AB, vb_zero_point, va3x01234567, 1);
206 vpacc3xCDEF = vdotq_lane_u32(vpacc3xCDEF, vb4567xCDEF, va3x01234567, 1);
207 vnacc3xCDEF = vdotq_lane_u32(vnacc3xCDEF, vb_zero_point, va3x01234567, 1);
208 vpacc4x0123 = vdotq_lane_u32(vpacc4x0123, vb4567x0123, va4x01234567, 1);
209 vnacc4x0123 = vdotq_lane_u32(vnacc4x0123, vb_zero_point, va4x01234567, 1);
210 vpacc4x4567 = vdotq_lane_u32(vpacc4x4567, vb4567x4567, va4x01234567, 1);
211 vnacc4x4567 = vdotq_lane_u32(vnacc4x4567, vb_zero_point, va4x01234567, 1);
212 vpacc4x89AB = vdotq_lane_u32(vpacc4x89AB, vb4567x89AB, va4x01234567, 1);
213 vnacc4x89AB = vdotq_lane_u32(vnacc4x89AB, vb_zero_point, va4x01234567, 1);
214 vpacc4xCDEF = vdotq_lane_u32(vpacc4xCDEF, vb4567xCDEF, va4x01234567, 1);
215 vnacc4xCDEF = vdotq_lane_u32(vnacc4xCDEF, vb_zero_point, va4x01234567, 1);
216
217 k -= 8 * sizeof(uint8_t);
218 }
219 // Handle up to 4 final positions of `k`
220 if XNN_UNLIKELY(k != 0) {
221 // Load a 5x4 block of activations.
222 const uint8x8_t va0x01234567 = vld1_u8(a0); a0 += 4;
223 const uint8x8_t va1x01234567 = vld1_u8(a1); a1 += 4;
224 const uint8x8_t va2x01234567 = vld1_u8(a2); a2 += 4;
225 const uint8x8_t va3x01234567 = vld1_u8(a3); a3 += 4;
226 const uint8x8_t va4x01234567 = vld1_u8(a4); a4 += 4;
227
228 // Load a 4x16 block of weights.
229 const uint8x16_t vb0123x0123 = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
230 const uint8x16_t vb0123x4567 = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
231 const uint8x16_t vb0123x89AB = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
232 const uint8x16_t vb0123xCDEF = vld1q_u8(w); w = (const void*) ((const uint8_t*) w + 16);
233
234 // Multiply-accumulate: 5x4 * 4x16 --> 5x16.
235 vpacc0x0123 = vdotq_lane_u32(vpacc0x0123, vb0123x0123, va0x01234567, 0);
236 vnacc0x0123 = vdotq_lane_u32(vnacc0x0123, vb_zero_point, va0x01234567, 0);
237 vpacc0x4567 = vdotq_lane_u32(vpacc0x4567, vb0123x4567, va0x01234567, 0);
238 vnacc0x4567 = vdotq_lane_u32(vnacc0x4567, vb_zero_point, va0x01234567, 0);
239 vpacc0x89AB = vdotq_lane_u32(vpacc0x89AB, vb0123x89AB, va0x01234567, 0);
240 vnacc0x89AB = vdotq_lane_u32(vnacc0x89AB, vb_zero_point, va0x01234567, 0);
241 vpacc0xCDEF = vdotq_lane_u32(vpacc0xCDEF, vb0123xCDEF, va0x01234567, 0);
242 vnacc0xCDEF = vdotq_lane_u32(vnacc0xCDEF, vb_zero_point, va0x01234567, 0);
243 vpacc1x0123 = vdotq_lane_u32(vpacc1x0123, vb0123x0123, va1x01234567, 0);
244 vnacc1x0123 = vdotq_lane_u32(vnacc1x0123, vb_zero_point, va1x01234567, 0);
245 vpacc1x4567 = vdotq_lane_u32(vpacc1x4567, vb0123x4567, va1x01234567, 0);
246 vnacc1x4567 = vdotq_lane_u32(vnacc1x4567, vb_zero_point, va1x01234567, 0);
247 vpacc1x89AB = vdotq_lane_u32(vpacc1x89AB, vb0123x89AB, va1x01234567, 0);
248 vnacc1x89AB = vdotq_lane_u32(vnacc1x89AB, vb_zero_point, va1x01234567, 0);
249 vpacc1xCDEF = vdotq_lane_u32(vpacc1xCDEF, vb0123xCDEF, va1x01234567, 0);
250 vnacc1xCDEF = vdotq_lane_u32(vnacc1xCDEF, vb_zero_point, va1x01234567, 0);
251 vpacc2x0123 = vdotq_lane_u32(vpacc2x0123, vb0123x0123, va2x01234567, 0);
252 vnacc2x0123 = vdotq_lane_u32(vnacc2x0123, vb_zero_point, va2x01234567, 0);
253 vpacc2x4567 = vdotq_lane_u32(vpacc2x4567, vb0123x4567, va2x01234567, 0);
254 vnacc2x4567 = vdotq_lane_u32(vnacc2x4567, vb_zero_point, va2x01234567, 0);
255 vpacc2x89AB = vdotq_lane_u32(vpacc2x89AB, vb0123x89AB, va2x01234567, 0);
256 vnacc2x89AB = vdotq_lane_u32(vnacc2x89AB, vb_zero_point, va2x01234567, 0);
257 vpacc2xCDEF = vdotq_lane_u32(vpacc2xCDEF, vb0123xCDEF, va2x01234567, 0);
258 vnacc2xCDEF = vdotq_lane_u32(vnacc2xCDEF, vb_zero_point, va2x01234567, 0);
259 vpacc3x0123 = vdotq_lane_u32(vpacc3x0123, vb0123x0123, va3x01234567, 0);
260 vnacc3x0123 = vdotq_lane_u32(vnacc3x0123, vb_zero_point, va3x01234567, 0);
261 vpacc3x4567 = vdotq_lane_u32(vpacc3x4567, vb0123x4567, va3x01234567, 0);
262 vnacc3x4567 = vdotq_lane_u32(vnacc3x4567, vb_zero_point, va3x01234567, 0);
263 vpacc3x89AB = vdotq_lane_u32(vpacc3x89AB, vb0123x89AB, va3x01234567, 0);
264 vnacc3x89AB = vdotq_lane_u32(vnacc3x89AB, vb_zero_point, va3x01234567, 0);
265 vpacc3xCDEF = vdotq_lane_u32(vpacc3xCDEF, vb0123xCDEF, va3x01234567, 0);
266 vnacc3xCDEF = vdotq_lane_u32(vnacc3xCDEF, vb_zero_point, va3x01234567, 0);
267 vpacc4x0123 = vdotq_lane_u32(vpacc4x0123, vb0123x0123, va4x01234567, 0);
268 vnacc4x0123 = vdotq_lane_u32(vnacc4x0123, vb_zero_point, va4x01234567, 0);
269 vpacc4x4567 = vdotq_lane_u32(vpacc4x4567, vb0123x4567, va4x01234567, 0);
270 vnacc4x4567 = vdotq_lane_u32(vnacc4x4567, vb_zero_point, va4x01234567, 0);
271 vpacc4x89AB = vdotq_lane_u32(vpacc4x89AB, vb0123x89AB, va4x01234567, 0);
272 vnacc4x89AB = vdotq_lane_u32(vnacc4x89AB, vb_zero_point, va4x01234567, 0);
273 vpacc4xCDEF = vdotq_lane_u32(vpacc4xCDEF, vb0123xCDEF, va4x01234567, 0);
274 vnacc4xCDEF = vdotq_lane_u32(vnacc4xCDEF, vb_zero_point, va4x01234567, 0);
275 }
276
277 // Subtract zero point accumulators with accumulators.
278 int32x4_t vacc0x0123 = vreinterpretq_s32_u32(vsubq_u32(vpacc0x0123, vnacc0x0123));
279 int32x4_t vacc0x4567 = vreinterpretq_s32_u32(vsubq_u32(vpacc0x4567, vnacc0x4567));
280 int32x4_t vacc0x89AB = vreinterpretq_s32_u32(vsubq_u32(vpacc0x89AB, vnacc0x89AB));
281 int32x4_t vacc0xCDEF = vreinterpretq_s32_u32(vsubq_u32(vpacc0xCDEF, vnacc0xCDEF));
282 int32x4_t vacc1x0123 = vreinterpretq_s32_u32(vsubq_u32(vpacc1x0123, vnacc1x0123));
283 int32x4_t vacc1x4567 = vreinterpretq_s32_u32(vsubq_u32(vpacc1x4567, vnacc1x4567));
284 int32x4_t vacc1x89AB = vreinterpretq_s32_u32(vsubq_u32(vpacc1x89AB, vnacc1x89AB));
285 int32x4_t vacc1xCDEF = vreinterpretq_s32_u32(vsubq_u32(vpacc1xCDEF, vnacc1xCDEF));
286 int32x4_t vacc2x0123 = vreinterpretq_s32_u32(vsubq_u32(vpacc2x0123, vnacc2x0123));
287 int32x4_t vacc2x4567 = vreinterpretq_s32_u32(vsubq_u32(vpacc2x4567, vnacc2x4567));
288 int32x4_t vacc2x89AB = vreinterpretq_s32_u32(vsubq_u32(vpacc2x89AB, vnacc2x89AB));
289 int32x4_t vacc2xCDEF = vreinterpretq_s32_u32(vsubq_u32(vpacc2xCDEF, vnacc2xCDEF));
290 int32x4_t vacc3x0123 = vreinterpretq_s32_u32(vsubq_u32(vpacc3x0123, vnacc3x0123));
291 int32x4_t vacc3x4567 = vreinterpretq_s32_u32(vsubq_u32(vpacc3x4567, vnacc3x4567));
292 int32x4_t vacc3x89AB = vreinterpretq_s32_u32(vsubq_u32(vpacc3x89AB, vnacc3x89AB));
293 int32x4_t vacc3xCDEF = vreinterpretq_s32_u32(vsubq_u32(vpacc3xCDEF, vnacc3xCDEF));
294 int32x4_t vacc4x0123 = vreinterpretq_s32_u32(vsubq_u32(vpacc4x0123, vnacc4x0123));
295 int32x4_t vacc4x4567 = vreinterpretq_s32_u32(vsubq_u32(vpacc4x4567, vnacc4x4567));
296 int32x4_t vacc4x89AB = vreinterpretq_s32_u32(vsubq_u32(vpacc4x89AB, vnacc4x89AB));
297 int32x4_t vacc4xCDEF = vreinterpretq_s32_u32(vsubq_u32(vpacc4xCDEF, vnacc4xCDEF));
298
299 const int32x4_t vright_pre_shift = vld1q_dup_s32(&params->rndnu_neon.right_pre_shift);
300 const int32x4_t vmultiplier = vld1q_dup_s32(&params->rndnu_neon.multiplier);
301 const int32x4_t vright_post_shift = vld1q_dup_s32(&params->rndnu_neon.right_post_shift);
302
303 vacc0x0123 = vshlq_s32(vacc0x0123, vright_pre_shift);
304 vacc0x4567 = vshlq_s32(vacc0x4567, vright_pre_shift);
305 vacc0x89AB = vshlq_s32(vacc0x89AB, vright_pre_shift);
306 vacc0xCDEF = vshlq_s32(vacc0xCDEF, vright_pre_shift);
307 vacc1x0123 = vshlq_s32(vacc1x0123, vright_pre_shift);
308 vacc1x4567 = vshlq_s32(vacc1x4567, vright_pre_shift);
309 vacc1x89AB = vshlq_s32(vacc1x89AB, vright_pre_shift);
310 vacc1xCDEF = vshlq_s32(vacc1xCDEF, vright_pre_shift);
311 vacc2x0123 = vshlq_s32(vacc2x0123, vright_pre_shift);
312 vacc2x4567 = vshlq_s32(vacc2x4567, vright_pre_shift);
313 vacc2x89AB = vshlq_s32(vacc2x89AB, vright_pre_shift);
314 vacc2xCDEF = vshlq_s32(vacc2xCDEF, vright_pre_shift);
315 vacc3x0123 = vshlq_s32(vacc3x0123, vright_pre_shift);
316 vacc3x4567 = vshlq_s32(vacc3x4567, vright_pre_shift);
317 vacc3x89AB = vshlq_s32(vacc3x89AB, vright_pre_shift);
318 vacc3xCDEF = vshlq_s32(vacc3xCDEF, vright_pre_shift);
319 vacc4x0123 = vshlq_s32(vacc4x0123, vright_pre_shift);
320 vacc4x4567 = vshlq_s32(vacc4x4567, vright_pre_shift);
321 vacc4x89AB = vshlq_s32(vacc4x89AB, vright_pre_shift);
322 vacc4xCDEF = vshlq_s32(vacc4xCDEF, vright_pre_shift);
323
324 vacc0x0123 = vqdmulhq_s32(vacc0x0123, vmultiplier);
325 vacc0x4567 = vqdmulhq_s32(vacc0x4567, vmultiplier);
326 vacc0x89AB = vqdmulhq_s32(vacc0x89AB, vmultiplier);
327 vacc0xCDEF = vqdmulhq_s32(vacc0xCDEF, vmultiplier);
328 vacc1x0123 = vqdmulhq_s32(vacc1x0123, vmultiplier);
329 vacc1x4567 = vqdmulhq_s32(vacc1x4567, vmultiplier);
330 vacc1x89AB = vqdmulhq_s32(vacc1x89AB, vmultiplier);
331 vacc1xCDEF = vqdmulhq_s32(vacc1xCDEF, vmultiplier);
332 vacc2x0123 = vqdmulhq_s32(vacc2x0123, vmultiplier);
333 vacc2x4567 = vqdmulhq_s32(vacc2x4567, vmultiplier);
334 vacc2x89AB = vqdmulhq_s32(vacc2x89AB, vmultiplier);
335 vacc2xCDEF = vqdmulhq_s32(vacc2xCDEF, vmultiplier);
336 vacc3x0123 = vqdmulhq_s32(vacc3x0123, vmultiplier);
337 vacc3x4567 = vqdmulhq_s32(vacc3x4567, vmultiplier);
338 vacc3x89AB = vqdmulhq_s32(vacc3x89AB, vmultiplier);
339 vacc3xCDEF = vqdmulhq_s32(vacc3xCDEF, vmultiplier);
340 vacc4x0123 = vqdmulhq_s32(vacc4x0123, vmultiplier);
341 vacc4x4567 = vqdmulhq_s32(vacc4x4567, vmultiplier);
342 vacc4x89AB = vqdmulhq_s32(vacc4x89AB, vmultiplier);
343 vacc4xCDEF = vqdmulhq_s32(vacc4xCDEF, vmultiplier);
344
345 vacc0x0123 = vrshlq_s32(vacc0x0123, vright_post_shift);
346 vacc0x4567 = vrshlq_s32(vacc0x4567, vright_post_shift);
347 vacc0x89AB = vrshlq_s32(vacc0x89AB, vright_post_shift);
348 vacc0xCDEF = vrshlq_s32(vacc0xCDEF, vright_post_shift);
349 vacc1x0123 = vrshlq_s32(vacc1x0123, vright_post_shift);
350 vacc1x4567 = vrshlq_s32(vacc1x4567, vright_post_shift);
351 vacc1x89AB = vrshlq_s32(vacc1x89AB, vright_post_shift);
352 vacc1xCDEF = vrshlq_s32(vacc1xCDEF, vright_post_shift);
353 vacc2x0123 = vrshlq_s32(vacc2x0123, vright_post_shift);
354 vacc2x4567 = vrshlq_s32(vacc2x4567, vright_post_shift);
355 vacc2x89AB = vrshlq_s32(vacc2x89AB, vright_post_shift);
356 vacc2xCDEF = vrshlq_s32(vacc2xCDEF, vright_post_shift);
357 vacc3x0123 = vrshlq_s32(vacc3x0123, vright_post_shift);
358 vacc3x4567 = vrshlq_s32(vacc3x4567, vright_post_shift);
359 vacc3x89AB = vrshlq_s32(vacc3x89AB, vright_post_shift);
360 vacc3xCDEF = vrshlq_s32(vacc3xCDEF, vright_post_shift);
361 vacc4x0123 = vrshlq_s32(vacc4x0123, vright_post_shift);
362 vacc4x4567 = vrshlq_s32(vacc4x4567, vright_post_shift);
363 vacc4x89AB = vrshlq_s32(vacc4x89AB, vright_post_shift);
364 vacc4xCDEF = vrshlq_s32(vacc4xCDEF, vright_post_shift);
365
366 const int16x8_t voutput_zero_point = vld1q_dup_s16(&params->rndnu_neon.output_zero_point);
367#if XNN_ARCH_ARM64
368 const int16x8_t vacc0x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x0123), vacc0x4567), voutput_zero_point);
369 const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc0x89AB), vacc0xCDEF), voutput_zero_point);
370 const int16x8_t vacc1x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x0123), vacc1x4567), voutput_zero_point);
371 const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc1x89AB), vacc1xCDEF), voutput_zero_point);
372 const int16x8_t vacc2x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x0123), vacc2x4567), voutput_zero_point);
373 const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc2x89AB), vacc2xCDEF), voutput_zero_point);
374 const int16x8_t vacc3x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x0123), vacc3x4567), voutput_zero_point);
375 const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc3x89AB), vacc3xCDEF), voutput_zero_point);
376 const int16x8_t vacc4x01234567 = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc4x0123), vacc4x4567), voutput_zero_point);
377 const int16x8_t vacc4x89ABCDEF = vqaddq_s16(vqmovn_high_s32(vqmovn_s32(vacc4x89AB), vacc4xCDEF), voutput_zero_point);
378
379 uint8x16_t vout0x0123456789ABCDEF = vqmovun_high_s16(vqmovun_s16(vacc0x01234567), vacc0x89ABCDEF);
380 uint8x16_t vout1x0123456789ABCDEF = vqmovun_high_s16(vqmovun_s16(vacc1x01234567), vacc1x89ABCDEF);
381 uint8x16_t vout2x0123456789ABCDEF = vqmovun_high_s16(vqmovun_s16(vacc2x01234567), vacc2x89ABCDEF);
382 uint8x16_t vout3x0123456789ABCDEF = vqmovun_high_s16(vqmovun_s16(vacc3x01234567), vacc3x89ABCDEF);
383 uint8x16_t vout4x0123456789ABCDEF = vqmovun_high_s16(vqmovun_s16(vacc4x01234567), vacc4x89ABCDEF);
384#else
385 const int16x8_t vacc0x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x0123), vqmovn_s32(vacc0x4567)), voutput_zero_point);
386 const int16x8_t vacc0x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc0x89AB), vqmovn_s32(vacc0xCDEF)), voutput_zero_point);
387 const int16x8_t vacc1x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x0123), vqmovn_s32(vacc1x4567)), voutput_zero_point);
388 const int16x8_t vacc1x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc1x89AB), vqmovn_s32(vacc1xCDEF)), voutput_zero_point);
389 const int16x8_t vacc2x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x0123), vqmovn_s32(vacc2x4567)), voutput_zero_point);
390 const int16x8_t vacc2x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc2x89AB), vqmovn_s32(vacc2xCDEF)), voutput_zero_point);
391 const int16x8_t vacc3x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x0123), vqmovn_s32(vacc3x4567)), voutput_zero_point);
392 const int16x8_t vacc3x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc3x89AB), vqmovn_s32(vacc3xCDEF)), voutput_zero_point);
393 const int16x8_t vacc4x01234567 = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc4x0123), vqmovn_s32(vacc4x4567)), voutput_zero_point);
394 const int16x8_t vacc4x89ABCDEF = vqaddq_s16(vcombine_s16(vqmovn_s32(vacc4x89AB), vqmovn_s32(vacc4xCDEF)), voutput_zero_point);
395
396 uint8x16_t vout0x0123456789ABCDEF = vcombine_u8(vqmovun_s16(vacc0x01234567), vqmovun_s16(vacc0x89ABCDEF));
397 uint8x16_t vout1x0123456789ABCDEF = vcombine_u8(vqmovun_s16(vacc1x01234567), vqmovun_s16(vacc1x89ABCDEF));
398 uint8x16_t vout2x0123456789ABCDEF = vcombine_u8(vqmovun_s16(vacc2x01234567), vqmovun_s16(vacc2x89ABCDEF));
399 uint8x16_t vout3x0123456789ABCDEF = vcombine_u8(vqmovun_s16(vacc3x01234567), vqmovun_s16(vacc3x89ABCDEF));
400 uint8x16_t vout4x0123456789ABCDEF = vcombine_u8(vqmovun_s16(vacc4x01234567), vqmovun_s16(vacc4x89ABCDEF));
401#endif
402 const uint8x16_t voutput_min = vld1q_dup_u8(&params->rndnu_neon.output_min);
403 const uint8x16_t voutput_max = vld1q_dup_u8(&params->rndnu_neon.output_max);
404
405 vout0x0123456789ABCDEF = vmaxq_u8(vout0x0123456789ABCDEF, voutput_min);
406 vout1x0123456789ABCDEF = vmaxq_u8(vout1x0123456789ABCDEF, voutput_min);
407 vout2x0123456789ABCDEF = vmaxq_u8(vout2x0123456789ABCDEF, voutput_min);
408 vout3x0123456789ABCDEF = vmaxq_u8(vout3x0123456789ABCDEF, voutput_min);
409 vout4x0123456789ABCDEF = vmaxq_u8(vout4x0123456789ABCDEF, voutput_min);
410
411 vout0x0123456789ABCDEF = vminq_u8(vout0x0123456789ABCDEF, voutput_max);
412 vout1x0123456789ABCDEF = vminq_u8(vout1x0123456789ABCDEF, voutput_max);
413 vout2x0123456789ABCDEF = vminq_u8(vout2x0123456789ABCDEF, voutput_max);
414 vout3x0123456789ABCDEF = vminq_u8(vout3x0123456789ABCDEF, voutput_max);
415 vout4x0123456789ABCDEF = vminq_u8(vout4x0123456789ABCDEF, voutput_max);
416
417 if (nc >= 16) {
418 vst1q_u8(c0 + 0, vout0x0123456789ABCDEF);
419 vst1q_u8(c1 + 0, vout1x0123456789ABCDEF);
420 vst1q_u8(c2 + 0, vout2x0123456789ABCDEF);
421 vst1q_u8(c3 + 0, vout3x0123456789ABCDEF);
422 vst1q_u8(c4 + 0, vout4x0123456789ABCDEF);
423
424 c0 = (uint8_t*) ((uintptr_t) c0 + cn_stride);
425 c1 = (uint8_t*) ((uintptr_t) c1 + cn_stride);
426 c2 = (uint8_t*) ((uintptr_t) c2 + cn_stride);
427 c3 = (uint8_t*) ((uintptr_t) c3 + cn_stride);
428 c4 = (uint8_t*) ((uintptr_t) c4 + cn_stride);
429
430 a0 = (const uint8_t*) ((uintptr_t) a0 - kc);
431 a1 = (const uint8_t*) ((uintptr_t) a1 - kc);
432 a2 = (const uint8_t*) ((uintptr_t) a2 - kc);
433 a3 = (const uint8_t*) ((uintptr_t) a3 - kc);
434 a4 = (const uint8_t*) ((uintptr_t) a4 - kc);
435
436 nc -= 16;
437 } else {
438 uint8x16_t vout0x01234567_1x01234567 = vcombine_u8(vget_low_u8(vout0x0123456789ABCDEF), vget_low_u8(vout1x0123456789ABCDEF));
439 uint8x16_t vout2x01234567_3x01234567 = vcombine_u8(vget_low_u8(vout2x0123456789ABCDEF), vget_low_u8(vout3x0123456789ABCDEF));
440 uint8x8_t vout4x01234567 = vget_low_u8(vout4x0123456789ABCDEF);
441 if (nc & 8) {
442 vst1_u8(c0, vget_low_u8(vout0x01234567_1x01234567)); c0 += 8;
443 vst1_u8(c1, vget_high_u8(vout0x01234567_1x01234567)); c1 += 8;
444 vst1_u8(c2, vget_low_u8(vout2x01234567_3x01234567)); c2 += 8;
445 vst1_u8(c3, vget_high_u8(vout2x01234567_3x01234567)); c3 += 8;
446 vst1_u8(c4, vout4x01234567); c4 += 8;
447 vout0x01234567_1x01234567 = vcombine_u8(vget_high_u8(vout0x0123456789ABCDEF), vget_high_u8(vout1x0123456789ABCDEF));
448 vout2x01234567_3x01234567 = vcombine_u8(vget_high_u8(vout2x0123456789ABCDEF), vget_high_u8(vout3x0123456789ABCDEF));
449 vout4x01234567 = vget_high_u8(vout4x0123456789ABCDEF);
450 }
451 if (nc & 4) {
452 vst1q_lane_u32(__builtin_assume_aligned(c0, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 0); c0 += 4;
453 vst1q_lane_u32(__builtin_assume_aligned(c1, 1), vreinterpretq_u32_u8(vout0x01234567_1x01234567), 2); c1 += 4;
454 vst1q_lane_u32(__builtin_assume_aligned(c2, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 0); c2 += 4;
455 vst1q_lane_u32(__builtin_assume_aligned(c3, 1), vreinterpretq_u32_u8(vout2x01234567_3x01234567), 2); c3 += 4;
456 vst1_lane_u32(__builtin_assume_aligned(c4, 1), vreinterpret_u32_u8(vout4x01234567), 0); c4 += 4;
457 vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 4);
458 vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 4);
459 vout4x01234567 = vext_u8(vout4x01234567, vout4x01234567, 4);
460 }
461 if (nc & 2) {
462 vst1q_lane_u16(__builtin_assume_aligned(c0, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 0); c0 += 2;
463 vst1q_lane_u16(__builtin_assume_aligned(c1, 1), vreinterpretq_u16_u8(vout0x01234567_1x01234567), 4); c1 += 2;
464 vst1q_lane_u16(__builtin_assume_aligned(c2, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 0); c2 += 2;
465 vst1q_lane_u16(__builtin_assume_aligned(c3, 1), vreinterpretq_u16_u8(vout2x01234567_3x01234567), 4); c3 += 2;
466 vst1_lane_u16(__builtin_assume_aligned(c4, 1), vreinterpret_u16_u8(vout4x01234567), 0); c4 += 2;
467 vout0x01234567_1x01234567 = vextq_u8(vout0x01234567_1x01234567, vout0x01234567_1x01234567, 2);
468 vout2x01234567_3x01234567 = vextq_u8(vout2x01234567_3x01234567, vout2x01234567_3x01234567, 2);
469 vout4x01234567 = vext_u8(vout4x01234567, vout4x01234567, 2);
470 }
471 if (nc & 1) {
472 vst1q_lane_u8(c0, vout0x01234567_1x01234567, 0);
473 vst1q_lane_u8(c1, vout0x01234567_1x01234567, 8);
474 vst1q_lane_u8(c2, vout2x01234567_3x01234567, 0);
475 vst1q_lane_u8(c3, vout2x01234567_3x01234567, 8);
476 vst1_lane_u8(c4, vout4x01234567, 0);
477 }
478
479 nc = 0;
480 }
481 } while (nc != 0);
482}