blob: 58367d0066d8419cb8c090773645d788e64c5f6d [file] [log] [blame]
Marat Dukhan12809522020-08-02 22:23:51 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
Marat Dukhan9b474cf2021-05-25 16:37:48 -07006$assert REQUANTIZATION in ["GEMMLOWP", "FP32"]
Marat Dukhane06c8132021-06-03 08:59:11 -07007$assert not CHANNELWISE or REQUANTIZATION == "FP32"
Marat Dukhan12809522020-08-02 22:23:51 -07008$assert MR <= 4
9#include <assert.h>
10
11#include <immintrin.h>
12
13#include <xnnpack/igemm.h>
Marat Dukhanf124e882020-08-09 19:48:10 -070014#include <xnnpack/intrinsics-polyfill.h>
Frank Barchard6d8ca7d2021-03-01 11:05:08 -080015#include <xnnpack/math.h>
Marat Dukhan12809522020-08-02 22:23:51 -070016
17
Marat Dukhane06c8132021-06-03 08:59:11 -070018$DATATYPE = "qc8" if CHANNELWISE else "qs8"
19$PARAMS_STRUCT = "avx2" if CHANNELWISE else REQUANTIZATION.lower() + "_avx2"
Marat Dukhan82286892021-06-04 17:27:27 -070020$CONV_PARAMS = "xnn_qs8_minmax_params" if CHANNELWISE else "xnn_qs8_conv_minmax_params"
Marat Dukhane06c8132021-06-03 08:59:11 -070021void xnn_${DATATYPE}_igemm_minmax_${REQUANTIZATION.lower()}_ukernel_${MR}x8c8__avx2(
Marat Dukhan12809522020-08-02 22:23:51 -070022 size_t mr,
23 size_t nc,
24 size_t kc,
25 size_t ks,
26 const int8_t** restrict a,
27 const void* restrict w,
28 int8_t* restrict c,
29 size_t cm_stride,
30 size_t cn_stride,
31 size_t a_offset,
32 const int8_t* zero,
Marat Dukhan82286892021-06-04 17:27:27 -070033 const union ${CONV_PARAMS} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_DISABLE_TSAN XNN_DISABLE_MSAN
Marat Dukhan12809522020-08-02 22:23:51 -070034{
35 assert(mr != 0);
36 assert(mr <= ${MR});
37 assert(nc != 0);
38 assert(kc != 0);
39 assert(ks != 0);
40 assert(ks % (${MR} * sizeof(void*)) == 0);
41 assert(a_offset % sizeof(int8_t) == 0);
42 assert(a != NULL);
43 assert(w != NULL);
44 assert(c != NULL);
45
Frank Barchard6d8ca7d2021-03-01 11:05:08 -080046 kc = round_up_po2(kc, 8);
Marat Dukhan12809522020-08-02 22:23:51 -070047 int8_t* c0 = c;
48 $for M in range(1, MR):
49 int8_t* c${M} = (int8_t*) ((uintptr_t) c${M-1} + cm_stride);
50 $if M % 2 == 0:
51 if XNN_UNPREDICTABLE(mr <= ${M}) {
52 c${M} = c${M-1};
53 }
54 $elif M + 1 == MR:
55 if XNN_UNPREDICTABLE(mr != ${M+1}) {
56 c${M} = c${M-1};
57 }
58 $else:
59 if XNN_UNPREDICTABLE(mr < ${M+1}) {
60 c${M} = c${M-1};
61 }
62
63 do {
64 const __m128i vbias0x0 = _mm_loadu_si32(w);
65 const __m128i vbias0x1 = _mm_loadu_si32((const void*) ((uintptr_t) w + sizeof(int32_t)));
66 __m256i vacc0x01 = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x0), vbias0x1, 1);
67 $for N in range(2, 8, 2):
68 const __m128i vbias0x${N} = _mm_loadu_si32((const void*) ((uintptr_t) w + ${N} * sizeof(int32_t)));
69 const __m128i vbias0x${N+1} = _mm_loadu_si32((const void*) ((uintptr_t) w + ${N+1} * sizeof(int32_t)));
70 __m256i vacc0x${N}${N+1} = _mm256_inserti128_si256(_mm256_castsi128_si256(vbias0x${N}), vbias0x${N+1}, 1);
71 $for M in range(1, MR):
72 $for N in range(0, 8, 2):
73 __m256i vacc${M}x${N}${N+1} = vacc0x${N}${N+1};
74 w = (const void*) ((uintptr_t) w + 8 * sizeof(int32_t));
75
76 size_t p = ks;
77 do {
78 $for M in range(MR):
79 const int8_t* restrict a${M} = a[${M}];
80 if XNN_UNPREDICTABLE(a${M} != zero) {
81 a${M} = (const int8_t*) ((uintptr_t) a${M} + a_offset);
82 }
83 a += ${MR};
84
85 size_t k = 0;
86 while (k < kc) {
87 $for M in range(MR):
88 const __m128i va${M} = _mm_broadcastq_epi64(_mm_loadl_epi64((const __m128i*) a${M}));
89 const __m256i vxa${M} = _mm256_cvtepi8_epi16(va${M});
90 a${M} += 8;
91
92 $for N in range(0, 8, 2):
93 $if N == 0:
94 const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) w);
95 $else:
96 const __m128i vb${N}${N+1} = _mm_load_si128((const __m128i*) ((uintptr_t) w + ${N * 8} * sizeof(int8_t)));
97 const __m256i vxb${N}${N+1} = _mm256_cvtepi8_epi16(vb${N}${N+1});
98
99 $for M in range(MR):
100 vacc${M}x${N}${N+1} = _mm256_add_epi32(vacc${M}x${N}${N+1}, _mm256_madd_epi16(vxa${M}, vxb${N}${N+1}));
101
102 w = (const void*) ((uintptr_t) w + 64 * sizeof(int8_t));
103 k += 8 * sizeof(int8_t);
104 }
105 p -= ${MR} * sizeof(void*);
106 } while (p != 0);
107
108 $for M in range(MR):
109 const __m256i vacc${M}x0213 = _mm256_hadd_epi32(vacc${M}x01, vacc${M}x23);
110 const __m256i vacc${M}x4657 = _mm256_hadd_epi32(vacc${M}x45, vacc${M}x67);
111
112 $for M in range(MR):
113 const __m256i vacc${M}x02461357 = _mm256_hadd_epi32(vacc${M}x0213, vacc${M}x4657);
114
Marat Dukhan23848db2020-08-05 09:10:16 -0700115 const __m256i vpermute_mask = _mm256_set_epi32(7, 3, 6, 2, 5, 1, 4, 0);
Marat Dukhan12809522020-08-02 22:23:51 -0700116 $for M in range(MR):
Marat Dukhan23848db2020-08-05 09:10:16 -0700117 __m256i vacc${M}x01234567 = _mm256_permutevar8x32_epi32(vacc${M}x02461357, vpermute_mask);
Marat Dukhan12809522020-08-02 22:23:51 -0700118
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700119 $if REQUANTIZATION == "GEMMLOWP":
120 const __m256i vmultiplier = _mm256_load_si256((const __m256i*) params->gemmlowp_avx2.multiplier);
121 const __m256i vrounding = _mm256_load_si256((const __m256i*) params->gemmlowp_avx2.rounding);
Marat Dukhan12809522020-08-02 22:23:51 -0700122
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700123 $for M in range(MR):
124 const __m256i vacc${M}x11335577 = _mm256_shuffle_epi32(vacc${M}x01234567, _MM_SHUFFLE(3, 3, 1, 1));
Marat Dukhan12809522020-08-02 22:23:51 -0700125
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700126 $for M in range(MR):
127 const __m256i vprod${M}x0246 = _mm256_add_epi64(_mm256_mul_epi32(vacc${M}x01234567, vmultiplier), vrounding);
Marat Dukhan12809522020-08-02 22:23:51 -0700128
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700129 $for M in range(MR):
130 const __m256i vprod${M}x1357 = _mm256_add_epi64(_mm256_mul_epi32(vacc${M}x11335577, vmultiplier), vrounding);
Marat Dukhan12809522020-08-02 22:23:51 -0700131
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700132 $for M in range(MR):
133 const __m256i vq31prod${M}x0246 = _mm256_srli_epi64(vprod${M}x0246, 31);
134 const __m256i vq31prod${M}x1357 = _mm256_add_epi64(vprod${M}x1357, vprod${M}x1357);
Marat Dukhan12809522020-08-02 22:23:51 -0700135
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700136 $for M in range(MR):
137 const __m256i vq31prod${M}x01234567 = _mm256_blend_epi16(vq31prod${M}x0246, vq31prod${M}x1357, 0xCC);
Marat Dukhan12809522020-08-02 22:23:51 -0700138
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700139 const __m256i vremainder_mask = _mm256_load_si256((const __m256i*) params->gemmlowp_avx2.remainder_mask);
140 $for M in range(MR):
141 const __m256i vrem${M}x01234567 =
142 _mm256_add_epi32(_mm256_and_si256(vq31prod${M}x01234567, vremainder_mask), _mm256_cmpgt_epi32(_mm256_setzero_si256(), vq31prod${M}x01234567));
Marat Dukhan12809522020-08-02 22:23:51 -0700143
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700144 const __m256i vremainder_threshold = _mm256_load_si256((const __m256i*) params->gemmlowp_avx2.remainder_threshold);
145 $if M > 1:
146 const __m128i vshift = _mm_loadl_epi64((const __m128i*) params->gemmlowp_avx2.shift);
147 $else:
148 const __m128i vshift = _mm_load_si128((const __m128i*) params->gemmlowp_avx2.shift);
149 $for M in range(MR):
150 vacc${M}x01234567 =
151 _mm256_sub_epi32(_mm256_sra_epi32(vq31prod${M}x01234567, vshift), _mm256_cmpgt_epi32(vrem${M}x01234567, vremainder_threshold));
Marat Dukhan748fd122021-05-22 20:46:56 -0700152 $else:
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700153 $for M in range(MR):
154 __m256 vscaled${M}x01234567 = _mm256_cvtepi32_ps(vacc${M}x01234567);
Marat Dukhan12809522020-08-02 22:23:51 -0700155
Marat Dukhane06c8132021-06-03 08:59:11 -0700156 $if CHANNELWISE:
157 const __m256 vscale01234567 = _mm256_load_ps(w);
158 w = (const void*) ((uintptr_t) w + 8 * sizeof(float));
159 $for M in range(MR):
160 vscaled${M}x01234567 = _mm256_mul_ps(vscaled${M}x01234567, vscale01234567);
161 $else:
162 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
163 $for M in range(MR):
164 vscaled${M}x01234567 = _mm256_mul_ps(vscaled${M}x01234567, vscale);
Marat Dukhan9b474cf2021-05-25 16:37:48 -0700165
166 $for M in range(MR):
167 vacc${M}x01234567 = _mm256_cvtps_epi32(vscaled${M}x01234567);
168
Marat Dukhane06c8132021-06-03 08:59:11 -0700169 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point);
Marat Dukhan12809522020-08-02 22:23:51 -0700170 $for M in range(0, MR, 2):
171 __m256i vacc${M}${min(M+1, MR-1)}x01234567 = _mm256_adds_epi16(_mm256_packs_epi32(vacc${M}x01234567, vacc${min(M+1, MR-1)}x01234567), voutput_zero_point);
172
173 $for M in range(0, MR, 2):
174 vacc${M}${min(M+1, MR-1)}x01234567 = _mm256_permute4x64_epi64(vacc${M}${min(M+1, MR-1)}x01234567, _MM_SHUFFLE(3, 1, 2, 0));
175
Marat Dukhan12809522020-08-02 22:23:51 -0700176 $if MR > 2:
177 __m256i vout = _mm256_packs_epi16(vacc0${min(1, MR-1)}x01234567, vacc${min(2, MR-1)}${min(3, MR-1)}x01234567);
178 $else:
179 __m256i vout = _mm256_packs_epi16(vacc0${min(1, MR-1)}x01234567, vacc0${min(1, MR-1)}x01234567);
Marat Dukhan748fd122021-05-22 20:46:56 -0700180
Marat Dukhane06c8132021-06-03 08:59:11 -0700181 vout = _mm256_max_epi8(vout, _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_min));
182 vout = _mm256_min_epi8(vout, _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_max));
Marat Dukhan748fd122021-05-22 20:46:56 -0700183
Marat Dukhan12809522020-08-02 22:23:51 -0700184 __m128i vout_lo = _mm256_castsi256_si128(vout);
185 __m128i vout_hi = _mm256_extracti128_si256(vout, 1);
186
187 if (nc >= 8) {
188 $if MR > 3:
189 _mm_storeh_pi((__m64*) c3, _mm_castsi128_ps(vout_hi));
190 $if MR > 2:
191 _mm_storeh_pi((__m64*) c2, _mm_castsi128_ps(vout_lo));
192 $if MR > 1:
193 _mm_storel_epi64((__m128i*) c1, vout_hi);
194 _mm_storel_epi64((__m128i*) c0, vout_lo);
195
196 $for M in reversed(range(MR)):
197 c${M} = (int8_t*) ((uintptr_t) c${M} + cn_stride);
198
199 a = (const int8_t**restrict) ((uintptr_t) a - ks);
200
201 nc -= 8;
202 } else {
203 if (nc & 4) {
204 $if MR > 3:
205 *((uint32_t*) c3) = (uint32_t) _mm_extract_epi32(vout_hi, 2);
206 $if MR > 2:
207 *((uint32_t*) c2) = (uint32_t) _mm_extract_epi32(vout_lo, 2);
208 $if MR > 1:
209 _mm_storeu_si32(c1, vout_hi);
210 _mm_storeu_si32(c0, vout_lo);
211
212 $for M in reversed(range(MR)):
213 c${M} += 4;
214
215 vout_lo = _mm_srli_epi64(vout_lo, 32);
216 vout_hi = _mm_srli_epi64(vout_hi, 32);
217 }
218 if (nc & 2) {
219 $if MR > 3:
220 *((uint16_t*) c3) = (uint16_t) _mm_extract_epi16(vout_hi, 4);
221 $if MR > 2:
222 *((uint16_t*) c2) = (uint16_t) _mm_extract_epi16(vout_lo, 4);
223 $if MR > 1:
224 *((uint16_t*) c1) = (uint16_t) _mm_extract_epi16(vout_hi, 0);
225 *((uint16_t*) c0) = (uint16_t) _mm_extract_epi16(vout_lo, 0);
226
227 $for M in reversed(range(MR)):
228 c${M} += 2;
229
230 vout_lo = _mm_srli_epi32(vout_lo, 16);
231 vout_hi = _mm_srli_epi32(vout_hi, 16);
232 }
233 if (nc & 1) {
234 $if MR > 3:
235 *c3 = (uint8_t) _mm_extract_epi8(vout_hi, 8);
236 $if MR > 2:
237 *c2 = (uint8_t) _mm_extract_epi8(vout_lo, 8);
238 $if MR > 1:
239 *c1 = (uint8_t) _mm_extract_epi8(vout_hi, 0);
240 *c0 = (int8_t) _mm_extract_epi8(vout_lo, 0);
241 }
242
243 nc = 0;
244 }
245 } while (nc != 0);
246}