blob: 1b1a744bcdbfd11576743332353c11eb2d8aedc9 [file] [log] [blame]
Logan Chiendf4f7662019-09-04 16:45:23 -07001/*===--------- avx512vlbf16intrin.h - AVX512_BF16 intrinsics ---------------===
2 *
3 * Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
4 * See https://llvm.org/LICENSE.txt for license information.
5 * SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
6 *
7 *===-----------------------------------------------------------------------===
8 */
9#ifndef __IMMINTRIN_H
10#error "Never use <avx512vlbf16intrin.h> directly; include <immintrin.h> instead."
11#endif
12
13#ifndef __AVX512VLBF16INTRIN_H
14#define __AVX512VLBF16INTRIN_H
15
16typedef short __m128bh __attribute__((__vector_size__(16), __aligned__(16)));
17
18#define __DEFAULT_FN_ATTRS128 \
19 __attribute__((__always_inline__, __nodebug__, \
20 __target__("avx512vl, avx512bf16"), __min_vector_width__(128)))
21#define __DEFAULT_FN_ATTRS256 \
22 __attribute__((__always_inline__, __nodebug__, \
23 __target__("avx512vl, avx512bf16"), __min_vector_width__(256)))
24
25/// Convert Two Packed Single Data to One Packed BF16 Data.
26///
27/// \headerfile <x86intrin.h>
28///
29/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
30///
31/// \param __A
32/// A 128-bit vector of [4 x float].
33/// \param __B
34/// A 128-bit vector of [4 x float].
35/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
36/// conversion of __B, and higher 64 bits come from conversion of __A.
37static __inline__ __m128bh __DEFAULT_FN_ATTRS128
38_mm_cvtne2ps_pbh(__m128 __A, __m128 __B) {
39 return (__m128bh)__builtin_ia32_cvtne2ps2bf16_128((__v4sf) __A,
40 (__v4sf) __B);
41}
42
43/// Convert Two Packed Single Data to One Packed BF16 Data.
44///
45/// \headerfile <x86intrin.h>
46///
47/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
48///
49/// \param __A
50/// A 128-bit vector of [4 x float].
51/// \param __B
52/// A 128-bit vector of [4 x float].
53/// \param __W
54/// A 128-bit vector of [8 x bfloat].
55/// \param __U
56/// A 8-bit mask value specifying what is chosen for each element.
57/// A 1 means conversion of __A or __B. A 0 means element from __W.
58/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
59/// conversion of __B, and higher 64 bits come from conversion of __A.
60static __inline__ __m128bh __DEFAULT_FN_ATTRS128
61_mm_mask_cvtne2ps_pbh(__m128bh __W, __mmask8 __U, __m128 __A, __m128 __B) {
62 return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
63 (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
64 (__v8hi)__W);
65}
66
67/// Convert Two Packed Single Data to One Packed BF16 Data.
68///
69/// \headerfile <x86intrin.h>
70///
71/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
72///
73/// \param __A
74/// A 128-bit vector of [4 x float].
75/// \param __B
76/// A 128-bit vector of [4 x float].
77/// \param __U
78/// A 8-bit mask value specifying what is chosen for each element.
79/// A 1 means conversion of __A or __B. A 0 means element is zero.
80/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
81/// conversion of __B, and higher 64 bits come from conversion of __A.
82static __inline__ __m128bh __DEFAULT_FN_ATTRS128
83_mm_maskz_cvtne2ps_pbh(__mmask8 __U, __m128 __A, __m128 __B) {
84 return (__m128bh)__builtin_ia32_selectw_128((__mmask8)__U,
85 (__v8hi)_mm_cvtne2ps_pbh(__A, __B),
86 (__v8hi)_mm_setzero_si128());
87}
88
89/// Convert Two Packed Single Data to One Packed BF16 Data.
90///
91/// \headerfile <x86intrin.h>
92///
93/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
94///
95/// \param __A
96/// A 256-bit vector of [8 x float].
97/// \param __B
98/// A 256-bit vector of [8 x float].
99/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
100/// conversion of __B, and higher 128 bits come from conversion of __A.
101static __inline__ __m256bh __DEFAULT_FN_ATTRS256
102_mm256_cvtne2ps_pbh(__m256 __A, __m256 __B) {
103 return (__m256bh)__builtin_ia32_cvtne2ps2bf16_256((__v8sf) __A,
104 (__v8sf) __B);
105}
106
107/// Convert Two Packed Single Data to One Packed BF16 Data.
108///
109/// \headerfile <x86intrin.h>
110///
111/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
112///
113/// \param __A
114/// A 256-bit vector of [8 x float].
115/// \param __B
116/// A 256-bit vector of [8 x float].
117/// \param __W
118/// A 256-bit vector of [16 x bfloat].
119/// \param __U
120/// A 16-bit mask value specifying what is chosen for each element.
121/// A 1 means conversion of __A or __B. A 0 means element from __W.
122/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
123/// conversion of __B, and higher 128 bits come from conversion of __A.
124static __inline__ __m256bh __DEFAULT_FN_ATTRS256
125_mm256_mask_cvtne2ps_pbh(__m256bh __W, __mmask16 __U, __m256 __A, __m256 __B) {
126 return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
127 (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
128 (__v16hi)__W);
129}
130
131/// Convert Two Packed Single Data to One Packed BF16 Data.
132///
133/// \headerfile <x86intrin.h>
134///
135/// This intrinsic corresponds to the <c> VCVTNE2PS2BF16 </c> instructions.
136///
137/// \param __A
138/// A 256-bit vector of [8 x float].
139/// \param __B
140/// A 256-bit vector of [8 x float].
141/// \param __U
142/// A 16-bit mask value specifying what is chosen for each element.
143/// A 1 means conversion of __A or __B. A 0 means element is zero.
144/// \returns A 256-bit vector of [16 x bfloat] whose lower 128 bits come from
145/// conversion of __B, and higher 128 bits come from conversion of __A.
146static __inline__ __m256bh __DEFAULT_FN_ATTRS256
147_mm256_maskz_cvtne2ps_pbh(__mmask16 __U, __m256 __A, __m256 __B) {
148 return (__m256bh)__builtin_ia32_selectw_256((__mmask16)__U,
149 (__v16hi)_mm256_cvtne2ps_pbh(__A, __B),
150 (__v16hi)_mm256_setzero_si256());
151}
152
153/// Convert Packed Single Data to Packed BF16 Data.
154///
155/// \headerfile <x86intrin.h>
156///
157/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
158///
159/// \param __A
160/// A 128-bit vector of [4 x float].
161/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
162/// conversion of __A, and higher 64 bits are 0.
163static __inline__ __m128bh __DEFAULT_FN_ATTRS128
164_mm_cvtneps_pbh(__m128 __A) {
165 return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
166 (__v8hi)_mm_undefined_si128(),
167 (__mmask8)-1);
168}
169
170/// Convert Packed Single Data to Packed BF16 Data.
171///
172/// \headerfile <x86intrin.h>
173///
174/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
175///
176/// \param __A
177/// A 128-bit vector of [4 x float].
178/// \param __W
179/// A 128-bit vector of [8 x bfloat].
180/// \param __U
181/// A 4-bit mask value specifying what is chosen for each element.
182/// A 1 means conversion of __A. A 0 means element from __W.
183/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
184/// conversion of __A, and higher 64 bits are 0.
185static __inline__ __m128bh __DEFAULT_FN_ATTRS128
186_mm_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m128 __A) {
187 return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
188 (__v8hi)__W,
189 (__mmask8)__U);
190}
191
192/// Convert Packed Single Data to Packed BF16 Data.
193///
194/// \headerfile <x86intrin.h>
195///
196/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
197///
198/// \param __A
199/// A 128-bit vector of [4 x float].
200/// \param __U
201/// A 4-bit mask value specifying what is chosen for each element.
202/// A 1 means conversion of __A. A 0 means element is zero.
203/// \returns A 128-bit vector of [8 x bfloat] whose lower 64 bits come from
204/// conversion of __A, and higher 64 bits are 0.
205static __inline__ __m128bh __DEFAULT_FN_ATTRS128
206_mm_maskz_cvtneps_pbh(__mmask8 __U, __m128 __A) {
207 return (__m128bh)__builtin_ia32_cvtneps2bf16_128_mask((__v4sf) __A,
208 (__v8hi)_mm_setzero_si128(),
209 (__mmask8)__U);
210}
211
212/// Convert Packed Single Data to Packed BF16 Data.
213///
214/// \headerfile <x86intrin.h>
215///
216/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
217///
218/// \param __A
219/// A 256-bit vector of [8 x float].
220/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
221static __inline__ __m128bh __DEFAULT_FN_ATTRS256
222_mm256_cvtneps_pbh(__m256 __A) {
223 return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
224 (__v8hi)_mm_undefined_si128(),
225 (__mmask8)-1);
226}
227
228/// Convert Packed Single Data to Packed BF16 Data.
229///
230/// \headerfile <x86intrin.h>
231///
232/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
233///
234/// \param __A
235/// A 256-bit vector of [8 x float].
236/// \param __W
237/// A 256-bit vector of [8 x bfloat].
238/// \param __U
239/// A 8-bit mask value specifying what is chosen for each element.
240/// A 1 means conversion of __A. A 0 means element from __W.
241/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
242static __inline__ __m128bh __DEFAULT_FN_ATTRS256
243_mm256_mask_cvtneps_pbh(__m128bh __W, __mmask8 __U, __m256 __A) {
244 return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
245 (__v8hi)__W,
246 (__mmask8)__U);
247}
248
249/// Convert Packed Single Data to Packed BF16 Data.
250///
251/// \headerfile <x86intrin.h>
252///
253/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
254///
255/// \param __A
256/// A 256-bit vector of [8 x float].
257/// \param __U
258/// A 8-bit mask value specifying what is chosen for each element.
259/// A 1 means conversion of __A. A 0 means element is zero.
260/// \returns A 128-bit vector of [8 x bfloat] comes from conversion of __A.
261static __inline__ __m128bh __DEFAULT_FN_ATTRS256
262_mm256_maskz_cvtneps_pbh(__mmask8 __U, __m256 __A) {
263 return (__m128bh)__builtin_ia32_cvtneps2bf16_256_mask((__v8sf)__A,
264 (__v8hi)_mm_setzero_si128(),
265 (__mmask8)__U);
266}
267
268/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
269///
270/// \headerfile <x86intrin.h>
271///
272/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
273///
274/// \param __A
275/// A 128-bit vector of [8 x bfloat].
276/// \param __B
277/// A 128-bit vector of [8 x bfloat].
278/// \param __D
279/// A 128-bit vector of [4 x float].
280/// \returns A 128-bit vector of [4 x float] comes from Dot Product of
281/// __A, __B and __D
282static __inline__ __m128 __DEFAULT_FN_ATTRS128
283_mm_dpbf16_ps(__m128 __D, __m128bh __A, __m128bh __B) {
284 return (__m128)__builtin_ia32_dpbf16ps_128((__v4sf)__D,
285 (__v4si)__A,
286 (__v4si)__B);
287}
288
289/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
290///
291/// \headerfile <x86intrin.h>
292///
293/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
294///
295/// \param __A
296/// A 128-bit vector of [8 x bfloat].
297/// \param __B
298/// A 128-bit vector of [8 x bfloat].
299/// \param __D
300/// A 128-bit vector of [4 x float].
301/// \param __U
302/// A 8-bit mask value specifying what is chosen for each element.
303/// A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
304/// \returns A 128-bit vector of [4 x float] comes from Dot Product of
305/// __A, __B and __D
306static __inline__ __m128 __DEFAULT_FN_ATTRS128
307_mm_mask_dpbf16_ps(__m128 __D, __mmask8 __U, __m128bh __A, __m128bh __B) {
308 return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
309 (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
310 (__v4sf)__D);
311}
312
313/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
314///
315/// \headerfile <x86intrin.h>
316///
317/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
318///
319/// \param __A
320/// A 128-bit vector of [8 x bfloat].
321/// \param __B
322/// A 128-bit vector of [8 x bfloat].
323/// \param __D
324/// A 128-bit vector of [4 x float].
325/// \param __U
326/// A 8-bit mask value specifying what is chosen for each element.
327/// A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
328/// \returns A 128-bit vector of [4 x float] comes from Dot Product of
329/// __A, __B and __D
330static __inline__ __m128 __DEFAULT_FN_ATTRS128
331_mm_maskz_dpbf16_ps(__mmask8 __U, __m128 __D, __m128bh __A, __m128bh __B) {
332 return (__m128)__builtin_ia32_selectps_128((__mmask8)__U,
333 (__v4sf)_mm_dpbf16_ps(__D, __A, __B),
334 (__v4sf)_mm_setzero_si128());
335}
336
337/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
338///
339/// \headerfile <x86intrin.h>
340///
341/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
342///
343/// \param __A
344/// A 256-bit vector of [16 x bfloat].
345/// \param __B
346/// A 256-bit vector of [16 x bfloat].
347/// \param __D
348/// A 256-bit vector of [8 x float].
349/// \returns A 256-bit vector of [8 x float] comes from Dot Product of
350/// __A, __B and __D
351static __inline__ __m256 __DEFAULT_FN_ATTRS256
352_mm256_dpbf16_ps(__m256 __D, __m256bh __A, __m256bh __B) {
353 return (__m256)__builtin_ia32_dpbf16ps_256((__v8sf)__D,
354 (__v8si)__A,
355 (__v8si)__B);
356}
357
358/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
359///
360/// \headerfile <x86intrin.h>
361///
362/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
363///
364/// \param __A
365/// A 256-bit vector of [16 x bfloat].
366/// \param __B
367/// A 256-bit vector of [16 x bfloat].
368/// \param __D
369/// A 256-bit vector of [8 x float].
370/// \param __U
371/// A 16-bit mask value specifying what is chosen for each element.
372/// A 1 means __A and __B's dot product accumulated with __D. A 0 means __D.
373/// \returns A 256-bit vector of [8 x float] comes from Dot Product of
374/// __A, __B and __D
375static __inline__ __m256 __DEFAULT_FN_ATTRS256
376_mm256_mask_dpbf16_ps(__m256 __D, __mmask8 __U, __m256bh __A, __m256bh __B) {
377 return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
378 (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
379 (__v8sf)__D);
380}
381
382/// Dot Product of BF16 Pairs Accumulated into Packed Single Precision.
383///
384/// \headerfile <x86intrin.h>
385///
386/// This intrinsic corresponds to the <c> VDPBF16PS </c> instructions.
387///
388/// \param __A
389/// A 256-bit vector of [16 x bfloat].
390/// \param __B
391/// A 256-bit vector of [16 x bfloat].
392/// \param __D
393/// A 256-bit vector of [8 x float].
394/// \param __U
395/// A 8-bit mask value specifying what is chosen for each element.
396/// A 1 means __A and __B's dot product accumulated with __D. A 0 means 0.
397/// \returns A 256-bit vector of [8 x float] comes from Dot Product of
398/// __A, __B and __D
399static __inline__ __m256 __DEFAULT_FN_ATTRS256
400_mm256_maskz_dpbf16_ps(__mmask8 __U, __m256 __D, __m256bh __A, __m256bh __B) {
401 return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
402 (__v8sf)_mm256_dpbf16_ps(__D, __A, __B),
403 (__v8sf)_mm256_setzero_si256());
404}
405
406/// Convert One Single float Data to One BF16 Data.
407///
408/// \headerfile <x86intrin.h>
409///
410/// This intrinsic corresponds to the <c> VCVTNEPS2BF16 </c> instructions.
411///
412/// \param __A
413/// A float data.
414/// \returns A bf16 data whose sign field and exponent field keep unchanged,
415/// and fraction field is truncated to 7 bits.
416static __inline__ __bfloat16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
417 __v4sf __V = {__A, 0, 0, 0};
418 __v8hi __R = __builtin_ia32_cvtneps2bf16_128_mask(
419 (__v4sf)__V, (__v8hi)_mm_undefined_si128(), (__mmask8)-1);
420 return __R[0];
421}
422
423/// Convert Packed BF16 Data to Packed float Data.
424///
425/// \headerfile <x86intrin.h>
426///
427/// \param __A
428/// A 128-bit vector of [8 x bfloat].
429/// \returns A 256-bit vector of [8 x float] come from convertion of __A
430static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
431 return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
432 (__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
433}
434
435/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
436///
437/// \headerfile <x86intrin.h>
438///
439/// \param __U
440/// A 8-bit mask. Elements are zeroed out when the corresponding mask
441/// bit is not set.
442/// \param __A
443/// A 128-bit vector of [8 x bfloat].
444/// \returns A 256-bit vector of [8 x float] come from convertion of __A
445static __inline__ __m256 __DEFAULT_FN_ATTRS256
446_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
447 return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
448 (__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
449}
450
451/// Convert Packed BF16 Data to Packed float Data using merging mask.
452///
453/// \headerfile <x86intrin.h>
454///
455/// \param __S
456/// A 256-bit vector of [8 x float]. Elements are copied from __S when
457/// the corresponding mask bit is not set.
458/// \param __U
459/// A 8-bit mask. Elements are zeroed out when the corresponding mask
460/// bit is not set.
461/// \param __A
462/// A 128-bit vector of [8 x bfloat].
463/// \returns A 256-bit vector of [8 x float] come from convertion of __A
464static __inline__ __m256 __DEFAULT_FN_ATTRS256
465_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
466 return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
467 (__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
468 16));
469}
470
471#undef __DEFAULT_FN_ATTRS128
472#undef __DEFAULT_FN_ATTRS256
473
474#endif