blob: 7cbaef95c43ce6cab4443085f8240ae07cc13599 [file] [log] [blame]
Marat Dukhan68db12e2022-01-05 15:11:49 -08001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#include <assert.h>
7
8#include <immintrin.h>
9
10#include <xnnpack/common.h>
Marat Dukhan8f920a62022-01-19 14:56:23 -080011#include <xnnpack/gavgpool.h>
Marat Dukhan68db12e2022-01-05 15:11:49 -080012#include <xnnpack/intrinsics-polyfill.h>
Marat Dukhan8f920a62022-01-19 14:56:23 -080013#include <xnnpack/math.h>
14#include <xnnpack/vbinary.h>
Marat Dukhan68db12e2022-01-05 15:11:49 -080015#include <xnnpack/vcvt.h>
Marat Dukhan8f920a62022-01-19 14:56:23 -080016#include <xnnpack/vunary.h>
Marat Dukhan68db12e2022-01-05 15:11:49 -080017
18
19void xnn_f16_f32_vcvt_ukernel__f16c_x16(
20 size_t n,
21 const void* input,
22 float* output,
23 const union xnn_f16_f32_cvt_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
24{
25 assert(n != 0);
26 assert(n % sizeof(uint16_t) == 0);
27 assert(input != NULL);
28 assert(output != NULL);
29
30 const uint16_t* i = (const uint16_t*) input;
31 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
32 const __m256 vacc0 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
33 const __m256 vacc1 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (i + 8)));
34 i += 16;
35
36 _mm256_storeu_ps(output, vacc0);
37 _mm256_storeu_ps(output + 8, vacc1);
38 output += 16;
39 }
40 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
41 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
42 i += 8;
43
44 _mm256_storeu_ps(output, vacc);
45 output += 8;
46 }
47 if XNN_UNLIKELY(n != 0) {
48 assert(n >= 1 * sizeof(uint16_t));
49 assert(n <= 7 * sizeof(uint16_t));
50 const __m256 vacc = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i));
51
52 __m128 vacc_lo = _mm256_castps256_ps128(vacc);
53 if (n & (4 * sizeof(uint16_t))) {
54 _mm_storeu_ps(output, vacc_lo);
55 vacc_lo = _mm256_extractf128_ps(vacc, 1);
56 output += 4;
57 }
58 if (n & (2 * sizeof(uint16_t))) {
59 _mm_storel_pi((__m64*) output, vacc_lo);
60 vacc_lo = _mm_movehl_ps(vacc_lo, vacc_lo);
61 output += 2;
62 }
63 if (n & (1 * sizeof(uint16_t))) {
64 _mm_store_ss(output, vacc_lo);
65 }
66 }
67}
68
Marat Dukhan8f920a62022-01-19 14:56:23 -080069void xnn_f16_gavgpool_minmax_ukernel_7p7x__f16c_c8(
70 size_t rows,
71 size_t channels,
72 const void* input,
73 size_t input_stride,
74 const void* zero,
75 void* buffer,
76 void* output,
77 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
78{
79 assert(rows > 7);
80 assert(channels != 0);
81
82 const uint16_t* i0 = input;
83 const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride);
84 const uint16_t* i2 = (const uint16_t*) ((uintptr_t) i1 + input_stride);
85 const uint16_t* i3 = (const uint16_t*) ((uintptr_t) i2 + input_stride);
86 const uint16_t* i4 = (const uint16_t*) ((uintptr_t) i3 + input_stride);
87 const uint16_t* i5 = (const uint16_t*) ((uintptr_t) i4 + input_stride);
88 const uint16_t* i6 = (const uint16_t*) ((uintptr_t) i5 + input_stride);
89 const size_t input_increment = 7 * input_stride - round_up_po2(channels, 8) * sizeof(uint16_t);
90
91 uint16_t* b = buffer;
92 size_t c = channels;
93 for (; c != 0; c = doz(c, 8)) {
94 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8;
95 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); i1 += 8;
96
97 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); i2 += 8;
98 __m128i vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(vi0x01234567, vi1x01234567), _MM_FROUND_NO_EXC);
99
100 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); i3 += 8;
101 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi2x01234567), _MM_FROUND_NO_EXC);
102 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4)); i4 += 8;
103 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi3x01234567), _MM_FROUND_NO_EXC);
104 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5)); i5 += 8;
105 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi4x01234567), _MM_FROUND_NO_EXC);
106 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6)); i6 += 8;
107 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi5x01234567), _MM_FROUND_NO_EXC);
108 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi6x01234567), _MM_FROUND_NO_EXC);
109
110 _mm_store_si128((__m128i*) b, vacc01234567); b += 8;
111 }
112
113 for (rows -= 7; rows > 7; rows -= 7) {
114 i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment);
115 i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment);
116 i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment);
117 i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment);
118 i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment);
119 i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment);
120 i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment);
121
122 uint16_t* b = buffer;
123 size_t c = channels;
124 for (; c != 0; c = doz(c, 8)) {
125 __m128i vacc01234567 = _mm_loadu_si128((const __m128i*) b);
126
127 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8;
128
129 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); i1 += 8;
130 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi0x01234567), _MM_FROUND_NO_EXC);
131 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); i2 += 8;
132 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi1x01234567), _MM_FROUND_NO_EXC);
133 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); i3 += 8;
134 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi2x01234567), _MM_FROUND_NO_EXC);
135 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4)); i4 += 8;
136 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi3x01234567), _MM_FROUND_NO_EXC);
137 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5)); i5 += 8;
138 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi4x01234567), _MM_FROUND_NO_EXC);
139 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6)); i6 += 8;
140 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi5x01234567), _MM_FROUND_NO_EXC);
141 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi6x01234567), _MM_FROUND_NO_EXC);
142
143 _mm_store_si128((__m128i*) b, vacc01234567); b += 8;
144 }
145 }
146
147 i0 = (const uint16_t*) ((uintptr_t) i0 + input_increment);
148 i1 = (const uint16_t*) ((uintptr_t) i1 + input_increment);
149 if XNN_UNPREDICTABLE(rows < 2) {
150 i1 = (const uint16_t*) zero;
151 }
152 i2 = (const uint16_t*) ((uintptr_t) i2 + input_increment);
153 if XNN_UNPREDICTABLE(rows <= 2) {
154 i2 = (const uint16_t*) zero;
155 }
156 i3 = (const uint16_t*) ((uintptr_t) i3 + input_increment);
157 if XNN_UNPREDICTABLE(rows < 4) {
158 i3 = (const uint16_t*) zero;
159 }
160 i4 = (const uint16_t*) ((uintptr_t) i4 + input_increment);
161 if XNN_UNPREDICTABLE(rows <= 4) {
162 i4 = (const uint16_t*) zero;
163 }
164 i5 = (const uint16_t*) ((uintptr_t) i5 + input_increment);
165 if XNN_UNPREDICTABLE(rows < 6) {
166 i5 = (const uint16_t*) zero;
167 }
168 i6 = (const uint16_t*) ((uintptr_t) i6 + input_increment);
169 if XNN_UNPREDICTABLE(rows <= 6) {
170 i6 = (const uint16_t*) zero;
171 }
172
173 const __m256 vscale = _mm256_load_ps(params->avx.scale);
174 const __m256 vmin = _mm256_load_ps(params->avx.min);
175 const __m256 vmax = _mm256_load_ps(params->avx.max);
176 for (; channels >= 8; channels -= 8) {
177 __m128i vacc01234567 = _mm_loadu_si128((const __m128i*) buffer); buffer = (uint16_t*) buffer + 8;
178
179 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8;
180
181 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); i1 += 8;
182 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi0x01234567), _MM_FROUND_NO_EXC);
183 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); i2 += 8;
184 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi1x01234567), _MM_FROUND_NO_EXC);
185 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); i3 += 8;
186 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi2x01234567), _MM_FROUND_NO_EXC);
187 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4)); i4 += 8;
188 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi3x01234567), _MM_FROUND_NO_EXC);
189 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5)); i5 += 8;
190 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi4x01234567), _MM_FROUND_NO_EXC);
191 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6)); i6 += 8;
192 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi5x01234567), _MM_FROUND_NO_EXC);
193 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi6x01234567), _MM_FROUND_NO_EXC);
194
195 vacc01234567 = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc01234567), vscale), _MM_FROUND_NO_EXC);
196
197 __m256 vout01234567 = _mm256_max_ps(_mm256_cvtph_ps(vacc01234567), vmin);
198
199 vout01234567 = _mm256_min_ps(vout01234567, vmax);
200
201 _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vout01234567, _MM_FROUND_NO_EXC));
202 output = (uint16_t*) output + 8;
203 }
204 if XNN_UNLIKELY(channels != 0) {
205 {
206 __m128i vacc01234567 = _mm_loadu_si128((const __m128i*) buffer); buffer = (uint16_t*) buffer + 8;
207
208 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0)); i0 += 8;
209 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1)); i1 += 8;
210 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi0x01234567), _MM_FROUND_NO_EXC);
211 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2)); i2 += 8;
212 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi1x01234567), _MM_FROUND_NO_EXC);
213 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3)); i3 += 8;
214 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi2x01234567), _MM_FROUND_NO_EXC);
215 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4)); i4 += 8;
216 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi3x01234567), _MM_FROUND_NO_EXC);
217 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5)); i5 += 8;
218 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi4x01234567), _MM_FROUND_NO_EXC);
219 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6)); i6 += 8;
220 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi5x01234567), _MM_FROUND_NO_EXC);
221 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi6x01234567), _MM_FROUND_NO_EXC);
222
223 vacc01234567 = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc01234567), vscale), _MM_FROUND_NO_EXC);
224 __m256 vout01234567 = _mm256_max_ps(_mm256_cvtph_ps(vacc01234567), vmin);
225 vout01234567 = _mm256_min_ps(vout01234567, vmax);
226
227 __m128i vh01234567 = _mm256_cvtps_ph(vout01234567, _MM_FROUND_NO_EXC);
228 if (channels & 4) {
229 _mm_storel_epi64((__m128i*) output, vh01234567);
230 output = (uint16_t*) output + 4;
231 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
232 }
233 if (channels & 2) {
234 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vh01234567);
235 output = (uint16_t*) output + 2;
236 vh01234567 = _mm_srli_epi64(vh01234567, 32);
237 }
238 if (channels & 1) {
239 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vh01234567, 0);
240 }
241 }
242 }
243}
244
245void xnn_f16_gavgpool_minmax_ukernel_7x__f16c_c8(
246 size_t rows,
247 size_t channels,
248 const void* input,
249 size_t input_stride,
250 const void* zero,
251 void* output,
252 const union xnn_f16_scaleminmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
253{
254 assert(rows != 0);
255 assert(rows <= 7);
256 assert(channels != 0);
257
258 const uint16_t* i0 = input;
259 const uint16_t* i1 = (const uint16_t*) ((uintptr_t) i0 + input_stride);
260 if XNN_UNPREDICTABLE(rows < 2) {
261 i1 = (const uint16_t*) zero;
262 }
263 const uint16_t* i2 = (const uint16_t*) ((uintptr_t) i1 + input_stride);
264 if XNN_UNPREDICTABLE(rows <= 2) {
265 i2 = (const uint16_t*) zero;
266 }
267 const uint16_t* i3 = (const uint16_t*) ((uintptr_t) i2 + input_stride);
268 if XNN_UNPREDICTABLE(rows < 4) {
269 i3 = (const uint16_t*) zero;
270 }
271 const uint16_t* i4 = (const uint16_t*) ((uintptr_t) i3 + input_stride);
272 if XNN_UNPREDICTABLE(rows <= 4) {
273 i4 = (const uint16_t*) zero;
274 }
275 const uint16_t* i5 = (const uint16_t*) ((uintptr_t) i4 + input_stride);
276 if XNN_UNPREDICTABLE(rows < 6) {
277 i5 = (const uint16_t*) zero;
278 }
279 const uint16_t* i6 = (const uint16_t*) ((uintptr_t) i5 + input_stride);
280 if XNN_UNPREDICTABLE(rows <= 6) {
281 i6 = (const uint16_t*) zero;
282 }
283
284 const __m256 vscale = _mm256_load_ps(params->avx.scale);
285 const __m256 vmin = _mm256_load_ps(params->avx.min);
286 const __m256 vmax = _mm256_load_ps(params->avx.max);
287 for (; channels >= 8; channels -= 8) {
288 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
289 i0 += 8;
290 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
291 i1 += 8;
292
293 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
294 __m128i vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(vi0x01234567, vi1x01234567), _MM_FROUND_NO_EXC);
295 i2 += 8;
296
297 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
298 i3 += 8;
299 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi2x01234567), _MM_FROUND_NO_EXC);
300 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
301 i4 += 8;
302 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi3x01234567), _MM_FROUND_NO_EXC);
303 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
304 i5 += 8;
305 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi4x01234567), _MM_FROUND_NO_EXC);
306 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
307 i6 += 8;
308 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi5x01234567), _MM_FROUND_NO_EXC);
309 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi6x01234567), _MM_FROUND_NO_EXC);
310
311 vacc01234567 = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc01234567), vscale), _MM_FROUND_NO_EXC);
312
313 __m256 vout01234567 = _mm256_max_ps(_mm256_cvtph_ps(vacc01234567), vmin);
314
315 vout01234567 = _mm256_min_ps(vout01234567, vmax);
316
317 _mm_storeu_si128((__m128i*) output, _mm256_cvtps_ph(vout01234567, _MM_FROUND_NO_EXC));
318 output = (uint16_t*) output + 8;
319 }
320 if XNN_UNLIKELY(channels != 0) {
321 {
322 const __m256 vi0x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i0));
323 const __m256 vi1x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i1));
324
325 const __m256 vi2x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i2));
326 __m128i vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(vi0x01234567, vi1x01234567), _MM_FROUND_NO_EXC);
327
328 const __m256 vi3x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i3));
329 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi2x01234567), _MM_FROUND_NO_EXC);
330 const __m256 vi4x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i4));
331 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi3x01234567), _MM_FROUND_NO_EXC);
332 const __m256 vi5x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i5));
333 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi4x01234567), _MM_FROUND_NO_EXC);
334 const __m256 vi6x01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) i6));
335 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi5x01234567), _MM_FROUND_NO_EXC);
336 vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(_mm256_cvtph_ps(vacc01234567), vi6x01234567), _MM_FROUND_NO_EXC);
337
338 vacc01234567 = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc01234567), vscale), _MM_FROUND_NO_EXC);
339 __m256 vout01234567 = _mm256_max_ps(_mm256_cvtph_ps(vacc01234567), vmin);
340 vout01234567 = _mm256_min_ps(vout01234567, vmax);
341
342 __m128i vh01234567 = _mm256_cvtps_ph(vout01234567, _MM_FROUND_NO_EXC);
343 if (channels & 4) {
344 _mm_storel_epi64((__m128i*) output, vh01234567);
345 output = (uint16_t*) output + 4;
346 vh01234567 = _mm_unpackhi_epi64(vh01234567, vh01234567);
347 }
348 if (channels & 2) {
349 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vh01234567);
350 output = (uint16_t*) output + 2;
351 vh01234567 = _mm_srli_epi64(vh01234567, 32);
352 }
353 if (channels & 1) {
354 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vh01234567, 0);
355 }
356 }
357 }
358}
359
360void xnn_f16_vadd_minmax_ukernel__f16c_x16(
361 size_t n,
362 const void* restrict a_ptr,
363 const void* restrict b_ptr,
364 void* restrict y_ptr,
365 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
366{
367 assert(n != 0);
368 assert(n % sizeof(uint16_t) == 0);
369 assert(a_ptr != NULL);
370 assert(b_ptr != NULL);
371 assert(y_ptr != NULL);
372
373 const uint16_t* a = (const uint16_t*) a_ptr;
374 const uint16_t* b = (const uint16_t*) b_ptr;
375 uint16_t* y = (uint16_t*) y_ptr;
376
377 const __m256 vy_min = _mm256_load_ps(params->avx.min);
378 const __m256 vy_max = _mm256_load_ps(params->avx.max);
379
380 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
381 const __m256 va01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
382 const __m256 vb01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
383 const __m256 va456789AB = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (a + 8)));
384 const __m256 vb456789AB = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (b + 8)));
385 a += 16;
386 b += 16;
387
388 __m256 vy01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va01234567, vb01234567), _MM_FROUND_NO_EXC));
389 __m256 vy456789AB = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va456789AB, vb456789AB), _MM_FROUND_NO_EXC));
390
391
392 vy01234567 = _mm256_max_ps(vy01234567, vy_min);
393 vy456789AB = _mm256_max_ps(vy456789AB, vy_min);
394
395 vy01234567 = _mm256_min_ps(vy01234567, vy_max);
396 vy456789AB = _mm256_min_ps(vy456789AB, vy_max);
397
398 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy01234567, _MM_FROUND_NO_EXC));
399 _mm_storeu_si128((__m128i*) (y + 8), _mm256_cvtps_ph(vy456789AB, _MM_FROUND_NO_EXC));
400 y += 16;
401 }
402 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
403 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
404 const __m256 vb = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
405 a += 8;
406 b += 8;
407
408 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va, vb), _MM_FROUND_NO_EXC));
409
410 vy = _mm256_max_ps(vy, vy_min);
411 vy = _mm256_min_ps(vy, vy_max);
412
413 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC));
414 y += 8;
415 }
416 if XNN_UNLIKELY(n != 0) {
417 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
418 const __m256 vb = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
419
420 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va, vb), _MM_FROUND_NO_EXC));
421
422 vy = _mm256_max_ps(vy, vy_min);
423 vy = _mm256_min_ps(vy, vy_max);
424
425 __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC);
426 if (n & (4 * sizeof(uint16_t))) {
427 _mm_storel_epi64((__m128i*) y, vh);
428 vh = _mm_unpackhi_epi64(vh, vh);
429 y += 4;
430 }
431 if (n & (2 * sizeof(uint16_t))) {
432 *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vh);
433 vh = _mm_srli_epi64(vh, 32);
434 y += 2;
435 }
436 if (n & (1 * sizeof(uint16_t))) {
437 *y = (uint16_t) _mm_extract_epi16(vh, 0);
438 }
439 }
440}
441
442void xnn_f16_vaddc_minmax_ukernel__f16c_x16(
443 size_t n,
444 const void* restrict a_ptr,
445 const void* restrict b_ptr,
446 void* restrict y_ptr,
447 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
448{
449 assert(n != 0);
450 assert(n % sizeof(uint16_t) == 0);
451 assert(a_ptr != NULL);
452 assert(b_ptr != NULL);
453 assert(y_ptr != NULL);
454
455 const uint16_t* a = (const uint16_t*) a_ptr;
456 const uint16_t* b = (const uint16_t*) b_ptr;
457 uint16_t* y = (uint16_t*) y_ptr;
458
459 const __m256 vy_min = _mm256_load_ps(params->avx.min);
460 const __m256 vy_max = _mm256_load_ps(params->avx.max);
461
462 const __m256 vb = _mm256_cvtph_ps(_mm_set1_epi16((short) *b));
463 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
464 const __m256 va01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
465 const __m256 va456789AB = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (a + 8)));
466 a += 16;
467
468 __m256 vy01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va01234567, vb), _MM_FROUND_NO_EXC));
469 __m256 vy456789AB = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va456789AB, vb), _MM_FROUND_NO_EXC));
470
471
472 vy01234567 = _mm256_max_ps(vy01234567, vy_min);
473 vy456789AB = _mm256_max_ps(vy456789AB, vy_min);
474
475 vy01234567 = _mm256_min_ps(vy01234567, vy_max);
476 vy456789AB = _mm256_min_ps(vy456789AB, vy_max);
477
478 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy01234567, _MM_FROUND_NO_EXC));
479 _mm_storeu_si128((__m128i*) (y + 8), _mm256_cvtps_ph(vy456789AB, _MM_FROUND_NO_EXC));
480 y += 16;
481 }
482 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
483 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
484 a += 8;
485
486 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va, vb), _MM_FROUND_NO_EXC));
487
488 vy = _mm256_max_ps(vy, vy_min);
489 vy = _mm256_min_ps(vy, vy_max);
490
491 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC));
492 y += 8;
493 }
494 if XNN_UNLIKELY(n != 0) {
495 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
496
497 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_add_ps(va, vb), _MM_FROUND_NO_EXC));
498
499 vy = _mm256_max_ps(vy, vy_min);
500 vy = _mm256_min_ps(vy, vy_max);
501
502 __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC);
503 if (n & (4 * sizeof(uint16_t))) {
504 _mm_storel_epi64((__m128i*) y, vh);
505 vh = _mm_unpackhi_epi64(vh, vh);
506 y += 4;
507 }
508 if (n & (2 * sizeof(uint16_t))) {
509 *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vh);
510 vh = _mm_srli_epi64(vh, 32);
511 y += 2;
512 }
513 if (n & (1 * sizeof(uint16_t))) {
514 *y = (uint16_t) _mm_extract_epi16(vh, 0);
515 }
516 }
517}
518
519void xnn_f16_vmul_minmax_ukernel__f16c_x16(
520 size_t n,
521 const void* restrict a_ptr,
522 const void* restrict b_ptr,
523 void* restrict y_ptr,
524 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
525{
526 assert(n != 0);
527 assert(n % sizeof(uint16_t) == 0);
528 assert(a_ptr != NULL);
529 assert(b_ptr != NULL);
530 assert(y_ptr != NULL);
531
532 const uint16_t* a = (const uint16_t*) a_ptr;
533 const uint16_t* b = (const uint16_t*) b_ptr;
534 uint16_t* y = (uint16_t*) y_ptr;
535
536 const __m256 vy_min = _mm256_load_ps(params->avx.min);
537 const __m256 vy_max = _mm256_load_ps(params->avx.max);
538
539 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
540 const __m256 va01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
541 const __m256 vb01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
542 const __m256 va456789AB = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (a + 8)));
543 const __m256 vb456789AB = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (b + 8)));
544 a += 16;
545 b += 16;
546
547 __m256 vy01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va01234567, vb01234567), _MM_FROUND_NO_EXC));
548 __m256 vy456789AB = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va456789AB, vb456789AB), _MM_FROUND_NO_EXC));
549
550
551 vy01234567 = _mm256_max_ps(vy01234567, vy_min);
552 vy456789AB = _mm256_max_ps(vy456789AB, vy_min);
553
554 vy01234567 = _mm256_min_ps(vy01234567, vy_max);
555 vy456789AB = _mm256_min_ps(vy456789AB, vy_max);
556
557 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy01234567, _MM_FROUND_NO_EXC));
558 _mm_storeu_si128((__m128i*) (y + 8), _mm256_cvtps_ph(vy456789AB, _MM_FROUND_NO_EXC));
559 y += 16;
560 }
561 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
562 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
563 const __m256 vb = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
564 a += 8;
565 b += 8;
566
567 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va, vb), _MM_FROUND_NO_EXC));
568
569 vy = _mm256_max_ps(vy, vy_min);
570 vy = _mm256_min_ps(vy, vy_max);
571
572 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC));
573 y += 8;
574 }
575 if XNN_UNLIKELY(n != 0) {
576 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
577 const __m256 vb = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) b));
578
579 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va, vb), _MM_FROUND_NO_EXC));
580
581 vy = _mm256_max_ps(vy, vy_min);
582 vy = _mm256_min_ps(vy, vy_max);
583
584 __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC);
585 if (n & (4 * sizeof(uint16_t))) {
586 _mm_storel_epi64((__m128i*) y, vh);
587 vh = _mm_unpackhi_epi64(vh, vh);
588 y += 4;
589 }
590 if (n & (2 * sizeof(uint16_t))) {
591 *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vh);
592 vh = _mm_srli_epi64(vh, 32);
593 y += 2;
594 }
595 if (n & (1 * sizeof(uint16_t))) {
596 *y = (uint16_t) _mm_extract_epi16(vh, 0);
597 }
598 }
599}
600
601void xnn_f16_vmulc_minmax_ukernel__f16c_x16(
602 size_t n,
603 const void* restrict a_ptr,
604 const void* restrict b_ptr,
605 void* restrict y_ptr,
606 const union xnn_f16_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
607{
608 assert(n != 0);
609 assert(n % sizeof(uint16_t) == 0);
610 assert(a_ptr != NULL);
611 assert(b_ptr != NULL);
612 assert(y_ptr != NULL);
613
614 const uint16_t* a = (const uint16_t*) a_ptr;
615 const uint16_t* b = (const uint16_t*) b_ptr;
616 uint16_t* y = (uint16_t*) y_ptr;
617
618 const __m256 vy_min = _mm256_load_ps(params->avx.min);
619 const __m256 vy_max = _mm256_load_ps(params->avx.max);
620
621 const __m256 vb = _mm256_cvtph_ps(_mm_set1_epi16((short) *b));
622 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
623 const __m256 va01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
624 const __m256 va456789AB = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (a + 8)));
625 a += 16;
626
627 __m256 vy01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va01234567, vb), _MM_FROUND_NO_EXC));
628 __m256 vy456789AB = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va456789AB, vb), _MM_FROUND_NO_EXC));
629
630
631 vy01234567 = _mm256_max_ps(vy01234567, vy_min);
632 vy456789AB = _mm256_max_ps(vy456789AB, vy_min);
633
634 vy01234567 = _mm256_min_ps(vy01234567, vy_max);
635 vy456789AB = _mm256_min_ps(vy456789AB, vy_max);
636
637 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy01234567, _MM_FROUND_NO_EXC));
638 _mm_storeu_si128((__m128i*) (y + 8), _mm256_cvtps_ph(vy456789AB, _MM_FROUND_NO_EXC));
639 y += 16;
640 }
641 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
642 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
643 a += 8;
644
645 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va, vb), _MM_FROUND_NO_EXC));
646
647 vy = _mm256_max_ps(vy, vy_min);
648 vy = _mm256_min_ps(vy, vy_max);
649
650 _mm_storeu_si128((__m128i*) y, _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC));
651 y += 8;
652 }
653 if XNN_UNLIKELY(n != 0) {
654 const __m256 va = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) a));
655
656 __m256 vy = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(va, vb), _MM_FROUND_NO_EXC));
657
658 vy = _mm256_max_ps(vy, vy_min);
659 vy = _mm256_min_ps(vy, vy_max);
660
661 __m128i vh = _mm256_cvtps_ph(vy, _MM_FROUND_NO_EXC);
662 if (n & (4 * sizeof(uint16_t))) {
663 _mm_storel_epi64((__m128i*) y, vh);
664 vh = _mm_unpackhi_epi64(vh, vh);
665 y += 4;
666 }
667 if (n & (2 * sizeof(uint16_t))) {
668 *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vh);
669 vh = _mm_srli_epi64(vh, 32);
670 y += 2;
671 }
672 if (n & (1 * sizeof(uint16_t))) {
673 *y = (uint16_t) _mm_extract_epi16(vh, 0);
674 }
675 }
676}
677
678void xnn_f16_vhswish_ukernel__f16c_x16(
679 size_t n,
680 const void* restrict x_ptr,
681 void* restrict y_ptr,
682 const union xnn_f16_hswish_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
683{
684 assert(n != 0);
685 assert(n % sizeof(uint16_t) == 0);
686
687 const uint16_t* x = (const uint16_t*) x_ptr;
688 uint16_t* y = (uint16_t*) y_ptr;
689
690 const __m256 vsixth = _mm256_load_ps(params->avx.sixth);
691 const __m256 vthree = _mm256_load_ps(params->avx.three);
692 const __m128i vsix = _mm_load_si128((const __m128i*) params->avx.six);
693 const __m128i vzero = _mm_setzero_si128();
694
695 for (; n >= 16 * sizeof(uint16_t); n -= 16 * sizeof(uint16_t)) {
696 __m256 vx01234567 = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) x));
697 __m256 vx89ABCDEF = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) (x + 8)));
698 x += 16;
699
700 __m128i vacc01234567 = _mm256_cvtps_ph(_mm256_add_ps(vx01234567, vthree), _MM_FROUND_NO_EXC);
701 vx01234567 = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vx01234567, vsixth), _MM_FROUND_NO_EXC));
702 __m128i vacc89ABCDEF = _mm256_cvtps_ph(_mm256_add_ps(vx89ABCDEF, vthree), _MM_FROUND_NO_EXC);
703 vx89ABCDEF = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vx89ABCDEF, vsixth), _MM_FROUND_NO_EXC));
704
705 vacc01234567 = _mm_max_epi16(vacc01234567, vzero);
706 vacc89ABCDEF = _mm_max_epi16(vacc89ABCDEF, vzero);
707
708 vacc01234567 = _mm_min_epi16(vacc01234567, vsix);
709 vacc89ABCDEF = _mm_min_epi16(vacc89ABCDEF, vsix);
710
711 vacc01234567 = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc01234567), vx01234567), _MM_FROUND_NO_EXC);
712 vacc89ABCDEF = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc89ABCDEF), vx89ABCDEF), _MM_FROUND_NO_EXC);
713
714 _mm_storeu_si128((__m128i*) y, vacc01234567);
715 _mm_storeu_si128((__m128i*) (y + 8), vacc89ABCDEF);
716 y += 16;
717 }
718 for (; n >= 8 * sizeof(uint16_t); n -= 8 * sizeof(uint16_t)) {
719 __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) x));
720 x += 8;
721 __m128i vacc = _mm256_cvtps_ph(_mm256_add_ps(vx, vthree), _MM_FROUND_NO_EXC);
722 vx = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vx, vsixth), _MM_FROUND_NO_EXC));
723 vacc = _mm_max_epi16(vacc, vzero);
724 vacc = _mm_min_epi16(vacc, vsix);
725 vacc = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc), vx), _MM_FROUND_NO_EXC);
726 _mm_storeu_si128((__m128i*) y, vacc);
727 y += 8;
728 }
729 if XNN_UNLIKELY(n != 0) {
730 __m256 vx = _mm256_cvtph_ps(_mm_loadu_si128((const __m128i*) x));
731 __m128i vacc = _mm256_cvtps_ph(_mm256_add_ps(vx, vthree), _MM_FROUND_NO_EXC);
732 vx = _mm256_cvtph_ps(_mm256_cvtps_ph(_mm256_mul_ps(vx, vsixth), _MM_FROUND_NO_EXC));
733 vacc = _mm_max_epi16(vacc, vzero);
734 vacc = _mm_min_epi16(vacc, vsix);
735 vacc = _mm256_cvtps_ph(_mm256_mul_ps(_mm256_cvtph_ps(vacc), vx), _MM_FROUND_NO_EXC);
736
737 if (n & (4 * sizeof(uint16_t))) {
738 _mm_storel_epi64((__m128i*) y, vacc);
739 vacc = _mm_unpackhi_epi64(vacc, vacc);
740 y += 4;
741 }
742 if (n & (2 * sizeof(uint16_t))) {
743 *((uint32_t*) y) = (uint32_t) _mm_cvtsi128_si32(vacc);
744 vacc = _mm_srli_epi64(vacc, 32);
745 y += 2;
746 }
747 if (n & (1 * sizeof(uint16_t))) {
748 *y = (uint16_t) _mm_extract_epi16(vacc, 0);
749 }
750 }
751}
752
Marat Dukhan68db12e2022-01-05 15:11:49 -0800753void xnn_f32_f16_vcvt_ukernel__f16c_x16(
754 size_t n,
755 const float* input,
756 void* output,
757 const union xnn_f32_f16_cvt_params params[restrict XNN_MIN_ELEMENTS(1)])
758{
759 assert(n != 0);
760 assert(n % sizeof(float) == 0);
761 assert(input != NULL);
762 assert(output != NULL);
763
764 uint16_t* o = (uint16_t*) output;
765 for (; n >= 16 * sizeof(float); n -= 16 * sizeof(float)) {
766 const __m256 vf0 = _mm256_loadu_ps(input);
767 const __m256 vf1 = _mm256_loadu_ps(input + 8);
768 input += 16;
769
770 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf0, _MM_FROUND_NO_EXC));
771 _mm_storeu_si128((__m128i*) (o + 8), _mm256_cvtps_ph(vf1, _MM_FROUND_NO_EXC));
772 o += 16;
773 }
774 for (; n >= 8 * sizeof(float); n -= 8 * sizeof(float)) {
775 const __m256 vf = _mm256_loadu_ps(input);
776 input += 8;
777
778 _mm_storeu_si128((__m128i*) o, _mm256_cvtps_ph(vf, _MM_FROUND_NO_EXC));
779 o += 8;
780 }
781 if XNN_UNLIKELY(n != 0) {
782 assert(n >= 1 * sizeof(float));
783 assert(n <= 7 * sizeof(float));
784 const __m256i vmask = _mm256_loadu_si256((const __m256i*) ((uintptr_t) &params->f16c.mask_table[7] - n));
785
786 const __m256 vf = _mm256_maskload_ps(input, vmask);
787
788 __m128 vf_lo = _mm256_castps256_ps128(vf);
789 if (n & (4 * sizeof(float))) {
790 _mm_storel_epi64((__m128i*) o, _mm_cvtps_ph(vf_lo, _MM_FROUND_NO_EXC));
791 vf_lo = _mm256_extractf128_ps(vf, 1);
792 o += 4;
793 }
794 __m128i vh = _mm_cvtps_ph(vf_lo, _MM_FROUND_NO_EXC);
795 if (n & (2 * sizeof(float))) {
796 _mm_storeu_si32(o, vh);
797 vh = _mm_srli_epi64(vh, 32);
798 o += 2;
799 }
800 if (n & (1 * sizeof(float))) {
801 *((uint16_t*) o) = _mm_extract_epi16(vh, 0);
802 }
803 }
804}