blob: 5e702624178e29a01da0b4be1df77d80f03ebbda [file] [log] [blame]
Marat Dukhane6dc0b62020-09-08 23:57:14 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
Marat Dukhan76e78c82021-07-20 21:11:23 -07006$assert DATATYPE in ["QS8", "QU8"]
Marat Dukhane6dc0b62020-09-08 23:57:14 -07007$assert BATCH_TILE % 8 == 0
8$assert BATCH_TILE >= 8
9$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
10#include <assert.h>
11
12#include <immintrin.h>
13
14#include <xnnpack/intrinsics-polyfill.h>
15#include <xnnpack/vadd.h>
16
17
Marat Dukhan3eac69c2021-07-21 01:42:29 -070018$XINT8_T = {"QS8": "int8_t", "QU8": "uint8_t"}[DATATYPE]
19$_MM256_CVTEPX8_EPI32 = {"QS8": "_mm256_cvtepi8_epi32", "QU8": "_mm256_cvtepu8_epi32"}[DATATYPE]
20$_MM_PACKXS_EPI16 = {"QS8": "_mm_packs_epi16", "QU8": "_mm_packus_epi16"}[DATATYPE]
21$_MM_MIN_EPX8 = {"QS8": "_mm_min_epi8", "QU8": "_mm_min_epu8"}[DATATYPE]
22$_MM_MAX_EPX8 = {"QS8": "_mm_max_epi8", "QU8": "_mm_max_epu8"}[DATATYPE]
23void xnn_${DATATYPE.lower()}_vaddc_minmax_ukernel__avx2_mul32_ld64_x${BATCH_TILE}(
Marat Dukhane6dc0b62020-09-08 23:57:14 -070024 size_t n,
Marat Dukhan3eac69c2021-07-21 01:42:29 -070025 const ${XINT8_T}* input_a,
26 const ${XINT8_T}* input_b,
27 ${XINT8_T}* output,
Marat Dukhan6c7b9e82021-07-22 13:14:07 -070028 const union xnn_${DATATYPE.lower()}_add_minmax_params params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
Marat Dukhane6dc0b62020-09-08 23:57:14 -070029{
Marat Dukhan7679b1e2021-07-20 18:32:23 -070030 const __m256i va_multiplier = _mm256_load_si256((const __m256i*) params->avx2.a_multiplier);
31 const __m256i vrounding = _mm256_load_si256((const __m256i*) params->avx2.rounding);
32 const __m128i vshift = _mm_loadu_si32(params->avx2.shift);
Marat Dukhane6dc0b62020-09-08 23:57:14 -070033 $if BATCH_TILE > 8:
Marat Dukhan7679b1e2021-07-20 18:32:23 -070034 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->avx2.output_zero_point);
Marat Dukhane6dc0b62020-09-08 23:57:14 -070035 $else:
Marat Dukhan7679b1e2021-07-20 18:32:23 -070036 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->avx2.output_zero_point);
37 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->avx2.output_min);
38 const __m128i voutput_max = _mm_load_si128((const __m128i*) params->avx2.output_max);
Marat Dukhane6dc0b62020-09-08 23:57:14 -070039
Marat Dukhan7679b1e2021-07-20 18:32:23 -070040 __m256i vbias = _mm256_add_epi32(
41 _mm256_broadcastd_epi32(_mm_cvtsi32_si128(params->avx2.b_multiplier[0] * (int32_t) *input_b)),
42 _mm256_load_si256((const __m256i*) params->avx2.bias));
Marat Dukhan3eac69c2021-07-21 01:42:29 -070043 for (; n >= ${BATCH_TILE} * sizeof(${XINT8_T}); n -= ${BATCH_TILE} * sizeof(${XINT8_T})) {
44 const __m256i va${ABC[0:8]} = ${_MM256_CVTEPX8_EPI32}(_mm_loadl_epi64((const __m128i*) input_a));
Marat Dukhane6dc0b62020-09-08 23:57:14 -070045 $for N in range(8, BATCH_TILE, 8):
Marat Dukhan3eac69c2021-07-21 01:42:29 -070046 const __m256i va${ABC[N:N+8]} = ${_MM256_CVTEPX8_EPI32}(_mm_loadl_epi64((const __m128i*) (input_a + ${N})));
Marat Dukhan076bcfe2021-07-19 19:24:42 -070047 input_a += ${BATCH_TILE};
Marat Dukhane6dc0b62020-09-08 23:57:14 -070048
49 $for N in range(0, BATCH_TILE, 8):
Marat Dukhana842fef2021-07-19 21:07:40 -070050 __m256i vacc${ABC[N:N+8]} = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va${ABC[N:N+8]}, va_multiplier));
Marat Dukhane6dc0b62020-09-08 23:57:14 -070051
52 $for N in range(0, BATCH_TILE, 8):
Marat Dukhan8a045652021-07-20 14:07:21 -070053 vacc${ABC[N:N+8]} = _mm256_sra_epi32(_mm256_add_epi32(vacc${ABC[N:N+8]}, vrounding), vshift);
Marat Dukhane6dc0b62020-09-08 23:57:14 -070054
55 $for N in range(0, BATCH_TILE, 16):
56 $if N + 8 < BATCH_TILE:
57 __m256i vout${ABC[N:N+4]}${ABC[N+8:N+12]}${ABC[N+4:N+8]}${ABC[N+12:N+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[N:N+8]}, vacc${ABC[N+8:N+16]}), voutput_zero_point);
58 $elif BATCH_TILE > 8:
59 __m128i vout${ABC[N:N+8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[N:N+8]}), _mm256_extracti128_si256(vacc${ABC[N:N+8]}, 1)), _mm256_castsi256_si128(voutput_zero_point));
60 $else:
61 __m128i vout${ABC[N:N+8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[N:N+8]}), _mm256_extracti128_si256(vacc${ABC[N:N+8]}, 1)), voutput_zero_point);
62
63 $for N in range(0, BATCH_TILE, 16):
64 $if N + 8 < BATCH_TILE:
Marat Dukhan3eac69c2021-07-21 01:42:29 -070065 __m128i vout${ABC[N:N+16]} = _mm_shuffle_epi32(${_MM_PACKXS_EPI16}(_mm256_castsi256_si128(vout${ABC[N:N+4]}${ABC[N+8:N+12]}${ABC[N+4:N+8]}${ABC[N+12:N+16]}), _mm256_extracti128_si256(vout${ABC[N:N+4]}${ABC[N+8:N+12]}${ABC[N+4:N+8]}${ABC[N+12:N+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0));
Marat Dukhane6dc0b62020-09-08 23:57:14 -070066 $else:
Marat Dukhan3eac69c2021-07-21 01:42:29 -070067 __m128i vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_PACKXS_EPI16}(vout${ABC[N:N+8]}, vout${ABC[N:N+8]});
Marat Dukhane6dc0b62020-09-08 23:57:14 -070068
Marat Dukhan7679b1e2021-07-20 18:32:23 -070069 $for N in range(0, BATCH_TILE, 16):
70 $if N + 8 < BATCH_TILE:
Marat Dukhan3eac69c2021-07-21 01:42:29 -070071 vout${ABC[N:N+16]} = ${_MM_MAX_EPX8}(vout${ABC[N:N+16]}, voutput_min);
Marat Dukhan7679b1e2021-07-20 18:32:23 -070072 $else:
Marat Dukhan3eac69c2021-07-21 01:42:29 -070073 vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_MAX_EPX8}(vout${ABC[N:N+8]}${ABC[N:N+8]}, voutput_min);
Marat Dukhan7679b1e2021-07-20 18:32:23 -070074
75 $for N in range(0, BATCH_TILE, 16):
76 $if N + 8 < BATCH_TILE:
Marat Dukhan3eac69c2021-07-21 01:42:29 -070077 vout${ABC[N:N+16]} = ${_MM_MIN_EPX8}(vout${ABC[N:N+16]}, voutput_max);
Marat Dukhan7679b1e2021-07-20 18:32:23 -070078 $else:
Marat Dukhan3eac69c2021-07-21 01:42:29 -070079 vout${ABC[N:N+8]}${ABC[N:N+8]} = ${_MM_MIN_EPX8}(vout${ABC[N:N+8]}${ABC[N:N+8]}, voutput_max);
Marat Dukhan7679b1e2021-07-20 18:32:23 -070080
Marat Dukhane6dc0b62020-09-08 23:57:14 -070081 $if BATCH_TILE >= 16:
82 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
83 $else:
84 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]});
85 $for N in range(16, BATCH_TILE, 16):
86 $if N + 8 < BATCH_TILE:
87 _mm_storeu_si128((__m128i*) (output + ${N}), vout${ABC[N:N+16]});
88 $else:
89 _mm_storel_epi64((__m128i*) (output + ${N}), vout${ABC[N:N+8]}${ABC[N:N+8]});
90 output += ${BATCH_TILE};
91 }
92 if XNN_UNLIKELY(n != 0) {
93 ${"do " if BATCH_TILE > 8 else ""}{
Marat Dukhan3eac69c2021-07-21 01:42:29 -070094 const __m256i va${ABC[0:8]} = ${_MM256_CVTEPX8_EPI32}(_mm_loadl_epi64((const __m128i*) input_a));
Marat Dukhane6dc0b62020-09-08 23:57:14 -070095 $if BATCH_TILE > 8:
Marat Dukhan076bcfe2021-07-19 19:24:42 -070096 input_a += 8;
Marat Dukhane6dc0b62020-09-08 23:57:14 -070097
Marat Dukhana842fef2021-07-19 21:07:40 -070098 __m256i vacc${ABC[0:8]} = _mm256_add_epi32(vbias, _mm256_mullo_epi32(va${ABC[0:8]}, va_multiplier));
Marat Dukhane6dc0b62020-09-08 23:57:14 -070099
Marat Dukhan8a045652021-07-20 14:07:21 -0700100 vacc${ABC[0:8]} = _mm256_sra_epi32(_mm256_add_epi32(vacc${ABC[0:8]}, vrounding), vshift);
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700101
102 $if BATCH_TILE > 8:
103 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), _mm256_castsi256_si128(voutput_zero_point));
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700104 $else:
105 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point);
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700106 __m128i vout${ABC[0:8]}${ABC[0:8]} = ${_MM_PACKXS_EPI16}(vout${ABC[0:8]}, vout${ABC[0:8]});
107 vout${ABC[0:8]}${ABC[0:8]} = ${_MM_MAX_EPX8}(vout${ABC[0:8]}${ABC[0:8]}, voutput_min);
108 vout${ABC[0:8]}${ABC[0:8]} = ${_MM_MIN_EPX8}(vout${ABC[0:8]}${ABC[0:8]}, voutput_max);
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700109
110 $if BATCH_TILE > 8:
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700111 if XNN_LIKELY(n >= (8 * sizeof(${XINT8_T}))) {
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700112 _mm_storel_epi64((__m128i*) output, vout${ABC[0:8]}${ABC[0:8]});
113 output += 8;
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700114 n -= 8 * sizeof(${XINT8_T});
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700115 } else {
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700116 if (n & (4 * sizeof(${XINT8_T}))) {
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700117 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]});
118 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32);
119 output += 4;
120 }
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700121 if (n & (2 * sizeof(${XINT8_T}))) {
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700122 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0);
123 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16);
124 output += 2;
125 }
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700126 if (n & (1 * sizeof(${XINT8_T}))) {
127 *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0);
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700128 }
129 n = 0;
130 }
131 $else:
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700132 if (n & (4 * sizeof(${XINT8_T}))) {
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700133 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:8]}${ABC[0:8]});
134 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi64(vout${ABC[0:8]}${ABC[0:8]}, 32);
135 output += 4;
136 }
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700137 if (n & (2 * sizeof(${XINT8_T}))) {
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700138 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:8]}${ABC[0:8]}, 0);
139 vout${ABC[0:8]}${ABC[0:8]} = _mm_srli_epi32(vout${ABC[0:8]}${ABC[0:8]}, 16);
140 output += 2;
141 }
Marat Dukhan3eac69c2021-07-21 01:42:29 -0700142 if (n & (1 * sizeof(${XINT8_T}))) {
143 *output = (${XINT8_T}) _mm_extract_epi8(vout${ABC[0:8]}${ABC[0:8]}, 0);
Marat Dukhane6dc0b62020-09-08 23:57:14 -0700144 }
145 }${" while (n != 0);" if BATCH_TILE > 8 else ""}
146 }
147}