blob: c76752df241e47e20a19a25e882174cdb86e079a [file] [log] [blame]
Marat Dukhan8a9eac62022-01-06 09:22:01 -08001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7
8#include <immintrin.h>
9
10#include <xnnpack/common.h>
11#include <xnnpack/dwconv.h>
12#include <xnnpack/gemm.h>
13#include <xnnpack/igemm.h>
Marat Dukhan8f920a62022-01-19 14:56:23 -080014#include <xnnpack/math.h>
15#include <xnnpack/vmulcaddc.h>
Marat Dukhan8a9eac62022-01-06 09:22:01 -080016#include <xnnpack/vunary.h>
17
18
Marat Dukhan8f920a62022-01-19 14:56:23 -080019void xnn_f16_dwconv_minmax_ukernel_up16x4__fma3(
20 size_t channels,
21 size_t output_width,
22 const void** input,
23 const void* weights,
24 void* output,
25 size_t input_stride,
26 size_t output_increment,
27 size_t input_offset,
28 const void* zero,
29 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
30{
31 assert(channels != 0);
32 assert(output_width != 0);
33
34 const __m256 vmax = _mm256_load_ps(params->avx.max);
35 const __m256 vmin = _mm256_load_ps(params->avx.min);
36
37 uint16_t* o = (uint16_t*) output;
38 do {
39 const uint16_t* i0 = input[0];
40 assert(i0 != NULL);
41 if XNN_UNPREDICTABLE(i0 != zero) {
42 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
43 }
44 const uint16_t* i1 = input[1];
45 assert(i1 != NULL);
46 if XNN_UNPREDICTABLE(i1 != zero) {
47 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
48 }
49 const uint16_t* i2 = input[2];
50 assert(i2 != NULL);
51 if XNN_UNPREDICTABLE(i2 != zero) {
52 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
53 }
54 const uint16_t* i3 = input[3];
55 assert(i3 != NULL);
56 if XNN_UNPREDICTABLE(i3 != zero) {
57 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
58 }
59 input = (const void**) ((uintptr_t) input + input_stride);
60
61 size_t c = channels;
62 const uint16_t* w = weights;
63 for (; c >= 16; c -= 16) {
64 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
65 __m256 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
66
67
68 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
69 const __m256 vi0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
70 i0 += 16;
71
72 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
73 const __m256 vk0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
74 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
75 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
76
77 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
78 const __m256 vi1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
79 i1 += 16;
80
81 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
82 const __m256 vk1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
83 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
84 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
85
86 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
87 const __m256 vi2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + 8)));
88 i2 += 16;
89
90 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
91 const __m256 vk2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
92 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
93 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
94
95 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
96 const __m256 vi3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i3 + 8)));
97 i3 += 16;
98
99 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 64)));
100 const __m256 vk3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 72)));
101 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
102 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
103
104 w += 80;
105
106
107 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
108 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
109 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
110 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
111
112 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
113 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vacc89ABCDEF, _MM_FROUND_NO_EXC));
114 o += 16;
115 }
116 for (; c >= 8; c -= 8) {
117 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
118
119 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
120 i0 += 8;
121
122 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
123 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
124
125 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
126 i1 += 8;
127
128 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
129 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
130
131 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
132 i2 += 8;
133
134 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
135 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
136
137 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
138 i3 += 8;
139
140 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
141 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
142
143 w += 8;
144
145
146 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
147 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
148
149 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
150 o += 8;
151 }
152 if XNN_UNLIKELY(c != 0) {
153 assert(c >= 1);
154 assert(c <= 7);
155
156 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
157
158 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
159
160 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
161 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
162
163 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
164
165 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
166 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
167
168 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
169
170 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
171 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
172
173 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
174
175 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
176 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
177
178
179 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
180 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
181
182 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
183 if (c & 4) {
184 _mm_storel_epi64((__m128i*) o, vh01234567);
185 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
186 o += 4;
187 }
188 if (c & 2) {
189 *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vh01234567);
190 vh01234567 = _mm_srli_epi64(vh01234567, 32);
191 o += 2;
192 }
193 if (c & 1) {
194 *((uint16_t*) o) = (uint16_t) _mm_extract_epi16(vh01234567, 0);
195 o += 1;
196 }
197 }
198
199 o = (uint16_t*) ((uintptr_t) o + output_increment);
200 } while (--output_width != 0);
201}
202
203void xnn_f16_dwconv_minmax_ukernel_up16x9__fma3(
204 size_t channels,
205 size_t output_width,
206 const void** input,
207 const void* weights,
208 void* output,
209 size_t input_stride,
210 size_t output_increment,
211 size_t input_offset,
212 const void* zero,
213 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
214{
215 assert(channels != 0);
216 assert(output_width != 0);
217
218 const __m256 vmax = _mm256_load_ps(params->avx.max);
219 const __m256 vmin = _mm256_load_ps(params->avx.min);
220
221 uint16_t* o = (uint16_t*) output;
222 do {
223 const uint16_t* i0 = input[0];
224 assert(i0 != NULL);
225 if XNN_UNPREDICTABLE(i0 != zero) {
226 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
227 }
228 const uint16_t* i1 = input[1];
229 assert(i1 != NULL);
230 if XNN_UNPREDICTABLE(i1 != zero) {
231 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
232 }
233 const uint16_t* i2 = input[2];
234 assert(i2 != NULL);
235 if XNN_UNPREDICTABLE(i2 != zero) {
236 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
237 }
238 const uint16_t* i3 = input[3];
239 assert(i3 != NULL);
240 if XNN_UNPREDICTABLE(i3 != zero) {
241 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
242 }
243 const uint16_t* i4 = input[4];
244 assert(i4 != NULL);
245 if XNN_UNPREDICTABLE(i4 != zero) {
246 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
247 }
248 const uint16_t* i5 = input[5];
249 assert(i5 != NULL);
250 if XNN_UNPREDICTABLE(i5 != zero) {
251 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
252 }
253 const uint16_t* i6 = input[6];
254 assert(i6 != NULL);
255 if XNN_UNPREDICTABLE(i6 != zero) {
256 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
257 }
258 const uint16_t* i7 = input[7];
259 assert(i7 != NULL);
260 if XNN_UNPREDICTABLE(i7 != zero) {
261 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
262 }
263 const uint16_t* i8 = input[8];
264 assert(i8 != NULL);
265 if XNN_UNPREDICTABLE(i8 != zero) {
266 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
267 }
268 input = (const void**) ((uintptr_t) input + input_stride);
269
270 size_t c = channels;
271 const uint16_t* w = weights;
272 for (; c >= 16; c -= 16) {
273 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
274 __m256 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
275
276
277 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
278 const __m256 vi0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i0 + 8)));
279 i0 += 16;
280
281 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
282 const __m256 vk0x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
283 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
284 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
285
286 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
287 const __m256 vi1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i1 + 8)));
288 i1 += 16;
289
290 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
291 const __m256 vk1x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
292 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
293 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
294
295 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
296 const __m256 vi2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i2 + 8)));
297 i2 += 16;
298
299 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
300 const __m256 vk2x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
301 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
302 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
303
304 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
305 const __m256 vi3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i3 + 8)));
306 i3 += 16;
307
308 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 64)));
309 const __m256 vk3x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 72)));
310 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
311 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
312
313 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
314 const __m256 vi4x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i4 + 8)));
315 i4 += 16;
316
317 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 80)));
318 const __m256 vk4x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 88)));
319 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
320 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x89ABCDEF, vk4x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
321
322 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
323 const __m256 vi5x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i5 + 8)));
324 i5 += 16;
325
326 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 96)));
327 const __m256 vk5x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 104)));
328 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
329 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x89ABCDEF, vk5x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
330
331 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
332 const __m256 vi6x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i6 + 8)));
333 i6 += 16;
334
335 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 112)));
336 const __m256 vk6x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 120)));
337 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
338 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x89ABCDEF, vk6x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
339
340 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
341 const __m256 vi7x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i7 + 8)));
342 i7 += 16;
343
344 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 128)));
345 const __m256 vk7x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 136)));
346 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
347 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x89ABCDEF, vk7x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
348
349 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
350 const __m256 vi8x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i8 + 8)));
351 i8 += 16;
352
353 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 144)));
354 const __m256 vk8x89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 152)));
355 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
356 vacc89ABCDEFp0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x89ABCDEF, vk8x89ABCDEF, vacc89ABCDEFp0), _MM_FROUND_NO_EXC));
357
358 w += 160;
359
360
361 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
362 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
363 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
364 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
365
366 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
367 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vacc89ABCDEF, _MM_FROUND_NO_EXC));
368 o += 16;
369 }
370 for (; c >= 8; c -= 8) {
371 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
372
373 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
374 i0 += 8;
375
376 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
377 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
378
379 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
380 i1 += 8;
381
382 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
383 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
384
385 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
386 i2 += 8;
387
388 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
389 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
390
391 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
392 i3 += 8;
393
394 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
395 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
396
397 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
398 i4 += 8;
399
400 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 80)));
401 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
402
403 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
404 i5 += 8;
405
406 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 96)));
407 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
408
409 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
410 i6 += 8;
411
412 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 112)));
413 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
414
415 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
416 i7 += 8;
417
418 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 128)));
419 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
420
421 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
422 i8 += 8;
423
424 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 144)));
425 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
426
427 w += 8;
428
429
430 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
431 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
432
433 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
434 o += 8;
435 }
436 if XNN_UNLIKELY(c != 0) {
437 assert(c >= 1);
438 assert(c <= 7);
439
440 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
441
442 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
443
444 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
445 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
446
447 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
448
449 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
450 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
451
452 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
453
454 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
455 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
456
457 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
458
459 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
460 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
461
462 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
463
464 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 80)));
465 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
466
467 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
468
469 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 96)));
470 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
471
472 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
473
474 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 112)));
475 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
476
477 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
478
479 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 128)));
480 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
481
482 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
483
484 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 144)));
485 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
486
487
488 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
489 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
490
491 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
492 if (c & 4) {
493 _mm_storel_epi64((__m128i*) o, vh01234567);
494 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
495 o += 4;
496 }
497 if (c & 2) {
498 *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vh01234567);
499 vh01234567 = _mm_srli_epi64(vh01234567, 32);
500 o += 2;
501 }
502 if (c & 1) {
503 *((uint16_t*) o) = (uint16_t) _mm_extract_epi16(vh01234567, 0);
504 o += 1;
505 }
506 }
507
508 o = (uint16_t*) ((uintptr_t) o + output_increment);
509 } while (--output_width != 0);
510}
511
512void xnn_f16_dwconv_minmax_ukernel_up8x25__fma3_acc2(
513 size_t channels,
514 size_t output_width,
515 const void** input,
516 const void* weights,
517 void* output,
518 size_t input_stride,
519 size_t output_increment,
520 size_t input_offset,
521 const void* zero,
522 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
523{
524 assert(channels != 0);
525 assert(output_width != 0);
526
527 const __m256 vmax = _mm256_load_ps(params->avx.max);
528 const __m256 vmin = _mm256_load_ps(params->avx.min);
529
530 uint16_t* o = (uint16_t*) output;
531 do {
532 const uint16_t* i0 = input[0];
533 assert(i0 != NULL);
534 if XNN_UNPREDICTABLE(i0 != zero) {
535 i0 = (const uint16_t*) ((uintptr_t) i0 + input_offset);
536 }
537 const uint16_t* i1 = input[1];
538 assert(i1 != NULL);
539 if XNN_UNPREDICTABLE(i1 != zero) {
540 i1 = (const uint16_t*) ((uintptr_t) i1 + input_offset);
541 }
542 const uint16_t* i2 = input[2];
543 assert(i2 != NULL);
544 if XNN_UNPREDICTABLE(i2 != zero) {
545 i2 = (const uint16_t*) ((uintptr_t) i2 + input_offset);
546 }
547 const uint16_t* i3 = input[3];
548 assert(i3 != NULL);
549 if XNN_UNPREDICTABLE(i3 != zero) {
550 i3 = (const uint16_t*) ((uintptr_t) i3 + input_offset);
551 }
552 const uint16_t* i4 = input[4];
553 assert(i4 != NULL);
554 if XNN_UNPREDICTABLE(i4 != zero) {
555 i4 = (const uint16_t*) ((uintptr_t) i4 + input_offset);
556 }
557 const uint16_t* i5 = input[5];
558 assert(i5 != NULL);
559 if XNN_UNPREDICTABLE(i5 != zero) {
560 i5 = (const uint16_t*) ((uintptr_t) i5 + input_offset);
561 }
562 const uint16_t* i6 = input[6];
563 assert(i6 != NULL);
564 if XNN_UNPREDICTABLE(i6 != zero) {
565 i6 = (const uint16_t*) ((uintptr_t) i6 + input_offset);
566 }
567 const uint16_t* i7 = input[7];
568 assert(i7 != NULL);
569 if XNN_UNPREDICTABLE(i7 != zero) {
570 i7 = (const uint16_t*) ((uintptr_t) i7 + input_offset);
571 }
572 const uint16_t* i8 = input[8];
573 assert(i8 != NULL);
574 if XNN_UNPREDICTABLE(i8 != zero) {
575 i8 = (const uint16_t*) ((uintptr_t) i8 + input_offset);
576 }
577 const uint16_t* i9 = input[9];
578 assert(i9 != NULL);
579 if XNN_UNPREDICTABLE(i9 != zero) {
580 i9 = (const uint16_t*) ((uintptr_t) i9 + input_offset);
581 }
582 const uint16_t* i10 = input[10];
583 assert(i10 != NULL);
584 if XNN_UNPREDICTABLE(i10 != zero) {
585 i10 = (const uint16_t*) ((uintptr_t) i10 + input_offset);
586 }
587 const uint16_t* i11 = input[11];
588 assert(i11 != NULL);
589 if XNN_UNPREDICTABLE(i11 != zero) {
590 i11 = (const uint16_t*) ((uintptr_t) i11 + input_offset);
591 }
592 const uint16_t* i12 = input[12];
593 assert(i12 != NULL);
594 if XNN_UNPREDICTABLE(i12 != zero) {
595 i12 = (const uint16_t*) ((uintptr_t) i12 + input_offset);
596 }
597 const uint16_t* i13 = input[13];
598 assert(i13 != NULL);
599 if XNN_UNPREDICTABLE(i13 != zero) {
600 i13 = (const uint16_t*) ((uintptr_t) i13 + input_offset);
601 }
602 const uint16_t* i14 = input[14];
603 assert(i14 != NULL);
604 if XNN_UNPREDICTABLE(i14 != zero) {
605 i14 = (const uint16_t*) ((uintptr_t) i14 + input_offset);
606 }
607 const uint16_t* i15 = input[15];
608 assert(i15 != NULL);
609 if XNN_UNPREDICTABLE(i15 != zero) {
610 i15 = (const uint16_t*) ((uintptr_t) i15 + input_offset);
611 }
612 const uint16_t* i16 = input[16];
613 assert(i16 != NULL);
614 if XNN_UNPREDICTABLE(i16 != zero) {
615 i16 = (const uint16_t*) ((uintptr_t) i16 + input_offset);
616 }
617 const uint16_t* i17 = input[17];
618 assert(i17 != NULL);
619 if XNN_UNPREDICTABLE(i17 != zero) {
620 i17 = (const uint16_t*) ((uintptr_t) i17 + input_offset);
621 }
622 const uint16_t* i18 = input[18];
623 assert(i18 != NULL);
624 if XNN_UNPREDICTABLE(i18 != zero) {
625 i18 = (const uint16_t*) ((uintptr_t) i18 + input_offset);
626 }
627 const uint16_t* i19 = input[19];
628 assert(i19 != NULL);
629 if XNN_UNPREDICTABLE(i19 != zero) {
630 i19 = (const uint16_t*) ((uintptr_t) i19 + input_offset);
631 }
632 const uint16_t* i20 = input[20];
633 assert(i20 != NULL);
634 if XNN_UNPREDICTABLE(i20 != zero) {
635 i20 = (const uint16_t*) ((uintptr_t) i20 + input_offset);
636 }
637 const uint16_t* i21 = input[21];
638 assert(i21 != NULL);
639 if XNN_UNPREDICTABLE(i21 != zero) {
640 i21 = (const uint16_t*) ((uintptr_t) i21 + input_offset);
641 }
642 const uint16_t* i22 = input[22];
643 assert(i22 != NULL);
644 if XNN_UNPREDICTABLE(i22 != zero) {
645 i22 = (const uint16_t*) ((uintptr_t) i22 + input_offset);
646 }
647 const uint16_t* i23 = input[23];
648 assert(i23 != NULL);
649 if XNN_UNPREDICTABLE(i23 != zero) {
650 i23 = (const uint16_t*) ((uintptr_t) i23 + input_offset);
651 }
652 const uint16_t* i24 = input[24];
653 assert(i24 != NULL);
654 if XNN_UNPREDICTABLE(i24 != zero) {
655 i24 = (const uint16_t*) ((uintptr_t) i24 + input_offset);
656 }
657 input = (const void**) ((uintptr_t) input + input_stride);
658
659 size_t c = channels;
660 const uint16_t* w = weights;
661 for (; c >= 8; c -= 8) {
662 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
663
664
665 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
666 i0 += 8;
667
668 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
669 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
670
671 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
672 i1 += 8;
673
674 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 16)));
675 __m256 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi1x01234567, vk1x01234567), _MM_FROUND_NO_EXC));
676
677 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
678 i2 += 8;
679
680 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 24)));
681 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
682
683 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
684 i3 += 8;
685
686 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 32)));
687 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
688
689 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
690 i4 += 8;
691
692 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 40)));
693 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
694
695 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
696 i5 += 8;
697
698 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 48)));
699 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
700
701 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
702 i6 += 8;
703
704 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 56)));
705 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
706
707 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
708 i7 += 8;
709
710 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 64)));
711 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
712
713 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
714 i8 += 8;
715
716 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 72)));
717 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
718
719 const __m256 vi9x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i9));
720 i9 += 8;
721
722 const __m256 vk9x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 80)));
723 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
724
725 const __m256 vi10x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i10));
726 i10 += 8;
727
728 const __m256 vk10x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 88)));
729 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
730
731 const __m256 vi11x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i11));
732 i11 += 8;
733
734 const __m256 vk11x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 96)));
735 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
736
737 const __m256 vi12x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i12));
738 i12 += 8;
739
740 const __m256 vk12x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 104)));
741 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
742
743 const __m256 vi13x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i13));
744 i13 += 8;
745
746 const __m256 vk13x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 112)));
747 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
748
749 const __m256 vi14x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i14));
750 i14 += 8;
751
752 const __m256 vk14x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 120)));
753 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
754
755 const __m256 vi15x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i15));
756 i15 += 8;
757
758 const __m256 vk15x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 128)));
759 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
760
761 const __m256 vi16x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i16));
762 i16 += 8;
763
764 const __m256 vk16x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 136)));
765 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
766
767 const __m256 vi17x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i17));
768 i17 += 8;
769
770 const __m256 vk17x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 144)));
771 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
772
773 const __m256 vi18x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i18));
774 i18 += 8;
775
776 const __m256 vk18x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 152)));
777 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
778
779 const __m256 vi19x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i19));
780 i19 += 8;
781
782 const __m256 vk19x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 160)));
783 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
784
785 const __m256 vi20x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i20));
786 i20 += 8;
787
788 const __m256 vk20x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 168)));
789 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
790
791 const __m256 vi21x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i21));
792 i21 += 8;
793
794 const __m256 vk21x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 176)));
795 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
796
797 const __m256 vi22x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i22));
798 i22 += 8;
799
800 const __m256 vk22x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 184)));
801 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
802
803 const __m256 vi23x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i23));
804 i23 += 8;
805
806 const __m256 vk23x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 192)));
807 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
808
809 const __m256 vi24x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i24));
810 i24 += 8;
811
812 const __m256 vk24x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 200)));
813 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
814
815 w += 208;
816
817 // Add up all accumulators to vacc01234567p0
818 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p0, vacc01234567p1), _MM_FROUND_NO_EXC));
819
820 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
821 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
822
823 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC));
824 o += 8;
825 }
826 if XNN_UNLIKELY(c != 0) {
827 assert(c >= 1);
828 assert(c <= 7);
829
830 __m256 vacc01234567p0 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) w));
831
832 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
833
834 const __m256 vk0x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 8)));
835 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
836
837 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
838
839 const __m256 vk1x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 16)));
840 __m256 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vi1x01234567, vk1x01234567), _MM_FROUND_NO_EXC));
841
842 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
843
844 const __m256 vk2x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 24)));
845 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
846
847 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
848
849 const __m256 vk3x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 32)));
850 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
851
852 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
853
854 const __m256 vk4x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 40)));
855 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
856
857 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
858
859 const __m256 vk5x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 48)));
860 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
861
862 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
863
864 const __m256 vk6x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 56)));
865 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
866
867 const __m256 vi7x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i7));
868
869 const __m256 vk7x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 64)));
870 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
871
872 const __m256 vi8x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i8));
873
874 const __m256 vk8x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 72)));
875 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
876
877 const __m256 vi9x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i9));
878
879 const __m256 vk9x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 80)));
880 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
881
882 const __m256 vi10x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i10));
883
884 const __m256 vk10x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 88)));
885 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
886
887 const __m256 vi11x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i11));
888
889 const __m256 vk11x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 96)));
890 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
891
892 const __m256 vi12x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i12));
893
894 const __m256 vk12x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 104)));
895 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
896
897 const __m256 vi13x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i13));
898
899 const __m256 vk13x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 112)));
900 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
901
902 const __m256 vi14x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i14));
903
904 const __m256 vk14x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 120)));
905 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
906
907 const __m256 vi15x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i15));
908
909 const __m256 vk15x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 128)));
910 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
911
912 const __m256 vi16x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i16));
913
914 const __m256 vk16x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 136)));
915 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
916
917 const __m256 vi17x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i17));
918
919 const __m256 vk17x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 144)));
920 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
921
922 const __m256 vi18x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i18));
923
924 const __m256 vk18x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 152)));
925 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
926
927 const __m256 vi19x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i19));
928
929 const __m256 vk19x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 160)));
930 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
931
932 const __m256 vi20x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i20));
933
934 const __m256 vk20x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 168)));
935 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
936
937 const __m256 vi21x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i21));
938
939 const __m256 vk21x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 176)));
940 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
941
942 const __m256 vi22x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i22));
943
944 const __m256 vk22x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 184)));
945 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
946
947 const __m256 vi23x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i23));
948
949 const __m256 vk23x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 192)));
950 vacc01234567p1 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p1), _MM_FROUND_NO_EXC));
951
952 const __m256 vi24x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i24));
953
954 const __m256 vk24x01234567 = _mm256_cvtph_ps(_mm_load_si128((const __m128i*) (w + 200)));
955 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0), _MM_FROUND_NO_EXC));
956
957 // Add up all accumulators to vacc01234567p0
958 vacc01234567p0 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(vacc01234567p0, vacc01234567p1), _MM_FROUND_NO_EXC));
959
960 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
961 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
962
963 __m128i vh01234567 = _mm256_cvtps_ph(vacc01234567, _MM_FROUND_NO_EXC);
964 if (c & 4) {
965 _mm_storel_epi64((__m128i*) o, vh01234567);
966 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
967 o += 4;
968 }
969 if (c & 2) {
970 *((uint32_t*) o) = (uint32_t) _mm_cvtsi128_si32(vh01234567);
971 vh01234567 = _mm_srli_epi64(vh01234567, 32);
972 o += 2;
973 }
974 if (c & 1) {
975 *((uint16_t*) o) = (uint16_t) _mm_extract_epi16(vh01234567, 0);
976 o += 1;
977 }
978 }
979
980 o = (uint16_t*) ((uintptr_t) o + output_increment);
981 } while (--output_width != 0);
982}
983
984void xnn_f16_vmulcaddc_minmax_ukernel_c8__fma3_2x(
985 size_t rows,
986 size_t channels,
987 const void*restrict input,
988 size_t input_stride,
989 const void*restrict weights,
990 void*restrict output,
991 size_t output_stride,
992 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
993{
994 assert(rows != 0);
995 assert(channels != 0);
996 assert(channels % sizeof(uint16_t) == 0);
997
998 const uint16_t* i0 = (const uint16_t*) input;
999 uint16_t* o0 = (uint16_t*) output;
1000 const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride);
1001 uint16_t* o1 = (uint16_t*) ((uintptr_t) o0 + output_stride);
1002
1003 const size_t input_increment = input_stride * 2 - channels;
1004 const size_t output_increment = output_stride * 2 - channels;
1005
1006 const __m256 vmin = _mm256_load_ps(params->avx.min);
1007 const __m256 vmax = _mm256_load_ps(params->avx.max);
1008 do {
1009 if XNN_UNPREDICTABLE(rows < 2) {
1010 i1 = i0;
1011 o1 = o0;
1012 }
1013
1014 const uint16_t* w = (const uint16_t*) weights;
1015 size_t c = channels;
1016 for (; c >= 8 * sizeof(uint16_t); c -= 8 * sizeof(uint16_t)) {
1017 const __m256 vscale = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
1018
1019 __m256 vacc0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1020 i0 += 8;
1021 __m256 vacc1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1022 i1 += 8;
1023
1024 const __m256 vbias = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
1025 w += 16;
1026
1027 vacc0 = _mm256_fmadd_ps(vacc0, vscale, vbias);
1028 vacc1 = _mm256_fmadd_ps(vacc1, vscale, vbias);
1029
1030 vacc0 = _mm256_max_ps(vacc0, vmin);
1031 vacc1 = _mm256_max_ps(vacc1, vmin);
1032
1033 vacc0 = _mm256_min_ps(vacc0, vmax);
1034 vacc1 = _mm256_min_ps(vacc1, vmax);
1035
1036 _mm_storeu_si128((__m128i*) o0, _mm256_cvtps_ph(vacc0, _MM_FROUND_NO_EXC));
1037 o0 += 8;
1038 _mm_storeu_si128((__m128i*) o1, _mm256_cvtps_ph(vacc1, _MM_FROUND_NO_EXC));
1039 o1 += 8;
1040 }
1041 if XNN_UNLIKELY(c != 0) {
1042 const __m256 vscale = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) w));
1043
1044 __m256 vacc0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
1045 i0 = (const uint16_t*) ((uintptr_t) i0 + c);
1046 __m256 vacc1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
1047 i1 = (const uint16_t*) ((uintptr_t) i1 + c);
1048
1049 const __m256 vbias = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (w + 8)));
1050
1051 vacc0 = _mm256_fmadd_ps(vacc0, vscale, vbias);
1052 vacc1 = _mm256_fmadd_ps(vacc1, vscale, vbias);
1053
1054 vacc0 = _mm256_max_ps(vacc0, vmin);
1055 vacc1 = _mm256_max_ps(vacc1, vmin);
1056
1057 vacc0 = _mm256_min_ps(vacc0, vmax);
1058 vacc1 = _mm256_min_ps(vacc1, vmax);
1059
1060 __m128i vh0 = _mm256_cvtps_ph(vacc0, _MM_FROUND_NO_EXC);
1061 __m128i vh1 = _mm256_cvtps_ph(vacc1, _MM_FROUND_NO_EXC);
1062
1063 if (c & (4 * sizeof(uint16_t))) {
1064 _mm_storel_epi64((__m128i*) o0, vh0);
1065 _mm_storel_epi64((__m128i*) o1, vh1);
1066
1067 vh0 = _mm_unpackhi_epi64(vh0, vh0);
1068 vh1 = _mm_unpackhi_epi64(vh1, vh1);
1069
1070 o0 += 4;
1071 o1 += 4;
1072 }
1073 if (c & (2 * sizeof(uint16_t))) {
1074 *((uint32_t*) o0) = (uint32_t) _mm_cvtsi128_si32(vh0);
1075 *((uint32_t*) o1) = (uint32_t) _mm_cvtsi128_si32(vh1);
1076
1077 vh0 = _mm_srli_epi64(vh0, 32);
1078 vh1 = _mm_srli_epi64(vh1, 32);
1079
1080 o0 += 2;
1081 o1 += 2;
1082 }
1083 if (c & (1 * sizeof(uint16_t))) {
1084 *o0 = (uint16_t) _mm_extract_epi16(vh0, 0);
1085 *o1 = (uint16_t) _mm_extract_epi16(vh1, 0);
1086
1087 o0 += 1;
1088 o1 += 1;
1089 }
1090 }
1091 i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment);
1092 o0 = (uint16_t*) ((uintptr_t) o0 + output_increment);
1093 i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment);
1094 o1 = (uint16_t*) ((uintptr_t) o1 + output_increment);
1095 rows = doz(rows, 2);
1096 } while (rows != 0);
1097}
1098
Marat Dukhan8a9eac62022-01-06 09:22:01 -08001099void xnn_f32_dwconv_minmax_ukernel_up16x3__fma3(
1100 size_t channels,
1101 size_t output_width,
1102 const float** input,
1103 const float* weights,
1104 float* output,
1105 size_t input_stride,
1106 size_t output_increment,
1107 size_t input_offset,
1108 const float* zero,
1109 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1110{
1111 assert(channels != 0);
1112 assert(output_width != 0);
1113
1114 const __m256 vmax = _mm256_load_ps(params->avx.max);
1115 const __m256 vmin = _mm256_load_ps(params->avx.min);
1116 do {
1117 const float* i0 = input[0];
1118 assert(i0 != NULL);
1119 if XNN_UNPREDICTABLE(i0 != zero) {
1120 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1121 }
1122 const float* i1 = input[1];
1123 assert(i1 != NULL);
1124 if XNN_UNPREDICTABLE(i1 != zero) {
1125 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1126 }
1127 const float* i2 = input[2];
1128 assert(i2 != NULL);
1129 if XNN_UNPREDICTABLE(i2 != zero) {
1130 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1131 }
1132 input = (const float**) ((uintptr_t) input + input_stride);
1133
1134 size_t c = channels;
1135 const float* w = weights;
1136 for (; c >= 16; c -= 16) {
1137 __m256 vacc01234567p0 = _mm256_load_ps(w);
1138 __m256 vacc89ABCDEFp0 = _mm256_load_ps(w + 8);
1139
1140
1141 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1142 const __m256 vi0x89ABCDEF = _mm256_loadu_ps(i0 + 8);
1143 i0 += 16;
1144
1145 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1146 const __m256 vk0x89ABCDEF = _mm256_load_ps(w + 24);
1147 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1148 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0);
1149
1150 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1151 const __m256 vi1x89ABCDEF = _mm256_loadu_ps(i1 + 8);
1152 i1 += 16;
1153
1154 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1155 const __m256 vk1x89ABCDEF = _mm256_load_ps(w + 40);
1156 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1157 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0);
1158
1159 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1160 const __m256 vi2x89ABCDEF = _mm256_loadu_ps(i2 + 8);
1161 i2 += 16;
1162
1163 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1164 const __m256 vk2x89ABCDEF = _mm256_load_ps(w + 56);
1165 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1166 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0);
1167
1168 w += 64;
1169
1170
1171 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1172 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
1173 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1174 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
1175
1176 _mm256_storeu_ps(output, vacc01234567);
1177 _mm256_storeu_ps(output + 8, vacc89ABCDEF);
1178 output += 16;
1179 }
1180 for (; c >= 8; c -= 8) {
1181 __m256 vacc01234567p0 = _mm256_load_ps(w);
1182
1183 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1184 i0 += 8;
1185
1186 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1187 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1188
1189 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1190 i1 += 8;
1191
1192 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1193 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1194
1195 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1196 i2 += 8;
1197
1198 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1199 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1200
1201 w += 8;
1202
1203
1204 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1205 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1206
1207 _mm256_storeu_ps(output, vacc01234567);
1208 output += 8;
1209 }
1210 if XNN_UNLIKELY(c != 0) {
1211 assert(c >= 1);
1212 assert(c <= 7);
1213 const __m256i vmask = _mm256_loadu_si256((const __m256i*) &params->avx.mask_table[7 - c]);
1214
1215 __m256 vacc01234567p0 = _mm256_load_ps(w);
1216
1217 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
1218 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1219 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1220
1221 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
1222 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1223 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1224
1225 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
1226 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1227 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1228
1229
1230 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1231 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1232
1233 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
1234 if (c & 4) {
1235 _mm_storeu_ps(output, vacc0123);
1236 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
1237 output += 4;
1238 }
1239 if (c & 2) {
1240 _mm_storel_pi((__m64*) output, vacc0123);
1241 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
1242 output += 2;
1243 }
1244 if (c & 1) {
1245 _mm_store_ss(output, vacc0123);
1246 output += 1;
1247 }
1248 }
1249
1250 output = (float*) ((uintptr_t) output + output_increment);
1251 } while (--output_width != 0);
1252}
1253
1254void xnn_f32_dwconv_minmax_ukernel_up16x4__fma3(
1255 size_t channels,
1256 size_t output_width,
1257 const float** input,
1258 const float* weights,
1259 float* output,
1260 size_t input_stride,
1261 size_t output_increment,
1262 size_t input_offset,
1263 const float* zero,
1264 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1265{
1266 assert(channels != 0);
1267 assert(output_width != 0);
1268
1269 const __m256 vmax = _mm256_load_ps(params->avx.max);
1270 const __m256 vmin = _mm256_load_ps(params->avx.min);
1271 do {
1272 const float* i0 = input[0];
1273 assert(i0 != NULL);
1274 if XNN_UNPREDICTABLE(i0 != zero) {
1275 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1276 }
1277 const float* i1 = input[1];
1278 assert(i1 != NULL);
1279 if XNN_UNPREDICTABLE(i1 != zero) {
1280 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1281 }
1282 const float* i2 = input[2];
1283 assert(i2 != NULL);
1284 if XNN_UNPREDICTABLE(i2 != zero) {
1285 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1286 }
1287 const float* i3 = input[3];
1288 assert(i3 != NULL);
1289 if XNN_UNPREDICTABLE(i3 != zero) {
1290 i3 = (const float*) ((uintptr_t) i3 + input_offset);
1291 }
1292 input = (const float**) ((uintptr_t) input + input_stride);
1293
1294 size_t c = channels;
1295 const float* w = weights;
1296 for (; c >= 16; c -= 16) {
1297 __m256 vacc01234567p0 = _mm256_load_ps(w);
1298 __m256 vacc89ABCDEFp0 = _mm256_load_ps(w + 8);
1299
1300
1301 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1302 const __m256 vi0x89ABCDEF = _mm256_loadu_ps(i0 + 8);
1303 i0 += 16;
1304
1305 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1306 const __m256 vk0x89ABCDEF = _mm256_load_ps(w + 24);
1307 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1308 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0);
1309
1310 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1311 const __m256 vi1x89ABCDEF = _mm256_loadu_ps(i1 + 8);
1312 i1 += 16;
1313
1314 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1315 const __m256 vk1x89ABCDEF = _mm256_load_ps(w + 40);
1316 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1317 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0);
1318
1319 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1320 const __m256 vi2x89ABCDEF = _mm256_loadu_ps(i2 + 8);
1321 i2 += 16;
1322
1323 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1324 const __m256 vk2x89ABCDEF = _mm256_load_ps(w + 56);
1325 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1326 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0);
1327
1328 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1329 const __m256 vi3x89ABCDEF = _mm256_loadu_ps(i3 + 8);
1330 i3 += 16;
1331
1332 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1333 const __m256 vk3x89ABCDEF = _mm256_load_ps(w + 72);
1334 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1335 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0);
1336
1337 w += 80;
1338
1339
1340 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1341 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
1342 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1343 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
1344
1345 _mm256_storeu_ps(output, vacc01234567);
1346 _mm256_storeu_ps(output + 8, vacc89ABCDEF);
1347 output += 16;
1348 }
1349 for (; c >= 8; c -= 8) {
1350 __m256 vacc01234567p0 = _mm256_load_ps(w);
1351
1352 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1353 i0 += 8;
1354
1355 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1356 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1357
1358 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1359 i1 += 8;
1360
1361 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1362 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1363
1364 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1365 i2 += 8;
1366
1367 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1368 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1369
1370 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1371 i3 += 8;
1372
1373 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1374 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1375
1376 w += 8;
1377
1378
1379 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1380 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1381
1382 _mm256_storeu_ps(output, vacc01234567);
1383 output += 8;
1384 }
1385 if XNN_UNLIKELY(c != 0) {
1386 assert(c >= 1);
1387 assert(c <= 7);
1388 const __m256i vmask = _mm256_loadu_si256((const __m256i*) &params->avx.mask_table[7 - c]);
1389
1390 __m256 vacc01234567p0 = _mm256_load_ps(w);
1391
1392 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
1393 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1394 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1395
1396 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
1397 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1398 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1399
1400 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
1401 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1402 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1403
1404 const __m256 vi3x01234567 = _mm256_maskload_ps(i3, vmask);
1405 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1406 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1407
1408
1409 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1410 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1411
1412 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
1413 if (c & 4) {
1414 _mm_storeu_ps(output, vacc0123);
1415 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
1416 output += 4;
1417 }
1418 if (c & 2) {
1419 _mm_storel_pi((__m64*) output, vacc0123);
1420 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
1421 output += 2;
1422 }
1423 if (c & 1) {
1424 _mm_store_ss(output, vacc0123);
1425 output += 1;
1426 }
1427 }
1428
1429 output = (float*) ((uintptr_t) output + output_increment);
1430 } while (--output_width != 0);
1431}
1432
1433void xnn_f32_dwconv_minmax_ukernel_up16x9__fma3(
1434 size_t channels,
1435 size_t output_width,
1436 const float** input,
1437 const float* weights,
1438 float* output,
1439 size_t input_stride,
1440 size_t output_increment,
1441 size_t input_offset,
1442 const float* zero,
1443 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1444{
1445 assert(channels != 0);
1446 assert(output_width != 0);
1447
1448 const __m256 vmax = _mm256_load_ps(params->avx.max);
1449 const __m256 vmin = _mm256_load_ps(params->avx.min);
1450 do {
1451 const float* i0 = input[0];
1452 assert(i0 != NULL);
1453 if XNN_UNPREDICTABLE(i0 != zero) {
1454 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1455 }
1456 const float* i1 = input[1];
1457 assert(i1 != NULL);
1458 if XNN_UNPREDICTABLE(i1 != zero) {
1459 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1460 }
1461 const float* i2 = input[2];
1462 assert(i2 != NULL);
1463 if XNN_UNPREDICTABLE(i2 != zero) {
1464 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1465 }
1466 const float* i3 = input[3];
1467 assert(i3 != NULL);
1468 if XNN_UNPREDICTABLE(i3 != zero) {
1469 i3 = (const float*) ((uintptr_t) i3 + input_offset);
1470 }
1471 const float* i4 = input[4];
1472 assert(i4 != NULL);
1473 if XNN_UNPREDICTABLE(i4 != zero) {
1474 i4 = (const float*) ((uintptr_t) i4 + input_offset);
1475 }
1476 const float* i5 = input[5];
1477 assert(i5 != NULL);
1478 if XNN_UNPREDICTABLE(i5 != zero) {
1479 i5 = (const float*) ((uintptr_t) i5 + input_offset);
1480 }
1481 const float* i6 = input[6];
1482 assert(i6 != NULL);
1483 if XNN_UNPREDICTABLE(i6 != zero) {
1484 i6 = (const float*) ((uintptr_t) i6 + input_offset);
1485 }
1486 const float* i7 = input[7];
1487 assert(i7 != NULL);
1488 if XNN_UNPREDICTABLE(i7 != zero) {
1489 i7 = (const float*) ((uintptr_t) i7 + input_offset);
1490 }
1491 const float* i8 = input[8];
1492 assert(i8 != NULL);
1493 if XNN_UNPREDICTABLE(i8 != zero) {
1494 i8 = (const float*) ((uintptr_t) i8 + input_offset);
1495 }
1496 input = (const float**) ((uintptr_t) input + input_stride);
1497
1498 size_t c = channels;
1499 const float* w = weights;
1500 for (; c >= 16; c -= 16) {
1501 __m256 vacc01234567p0 = _mm256_load_ps(w);
1502 __m256 vacc89ABCDEFp0 = _mm256_load_ps(w + 8);
1503
1504
1505 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1506 const __m256 vi0x89ABCDEF = _mm256_loadu_ps(i0 + 8);
1507 i0 += 16;
1508
1509 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1510 const __m256 vk0x89ABCDEF = _mm256_load_ps(w + 24);
1511 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1512 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi0x89ABCDEF, vk0x89ABCDEF, vacc89ABCDEFp0);
1513
1514 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1515 const __m256 vi1x89ABCDEF = _mm256_loadu_ps(i1 + 8);
1516 i1 += 16;
1517
1518 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1519 const __m256 vk1x89ABCDEF = _mm256_load_ps(w + 40);
1520 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1521 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi1x89ABCDEF, vk1x89ABCDEF, vacc89ABCDEFp0);
1522
1523 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1524 const __m256 vi2x89ABCDEF = _mm256_loadu_ps(i2 + 8);
1525 i2 += 16;
1526
1527 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1528 const __m256 vk2x89ABCDEF = _mm256_load_ps(w + 56);
1529 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1530 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi2x89ABCDEF, vk2x89ABCDEF, vacc89ABCDEFp0);
1531
1532 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1533 const __m256 vi3x89ABCDEF = _mm256_loadu_ps(i3 + 8);
1534 i3 += 16;
1535
1536 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1537 const __m256 vk3x89ABCDEF = _mm256_load_ps(w + 72);
1538 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1539 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi3x89ABCDEF, vk3x89ABCDEF, vacc89ABCDEFp0);
1540
1541 const __m256 vi4x01234567 = _mm256_loadu_ps(i4);
1542 const __m256 vi4x89ABCDEF = _mm256_loadu_ps(i4 + 8);
1543 i4 += 16;
1544
1545 const __m256 vk4x01234567 = _mm256_load_ps(w + 80);
1546 const __m256 vk4x89ABCDEF = _mm256_load_ps(w + 88);
1547 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1548 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi4x89ABCDEF, vk4x89ABCDEF, vacc89ABCDEFp0);
1549
1550 const __m256 vi5x01234567 = _mm256_loadu_ps(i5);
1551 const __m256 vi5x89ABCDEF = _mm256_loadu_ps(i5 + 8);
1552 i5 += 16;
1553
1554 const __m256 vk5x01234567 = _mm256_load_ps(w + 96);
1555 const __m256 vk5x89ABCDEF = _mm256_load_ps(w + 104);
1556 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1557 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi5x89ABCDEF, vk5x89ABCDEF, vacc89ABCDEFp0);
1558
1559 const __m256 vi6x01234567 = _mm256_loadu_ps(i6);
1560 const __m256 vi6x89ABCDEF = _mm256_loadu_ps(i6 + 8);
1561 i6 += 16;
1562
1563 const __m256 vk6x01234567 = _mm256_load_ps(w + 112);
1564 const __m256 vk6x89ABCDEF = _mm256_load_ps(w + 120);
1565 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1566 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi6x89ABCDEF, vk6x89ABCDEF, vacc89ABCDEFp0);
1567
1568 const __m256 vi7x01234567 = _mm256_loadu_ps(i7);
1569 const __m256 vi7x89ABCDEF = _mm256_loadu_ps(i7 + 8);
1570 i7 += 16;
1571
1572 const __m256 vk7x01234567 = _mm256_load_ps(w + 128);
1573 const __m256 vk7x89ABCDEF = _mm256_load_ps(w + 136);
1574 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1575 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi7x89ABCDEF, vk7x89ABCDEF, vacc89ABCDEFp0);
1576
1577 const __m256 vi8x01234567 = _mm256_loadu_ps(i8);
1578 const __m256 vi8x89ABCDEF = _mm256_loadu_ps(i8 + 8);
1579 i8 += 16;
1580
1581 const __m256 vk8x01234567 = _mm256_load_ps(w + 144);
1582 const __m256 vk8x89ABCDEF = _mm256_load_ps(w + 152);
1583 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1584 vacc89ABCDEFp0 = _mm256_fmadd_ps(vi8x89ABCDEF, vk8x89ABCDEF, vacc89ABCDEFp0);
1585
1586 w += 160;
1587
1588
1589 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1590 __m256 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEFp0, vmin);
1591 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1592 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vmax);
1593
1594 _mm256_storeu_ps(output, vacc01234567);
1595 _mm256_storeu_ps(output + 8, vacc89ABCDEF);
1596 output += 16;
1597 }
1598 for (; c >= 8; c -= 8) {
1599 __m256 vacc01234567p0 = _mm256_load_ps(w);
1600
1601 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1602 i0 += 8;
1603
1604 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1605 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1606
1607 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1608 i1 += 8;
1609
1610 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1611 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1612
1613 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1614 i2 += 8;
1615
1616 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1617 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1618
1619 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1620 i3 += 8;
1621
1622 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1623 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1624
1625 const __m256 vi4x01234567 = _mm256_loadu_ps(i4);
1626 i4 += 8;
1627
1628 const __m256 vk4x01234567 = _mm256_load_ps(w + 80);
1629 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1630
1631 const __m256 vi5x01234567 = _mm256_loadu_ps(i5);
1632 i5 += 8;
1633
1634 const __m256 vk5x01234567 = _mm256_load_ps(w + 96);
1635 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1636
1637 const __m256 vi6x01234567 = _mm256_loadu_ps(i6);
1638 i6 += 8;
1639
1640 const __m256 vk6x01234567 = _mm256_load_ps(w + 112);
1641 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1642
1643 const __m256 vi7x01234567 = _mm256_loadu_ps(i7);
1644 i7 += 8;
1645
1646 const __m256 vk7x01234567 = _mm256_load_ps(w + 128);
1647 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1648
1649 const __m256 vi8x01234567 = _mm256_loadu_ps(i8);
1650 i8 += 8;
1651
1652 const __m256 vk8x01234567 = _mm256_load_ps(w + 144);
1653 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1654
1655 w += 8;
1656
1657
1658 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1659 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1660
1661 _mm256_storeu_ps(output, vacc01234567);
1662 output += 8;
1663 }
1664 if XNN_UNLIKELY(c != 0) {
1665 assert(c >= 1);
1666 assert(c <= 7);
1667 const __m256i vmask = _mm256_loadu_si256((const __m256i*) &params->avx.mask_table[7 - c]);
1668
1669 __m256 vacc01234567p0 = _mm256_load_ps(w);
1670
1671 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
1672 const __m256 vk0x01234567 = _mm256_load_ps(w + 16);
1673 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1674
1675 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
1676 const __m256 vk1x01234567 = _mm256_load_ps(w + 32);
1677 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1678
1679 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
1680 const __m256 vk2x01234567 = _mm256_load_ps(w + 48);
1681 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1682
1683 const __m256 vi3x01234567 = _mm256_maskload_ps(i3, vmask);
1684 const __m256 vk3x01234567 = _mm256_load_ps(w + 64);
1685 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1686
1687 const __m256 vi4x01234567 = _mm256_maskload_ps(i4, vmask);
1688 const __m256 vk4x01234567 = _mm256_load_ps(w + 80);
1689 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1690
1691 const __m256 vi5x01234567 = _mm256_maskload_ps(i5, vmask);
1692 const __m256 vk5x01234567 = _mm256_load_ps(w + 96);
1693 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1694
1695 const __m256 vi6x01234567 = _mm256_maskload_ps(i6, vmask);
1696 const __m256 vk6x01234567 = _mm256_load_ps(w + 112);
1697 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1698
1699 const __m256 vi7x01234567 = _mm256_maskload_ps(i7, vmask);
1700 const __m256 vk7x01234567 = _mm256_load_ps(w + 128);
1701 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1702
1703 const __m256 vi8x01234567 = _mm256_maskload_ps(i8, vmask);
1704 const __m256 vk8x01234567 = _mm256_load_ps(w + 144);
1705 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1706
1707
1708 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
1709 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
1710
1711 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
1712 if (c & 4) {
1713 _mm_storeu_ps(output, vacc0123);
1714 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
1715 output += 4;
1716 }
1717 if (c & 2) {
1718 _mm_storel_pi((__m64*) output, vacc0123);
1719 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
1720 output += 2;
1721 }
1722 if (c & 1) {
1723 _mm_store_ss(output, vacc0123);
1724 output += 1;
1725 }
1726 }
1727
1728 output = (float*) ((uintptr_t) output + output_increment);
1729 } while (--output_width != 0);
1730}
1731
1732void xnn_f32_dwconv_minmax_ukernel_up8x25__fma3(
1733 size_t channels,
1734 size_t output_width,
1735 const float** input,
1736 const float* weights,
1737 float* output,
1738 size_t input_stride,
1739 size_t output_increment,
1740 size_t input_offset,
1741 const float* zero,
1742 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
1743{
1744 assert(channels != 0);
1745 assert(output_width != 0);
1746
1747 const __m256 vmax = _mm256_load_ps(params->avx.max);
1748 const __m256 vmin = _mm256_load_ps(params->avx.min);
1749 do {
1750 const float* i0 = input[0];
1751 assert(i0 != NULL);
1752 if XNN_UNPREDICTABLE(i0 != zero) {
1753 i0 = (const float*) ((uintptr_t) i0 + input_offset);
1754 }
1755 const float* i1 = input[1];
1756 assert(i1 != NULL);
1757 if XNN_UNPREDICTABLE(i1 != zero) {
1758 i1 = (const float*) ((uintptr_t) i1 + input_offset);
1759 }
1760 const float* i2 = input[2];
1761 assert(i2 != NULL);
1762 if XNN_UNPREDICTABLE(i2 != zero) {
1763 i2 = (const float*) ((uintptr_t) i2 + input_offset);
1764 }
1765 const float* i3 = input[3];
1766 assert(i3 != NULL);
1767 if XNN_UNPREDICTABLE(i3 != zero) {
1768 i3 = (const float*) ((uintptr_t) i3 + input_offset);
1769 }
1770 const float* i4 = input[4];
1771 assert(i4 != NULL);
1772 if XNN_UNPREDICTABLE(i4 != zero) {
1773 i4 = (const float*) ((uintptr_t) i4 + input_offset);
1774 }
1775 const float* i5 = input[5];
1776 assert(i5 != NULL);
1777 if XNN_UNPREDICTABLE(i5 != zero) {
1778 i5 = (const float*) ((uintptr_t) i5 + input_offset);
1779 }
1780 const float* i6 = input[6];
1781 assert(i6 != NULL);
1782 if XNN_UNPREDICTABLE(i6 != zero) {
1783 i6 = (const float*) ((uintptr_t) i6 + input_offset);
1784 }
1785 const float* i7 = input[7];
1786 assert(i7 != NULL);
1787 if XNN_UNPREDICTABLE(i7 != zero) {
1788 i7 = (const float*) ((uintptr_t) i7 + input_offset);
1789 }
1790 const float* i8 = input[8];
1791 assert(i8 != NULL);
1792 if XNN_UNPREDICTABLE(i8 != zero) {
1793 i8 = (const float*) ((uintptr_t) i8 + input_offset);
1794 }
1795 const float* i9 = input[9];
1796 assert(i9 != NULL);
1797 if XNN_UNPREDICTABLE(i9 != zero) {
1798 i9 = (const float*) ((uintptr_t) i9 + input_offset);
1799 }
1800 const float* i10 = input[10];
1801 assert(i10 != NULL);
1802 if XNN_UNPREDICTABLE(i10 != zero) {
1803 i10 = (const float*) ((uintptr_t) i10 + input_offset);
1804 }
1805 const float* i11 = input[11];
1806 assert(i11 != NULL);
1807 if XNN_UNPREDICTABLE(i11 != zero) {
1808 i11 = (const float*) ((uintptr_t) i11 + input_offset);
1809 }
1810 const float* i12 = input[12];
1811 assert(i12 != NULL);
1812 if XNN_UNPREDICTABLE(i12 != zero) {
1813 i12 = (const float*) ((uintptr_t) i12 + input_offset);
1814 }
1815 const float* i13 = input[13];
1816 assert(i13 != NULL);
1817 if XNN_UNPREDICTABLE(i13 != zero) {
1818 i13 = (const float*) ((uintptr_t) i13 + input_offset);
1819 }
1820 const float* i14 = input[14];
1821 assert(i14 != NULL);
1822 if XNN_UNPREDICTABLE(i14 != zero) {
1823 i14 = (const float*) ((uintptr_t) i14 + input_offset);
1824 }
1825 const float* i15 = input[15];
1826 assert(i15 != NULL);
1827 if XNN_UNPREDICTABLE(i15 != zero) {
1828 i15 = (const float*) ((uintptr_t) i15 + input_offset);
1829 }
1830 const float* i16 = input[16];
1831 assert(i16 != NULL);
1832 if XNN_UNPREDICTABLE(i16 != zero) {
1833 i16 = (const float*) ((uintptr_t) i16 + input_offset);
1834 }
1835 const float* i17 = input[17];
1836 assert(i17 != NULL);
1837 if XNN_UNPREDICTABLE(i17 != zero) {
1838 i17 = (const float*) ((uintptr_t) i17 + input_offset);
1839 }
1840 const float* i18 = input[18];
1841 assert(i18 != NULL);
1842 if XNN_UNPREDICTABLE(i18 != zero) {
1843 i18 = (const float*) ((uintptr_t) i18 + input_offset);
1844 }
1845 const float* i19 = input[19];
1846 assert(i19 != NULL);
1847 if XNN_UNPREDICTABLE(i19 != zero) {
1848 i19 = (const float*) ((uintptr_t) i19 + input_offset);
1849 }
1850 const float* i20 = input[20];
1851 assert(i20 != NULL);
1852 if XNN_UNPREDICTABLE(i20 != zero) {
1853 i20 = (const float*) ((uintptr_t) i20 + input_offset);
1854 }
1855 const float* i21 = input[21];
1856 assert(i21 != NULL);
1857 if XNN_UNPREDICTABLE(i21 != zero) {
1858 i21 = (const float*) ((uintptr_t) i21 + input_offset);
1859 }
1860 const float* i22 = input[22];
1861 assert(i22 != NULL);
1862 if XNN_UNPREDICTABLE(i22 != zero) {
1863 i22 = (const float*) ((uintptr_t) i22 + input_offset);
1864 }
1865 const float* i23 = input[23];
1866 assert(i23 != NULL);
1867 if XNN_UNPREDICTABLE(i23 != zero) {
1868 i23 = (const float*) ((uintptr_t) i23 + input_offset);
1869 }
1870 const float* i24 = input[24];
1871 assert(i24 != NULL);
1872 if XNN_UNPREDICTABLE(i24 != zero) {
1873 i24 = (const float*) ((uintptr_t) i24 + input_offset);
1874 }
1875 input = (const float**) ((uintptr_t) input + input_stride);
1876
1877 size_t c = channels;
1878 const float* w = weights;
1879 for (; c >= 8; c -= 8) {
1880 __m256 vacc01234567p0 = _mm256_load_ps(w);
1881
1882
1883 const __m256 vi0x01234567 = _mm256_loadu_ps(i0);
1884 i0 += 8;
1885
1886 const __m256 vk0x01234567 = _mm256_load_ps(w + 8);
1887 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
1888
1889 const __m256 vi1x01234567 = _mm256_loadu_ps(i1);
1890 i1 += 8;
1891
1892 const __m256 vk1x01234567 = _mm256_load_ps(w + 16);
1893 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
1894
1895 const __m256 vi2x01234567 = _mm256_loadu_ps(i2);
1896 i2 += 8;
1897
1898 const __m256 vk2x01234567 = _mm256_load_ps(w + 24);
1899 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
1900
1901 const __m256 vi3x01234567 = _mm256_loadu_ps(i3);
1902 i3 += 8;
1903
1904 const __m256 vk3x01234567 = _mm256_load_ps(w + 32);
1905 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
1906
1907 const __m256 vi4x01234567 = _mm256_loadu_ps(i4);
1908 i4 += 8;
1909
1910 const __m256 vk4x01234567 = _mm256_load_ps(w + 40);
1911 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
1912
1913 const __m256 vi5x01234567 = _mm256_loadu_ps(i5);
1914 i5 += 8;
1915
1916 const __m256 vk5x01234567 = _mm256_load_ps(w + 48);
1917 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
1918
1919 const __m256 vi6x01234567 = _mm256_loadu_ps(i6);
1920 i6 += 8;
1921
1922 const __m256 vk6x01234567 = _mm256_load_ps(w + 56);
1923 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
1924
1925 const __m256 vi7x01234567 = _mm256_loadu_ps(i7);
1926 i7 += 8;
1927
1928 const __m256 vk7x01234567 = _mm256_load_ps(w + 64);
1929 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
1930
1931 const __m256 vi8x01234567 = _mm256_loadu_ps(i8);
1932 i8 += 8;
1933
1934 const __m256 vk8x01234567 = _mm256_load_ps(w + 72);
1935 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
1936
1937 const __m256 vi9x01234567 = _mm256_loadu_ps(i9);
1938 i9 += 8;
1939
1940 const __m256 vk9x01234567 = _mm256_load_ps(w + 80);
1941 vacc01234567p0 = _mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p0);
1942
1943 const __m256 vi10x01234567 = _mm256_loadu_ps(i10);
1944 i10 += 8;
1945
1946 const __m256 vk10x01234567 = _mm256_load_ps(w + 88);
1947 vacc01234567p0 = _mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0);
1948
1949 const __m256 vi11x01234567 = _mm256_loadu_ps(i11);
1950 i11 += 8;
1951
1952 const __m256 vk11x01234567 = _mm256_load_ps(w + 96);
1953 vacc01234567p0 = _mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p0);
1954
1955 const __m256 vi12x01234567 = _mm256_loadu_ps(i12);
1956 i12 += 8;
1957
1958 const __m256 vk12x01234567 = _mm256_load_ps(w + 104);
1959 vacc01234567p0 = _mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0);
1960
1961 const __m256 vi13x01234567 = _mm256_loadu_ps(i13);
1962 i13 += 8;
1963
1964 const __m256 vk13x01234567 = _mm256_load_ps(w + 112);
1965 vacc01234567p0 = _mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p0);
1966
1967 const __m256 vi14x01234567 = _mm256_loadu_ps(i14);
1968 i14 += 8;
1969
1970 const __m256 vk14x01234567 = _mm256_load_ps(w + 120);
1971 vacc01234567p0 = _mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0);
1972
1973 const __m256 vi15x01234567 = _mm256_loadu_ps(i15);
1974 i15 += 8;
1975
1976 const __m256 vk15x01234567 = _mm256_load_ps(w + 128);
1977 vacc01234567p0 = _mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p0);
1978
1979 const __m256 vi16x01234567 = _mm256_loadu_ps(i16);
1980 i16 += 8;
1981
1982 const __m256 vk16x01234567 = _mm256_load_ps(w + 136);
1983 vacc01234567p0 = _mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0);
1984
1985 const __m256 vi17x01234567 = _mm256_loadu_ps(i17);
1986 i17 += 8;
1987
1988 const __m256 vk17x01234567 = _mm256_load_ps(w + 144);
1989 vacc01234567p0 = _mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p0);
1990
1991 const __m256 vi18x01234567 = _mm256_loadu_ps(i18);
1992 i18 += 8;
1993
1994 const __m256 vk18x01234567 = _mm256_load_ps(w + 152);
1995 vacc01234567p0 = _mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0);
1996
1997 const __m256 vi19x01234567 = _mm256_loadu_ps(i19);
1998 i19 += 8;
1999
2000 const __m256 vk19x01234567 = _mm256_load_ps(w + 160);
2001 vacc01234567p0 = _mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p0);
2002
2003 const __m256 vi20x01234567 = _mm256_loadu_ps(i20);
2004 i20 += 8;
2005
2006 const __m256 vk20x01234567 = _mm256_load_ps(w + 168);
2007 vacc01234567p0 = _mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0);
2008
2009 const __m256 vi21x01234567 = _mm256_loadu_ps(i21);
2010 i21 += 8;
2011
2012 const __m256 vk21x01234567 = _mm256_load_ps(w + 176);
2013 vacc01234567p0 = _mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p0);
2014
2015 const __m256 vi22x01234567 = _mm256_loadu_ps(i22);
2016 i22 += 8;
2017
2018 const __m256 vk22x01234567 = _mm256_load_ps(w + 184);
2019 vacc01234567p0 = _mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0);
2020
2021 const __m256 vi23x01234567 = _mm256_loadu_ps(i23);
2022 i23 += 8;
2023
2024 const __m256 vk23x01234567 = _mm256_load_ps(w + 192);
2025 vacc01234567p0 = _mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p0);
2026
2027 const __m256 vi24x01234567 = _mm256_loadu_ps(i24);
2028 i24 += 8;
2029
2030 const __m256 vk24x01234567 = _mm256_load_ps(w + 200);
2031 vacc01234567p0 = _mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0);
2032
2033 w += 208;
2034
2035
2036 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
2037 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
2038
2039 _mm256_storeu_ps(output, vacc01234567);
2040 output += 8;
2041 }
2042 if XNN_UNLIKELY(c != 0) {
2043 assert(c >= 1);
2044 assert(c <= 7);
2045 const __m256i vmask = _mm256_loadu_si256((const __m256i*) &params->avx.mask_table[7 - c]);
2046
2047 __m256 vacc01234567p0 = _mm256_load_ps(w);
2048
2049 const __m256 vi0x01234567 = _mm256_maskload_ps(i0, vmask);
2050 const __m256 vk0x01234567 = _mm256_load_ps(w + 8);
2051 vacc01234567p0 = _mm256_fmadd_ps(vi0x01234567, vk0x01234567, vacc01234567p0);
2052
2053 const __m256 vi1x01234567 = _mm256_maskload_ps(i1, vmask);
2054 const __m256 vk1x01234567 = _mm256_load_ps(w + 16);
2055 vacc01234567p0 = _mm256_fmadd_ps(vi1x01234567, vk1x01234567, vacc01234567p0);
2056
2057 const __m256 vi2x01234567 = _mm256_maskload_ps(i2, vmask);
2058 const __m256 vk2x01234567 = _mm256_load_ps(w + 24);
2059 vacc01234567p0 = _mm256_fmadd_ps(vi2x01234567, vk2x01234567, vacc01234567p0);
2060
2061 const __m256 vi3x01234567 = _mm256_maskload_ps(i3, vmask);
2062 const __m256 vk3x01234567 = _mm256_load_ps(w + 32);
2063 vacc01234567p0 = _mm256_fmadd_ps(vi3x01234567, vk3x01234567, vacc01234567p0);
2064
2065 const __m256 vi4x01234567 = _mm256_maskload_ps(i4, vmask);
2066 const __m256 vk4x01234567 = _mm256_load_ps(w + 40);
2067 vacc01234567p0 = _mm256_fmadd_ps(vi4x01234567, vk4x01234567, vacc01234567p0);
2068
2069 const __m256 vi5x01234567 = _mm256_maskload_ps(i5, vmask);
2070 const __m256 vk5x01234567 = _mm256_load_ps(w + 48);
2071 vacc01234567p0 = _mm256_fmadd_ps(vi5x01234567, vk5x01234567, vacc01234567p0);
2072
2073 const __m256 vi6x01234567 = _mm256_maskload_ps(i6, vmask);
2074 const __m256 vk6x01234567 = _mm256_load_ps(w + 56);
2075 vacc01234567p0 = _mm256_fmadd_ps(vi6x01234567, vk6x01234567, vacc01234567p0);
2076
2077 const __m256 vi7x01234567 = _mm256_maskload_ps(i7, vmask);
2078 const __m256 vk7x01234567 = _mm256_load_ps(w + 64);
2079 vacc01234567p0 = _mm256_fmadd_ps(vi7x01234567, vk7x01234567, vacc01234567p0);
2080
2081 const __m256 vi8x01234567 = _mm256_maskload_ps(i8, vmask);
2082 const __m256 vk8x01234567 = _mm256_load_ps(w + 72);
2083 vacc01234567p0 = _mm256_fmadd_ps(vi8x01234567, vk8x01234567, vacc01234567p0);
2084
2085 const __m256 vi9x01234567 = _mm256_maskload_ps(i9, vmask);
2086 const __m256 vk9x01234567 = _mm256_load_ps(w + 80);
2087 vacc01234567p0 = _mm256_fmadd_ps(vi9x01234567, vk9x01234567, vacc01234567p0);
2088
2089 const __m256 vi10x01234567 = _mm256_maskload_ps(i10, vmask);
2090 const __m256 vk10x01234567 = _mm256_load_ps(w + 88);
2091 vacc01234567p0 = _mm256_fmadd_ps(vi10x01234567, vk10x01234567, vacc01234567p0);
2092
2093 const __m256 vi11x01234567 = _mm256_maskload_ps(i11, vmask);
2094 const __m256 vk11x01234567 = _mm256_load_ps(w + 96);
2095 vacc01234567p0 = _mm256_fmadd_ps(vi11x01234567, vk11x01234567, vacc01234567p0);
2096
2097 const __m256 vi12x01234567 = _mm256_maskload_ps(i12, vmask);
2098 const __m256 vk12x01234567 = _mm256_load_ps(w + 104);
2099 vacc01234567p0 = _mm256_fmadd_ps(vi12x01234567, vk12x01234567, vacc01234567p0);
2100
2101 const __m256 vi13x01234567 = _mm256_maskload_ps(i13, vmask);
2102 const __m256 vk13x01234567 = _mm256_load_ps(w + 112);
2103 vacc01234567p0 = _mm256_fmadd_ps(vi13x01234567, vk13x01234567, vacc01234567p0);
2104
2105 const __m256 vi14x01234567 = _mm256_maskload_ps(i14, vmask);
2106 const __m256 vk14x01234567 = _mm256_load_ps(w + 120);
2107 vacc01234567p0 = _mm256_fmadd_ps(vi14x01234567, vk14x01234567, vacc01234567p0);
2108
2109 const __m256 vi15x01234567 = _mm256_maskload_ps(i15, vmask);
2110 const __m256 vk15x01234567 = _mm256_load_ps(w + 128);
2111 vacc01234567p0 = _mm256_fmadd_ps(vi15x01234567, vk15x01234567, vacc01234567p0);
2112
2113 const __m256 vi16x01234567 = _mm256_maskload_ps(i16, vmask);
2114 const __m256 vk16x01234567 = _mm256_load_ps(w + 136);
2115 vacc01234567p0 = _mm256_fmadd_ps(vi16x01234567, vk16x01234567, vacc01234567p0);
2116
2117 const __m256 vi17x01234567 = _mm256_maskload_ps(i17, vmask);
2118 const __m256 vk17x01234567 = _mm256_load_ps(w + 144);
2119 vacc01234567p0 = _mm256_fmadd_ps(vi17x01234567, vk17x01234567, vacc01234567p0);
2120
2121 const __m256 vi18x01234567 = _mm256_maskload_ps(i18, vmask);
2122 const __m256 vk18x01234567 = _mm256_load_ps(w + 152);
2123 vacc01234567p0 = _mm256_fmadd_ps(vi18x01234567, vk18x01234567, vacc01234567p0);
2124
2125 const __m256 vi19x01234567 = _mm256_maskload_ps(i19, vmask);
2126 const __m256 vk19x01234567 = _mm256_load_ps(w + 160);
2127 vacc01234567p0 = _mm256_fmadd_ps(vi19x01234567, vk19x01234567, vacc01234567p0);
2128
2129 const __m256 vi20x01234567 = _mm256_maskload_ps(i20, vmask);
2130 const __m256 vk20x01234567 = _mm256_load_ps(w + 168);
2131 vacc01234567p0 = _mm256_fmadd_ps(vi20x01234567, vk20x01234567, vacc01234567p0);
2132
2133 const __m256 vi21x01234567 = _mm256_maskload_ps(i21, vmask);
2134 const __m256 vk21x01234567 = _mm256_load_ps(w + 176);
2135 vacc01234567p0 = _mm256_fmadd_ps(vi21x01234567, vk21x01234567, vacc01234567p0);
2136
2137 const __m256 vi22x01234567 = _mm256_maskload_ps(i22, vmask);
2138 const __m256 vk22x01234567 = _mm256_load_ps(w + 184);
2139 vacc01234567p0 = _mm256_fmadd_ps(vi22x01234567, vk22x01234567, vacc01234567p0);
2140
2141 const __m256 vi23x01234567 = _mm256_maskload_ps(i23, vmask);
2142 const __m256 vk23x01234567 = _mm256_load_ps(w + 192);
2143 vacc01234567p0 = _mm256_fmadd_ps(vi23x01234567, vk23x01234567, vacc01234567p0);
2144
2145 const __m256 vi24x01234567 = _mm256_maskload_ps(i24, vmask);
2146 const __m256 vk24x01234567 = _mm256_load_ps(w + 200);
2147 vacc01234567p0 = _mm256_fmadd_ps(vi24x01234567, vk24x01234567, vacc01234567p0);
2148
2149
2150 __m256 vacc01234567 = _mm256_max_ps(vacc01234567p0, vmin);
2151 vacc01234567 = _mm256_min_ps(vacc01234567, vmax);
2152
2153 __m128 vacc0123 = _mm256_castps256_ps128(vacc01234567);
2154 if (c & 4) {
2155 _mm_storeu_ps(output, vacc0123);
2156 vacc0123 = _mm256_extractf128_ps(vacc01234567, 1);
2157 output += 4;
2158 }
2159 if (c & 2) {
2160 _mm_storel_pi((__m64*) output, vacc0123);
2161 vacc0123 = _mm_movehl_ps(vacc0123, vacc0123);
2162 output += 2;
2163 }
2164 if (c & 1) {
2165 _mm_store_ss(output, vacc0123);
2166 output += 1;
2167 }
2168 }
2169
2170 output = (float*) ((uintptr_t) output + output_increment);
2171 } while (--output_width != 0);
2172}
2173
2174void xnn_f32_gemm_minmax_ukernel_1x16__fma3_broadcast(
2175 size_t mr,
2176 size_t nc,
2177 size_t kc,
2178 const float*restrict a,
2179 size_t a_stride,
2180 const float*restrict w,
2181 float*restrict c,
2182 size_t cm_stride,
2183 size_t cn_stride,
2184 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2185{
2186 assert(mr != 0);
2187 assert(mr <= 1);
2188 assert(nc != 0);
2189 assert(kc != 0);
2190 assert(kc % sizeof(float) == 0);
2191 assert(a != NULL);
2192 assert(w != NULL);
2193 assert(c != NULL);
2194
2195 const float* a0 = a;
2196 float* c0 = c;
2197
2198 do {
2199 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2200 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2201 w += 16;
2202
2203 size_t k = kc;
2204 do {
2205 const __m256 va0 = _mm256_broadcast_ss(a0);
2206 a0 += 1;
2207
2208 const __m256 vb01234567 = _mm256_load_ps(w);
2209 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
2210 w += 16;
2211
2212 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
2213 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
2214
2215 k -= sizeof(float);
2216 } while (k != 0);
2217
2218 const __m256 vmin = _mm256_load_ps(params->avx.min);
2219 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2220 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2221
2222 const __m256 vmax = _mm256_load_ps(params->avx.max);
2223 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2224 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2225
2226 if XNN_LIKELY(nc >= 16) {
2227 _mm256_storeu_ps(c0, vacc0x01234567);
2228 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2229 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2230
2231 a0 = (const float*) ((uintptr_t) a0 - kc);
2232
2233 nc -= 16;
2234 } else {
2235 if (nc & 8) {
2236 _mm256_storeu_ps(c0, vacc0x01234567);
2237
2238 vacc0x01234567 = vacc0x89ABCDEF;
2239
2240 c0 += 8;
2241 }
2242 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2243 if (nc & 4) {
2244 _mm_storeu_ps(c0, vacc0x0123);
2245
2246 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2247
2248 c0 += 4;
2249 }
2250 if (nc & 2) {
2251 _mm_storel_pi((__m64*) c0, vacc0x0123);
2252
2253 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2254
2255 c0 += 2;
2256 }
2257 if (nc & 1) {
2258 _mm_store_ss(c0, vacc0x0123);
2259 }
2260
2261 nc = 0;
2262 }
2263 } while (nc != 0);
2264}
2265
2266void xnn_f32_gemm_minmax_ukernel_1x16s4__fma3_broadcast(
2267 size_t mr,
2268 size_t nc,
2269 size_t kc,
2270 const float*restrict a,
2271 size_t a_stride,
2272 const float*restrict w,
2273 float*restrict c,
2274 size_t cm_stride,
2275 size_t cn_stride,
Marat Dukhan8319baa2022-01-31 20:14:13 -08002276 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002277{
2278 assert(mr != 0);
2279 assert(mr <= 1);
2280 assert(nc != 0);
2281 assert(kc != 0);
2282 assert(kc % sizeof(float) == 0);
2283 assert(a != NULL);
2284 assert(w != NULL);
2285 assert(c != NULL);
2286
2287 const float* a0 = a;
2288 float* c0 = c;
2289
2290 do {
2291 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2292 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2293 w += 16;
2294
2295 size_t k = kc;
2296 while (k >= 4 * sizeof(float)) {
2297 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2298 a0 += 4;
2299
2300
2301 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2302 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
2303
2304 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
2305 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
2306
2307 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2308
2309 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2310 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2311
2312 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
2313 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
2314
2315 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2316
2317 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2318 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2319
2320 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
2321 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
2322
2323 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2324
2325 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2326 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2327
2328 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
2329 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
2330
2331
2332 w += 64;
2333 k -= 4 * sizeof(float);
2334 }
2335 if XNN_UNLIKELY(k != 0) {
Marat Dukhan8319baa2022-01-31 20:14:13 -08002336 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2337 a0 = (const float*) ((uintptr_t) a0 + k);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002338
Marat Dukhan8319baa2022-01-31 20:14:13 -08002339 const __m256 vzero = _mm256_setzero_ps();
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002340
Marat Dukhan8319baa2022-01-31 20:14:13 -08002341 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2342 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002343
Marat Dukhan8319baa2022-01-31 20:14:13 -08002344 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
2345 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
2346
2347 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2348
2349 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2350 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2351
2352 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
2353 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
2354
2355 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2356
2357 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2358 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2359
2360 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
2361 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
2362
2363 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2364
2365 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2366 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2367
2368 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
2369 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
2370
2371
2372 w += 64;
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002373 }
2374
2375 const __m256 vmin = _mm256_load_ps(params->avx.min);
2376 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2377 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2378
2379 const __m256 vmax = _mm256_load_ps(params->avx.max);
2380 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2381 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2382
2383 if XNN_LIKELY(nc >= 16) {
2384 _mm256_storeu_ps(c0, vacc0x01234567);
2385 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2386 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2387
2388 a0 = (const float*) ((uintptr_t) a0 - kc);
2389
2390 nc -= 16;
2391 } else {
2392 if (nc & 8) {
2393 _mm256_storeu_ps(c0, vacc0x01234567);
2394
2395 vacc0x01234567 = vacc0x89ABCDEF;
2396
2397 c0 += 8;
2398 }
2399 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2400 if (nc & 4) {
2401 _mm_storeu_ps(c0, vacc0x0123);
2402
2403 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2404
2405 c0 += 4;
2406 }
2407 if (nc & 2) {
2408 _mm_storel_pi((__m64*) c0, vacc0x0123);
2409
2410 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2411
2412 c0 += 2;
2413 }
2414 if (nc & 1) {
2415 _mm_store_ss(c0, vacc0x0123);
2416 }
2417
2418 nc = 0;
2419 }
2420 } while (nc != 0);
2421}
2422
2423void xnn_f32_gemm_minmax_ukernel_4x16s4__fma3_broadcast(
2424 size_t mr,
2425 size_t nc,
2426 size_t kc,
2427 const float*restrict a,
2428 size_t a_stride,
2429 const float*restrict w,
2430 float*restrict c,
2431 size_t cm_stride,
2432 size_t cn_stride,
Marat Dukhan8319baa2022-01-31 20:14:13 -08002433 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002434{
2435 assert(mr != 0);
2436 assert(mr <= 4);
2437 assert(nc != 0);
2438 assert(kc != 0);
2439 assert(kc % sizeof(float) == 0);
2440 assert(a != NULL);
2441 assert(w != NULL);
2442 assert(c != NULL);
2443
2444 const float* a0 = a;
2445 float* c0 = c;
2446 const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
2447 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
2448 if XNN_UNPREDICTABLE(mr < 2) {
2449 a1 = a0;
2450 c1 = c0;
2451 }
2452 const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
2453 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
2454 if XNN_UNPREDICTABLE(mr <= 2) {
2455 a2 = a1;
2456 c2 = c1;
2457 }
2458 const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
2459 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
2460 if XNN_UNPREDICTABLE(mr != 4) {
2461 a3 = a2;
2462 c3 = c2;
2463 }
2464
2465 do {
2466 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2467 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2468 __m256 vacc1x01234567 = vacc0x01234567;
2469 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
2470 __m256 vacc2x01234567 = vacc0x01234567;
2471 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
2472 __m256 vacc3x01234567 = vacc0x01234567;
2473 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
2474 w += 16;
2475
2476 size_t k = kc;
2477 while (k >= 4 * sizeof(float)) {
2478 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2479 a0 += 4;
2480 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
2481 a1 += 4;
2482 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
2483 a2 += 4;
2484 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
2485 a3 += 4;
2486
2487
2488 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2489 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
2490
2491 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
2492 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c0, vacc1x01234567);
2493 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c0, vacc2x01234567);
2494 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c0, vacc3x01234567);
2495 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
2496 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc0, vacc1x89ABCDEF);
2497 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc0, vacc2x89ABCDEF);
2498 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc0, vacc3x89ABCDEF);
2499
2500 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2501 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2502 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2503 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2504
2505 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2506 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2507
2508 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
2509 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c1, vacc1x01234567);
2510 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c1, vacc2x01234567);
2511 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c1, vacc3x01234567);
2512 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
2513 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc1, vacc1x89ABCDEF);
2514 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc1, vacc2x89ABCDEF);
2515 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc1, vacc3x89ABCDEF);
2516
2517 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2518 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2519 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2520 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2521
2522 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2523 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2524
2525 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
2526 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c2, vacc1x01234567);
2527 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c2, vacc2x01234567);
2528 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c2, vacc3x01234567);
2529 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
2530 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc2, vacc1x89ABCDEF);
2531 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc2, vacc2x89ABCDEF);
2532 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc2, vacc3x89ABCDEF);
2533
2534 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2535 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2536 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2537 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2538
2539 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2540 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2541
2542 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
2543 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c3, vacc1x01234567);
2544 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c3, vacc2x01234567);
2545 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c3, vacc3x01234567);
2546 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
2547 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc3, vacc1x89ABCDEF);
2548 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc3, vacc2x89ABCDEF);
2549 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc3, vacc3x89ABCDEF);
2550
2551
2552 w += 64;
2553 k -= 4 * sizeof(float);
2554 }
2555 if XNN_UNLIKELY(k != 0) {
Marat Dukhan8319baa2022-01-31 20:14:13 -08002556 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
2557 a0 = (const float*) ((uintptr_t) a0 + k);
2558 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
2559 a1 = (const float*) ((uintptr_t) a1 + k);
2560 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
2561 a2 = (const float*) ((uintptr_t) a2 + k);
2562 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
2563 a3 = (const float*) ((uintptr_t) a3 + k);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002564
Marat Dukhan8319baa2022-01-31 20:14:13 -08002565 const __m256 vzero = _mm256_setzero_ps();
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002566
Marat Dukhan8319baa2022-01-31 20:14:13 -08002567 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
2568 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002569
Marat Dukhan8319baa2022-01-31 20:14:13 -08002570 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
2571 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc1x01234567);
2572 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc2x01234567);
2573 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc3x01234567);
2574 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
2575 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc1x89ABCDEF);
2576 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc2x89ABCDEF);
2577 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc3x89ABCDEF);
2578
2579 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2580 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2581 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2582 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2583
2584 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
2585 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
2586
2587 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
2588 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc1x01234567);
2589 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc2x01234567);
2590 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc3x01234567);
2591 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
2592 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc1x89ABCDEF);
2593 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc2x89ABCDEF);
2594 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc3x89ABCDEF);
2595
2596 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2597 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2598 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2599 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2600
2601 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
2602 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
2603
2604 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
2605 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc1x01234567);
2606 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc2x01234567);
2607 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc3x01234567);
2608 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
2609 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc1x89ABCDEF);
2610 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc2x89ABCDEF);
2611 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc3x89ABCDEF);
2612
2613 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
2614 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
2615 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
2616 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
2617
2618 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
2619 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
2620
2621 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
2622 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc1x01234567);
2623 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc2x01234567);
2624 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc3x01234567);
2625 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
2626 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc1x89ABCDEF);
2627 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc2x89ABCDEF);
2628 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc3x89ABCDEF);
2629
2630
2631 w += 64;
Marat Dukhan8a9eac62022-01-06 09:22:01 -08002632 }
2633
2634 const __m256 vmin = _mm256_load_ps(params->avx.min);
2635 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2636 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
2637 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
2638 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
2639 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2640 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
2641 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
2642 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
2643
2644 const __m256 vmax = _mm256_load_ps(params->avx.max);
2645 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2646 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
2647 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
2648 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
2649 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2650 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
2651 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
2652 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
2653
2654 if XNN_LIKELY(nc >= 16) {
2655 _mm256_storeu_ps(c3, vacc3x01234567);
2656 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
2657 c3 = (float*) ((uintptr_t) c3 + cn_stride);
2658 _mm256_storeu_ps(c2, vacc2x01234567);
2659 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
2660 c2 = (float*) ((uintptr_t) c2 + cn_stride);
2661 _mm256_storeu_ps(c1, vacc1x01234567);
2662 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
2663 c1 = (float*) ((uintptr_t) c1 + cn_stride);
2664 _mm256_storeu_ps(c0, vacc0x01234567);
2665 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2666 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2667
2668 a3 = (const float*) ((uintptr_t) a3 - kc);
2669 a2 = (const float*) ((uintptr_t) a2 - kc);
2670 a1 = (const float*) ((uintptr_t) a1 - kc);
2671 a0 = (const float*) ((uintptr_t) a0 - kc);
2672
2673 nc -= 16;
2674 } else {
2675 if (nc & 8) {
2676 _mm256_storeu_ps(c3, vacc3x01234567);
2677 _mm256_storeu_ps(c2, vacc2x01234567);
2678 _mm256_storeu_ps(c1, vacc1x01234567);
2679 _mm256_storeu_ps(c0, vacc0x01234567);
2680
2681 vacc3x01234567 = vacc3x89ABCDEF;
2682 vacc2x01234567 = vacc2x89ABCDEF;
2683 vacc1x01234567 = vacc1x89ABCDEF;
2684 vacc0x01234567 = vacc0x89ABCDEF;
2685
2686 c3 += 8;
2687 c2 += 8;
2688 c1 += 8;
2689 c0 += 8;
2690 }
2691 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
2692 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
2693 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
2694 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2695 if (nc & 4) {
2696 _mm_storeu_ps(c3, vacc3x0123);
2697 _mm_storeu_ps(c2, vacc2x0123);
2698 _mm_storeu_ps(c1, vacc1x0123);
2699 _mm_storeu_ps(c0, vacc0x0123);
2700
2701 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
2702 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
2703 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
2704 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2705
2706 c3 += 4;
2707 c2 += 4;
2708 c1 += 4;
2709 c0 += 4;
2710 }
2711 if (nc & 2) {
2712 _mm_storel_pi((__m64*) c3, vacc3x0123);
2713 _mm_storel_pi((__m64*) c2, vacc2x0123);
2714 _mm_storel_pi((__m64*) c1, vacc1x0123);
2715 _mm_storel_pi((__m64*) c0, vacc0x0123);
2716
2717 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
2718 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
2719 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
2720 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2721
2722 c3 += 2;
2723 c2 += 2;
2724 c1 += 2;
2725 c0 += 2;
2726 }
2727 if (nc & 1) {
2728 _mm_store_ss(c3, vacc3x0123);
2729 _mm_store_ss(c2, vacc2x0123);
2730 _mm_store_ss(c1, vacc1x0123);
2731 _mm_store_ss(c0, vacc0x0123);
2732 }
2733
2734 nc = 0;
2735 }
2736 } while (nc != 0);
2737}
2738
2739void xnn_f32_gemm_minmax_ukernel_5x16__fma3_broadcast(
2740 size_t mr,
2741 size_t nc,
2742 size_t kc,
2743 const float*restrict a,
2744 size_t a_stride,
2745 const float*restrict w,
2746 float*restrict c,
2747 size_t cm_stride,
2748 size_t cn_stride,
2749 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2750{
2751 assert(mr != 0);
2752 assert(mr <= 5);
2753 assert(nc != 0);
2754 assert(kc != 0);
2755 assert(kc % sizeof(float) == 0);
2756 assert(a != NULL);
2757 assert(w != NULL);
2758 assert(c != NULL);
2759
2760 const float* a0 = a;
2761 float* c0 = c;
2762 const float* a1 = (const float*) ((uintptr_t) a0 + a_stride);
2763 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
2764 if XNN_UNPREDICTABLE(mr < 2) {
2765 a1 = a0;
2766 c1 = c0;
2767 }
2768 const float* a2 = (const float*) ((uintptr_t) a1 + a_stride);
2769 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
2770 if XNN_UNPREDICTABLE(mr <= 2) {
2771 a2 = a1;
2772 c2 = c1;
2773 }
2774 const float* a3 = (const float*) ((uintptr_t) a2 + a_stride);
2775 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
2776 if XNN_UNPREDICTABLE(mr < 4) {
2777 a3 = a2;
2778 c3 = c2;
2779 }
2780 const float* a4 = (const float*) ((uintptr_t) a3 + a_stride);
2781 float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
2782 if XNN_UNPREDICTABLE(mr <= 4) {
2783 a4 = a3;
2784 c4 = c3;
2785 }
2786
2787 do {
2788 __m256 vacc0x01234567 = _mm256_load_ps(w + 0);
2789 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2790 __m256 vacc1x01234567 = vacc0x01234567;
2791 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
2792 __m256 vacc2x01234567 = vacc0x01234567;
2793 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
2794 __m256 vacc3x01234567 = vacc0x01234567;
2795 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
2796 __m256 vacc4x01234567 = vacc0x01234567;
2797 __m256 vacc4x89ABCDEF = vacc0x89ABCDEF;
2798 w += 16;
2799
2800 size_t k = kc;
2801 do {
2802 const __m256 va0 = _mm256_broadcast_ss(a0);
2803 a0 += 1;
2804 const __m256 va1 = _mm256_broadcast_ss(a1);
2805 a1 += 1;
2806 const __m256 va2 = _mm256_broadcast_ss(a2);
2807 a2 += 1;
2808 const __m256 va3 = _mm256_broadcast_ss(a3);
2809 a3 += 1;
2810 const __m256 va4 = _mm256_broadcast_ss(a4);
2811 a4 += 1;
2812
2813 const __m256 vb01234567 = _mm256_load_ps(w);
2814 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
2815 w += 16;
2816
2817 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
2818 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567, vacc1x01234567);
2819 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567, vacc2x01234567);
2820 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567, vacc3x01234567);
2821 vacc4x01234567 = _mm256_fmadd_ps(va4, vb01234567, vacc4x01234567);
2822 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
2823 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF);
2824 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF);
2825 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF);
2826 vacc4x89ABCDEF = _mm256_fmadd_ps(va4, vb89ABCDEF, vacc4x89ABCDEF);
2827
2828 k -= sizeof(float);
2829 } while (k != 0);
2830
2831 const __m256 vmin = _mm256_load_ps(params->avx.min);
2832 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
2833 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
2834 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
2835 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
2836 vacc4x01234567 = _mm256_max_ps(vacc4x01234567, vmin);
2837 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
2838 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
2839 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
2840 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
2841 vacc4x89ABCDEF = _mm256_max_ps(vacc4x89ABCDEF, vmin);
2842
2843 const __m256 vmax = _mm256_load_ps(params->avx.max);
2844 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
2845 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
2846 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
2847 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
2848 vacc4x01234567 = _mm256_min_ps(vacc4x01234567, vmax);
2849 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
2850 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
2851 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
2852 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
2853 vacc4x89ABCDEF = _mm256_min_ps(vacc4x89ABCDEF, vmax);
2854
2855 if XNN_LIKELY(nc >= 16) {
2856 _mm256_storeu_ps(c4, vacc4x01234567);
2857 _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF);
2858 c4 = (float*) ((uintptr_t) c4 + cn_stride);
2859 _mm256_storeu_ps(c3, vacc3x01234567);
2860 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
2861 c3 = (float*) ((uintptr_t) c3 + cn_stride);
2862 _mm256_storeu_ps(c2, vacc2x01234567);
2863 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
2864 c2 = (float*) ((uintptr_t) c2 + cn_stride);
2865 _mm256_storeu_ps(c1, vacc1x01234567);
2866 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
2867 c1 = (float*) ((uintptr_t) c1 + cn_stride);
2868 _mm256_storeu_ps(c0, vacc0x01234567);
2869 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
2870 c0 = (float*) ((uintptr_t) c0 + cn_stride);
2871
2872 a4 = (const float*) ((uintptr_t) a4 - kc);
2873 a3 = (const float*) ((uintptr_t) a3 - kc);
2874 a2 = (const float*) ((uintptr_t) a2 - kc);
2875 a1 = (const float*) ((uintptr_t) a1 - kc);
2876 a0 = (const float*) ((uintptr_t) a0 - kc);
2877
2878 nc -= 16;
2879 } else {
2880 if (nc & 8) {
2881 _mm256_storeu_ps(c4, vacc4x01234567);
2882 _mm256_storeu_ps(c3, vacc3x01234567);
2883 _mm256_storeu_ps(c2, vacc2x01234567);
2884 _mm256_storeu_ps(c1, vacc1x01234567);
2885 _mm256_storeu_ps(c0, vacc0x01234567);
2886
2887 vacc4x01234567 = vacc4x89ABCDEF;
2888 vacc3x01234567 = vacc3x89ABCDEF;
2889 vacc2x01234567 = vacc2x89ABCDEF;
2890 vacc1x01234567 = vacc1x89ABCDEF;
2891 vacc0x01234567 = vacc0x89ABCDEF;
2892
2893 c4 += 8;
2894 c3 += 8;
2895 c2 += 8;
2896 c1 += 8;
2897 c0 += 8;
2898 }
2899 __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567);
2900 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
2901 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
2902 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
2903 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
2904 if (nc & 4) {
2905 _mm_storeu_ps(c4, vacc4x0123);
2906 _mm_storeu_ps(c3, vacc3x0123);
2907 _mm_storeu_ps(c2, vacc2x0123);
2908 _mm_storeu_ps(c1, vacc1x0123);
2909 _mm_storeu_ps(c0, vacc0x0123);
2910
2911 vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1);
2912 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
2913 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
2914 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
2915 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
2916
2917 c4 += 4;
2918 c3 += 4;
2919 c2 += 4;
2920 c1 += 4;
2921 c0 += 4;
2922 }
2923 if (nc & 2) {
2924 _mm_storel_pi((__m64*) c4, vacc4x0123);
2925 _mm_storel_pi((__m64*) c3, vacc3x0123);
2926 _mm_storel_pi((__m64*) c2, vacc2x0123);
2927 _mm_storel_pi((__m64*) c1, vacc1x0123);
2928 _mm_storel_pi((__m64*) c0, vacc0x0123);
2929
2930 vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123);
2931 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
2932 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
2933 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
2934 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
2935
2936 c4 += 2;
2937 c3 += 2;
2938 c2 += 2;
2939 c1 += 2;
2940 c0 += 2;
2941 }
2942 if (nc & 1) {
2943 _mm_store_ss(c4, vacc4x0123);
2944 _mm_store_ss(c3, vacc3x0123);
2945 _mm_store_ss(c2, vacc2x0123);
2946 _mm_store_ss(c1, vacc1x0123);
2947 _mm_store_ss(c0, vacc0x0123);
2948 }
2949
2950 nc = 0;
2951 }
2952 } while (nc != 0);
2953}
2954
2955void xnn_f32_igemm_minmax_ukernel_1x16__fma3_broadcast(
2956 size_t mr,
2957 size_t nc,
2958 size_t kc,
2959 size_t ks,
2960 const float**restrict a,
2961 const float*restrict w,
2962 float*restrict c,
2963 size_t cm_stride,
2964 size_t cn_stride,
2965 size_t a_offset,
2966 const float* zero,
2967 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
2968{
2969 assert(mr != 0);
2970 assert(mr <= 1);
2971 assert(nc != 0);
2972 assert(kc != 0);
2973 assert(kc % sizeof(float) == 0);
2974 assert(ks != 0);
2975 assert(ks % (1 * sizeof(void*)) == 0);
2976 assert(a_offset % sizeof(float) == 0);
2977 assert(a != NULL);
2978 assert(w != NULL);
2979 assert(c != NULL);
2980
2981 float* c0 = c;
2982
2983 do {
2984 __m256 vacc0x01234567 = _mm256_load_ps(w);
2985 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
2986 w += 16;
2987
2988 size_t p = ks;
2989 do {
2990 const float* restrict a0 = a[0];
2991 assert(a0 != NULL);
2992 if XNN_UNPREDICTABLE(a0 != zero) {
2993 a0 = (const float*) ((uintptr_t) a0 + a_offset);
2994 }
2995 a += 1;
2996
2997 size_t k = kc;
2998 do {
2999 const __m256 vb01234567 = _mm256_load_ps(w);
3000 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
3001 w += 16;
3002
3003 const __m256 va0 = _mm256_broadcast_ss(a0);
3004 a0 += 1;
3005
3006 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
3007 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
3008 k -= sizeof(float);
3009 } while (k != 0);
3010 p -= 1 * sizeof(void*);
3011 } while (p != 0);
3012
3013 const __m256 vmin = _mm256_load_ps(params->avx.min);
3014 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3015 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3016
3017 const __m256 vmax = _mm256_load_ps(params->avx.max);
3018 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3019 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3020
3021 if XNN_LIKELY(nc >= 16) {
3022 _mm256_storeu_ps(c0, vacc0x01234567);
3023 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3024 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3025
3026 a = (const float**restrict) ((uintptr_t) a - ks);
3027 nc -= 16;
3028 } else {
3029 if (nc & 8) {
3030 _mm256_storeu_ps(c0, vacc0x01234567);
3031
3032 vacc0x01234567 = vacc0x89ABCDEF;
3033
3034 c0 += 8;
3035 }
3036 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3037 if (nc & 4) {
3038 _mm_storeu_ps(c0, vacc0x0123);
3039
3040 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3041
3042 c0 += 4;
3043 }
3044 if (nc & 2) {
3045 _mm_storel_pi((__m64*) c0, vacc0x0123);
3046
3047 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3048
3049 c0 += 2;
3050 }
3051 if (nc & 1) {
3052 _mm_store_ss(c0, vacc0x0123);
3053 }
3054
3055 nc = 0;
3056 }
3057 } while (nc != 0);
3058}
3059
3060void xnn_f32_igemm_minmax_ukernel_1x16s4__fma3_broadcast(
3061 size_t mr,
3062 size_t nc,
3063 size_t kc,
3064 size_t ks,
3065 const float**restrict a,
3066 const float*restrict w,
3067 float*restrict c,
3068 size_t cm_stride,
3069 size_t cn_stride,
3070 size_t a_offset,
3071 const float* zero,
Marat Dukhan8319baa2022-01-31 20:14:13 -08003072 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003073{
3074 assert(mr != 0);
3075 assert(mr <= 1);
3076 assert(nc != 0);
3077 assert(kc != 0);
3078 assert(kc % sizeof(float) == 0);
3079 assert(ks != 0);
3080 assert(ks % (1 * sizeof(void*)) == 0);
3081 assert(a_offset % sizeof(float) == 0);
3082 assert(a != NULL);
3083 assert(w != NULL);
3084 assert(c != NULL);
3085
3086 float* c0 = c;
3087
3088 do {
3089 __m256 vacc0x01234567 = _mm256_load_ps(w);
3090 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3091 w += 16;
3092
3093 size_t p = ks;
3094 do {
3095 const float* restrict a0 = a[0];
3096 assert(a0 != NULL);
3097 if XNN_UNPREDICTABLE(a0 != zero) {
3098 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3099 }
3100 a += 1;
3101
3102 size_t k = kc;
3103 while (k >= 4 * sizeof(float)) {
3104 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3105 a0 += 4;
3106
3107
3108 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3109 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
3110
3111 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
3112 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
3113
3114 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3115
3116 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3117 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3118
3119 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
3120 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
3121
3122 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3123
3124 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3125 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3126
3127 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
3128 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
3129
3130 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3131
3132 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3133 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3134
3135 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
3136 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
3137
3138
3139 w += 64;
3140 k -= 4 * sizeof(float);
3141 }
3142 if XNN_UNLIKELY(k != 0) {
Marat Dukhan8319baa2022-01-31 20:14:13 -08003143 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3144 a0 = (const float*) ((uintptr_t) a0 + k);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003145
Marat Dukhan8319baa2022-01-31 20:14:13 -08003146 const __m256 vzero = _mm256_setzero_ps();
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003147
Marat Dukhan8319baa2022-01-31 20:14:13 -08003148 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3149 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003150
Marat Dukhan8319baa2022-01-31 20:14:13 -08003151 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
3152 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
3153
3154 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3155
3156 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3157 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3158
3159 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
3160 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
3161
3162 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3163
3164 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3165 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3166
3167 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
3168 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
3169
3170 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3171
3172 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3173 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3174
3175 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
3176 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
3177
3178
3179 w += 64;
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003180 }
3181 p -= 1 * sizeof(void*);
3182 } while (p != 0);
3183
3184 const __m256 vmin = _mm256_load_ps(params->avx.min);
3185 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3186 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3187
3188 const __m256 vmax = _mm256_load_ps(params->avx.max);
3189 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3190 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3191
3192 if XNN_LIKELY(nc >= 16) {
3193 _mm256_storeu_ps(c0, vacc0x01234567);
3194 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3195 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3196
3197 a = (const float**restrict) ((uintptr_t) a - ks);
3198 nc -= 16;
3199 } else {
3200 if (nc & 8) {
3201 _mm256_storeu_ps(c0, vacc0x01234567);
3202
3203 vacc0x01234567 = vacc0x89ABCDEF;
3204
3205 c0 += 8;
3206 }
3207 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3208 if (nc & 4) {
3209 _mm_storeu_ps(c0, vacc0x0123);
3210
3211 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3212
3213 c0 += 4;
3214 }
3215 if (nc & 2) {
3216 _mm_storel_pi((__m64*) c0, vacc0x0123);
3217
3218 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3219
3220 c0 += 2;
3221 }
3222 if (nc & 1) {
3223 _mm_store_ss(c0, vacc0x0123);
3224 }
3225
3226 nc = 0;
3227 }
3228 } while (nc != 0);
3229}
3230
3231void xnn_f32_igemm_minmax_ukernel_4x16s4__fma3_broadcast(
3232 size_t mr,
3233 size_t nc,
3234 size_t kc,
3235 size_t ks,
3236 const float**restrict a,
3237 const float*restrict w,
3238 float*restrict c,
3239 size_t cm_stride,
3240 size_t cn_stride,
3241 size_t a_offset,
3242 const float* zero,
Marat Dukhan8319baa2022-01-31 20:14:13 -08003243 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003244{
3245 assert(mr != 0);
3246 assert(mr <= 4);
3247 assert(nc != 0);
3248 assert(kc != 0);
3249 assert(kc % sizeof(float) == 0);
3250 assert(ks != 0);
3251 assert(ks % (4 * sizeof(void*)) == 0);
3252 assert(a_offset % sizeof(float) == 0);
3253 assert(a != NULL);
3254 assert(w != NULL);
3255 assert(c != NULL);
3256
3257 float* c0 = c;
3258 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
3259 if XNN_UNPREDICTABLE(mr < 2) {
3260 c1 = c0;
3261 }
3262 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
3263 if XNN_UNPREDICTABLE(mr <= 2) {
3264 c2 = c1;
3265 }
3266 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
3267 if XNN_UNPREDICTABLE(mr != 4) {
3268 c3 = c2;
3269 }
3270
3271 do {
3272 __m256 vacc0x01234567 = _mm256_load_ps(w);
3273 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3274 __m256 vacc1x01234567 = vacc0x01234567;
3275 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
3276 __m256 vacc2x01234567 = vacc0x01234567;
3277 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
3278 __m256 vacc3x01234567 = vacc0x01234567;
3279 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
3280 w += 16;
3281
3282 size_t p = ks;
3283 do {
3284 const float* restrict a0 = a[0];
3285 assert(a0 != NULL);
3286 if XNN_UNPREDICTABLE(a0 != zero) {
3287 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3288 }
3289 const float* restrict a1 = a[1];
3290 assert(a1 != NULL);
3291 if XNN_UNPREDICTABLE(a1 != zero) {
3292 a1 = (const float*) ((uintptr_t) a1 + a_offset);
3293 }
3294 const float* restrict a2 = a[2];
3295 assert(a2 != NULL);
3296 if XNN_UNPREDICTABLE(a2 != zero) {
3297 a2 = (const float*) ((uintptr_t) a2 + a_offset);
3298 }
3299 const float* restrict a3 = a[3];
3300 assert(a3 != NULL);
3301 if XNN_UNPREDICTABLE(a3 != zero) {
3302 a3 = (const float*) ((uintptr_t) a3 + a_offset);
3303 }
3304 a += 4;
3305
3306 size_t k = kc;
3307 while (k >= 4 * sizeof(float)) {
3308 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3309 a0 += 4;
3310 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
3311 a1 += 4;
3312 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
3313 a2 += 4;
3314 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
3315 a3 += 4;
3316
3317
3318 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3319 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
3320
3321 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c0, vacc0x01234567);
3322 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c0, vacc1x01234567);
3323 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c0, vacc2x01234567);
3324 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c0, vacc3x01234567);
3325 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc0, vacc0x89ABCDEF);
3326 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc0, vacc1x89ABCDEF);
3327 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc0, vacc2x89ABCDEF);
3328 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc0, vacc3x89ABCDEF);
3329
3330 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3331 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3332 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3333 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3334
3335 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3336 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3337
3338 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c1, vacc0x01234567);
3339 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c1, vacc1x01234567);
3340 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c1, vacc2x01234567);
3341 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c1, vacc3x01234567);
3342 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc1, vacc0x89ABCDEF);
3343 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc1, vacc1x89ABCDEF);
3344 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc1, vacc2x89ABCDEF);
3345 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc1, vacc3x89ABCDEF);
3346
3347 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3348 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3349 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3350 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3351
3352 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3353 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3354
3355 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c2, vacc0x01234567);
3356 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c2, vacc1x01234567);
3357 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c2, vacc2x01234567);
3358 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c2, vacc3x01234567);
3359 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc2, vacc0x89ABCDEF);
3360 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc2, vacc1x89ABCDEF);
3361 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc2, vacc2x89ABCDEF);
3362 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc2, vacc3x89ABCDEF);
3363
3364 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3365 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3366 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3367 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3368
3369 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3370 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3371
3372 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567c3, vacc0x01234567);
3373 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567c3, vacc1x01234567);
3374 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567c3, vacc2x01234567);
3375 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567c3, vacc3x01234567);
3376 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEFc3, vacc0x89ABCDEF);
3377 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEFc3, vacc1x89ABCDEF);
3378 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEFc3, vacc2x89ABCDEF);
3379 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEFc3, vacc3x89ABCDEF);
3380
3381
3382 w += 64;
3383 k -= 4 * sizeof(float);
3384 }
3385 if XNN_UNLIKELY(k != 0) {
Marat Dukhan8319baa2022-01-31 20:14:13 -08003386 __m256 va0 = _mm256_broadcast_ps((const __m128*) a0);
3387 a0 = (const float*) ((uintptr_t) a0 + k);
3388 __m256 va1 = _mm256_broadcast_ps((const __m128*) a1);
3389 a1 = (const float*) ((uintptr_t) a1 + k);
3390 __m256 va2 = _mm256_broadcast_ps((const __m128*) a2);
3391 a2 = (const float*) ((uintptr_t) a2 + k);
3392 __m256 va3 = _mm256_broadcast_ps((const __m128*) a3);
3393 a3 = (const float*) ((uintptr_t) a3 + k);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003394
Marat Dukhan8319baa2022-01-31 20:14:13 -08003395 const __m256 vzero = _mm256_setzero_ps();
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003396
Marat Dukhan8319baa2022-01-31 20:14:13 -08003397 const __m256 vb01234567c0 = _mm256_load_ps(w + 0);
3398 const __m256 vb89ABCDEFc0 = _mm256_load_ps(w + 8);
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003399
Marat Dukhan8319baa2022-01-31 20:14:13 -08003400 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc0x01234567);
3401 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc1x01234567);
3402 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc2x01234567);
3403 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c0, vzero, _CMP_NEQ_OQ)), vb01234567c0, vacc3x01234567);
3404 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc0x89ABCDEF);
3405 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc1x89ABCDEF);
3406 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc2x89ABCDEF);
3407 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc0, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc0, vacc3x89ABCDEF);
3408
3409 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3410 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3411 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3412 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3413
3414 const __m256 vb01234567c1 = _mm256_load_ps(w + 16);
3415 const __m256 vb89ABCDEFc1 = _mm256_load_ps(w + 24);
3416
3417 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc0x01234567);
3418 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc1x01234567);
3419 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc2x01234567);
3420 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c1, vzero, _CMP_NEQ_OQ)), vb01234567c1, vacc3x01234567);
3421 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc0x89ABCDEF);
3422 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc1x89ABCDEF);
3423 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc2x89ABCDEF);
3424 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc1, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc1, vacc3x89ABCDEF);
3425
3426 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3427 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3428 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3429 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3430
3431 const __m256 vb01234567c2 = _mm256_load_ps(w + 32);
3432 const __m256 vb89ABCDEFc2 = _mm256_load_ps(w + 40);
3433
3434 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc0x01234567);
3435 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc1x01234567);
3436 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc2x01234567);
3437 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c2, vzero, _CMP_NEQ_OQ)), vb01234567c2, vacc3x01234567);
3438 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc0x89ABCDEF);
3439 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc1x89ABCDEF);
3440 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc2x89ABCDEF);
3441 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc2, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc2, vacc3x89ABCDEF);
3442
3443 va0 = _mm256_permute_ps(va0, _MM_SHUFFLE(0, 3, 2, 1));
3444 va1 = _mm256_permute_ps(va1, _MM_SHUFFLE(0, 3, 2, 1));
3445 va2 = _mm256_permute_ps(va2, _MM_SHUFFLE(0, 3, 2, 1));
3446 va3 = _mm256_permute_ps(va3, _MM_SHUFFLE(0, 3, 2, 1));
3447
3448 const __m256 vb01234567c3 = _mm256_load_ps(w + 48);
3449 const __m256 vb89ABCDEFc3 = _mm256_load_ps(w + 56);
3450
3451 vacc0x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc0x01234567);
3452 vacc1x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc1x01234567);
3453 vacc2x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc2x01234567);
3454 vacc3x01234567 = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb01234567c3, vzero, _CMP_NEQ_OQ)), vb01234567c3, vacc3x01234567);
3455 vacc0x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va0, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc0x89ABCDEF);
3456 vacc1x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va1, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc1x89ABCDEF);
3457 vacc2x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va2, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc2x89ABCDEF);
3458 vacc3x89ABCDEF = _mm256_fmadd_ps(_mm256_and_ps(va3, _mm256_cmp_ps(vb89ABCDEFc3, vzero, _CMP_NEQ_OQ)), vb89ABCDEFc3, vacc3x89ABCDEF);
3459
3460
3461 w += 64;
Marat Dukhan8a9eac62022-01-06 09:22:01 -08003462 }
3463 p -= 4 * sizeof(void*);
3464 } while (p != 0);
3465
3466 const __m256 vmin = _mm256_load_ps(params->avx.min);
3467 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3468 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
3469 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
3470 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
3471 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3472 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
3473 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
3474 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
3475
3476 const __m256 vmax = _mm256_load_ps(params->avx.max);
3477 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3478 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
3479 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
3480 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
3481 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3482 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
3483 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
3484 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
3485
3486 if XNN_LIKELY(nc >= 16) {
3487 _mm256_storeu_ps(c3, vacc3x01234567);
3488 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
3489 c3 = (float*) ((uintptr_t) c3 + cn_stride);
3490 _mm256_storeu_ps(c2, vacc2x01234567);
3491 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
3492 c2 = (float*) ((uintptr_t) c2 + cn_stride);
3493 _mm256_storeu_ps(c1, vacc1x01234567);
3494 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
3495 c1 = (float*) ((uintptr_t) c1 + cn_stride);
3496 _mm256_storeu_ps(c0, vacc0x01234567);
3497 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3498 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3499
3500 a = (const float**restrict) ((uintptr_t) a - ks);
3501 nc -= 16;
3502 } else {
3503 if (nc & 8) {
3504 _mm256_storeu_ps(c3, vacc3x01234567);
3505 _mm256_storeu_ps(c2, vacc2x01234567);
3506 _mm256_storeu_ps(c1, vacc1x01234567);
3507 _mm256_storeu_ps(c0, vacc0x01234567);
3508
3509 vacc3x01234567 = vacc3x89ABCDEF;
3510 vacc2x01234567 = vacc2x89ABCDEF;
3511 vacc1x01234567 = vacc1x89ABCDEF;
3512 vacc0x01234567 = vacc0x89ABCDEF;
3513
3514 c3 += 8;
3515 c2 += 8;
3516 c1 += 8;
3517 c0 += 8;
3518 }
3519 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
3520 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
3521 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
3522 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3523 if (nc & 4) {
3524 _mm_storeu_ps(c3, vacc3x0123);
3525 _mm_storeu_ps(c2, vacc2x0123);
3526 _mm_storeu_ps(c1, vacc1x0123);
3527 _mm_storeu_ps(c0, vacc0x0123);
3528
3529 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
3530 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
3531 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
3532 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3533
3534 c3 += 4;
3535 c2 += 4;
3536 c1 += 4;
3537 c0 += 4;
3538 }
3539 if (nc & 2) {
3540 _mm_storel_pi((__m64*) c3, vacc3x0123);
3541 _mm_storel_pi((__m64*) c2, vacc2x0123);
3542 _mm_storel_pi((__m64*) c1, vacc1x0123);
3543 _mm_storel_pi((__m64*) c0, vacc0x0123);
3544
3545 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
3546 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
3547 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
3548 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3549
3550 c3 += 2;
3551 c2 += 2;
3552 c1 += 2;
3553 c0 += 2;
3554 }
3555 if (nc & 1) {
3556 _mm_store_ss(c3, vacc3x0123);
3557 _mm_store_ss(c2, vacc2x0123);
3558 _mm_store_ss(c1, vacc1x0123);
3559 _mm_store_ss(c0, vacc0x0123);
3560 }
3561
3562 nc = 0;
3563 }
3564 } while (nc != 0);
3565}
3566
3567void xnn_f32_igemm_minmax_ukernel_5x16__fma3_broadcast(
3568 size_t mr,
3569 size_t nc,
3570 size_t kc,
3571 size_t ks,
3572 const float**restrict a,
3573 const float*restrict w,
3574 float*restrict c,
3575 size_t cm_stride,
3576 size_t cn_stride,
3577 size_t a_offset,
3578 const float* zero,
3579 const union xnn_f32_minmax_params params[restrict XNN_MIN_ELEMENTS(1)])
3580{
3581 assert(mr != 0);
3582 assert(mr <= 5);
3583 assert(nc != 0);
3584 assert(kc != 0);
3585 assert(kc % sizeof(float) == 0);
3586 assert(ks != 0);
3587 assert(ks % (5 * sizeof(void*)) == 0);
3588 assert(a_offset % sizeof(float) == 0);
3589 assert(a != NULL);
3590 assert(w != NULL);
3591 assert(c != NULL);
3592
3593 float* c0 = c;
3594 float* c1 = (float*) ((uintptr_t) c0 + cm_stride);
3595 if XNN_UNPREDICTABLE(mr < 2) {
3596 c1 = c0;
3597 }
3598 float* c2 = (float*) ((uintptr_t) c1 + cm_stride);
3599 if XNN_UNPREDICTABLE(mr <= 2) {
3600 c2 = c1;
3601 }
3602 float* c3 = (float*) ((uintptr_t) c2 + cm_stride);
3603 if XNN_UNPREDICTABLE(mr < 4) {
3604 c3 = c2;
3605 }
3606 float* c4 = (float*) ((uintptr_t) c3 + cm_stride);
3607 if XNN_UNPREDICTABLE(mr <= 4) {
3608 c4 = c3;
3609 }
3610
3611 do {
3612 __m256 vacc0x01234567 = _mm256_load_ps(w);
3613 __m256 vacc0x89ABCDEF = _mm256_load_ps(w + 8);
3614 __m256 vacc1x01234567 = vacc0x01234567;
3615 __m256 vacc1x89ABCDEF = vacc0x89ABCDEF;
3616 __m256 vacc2x01234567 = vacc0x01234567;
3617 __m256 vacc2x89ABCDEF = vacc0x89ABCDEF;
3618 __m256 vacc3x01234567 = vacc0x01234567;
3619 __m256 vacc3x89ABCDEF = vacc0x89ABCDEF;
3620 __m256 vacc4x01234567 = vacc0x01234567;
3621 __m256 vacc4x89ABCDEF = vacc0x89ABCDEF;
3622 w += 16;
3623
3624 size_t p = ks;
3625 do {
3626 const float* restrict a0 = a[0];
3627 assert(a0 != NULL);
3628 if XNN_UNPREDICTABLE(a0 != zero) {
3629 a0 = (const float*) ((uintptr_t) a0 + a_offset);
3630 }
3631 const float* restrict a1 = a[1];
3632 assert(a1 != NULL);
3633 if XNN_UNPREDICTABLE(a1 != zero) {
3634 a1 = (const float*) ((uintptr_t) a1 + a_offset);
3635 }
3636 const float* restrict a2 = a[2];
3637 assert(a2 != NULL);
3638 if XNN_UNPREDICTABLE(a2 != zero) {
3639 a2 = (const float*) ((uintptr_t) a2 + a_offset);
3640 }
3641 const float* restrict a3 = a[3];
3642 assert(a3 != NULL);
3643 if XNN_UNPREDICTABLE(a3 != zero) {
3644 a3 = (const float*) ((uintptr_t) a3 + a_offset);
3645 }
3646 const float* restrict a4 = a[4];
3647 assert(a4 != NULL);
3648 if XNN_UNPREDICTABLE(a4 != zero) {
3649 a4 = (const float*) ((uintptr_t) a4 + a_offset);
3650 }
3651 a += 5;
3652
3653 size_t k = kc;
3654 do {
3655 const __m256 vb01234567 = _mm256_load_ps(w);
3656 const __m256 vb89ABCDEF = _mm256_load_ps(w + 8);
3657 w += 16;
3658
3659 const __m256 va0 = _mm256_broadcast_ss(a0);
3660 a0 += 1;
3661 const __m256 va1 = _mm256_broadcast_ss(a1);
3662 a1 += 1;
3663 const __m256 va2 = _mm256_broadcast_ss(a2);
3664 a2 += 1;
3665 const __m256 va3 = _mm256_broadcast_ss(a3);
3666 a3 += 1;
3667 const __m256 va4 = _mm256_broadcast_ss(a4);
3668 a4 += 1;
3669
3670 vacc0x01234567 = _mm256_fmadd_ps(va0, vb01234567, vacc0x01234567);
3671 vacc0x89ABCDEF = _mm256_fmadd_ps(va0, vb89ABCDEF, vacc0x89ABCDEF);
3672 vacc1x01234567 = _mm256_fmadd_ps(va1, vb01234567, vacc1x01234567);
3673 vacc1x89ABCDEF = _mm256_fmadd_ps(va1, vb89ABCDEF, vacc1x89ABCDEF);
3674 vacc2x01234567 = _mm256_fmadd_ps(va2, vb01234567, vacc2x01234567);
3675 vacc2x89ABCDEF = _mm256_fmadd_ps(va2, vb89ABCDEF, vacc2x89ABCDEF);
3676 vacc3x01234567 = _mm256_fmadd_ps(va3, vb01234567, vacc3x01234567);
3677 vacc3x89ABCDEF = _mm256_fmadd_ps(va3, vb89ABCDEF, vacc3x89ABCDEF);
3678 vacc4x01234567 = _mm256_fmadd_ps(va4, vb01234567, vacc4x01234567);
3679 vacc4x89ABCDEF = _mm256_fmadd_ps(va4, vb89ABCDEF, vacc4x89ABCDEF);
3680 k -= sizeof(float);
3681 } while (k != 0);
3682 p -= 5 * sizeof(void*);
3683 } while (p != 0);
3684
3685 const __m256 vmin = _mm256_load_ps(params->avx.min);
3686 vacc0x01234567 = _mm256_max_ps(vacc0x01234567, vmin);
3687 vacc1x01234567 = _mm256_max_ps(vacc1x01234567, vmin);
3688 vacc2x01234567 = _mm256_max_ps(vacc2x01234567, vmin);
3689 vacc3x01234567 = _mm256_max_ps(vacc3x01234567, vmin);
3690 vacc4x01234567 = _mm256_max_ps(vacc4x01234567, vmin);
3691 vacc0x89ABCDEF = _mm256_max_ps(vacc0x89ABCDEF, vmin);
3692 vacc1x89ABCDEF = _mm256_max_ps(vacc1x89ABCDEF, vmin);
3693 vacc2x89ABCDEF = _mm256_max_ps(vacc2x89ABCDEF, vmin);
3694 vacc3x89ABCDEF = _mm256_max_ps(vacc3x89ABCDEF, vmin);
3695 vacc4x89ABCDEF = _mm256_max_ps(vacc4x89ABCDEF, vmin);
3696
3697 const __m256 vmax = _mm256_load_ps(params->avx.max);
3698 vacc0x01234567 = _mm256_min_ps(vacc0x01234567, vmax);
3699 vacc1x01234567 = _mm256_min_ps(vacc1x01234567, vmax);
3700 vacc2x01234567 = _mm256_min_ps(vacc2x01234567, vmax);
3701 vacc3x01234567 = _mm256_min_ps(vacc3x01234567, vmax);
3702 vacc4x01234567 = _mm256_min_ps(vacc4x01234567, vmax);
3703 vacc0x89ABCDEF = _mm256_min_ps(vacc0x89ABCDEF, vmax);
3704 vacc1x89ABCDEF = _mm256_min_ps(vacc1x89ABCDEF, vmax);
3705 vacc2x89ABCDEF = _mm256_min_ps(vacc2x89ABCDEF, vmax);
3706 vacc3x89ABCDEF = _mm256_min_ps(vacc3x89ABCDEF, vmax);
3707 vacc4x89ABCDEF = _mm256_min_ps(vacc4x89ABCDEF, vmax);
3708
3709 if XNN_LIKELY(nc >= 16) {
3710 _mm256_storeu_ps(c4, vacc4x01234567);
3711 _mm256_storeu_ps(c4 + 8, vacc4x89ABCDEF);
3712 c4 = (float*) ((uintptr_t) c4 + cn_stride);
3713 _mm256_storeu_ps(c3, vacc3x01234567);
3714 _mm256_storeu_ps(c3 + 8, vacc3x89ABCDEF);
3715 c3 = (float*) ((uintptr_t) c3 + cn_stride);
3716 _mm256_storeu_ps(c2, vacc2x01234567);
3717 _mm256_storeu_ps(c2 + 8, vacc2x89ABCDEF);
3718 c2 = (float*) ((uintptr_t) c2 + cn_stride);
3719 _mm256_storeu_ps(c1, vacc1x01234567);
3720 _mm256_storeu_ps(c1 + 8, vacc1x89ABCDEF);
3721 c1 = (float*) ((uintptr_t) c1 + cn_stride);
3722 _mm256_storeu_ps(c0, vacc0x01234567);
3723 _mm256_storeu_ps(c0 + 8, vacc0x89ABCDEF);
3724 c0 = (float*) ((uintptr_t) c0 + cn_stride);
3725
3726 a = (const float**restrict) ((uintptr_t) a - ks);
3727 nc -= 16;
3728 } else {
3729 if (nc & 8) {
3730 _mm256_storeu_ps(c4, vacc4x01234567);
3731 _mm256_storeu_ps(c3, vacc3x01234567);
3732 _mm256_storeu_ps(c2, vacc2x01234567);
3733 _mm256_storeu_ps(c1, vacc1x01234567);
3734 _mm256_storeu_ps(c0, vacc0x01234567);
3735
3736 vacc4x01234567 = vacc4x89ABCDEF;
3737 vacc3x01234567 = vacc3x89ABCDEF;
3738 vacc2x01234567 = vacc2x89ABCDEF;
3739 vacc1x01234567 = vacc1x89ABCDEF;
3740 vacc0x01234567 = vacc0x89ABCDEF;
3741
3742 c4 += 8;
3743 c3 += 8;
3744 c2 += 8;
3745 c1 += 8;
3746 c0 += 8;
3747 }
3748 __m128 vacc4x0123 = _mm256_castps256_ps128(vacc4x01234567);
3749 __m128 vacc3x0123 = _mm256_castps256_ps128(vacc3x01234567);
3750 __m128 vacc2x0123 = _mm256_castps256_ps128(vacc2x01234567);
3751 __m128 vacc1x0123 = _mm256_castps256_ps128(vacc1x01234567);
3752 __m128 vacc0x0123 = _mm256_castps256_ps128(vacc0x01234567);
3753 if (nc & 4) {
3754 _mm_storeu_ps(c4, vacc4x0123);
3755 _mm_storeu_ps(c3, vacc3x0123);
3756 _mm_storeu_ps(c2, vacc2x0123);
3757 _mm_storeu_ps(c1, vacc1x0123);
3758 _mm_storeu_ps(c0, vacc0x0123);
3759
3760 vacc4x0123 = _mm256_extractf128_ps(vacc4x01234567, 1);
3761 vacc3x0123 = _mm256_extractf128_ps(vacc3x01234567, 1);
3762 vacc2x0123 = _mm256_extractf128_ps(vacc2x01234567, 1);
3763 vacc1x0123 = _mm256_extractf128_ps(vacc1x01234567, 1);
3764 vacc0x0123 = _mm256_extractf128_ps(vacc0x01234567, 1);
3765
3766 c4 += 4;
3767 c3 += 4;
3768 c2 += 4;
3769 c1 += 4;
3770 c0 += 4;
3771 }
3772 if (nc & 2) {
3773 _mm_storel_pi((__m64*) c4, vacc4x0123);
3774 _mm_storel_pi((__m64*) c3, vacc3x0123);
3775 _mm_storel_pi((__m64*) c2, vacc2x0123);
3776 _mm_storel_pi((__m64*) c1, vacc1x0123);
3777 _mm_storel_pi((__m64*) c0, vacc0x0123);
3778
3779 vacc4x0123 = _mm_movehl_ps(vacc4x0123, vacc4x0123);
3780 vacc3x0123 = _mm_movehl_ps(vacc3x0123, vacc3x0123);
3781 vacc2x0123 = _mm_movehl_ps(vacc2x0123, vacc2x0123);
3782 vacc1x0123 = _mm_movehl_ps(vacc1x0123, vacc1x0123);
3783 vacc0x0123 = _mm_movehl_ps(vacc0x0123, vacc0x0123);
3784
3785 c4 += 2;
3786 c3 += 2;
3787 c2 += 2;
3788 c1 += 2;
3789 c0 += 2;
3790 }
3791 if (nc & 1) {
3792 _mm_store_ss(c4, vacc4x0123);
3793 _mm_store_ss(c3, vacc3x0123);
3794 _mm_store_ss(c2, vacc2x0123);
3795 _mm_store_ss(c1, vacc1x0123);
3796 _mm_store_ss(c0, vacc0x0123);
3797 }
3798
3799 nc = 0;
3800 }
3801 } while (nc != 0);
3802}
3803
3804void xnn_f32_vhswish_ukernel__fma3_x16(
3805 size_t n,
3806 const float* x,
3807 float* y,
3808 const union xnn_f32_hswish_params params[restrict XNN_MIN_ELEMENTS(1)])
3809{
3810 assert(n != 0);
3811 assert(n % sizeof(float) == 0);
3812
3813 const __m256 vsixth = _mm256_load_ps(params->avx.sixth);
3814 const __m256 vhalf = _mm256_load_ps(params->avx.half);
3815 const __m256 vone = _mm256_load_ps(params->avx.one);
3816 const __m256 vzero = _mm256_setzero_ps();
3817
3818 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
3819 const __m256 vx01234567 = _mm256_loadu_ps(x);
3820 const __m256 vx89ABCDEF = _mm256_loadu_ps(x + 8);
3821 x += 16;
3822
3823 __m256 vacc01234567 = _mm256_fmadd_ps(vx01234567, vsixth, vhalf);
3824 __m256 vacc89ABCDEF = _mm256_fmadd_ps(vx89ABCDEF, vsixth, vhalf);
3825
3826 vacc01234567 = _mm256_max_ps(vacc01234567, vzero);
3827 vacc89ABCDEF = _mm256_max_ps(vacc89ABCDEF, vzero);
3828
3829 vacc01234567 = _mm256_min_ps(vacc01234567, vone);
3830 vacc89ABCDEF = _mm256_min_ps(vacc89ABCDEF, vone);
3831
3832 vacc01234567 = _mm256_mul_ps(vacc01234567, vx01234567);
3833 vacc89ABCDEF = _mm256_mul_ps(vacc89ABCDEF, vx89ABCDEF);
3834
3835 _mm256_storeu_ps(y, vacc01234567);
3836 _mm256_storeu_ps(y + 8, vacc89ABCDEF);
3837 y += 16;
3838 }
3839 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
3840 const __m256 vx = _mm256_loadu_ps(x);
3841 x += 8;
3842 __m256 vacc = _mm256_fmadd_ps(vx, vsixth, vhalf);
3843 vacc = _mm256_max_ps(vacc, vzero);
3844 vacc = _mm256_min_ps(vacc, vone);
3845 vacc = _mm256_mul_ps(vacc, vx);
3846 _mm256_storeu_ps(y, vacc);
3847 y += 8;
3848 }
3849 if XNN_UNLIKELY(n != 0) {
3850 assert(n >= 1 * sizeof(float));
3851 assert(n <= 7 * sizeof(float));
3852 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->avx.mask_table[7] - n));
3853
3854 const __m256 vx = _mm256_maskload_ps(x, vmask);
3855 __m256 vacc = _mm256_fmadd_ps(vx, vsixth, vhalf);
3856 vacc = _mm256_max_ps(vacc, vzero);
3857 vacc = _mm256_min_ps(vacc, vone);
3858 vacc = _mm256_mul_ps(vacc, vx);
3859
3860 __m128 vacc_lo = _mm256_castps256_ps128(vacc);
3861 if (n & (4 * sizeof(float))) {
3862 _mm_storeu_ps(y, vacc_lo);
3863 vacc_lo = _mm256_extractf128_ps(vacc, 1);
3864 y += 4;
3865 }
3866 if (n & (2 * sizeof(float))) {
3867 _mm_storel_pi((__m64*) y, vacc_lo);
3868 vacc_lo = _mm_movehl_ps(vacc_lo, vacc_lo);
3869 y += 2;
3870 }
3871 if (n & (1 * sizeof(float))) {
3872 _mm_store_ss(y, vacc_lo);
3873 }
3874 }
3875}