blob: 3e129c03d5331d92a602ccda5480a04ebf9468ad [file] [log] [blame]
Marat Dukhan881ab022021-07-28 13:49:26 -07001// Copyright 2021 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6$ABC = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZ"
7$assert REQUANTIZATION == "FP32"
8$assert DATATYPE in ["QC8", "QS8"]
9$assert CHANNEL_TILE % 16 == 0
10$assert CHANNEL_TILE >= 16
11$assert KERNEL_TILE >= 2
12#include <assert.h>
13
14#include <immintrin.h>
15
16#include <xnnpack/dwconv.h>
17
18
19$PARAMS_STRUCT = "avx2" if DATATYPE == "QC8" else REQUANTIZATION.lower() + "_avx2"
20$PARAMS_UNION = "xnn_qs8_minmax_params" if DATATYPE == "QC8" else "xnn_qs8_conv_minmax_params"
Marat Dukhan60bb7ec2021-07-28 18:51:28 -070021void xnn_${DATATYPE.lower()}_dwconv_minmax_${REQUANTIZATION.lower()}_ukernel_up${CHANNEL_TILE}x${KERNEL_TILE}__avx2_mul16${"_add16" if ADD16 else ""}_vpunpck(
Marat Dukhan881ab022021-07-28 13:49:26 -070022 size_t channels,
23 size_t output_width,
24 const int8_t** input,
25 const void* weights,
26 int8_t* output,
27 size_t input_stride,
28 size_t output_increment,
29 size_t input_offset,
30 const int8_t* zero,
Marat Dukhan7be427a2021-12-13 23:38:20 -080031 const union ${PARAMS_UNION} params[restrict XNN_MIN_ELEMENTS(1)]) XNN_OOB_READS
Marat Dukhan881ab022021-07-28 13:49:26 -070032{
33 assert(channels != 0);
34 assert(output_width != 0);
35
36 do {
37 $for K in range(KERNEL_TILE):
38 const int8_t* i${K} = input[${K}];
39 assert(i${K} != NULL);
40 if XNN_UNPREDICTABLE(i${K} != zero) {
41 i${K} = (const int8_t*) ((uintptr_t) i${K} + input_offset);
42 }
43 input = (const int8_t**) ((uintptr_t) input + input_stride);
44
45 size_t c = channels;
46 const void* w = weights;
47 for (; c >= ${CHANNEL_TILE}; c -= ${CHANNEL_TILE}) {
48 __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
49 $for C in range(8, CHANNEL_TILE, 8):
50 __m256i vacc${ABC[C:C+8]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + ${C} * sizeof(int32_t)));
51
52 $for C in range(0, CHANNEL_TILE, 16):
53 __m256i vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_inserti128_si256(vacc${ABC[C:C+8]}, _mm256_castsi256_si128(vacc${ABC[C+8:C+16]}), 1);
54 __m256i vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}, 0x31);
55
56 $for K in range(KERNEL_TILE):
57
58 $for C in range(0, CHANNEL_TILE, 16):
59 $if C == 0:
60 const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K}));
61 $else:
62 const __m256i vi${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (i${K} + ${C})));
63 const __m256i vk${K}x${ABC[C:C+16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE + C} * sizeof(int8_t))));
64 i${K} += ${CHANNEL_TILE};
65
Marat Dukhan60bb7ec2021-07-28 18:51:28 -070066 $if ADD16:
67 $for C in range(0, CHANNEL_TILE, 16):
68 $if K == 0:
69 __m256i vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
70 $elif K % 2 == 0 or K + 1 == KERNEL_TILE:
71 vacc${ABC[C:C+16]} = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
72 $else:
73 vacc${ABC[C:C+16]} = _mm256_add_epi16(vacc${ABC[C:C+16]}, _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]}));
Marat Dukhan881ab022021-07-28 13:49:26 -070074
Marat Dukhan60bb7ec2021-07-28 18:51:28 -070075 $if K % 2 == 1 or K + 1 == KERNEL_TILE:
76 $for C in range(0, CHANNEL_TILE, 16):
77 $if K == 1:
78 __m256i vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15);
79 $else:
80 vsignacc${ABC[C:C+16]} = _mm256_srai_epi16(vacc${ABC[C:C+16]}, 15);
81 vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]}));
82 vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vacc${ABC[C:C+16]}, vsignacc${ABC[C:C+16]}));
83 $else:
84 $for C in range(0, CHANNEL_TILE, 16):
85 const __m256i vprod${K}x${ABC[C:C+16]}lo = _mm256_mullo_epi16(vi${K}x${ABC[C:C+16]}, vk${K}x${ABC[C:C+16]});
86 const __m256i vprod${K}x${ABC[C:C+16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[C:C+16]}lo, 15);
87
88 $for C in range(0, CHANNEL_TILE, 16):
89 vacc${ABC[C:C+4]}${ABC[C+8:C+12]} = _mm256_add_epi32(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi));
90 vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_add_epi32(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[C:C+16]}lo, vprod${K}x${ABC[C:C+16]}hi));
Marat Dukhan881ab022021-07-28 13:49:26 -070091
92 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${KERNEL_TILE * CHANNEL_TILE} * sizeof(int8_t));
93
94 $for C in range(0, CHANNEL_TILE, 16):
95 vacc${ABC[C:C+8]} = _mm256_inserti128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, _mm256_castsi256_si128(vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}), 1);
96 vacc${ABC[C+8:C+16]} = _mm256_permute2x128_si256(vacc${ABC[C:C+4]}${ABC[C+8:C+12]}, vacc${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 0x31);
97
98 $for C in range(0, CHANNEL_TILE, 8):
99 __m256 vfpacc${ABC[C:C+8]} = _mm256_cvtepi32_ps(vacc${ABC[C:C+8]});
100
101 $if DATATYPE == "QC8":
102 const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) w);
103 $for C in range(8, CHANNEL_TILE, 8):
104 const __m256 vscale${ABC[C:C+8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${C} * sizeof(float)));
105 w = (const void*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(float));
106 $for C in range(0, CHANNEL_TILE, 8):
107 vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale${ABC[C:C+8]});
108 $else:
109 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
110 $for C in range(0, CHANNEL_TILE, 8):
111 vfpacc${ABC[C:C+8]} = _mm256_mul_ps(vfpacc${ABC[C:C+8]}, vscale);
112
Marat Dukhan13c9f8d2021-12-06 02:21:03 -0800113 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
114 $for C in range(0, CHANNEL_TILE, 8):
115 vfpacc${ABC[C:C+8]} = _mm256_min_ps(vfpacc${ABC[C:C+8]}, voutput_max_less_zero_point);
116
Marat Dukhan881ab022021-07-28 13:49:26 -0700117 $for C in range(0, CHANNEL_TILE, 8):
118 vacc${ABC[C:C+8]} = _mm256_cvtps_epi32(vfpacc${ABC[C:C+8]});
119
120 const __m256i voutput_zero_point = _mm256_load_si256((const __m256i*) params->${PARAMS_STRUCT}.output_zero_point);
121 $for C in range(0, CHANNEL_TILE, 16):
122 const __m256i vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]} = _mm256_adds_epi16(_mm256_packs_epi32(vacc${ABC[C:C+8]}, vacc${ABC[C+8:C+16]}), voutput_zero_point);
123
124 $for C in range(0, CHANNEL_TILE, 16):
125 __m128i vout${ABC[C:C+16]} = _mm_shuffle_epi32(_mm_packs_epi16(_mm256_castsi256_si128(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}), _mm256_extracti128_si256(vout${ABC[C:C+4]}${ABC[C+8:C+12]}${ABC[C+4:C+8]}${ABC[C+12:C+16]}, 1)), _MM_SHUFFLE(3, 1, 2, 0));
126
127 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
Marat Dukhan881ab022021-07-28 13:49:26 -0700128 $for C in range(0, CHANNEL_TILE, 16):
129 vout${ABC[C:C+16]} = _mm_max_epi8(vout${ABC[C:C+16]}, voutput_min);
Marat Dukhan881ab022021-07-28 13:49:26 -0700130
131 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
132 $for C in range(16, CHANNEL_TILE, 16):
133 _mm_storeu_si128((__m128i*) (output + ${C}), vout${ABC[C:C+16]});
134 output += ${CHANNEL_TILE};
135 }
136 if XNN_UNLIKELY(c != 0) {
137 $if CHANNEL_TILE > 16:
138 const int8_t* k = (const int8_t*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t));
139 ${"do " if CHANNEL_TILE > 16 else ""}{
140 __m256i vacc${ABC[0:8]} = _mm256_loadu_si256((const __m256i*) w);
141 __m256i vacc${ABC[8:16]} = _mm256_loadu_si256((const __m256i*) ((uintptr_t) w + 8 * sizeof(int32_t)));
142
143 __m256i vacc${ABC[0:4]}${ABC[8:12]} = _mm256_inserti128_si256(vacc${ABC[0:8]}, _mm256_castsi256_si128(vacc${ABC[8:16]}), 1);
144 __m256i vacc${ABC[4:8]}${ABC[12:16]} = _mm256_permute2x128_si256(vacc${ABC[0:8]}, vacc${ABC[8:16]}, 0x31);
145
146 $for K in range(KERNEL_TILE):
147
148 const __m256i vi${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) i${K}));
149 $if CHANNEL_TILE > 16:
150 $if K == 0:
151 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) k));
152 $else:
153 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) (k + ${K * CHANNEL_TILE})));
154 $else:
155 const __m256i vk${K}x${ABC[0:16]} = _mm256_cvtepi8_epi16(_mm_loadu_si128((const __m128i*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${K * CHANNEL_TILE} * sizeof(int8_t))));
156 $if CHANNEL_TILE > 16:
157 i${K} += 16;
158
159 const __m256i vprod${K}x${ABC[0:16]}lo = _mm256_mullo_epi16(vi${K}x${ABC[0:16]}, vk${K}x${ABC[0:16]});
160 const __m256i vprod${K}x${ABC[0:16]}hi = _mm256_srai_epi16(vprod${K}x${ABC[0:16]}lo, 15);
161
162 vacc${ABC[0:4]}${ABC[8:12]} = _mm256_add_epi32(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_unpacklo_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi));
163 vacc${ABC[4:8]}${ABC[12:16]} = _mm256_add_epi32(vacc${ABC[4:8]}${ABC[12:16]}, _mm256_unpackhi_epi16(vprod${K}x${ABC[0:16]}lo, vprod${K}x${ABC[0:16]}hi));
164
165 vacc${ABC[0:8]} = _mm256_inserti128_si256(vacc${ABC[0:4]}${ABC[8:12]}, _mm256_castsi256_si128(vacc${ABC[4:8]}${ABC[12:16]}), 1);
166 vacc${ABC[8:16]} = _mm256_permute2x128_si256(vacc${ABC[0:4]}${ABC[8:12]}, vacc${ABC[4:8]}${ABC[12:16]}, 0x31);
167
168 $if CHANNEL_TILE > 16:
169 k += 16;
170
171 __m256 vfpacc${ABC[0:8]} = _mm256_cvtepi32_ps(vacc${ABC[0:8]});
172 __m256 vfpacc${ABC[8:16]} = _mm256_cvtepi32_ps(vacc${ABC[8:16]});
173
174 $if DATATYPE == "QC8":
175 const __m256 vscale${ABC[0:8]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t)));
176 const __m256 vscale${ABC[8:16]} = _mm256_loadu_ps((const float*) ((uintptr_t) w + ${CHANNEL_TILE} * sizeof(int32_t) + ${CHANNEL_TILE * KERNEL_TILE} * sizeof(int8_t) + 8 * sizeof(float)));
177 vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale${ABC[0:8]});
178 vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale${ABC[8:16]});
179 $else:
180 const __m256 vscale = _mm256_load_ps(params->fp32_avx2.scale);
181 vfpacc${ABC[0:8]} = _mm256_mul_ps(vfpacc${ABC[0:8]}, vscale);
182 vfpacc${ABC[8:16]} = _mm256_mul_ps(vfpacc${ABC[8:16]}, vscale);
183
Marat Dukhan13c9f8d2021-12-06 02:21:03 -0800184 const __m256 voutput_max_less_zero_point = _mm256_load_ps(params->${PARAMS_STRUCT}.output_max_less_zero_point);
185 vfpacc${ABC[0:8]} = _mm256_min_ps(vfpacc${ABC[0:8]}, voutput_max_less_zero_point);
186 vfpacc${ABC[8:16]} = _mm256_min_ps(vfpacc${ABC[8:16]}, voutput_max_less_zero_point);
187
Marat Dukhan881ab022021-07-28 13:49:26 -0700188 vacc${ABC[0:8]} = _mm256_cvtps_epi32(vfpacc${ABC[0:8]});
189 vacc${ABC[8:16]} = _mm256_cvtps_epi32(vfpacc${ABC[8:16]});
190
191 $if CHANNEL_TILE > 16:
192 w = (const void*) ((uintptr_t) w + 16 * sizeof(int32_t));
193
194 const __m128i voutput_zero_point = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_zero_point);
195 __m128i vout${ABC[0:8]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[0:8]}), _mm256_extracti128_si256(vacc${ABC[0:8]}, 1)), voutput_zero_point);
196 __m128i vout${ABC[8:16]} = _mm_adds_epi16(_mm_packs_epi32(_mm256_castsi256_si128(vacc${ABC[8:16]}), _mm256_extracti128_si256(vacc${ABC[8:16]}, 1)), voutput_zero_point);
197
198 const __m128i voutput_min = _mm_load_si128((const __m128i*) params->${PARAMS_STRUCT}.output_min);
Marat Dukhan881ab022021-07-28 13:49:26 -0700199
200 __m128i vout${ABC[0:16]} = _mm_packs_epi16(vout${ABC[0:8]}, vout${ABC[8:16]});
Marat Dukhan13c9f8d2021-12-06 02:21:03 -0800201 vout${ABC[0:16]} = _mm_max_epi8(vout${ABC[0:16]}, voutput_min);
Marat Dukhan881ab022021-07-28 13:49:26 -0700202
203 $if CHANNEL_TILE > 16:
204 if XNN_LIKELY(c >= 16) {
205 _mm_storeu_si128((__m128i*) output, vout${ABC[0:16]});
206 output += 16;
207 c -= 16;
208 } else {
209 if (c & 8) {
210 _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]});
211 vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]});
212 output += 8;
213 }
214 if (c & 4) {
215 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]});
216 vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32);
217 output += 4;
218 }
219 if (c & 2) {
220 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0);
221 vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16);
222 output += 2;
223 }
224 if (c & 1) {
225 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0);
226 output += 1;
227 }
228 c = 0;
229 }
230 $else:
231 if (c & 8) {
232 _mm_storel_epi64((__m128i*) output, vout${ABC[0:16]});
233 vout${ABC[0:16]} = _mm_unpackhi_epi64(vout${ABC[0:16]}, vout${ABC[0:16]});
234 output += 8;
235 }
236 if (c & 4) {
237 *((uint32_t*) output) = (uint32_t) _mm_cvtsi128_si32(vout${ABC[0:16]});
238 vout${ABC[0:16]} = _mm_srli_epi64(vout${ABC[0:16]}, 32);
239 output += 4;
240 }
241 if (c & 2) {
242 *((uint16_t*) output) = (uint16_t) _mm_extract_epi16(vout${ABC[0:16]}, 0);
243 vout${ABC[0:16]} = _mm_srli_epi32(vout${ABC[0:16]}, 16);
244 output += 2;
245 }
246 if (c & 1) {
247 *output = (int8_t) _mm_extract_epi8(vout${ABC[0:16]}, 0);
248 output += 1;
249 }
250 }${" while (c != 0);" if CHANNEL_TILE > 16 else ""}
251 }
252
253 output = (int8_t*) ((uintptr_t) output + output_increment);
254 } while (--output_width != 0);
255}