blob: ca59dc74cd1907e10e7901070c93b22e70eede83 [file] [log] [blame]
Alex Stark6180f1f2020-01-10 13:24:08 -05001/* Copyright 2019 Google LLC. All Rights Reserved.
2
3Licensed under the Apache License, Version 2.0 (the "License");
4you may not use this file except in compliance with the License.
5You may obtain a copy of the License at
6
7 http://www.apache.org/licenses/LICENSE-2.0
8
9Unless required by applicable law or agreed to in writing, software
10distributed under the License is distributed on an "AS IS" BASIS,
11WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12See the License for the specific language governing permissions and
13limitations under the License.
14==============================================================================*/
15
16#include <cstdint>
17#include <cstring>
18
Alex Stark6180f1f2020-01-10 13:24:08 -050019#include "check_macros.h"
20#include "matrix.h"
21#include "opt_set.h"
22#include "pack.h"
23#include "path.h"
24#include "platform.h"
Benoit Jacobd1a14aa2020-01-14 13:28:47 -050025#include "profiler/instrumentation.h"
Alex Stark6180f1f2020-01-10 13:24:08 -050026
27#if RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
28#include <immintrin.h> // IWYU pragma: keep
29#endif
30
31namespace ruy {
32
33#if !(RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM))
34
35void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor,
36 const std::int8_t* zerobuf, int src_stride,
37 int remaining_src_cols, int src_rows,
38 std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
39 // CPU-ID-based checks should disable the path that would reach this point.
40 RUY_DCHECK(false);
41}
42
43void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride,
44 int remaining_src_cols, int src_rows, float* packed_ptr) {
45 // CPU-ID-based checks should disable the path that would reach this point.
46 RUY_DCHECK(false);
47}
48
49#else // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_ASM)
50
51// The first int8_t template parameter is arbitrary: this routine is common to
52// all 8-bit source matrix types.
53using PackImpl8bitSse42 =
54 PackImpl<Path::kSse42, FixedKernelLayout<Order::kColMajor, 4, 8>,
55 std::int8_t, std::int8_t, std::int32_t>;
56
57using PackImplFloatSse42 =
58 PackImpl<Path::kSse42, FixedKernelLayout<Order::kRowMajor, 1, 8>, float,
59 float, float>;
60
61namespace {
62
63inline void Pack8bitSse42Packer(const std::int8_t* src_ptr,
64 std::int8_t input_xor,
65 const std::int8_t* zerobuf, int src_stride,
66 int remaining_src_cols, int src_rows,
67 std::int8_t* packed_ptr, std::int32_t* sums_ptr,
68 std::int8_t* trailing_buf) {
69 using Layout = PackImpl8bitSse42::Layout;
70 RUY_DCHECK_EQ(Layout::kCols, 8);
71 RUY_DCHECK_EQ(Layout::kRows, 4);
72 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
73 // We process 8 of these chunks at a time, padding short input chunks.
74 constexpr int kNumRowChunks = 8;
75 constexpr int kNumChunkedSrcRows = kNumRowChunks * Layout::kRows;
76
77 std::int8_t in_data[Layout::kCols][kNumRowChunks][Layout::kRows];
78
79 const std::int8_t* src_ptr0 = src_ptr;
80 const std::int8_t* src_ptr1 = src_ptr0 + src_stride;
81 const std::int8_t* src_ptr2 = src_ptr1 + src_stride;
82 const std::int8_t* src_ptr3 = src_ptr2 + src_stride;
83 const std::int8_t* src_ptr4 = src_ptr3 + src_stride;
84 const std::int8_t* src_ptr5 = src_ptr4 + src_stride;
85 const std::int8_t* src_ptr6 = src_ptr5 + src_stride;
86 const std::int8_t* src_ptr7 = src_ptr6 + src_stride;
87 std::int64_t src_inc0 = kNumChunkedSrcRows;
88 std::int64_t src_inc1 = kNumChunkedSrcRows;
89 std::int64_t src_inc2 = kNumChunkedSrcRows;
90 std::int64_t src_inc3 = kNumChunkedSrcRows;
91 std::int64_t src_inc4 = kNumChunkedSrcRows;
92 std::int64_t src_inc5 = kNumChunkedSrcRows;
93 std::int64_t src_inc6 = kNumChunkedSrcRows;
94 std::int64_t src_inc7 = kNumChunkedSrcRows;
95 // Handle cases where source does not have Layout::kCols (8) columns.
96 if (remaining_src_cols < 8) {
97 if (remaining_src_cols <= 0) {
98 src_ptr0 = zerobuf;
99 src_inc0 = 0;
100 }
101 if (remaining_src_cols <= 1) {
102 src_ptr1 = zerobuf;
103 src_inc1 = 0;
104 }
105 if (remaining_src_cols <= 2) {
106 src_ptr2 = zerobuf;
107 src_inc2 = 0;
108 }
109 if (remaining_src_cols <= 3) {
110 src_ptr3 = zerobuf;
111 src_inc3 = 0;
112 }
113 if (remaining_src_cols <= 4) {
114 src_ptr4 = zerobuf;
115 src_inc4 = 0;
116 }
117 if (remaining_src_cols <= 5) {
118 src_ptr5 = zerobuf;
119 src_inc5 = 0;
120 }
121 if (remaining_src_cols <= 6) {
122 src_ptr6 = zerobuf;
123 src_inc6 = 0;
124 }
125 src_ptr7 = zerobuf;
126 src_inc7 = 0;
127 }
128
129 const std::int8_t zero_point = zerobuf[0];
130
131 if (sums_ptr) {
132 // i: Layout::kCols.
133 for (int i = 0; i < 8; ++i) {
134 sums_ptr[i] = 0;
135 }
136 }
137
138 // The overall packing effectively pads the source rows to
139 // (src_rows + 63) & ~63. The iteration over k may skip when m=1, and then we
140 // only pack for (src_rows + 31) & ~31. When there is an incomplete
141 // destination block, this is stored into trailing_buf instead of packed_ptr.
142 for (int k = 0; k < src_rows; k += kNumChunkedSrcRows) {
143 // Available source rows.
144 // If this is less than 0 (for m=1), we skip, having filled trailing
145 // buffer for m=0. Also, if source rows is zero on m=1, then we filled
146 // exactly to the end of the column in the packed buffer.
147 const int available_src_rows = src_rows - k;
148 // Effectively,
149 // available rows = std::max(0, std::min(8, src_rows - k));
150 // treat each case separately.
151 if (available_src_rows >= kNumChunkedSrcRows) {
152 // i: chunks, s: Layout::Rows.
153 for (int i = 0; i < 8; ++i) {
154 for (int s = 0; s < 4; ++s) {
155 in_data[0][i][s] = src_ptr0[i * 4 + s];
156 in_data[1][i][s] = src_ptr1[i * 4 + s];
157 in_data[2][i][s] = src_ptr2[i * 4 + s];
158 in_data[3][i][s] = src_ptr3[i * 4 + s];
159 in_data[4][i][s] = src_ptr4[i * 4 + s];
160 in_data[5][i][s] = src_ptr5[i * 4 + s];
161 in_data[6][i][s] = src_ptr6[i * 4 + s];
162 in_data[7][i][s] = src_ptr7[i * 4 + s];
163 }
164 }
165 // i: chunks, j: Layout::kCols, s: Layout::Rows.
166 for (int i = 0; i < 8; ++i) {
167 for (int j = 0; j < 8; ++j) {
168 for (int s = 0; s < 4; ++s) {
169 // 8 * 4 * i is offset for each block, that is
170 // (Layout::kCols * Layout::kRows * i)
171 packed_ptr[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
172 }
173 if (sums_ptr) {
174 for (int s = 0; s < 4; ++s) {
175 sums_ptr[j] += in_data[j][i][s] ^ input_xor;
176 }
177 }
178 }
179 }
180 } else if (available_src_rows > 0) {
181 RUY_DCHECK_LT(available_src_rows, kNumChunkedSrcRows);
182 int i = 0;
183 // Consume chunks of 4 rows that are complete.
184 for (; i < (available_src_rows >> 2); ++i) {
185 for (int s = 0; s < 4; ++s) {
186 in_data[0][i][s] = src_ptr0[i * 4 + s];
187 in_data[1][i][s] = src_ptr1[i * 4 + s];
188 in_data[2][i][s] = src_ptr2[i * 4 + s];
189 in_data[3][i][s] = src_ptr3[i * 4 + s];
190 in_data[4][i][s] = src_ptr4[i * 4 + s];
191 in_data[5][i][s] = src_ptr5[i * 4 + s];
192 in_data[6][i][s] = src_ptr6[i * 4 + s];
193 in_data[7][i][s] = src_ptr7[i * 4 + s];
194 }
195 }
196 // Consume any incomplete chunk.
197 if (i < ((available_src_rows + 3) >> 2)) {
198 int s = 0;
199 for (; s < (available_src_rows & 3); ++s) {
200 in_data[0][i][s] = src_ptr0[i * 4 + s];
201 in_data[1][i][s] = src_ptr1[i * 4 + s];
202 in_data[2][i][s] = src_ptr2[i * 4 + s];
203 in_data[3][i][s] = src_ptr3[i * 4 + s];
204 in_data[4][i][s] = src_ptr4[i * 4 + s];
205 in_data[5][i][s] = src_ptr5[i * 4 + s];
206 in_data[6][i][s] = src_ptr6[i * 4 + s];
207 in_data[7][i][s] = src_ptr7[i * 4 + s];
208 }
209 RUY_DCHECK_LE(s, 4);
210 for (; s < 4; ++s) {
211 // j: Layout::kCols.
212 for (int j = 0; j < 8; ++j) {
213 in_data[j][i][s] = zero_point;
214 }
215 }
216 ++i;
217 }
218 // We do not care what goes into the trailing buffer, but we want
219 // in_data[...] ^ input_xor == 0 for irrelevant values in the summation.
220 //
221 // It might prove better in optimized code to pad uniformly with
222 // zero_point, and compensate by initializing the summations with the
223 // compensating offset, effectively
224 // ((input_xor - zero_point) ^ input_xor) *
225 // 4 * (8 - ((available_src_rows + 3) >> 2)).
226 for (; i < 8; ++i) {
227 for (int s = 0; s < 4; ++s) {
228 for (int j = 0; j < 8; ++j) {
229 in_data[j][i][s] = input_xor;
230 }
231 }
232 }
233 // We loop through [0, 8) rather than
234 // [0, (available_src_rows + 3) >> 2), since that emulates what we might
235 // do in fully-optimized code.
236 //
237 // i: chunks, j: Layout::kCols, s: Layout::Rows.
238 if (sums_ptr) {
239 for (int i = 0; i < 8; ++i) {
240 for (int j = 0; j < 8; ++j) {
241 for (int s = 0; s < 4; ++s) {
242 trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
243 sums_ptr[j] = sums_ptr[j] + (in_data[j][i][s] ^ input_xor);
244 }
245 }
246 }
247 } else {
248 for (int i = 0; i < 8; ++i) {
249 for (int j = 0; j < 8; ++j) {
250 for (int s = 0; s < 4; ++s) {
251 trailing_buf[(8 * i + j) * 4 + s] = in_data[j][i][s] ^ input_xor;
252 }
253 }
254 }
255 }
256 }
257
258 packed_ptr += 8 * kNumChunkedSrcRows;
259 src_ptr0 += src_inc0;
260 src_ptr1 += src_inc1;
261 src_ptr2 += src_inc2;
262 src_ptr3 += src_inc3;
263 src_ptr4 += src_inc4;
264 src_ptr5 += src_inc5;
265 src_ptr6 += src_inc6;
266 src_ptr7 += src_inc7;
267 }
268}
269
270inline void PackFloatSse42Packer(const float* src_ptr, const float* zerobuf,
271 int src_stride, int remaining_src_cols,
272 int src_rows, float* packed_ptr,
273 float* trailing_buf) {
274 using Layout = PackImplFloatSse42::Layout;
275 RUY_DCHECK_EQ(Layout::kCols, 8);
276 RUY_DCHECK_EQ(Layout::kRows, 1);
277
278 // This packing amounts to tranposition of 8x8 blocks.
279 static constexpr int kPackCols = 8; // Source cols packed together.
280 static constexpr int kPackRows = 8; // Short input is padded.
281
282 float in_data[kPackCols][kPackRows];
283
284 const float* src_ptr0 = src_ptr;
285 const float* src_ptr1 = src_ptr0 + src_stride;
286 const float* src_ptr2 = src_ptr1 + src_stride;
287 const float* src_ptr3 = src_ptr2 + src_stride;
288 const float* src_ptr4 = src_ptr3 + src_stride;
289 const float* src_ptr5 = src_ptr4 + src_stride;
290 const float* src_ptr6 = src_ptr5 + src_stride;
291 const float* src_ptr7 = src_ptr6 + src_stride;
292 std::int64_t src_inc0 = 8;
293 std::int64_t src_inc1 = 8;
294 std::int64_t src_inc2 = 8;
295 std::int64_t src_inc3 = 8;
296 std::int64_t src_inc4 = 8;
297 std::int64_t src_inc5 = 8;
298 std::int64_t src_inc6 = 8;
299 std::int64_t src_inc7 = 8;
300 // Handle cases where source does not have kPackDim (8) columns.
301 if (remaining_src_cols < kPackCols) {
302 if (remaining_src_cols <= 0) {
303 src_ptr0 = zerobuf;
304 src_inc0 = 0;
305 }
306 if (remaining_src_cols <= 1) {
307 src_ptr1 = zerobuf;
308 src_inc1 = 0;
309 }
310 if (remaining_src_cols <= 2) {
311 src_ptr2 = zerobuf;
312 src_inc2 = 0;
313 }
314 if (remaining_src_cols <= 3) {
315 src_ptr3 = zerobuf;
316 src_inc3 = 0;
317 }
318 if (remaining_src_cols <= 4) {
319 src_ptr4 = zerobuf;
320 src_inc4 = 0;
321 }
322 if (remaining_src_cols <= 5) {
323 src_ptr5 = zerobuf;
324 src_inc5 = 0;
325 }
326 if (remaining_src_cols <= 6) {
327 src_ptr6 = zerobuf;
328 src_inc6 = 0;
329 }
330 src_ptr7 = zerobuf;
331 src_inc7 = 0;
332 }
333
334 for (int k = 0; k < src_rows; k += kPackRows) {
335 const int available_src_rows = src_rows - k;
336 // Effectively,
337 // available_src_rows = std::max(0, std::min(kPackDim, src_rows - k));
338 // but treat each case separately.
339 if (available_src_rows >= kPackRows) {
340 for (int i = 0; i < 8; ++i) {
341 in_data[0][i] = src_ptr0[i];
342 in_data[1][i] = src_ptr1[i];
343 in_data[2][i] = src_ptr2[i];
344 in_data[3][i] = src_ptr3[i];
345 in_data[4][i] = src_ptr4[i];
346 in_data[5][i] = src_ptr5[i];
347 in_data[6][i] = src_ptr6[i];
348 in_data[7][i] = src_ptr7[i];
349 }
350 for (int i = 0; i < 8; ++i) {
351 for (int j = 0; j < 8; ++j) {
352 packed_ptr[8 * i + j] = in_data[j][i];
353 }
354 }
355 } else if (available_src_rows > 0) {
356 for (int i = 0; i < available_src_rows; ++i) {
357 in_data[0][i] = src_ptr0[i];
358 in_data[1][i] = src_ptr1[i];
359 in_data[2][i] = src_ptr2[i];
360 in_data[3][i] = src_ptr3[i];
361 in_data[4][i] = src_ptr4[i];
362 in_data[5][i] = src_ptr5[i];
363 in_data[6][i] = src_ptr6[i];
364 in_data[7][i] = src_ptr7[i];
365 }
366 for (int i = available_src_rows; i < kPackRows; ++i) {
367 in_data[0][i] = 0.0f;
368 in_data[1][i] = 0.0f;
369 in_data[2][i] = 0.0f;
370 in_data[3][i] = 0.0f;
371 in_data[4][i] = 0.0f;
372 in_data[5][i] = 0.0f;
373 in_data[6][i] = 0.0f;
374 in_data[7][i] = 0.0f;
375 }
376 // We loop through [0, 7) rather than [0, packed_rows), since that
377 // emulates what we might do in fully-optimized code.
378 // i: (kPackRows - 1), j: kPackCols.
379 for (int i = 0; i < 7; ++i) {
380 for (int j = 0; j < 8; ++j) {
381 trailing_buf[kPackRows * i + j] = in_data[j][i];
382 }
383 }
384 }
385
386 packed_ptr += kPackRows * kPackCols;
387 src_ptr0 += src_inc0;
388 src_ptr1 += src_inc1;
389 src_ptr2 += src_inc2;
390 src_ptr3 += src_inc3;
391 src_ptr4 += src_inc4;
392 src_ptr5 += src_inc5;
393 src_ptr6 += src_inc6;
394 src_ptr7 += src_inc7;
395 }
396}
397
398} // namespace.
399
400// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
401// Optimization is not finished. In particular the dimensions of the kernel
402// blocks can be changed as desired.
403//
404// When removing this comment, update profiling label below.
405void Pack8bitSse42(const std::int8_t* src_ptr, std::int8_t input_xor,
406 const std::int8_t* zerobuf, int src_stride,
407 int remaining_src_cols, int src_rows,
408 std::int8_t* packed_ptr, std::int32_t* sums_ptr) {
Benoit Jacobd1a14aa2020-01-14 13:28:47 -0500409 profiler::ScopeLabel label("Pack kSse42 8bit (UNFINISHED)");
Alex Stark6180f1f2020-01-10 13:24:08 -0500410
411 using Layout = PackImpl8bitSse42::Layout;
412 RUY_DCHECK_EQ(Layout::kCols, 8);
413 RUY_DCHECK_EQ(Layout::kRows, 4);
414
415 // Each Layout::Rows is 4 contiguous input, contiguous packed elements.
416 // We process 8 of these chunks at a time, padding short input chunks.
417 static constexpr int kNumRowChunks = 8; // Short input is padded.
418
419 // Each packed block is 4*8, and there are normally 8. The trailing block is
420 // only slightly shorter.
421 constexpr int kTrailingBufSize =
422 kNumRowChunks * Layout::kCols * Layout::kRows;
423 std::int8_t trailing_buf[kTrailingBufSize];
424 memset(trailing_buf, 0, kTrailingBufSize * sizeof(std::int8_t));
425
426 Pack8bitSse42Packer(src_ptr, input_xor, zerobuf, src_stride,
427 remaining_src_cols, src_rows, packed_ptr, sums_ptr,
428 trailing_buf);
429
430 constexpr int kChunkedRowMask = kNumRowChunks * Layout::kRows - 1;
431 const bool trailing_data = (src_rows & kChunkedRowMask) > 0;
432 // If the number of source rows is not a multiple of kChunkedRowMask, there
433 // will be data in the trailing buffer,
434 if (trailing_data > 0) {
435 const int non_trailing_rows = src_rows & ~kChunkedRowMask;
436 // Destination "rows" are padded to next highest multiple of Layout::kRows.
437 const int dst_rows = (src_rows + 3) & ~3;
438 const int trailing_rows = dst_rows - non_trailing_rows;
439 memcpy(packed_ptr + Layout::kCols * non_trailing_rows, trailing_buf,
440 Layout::kCols * trailing_rows * sizeof(std::int8_t));
441 }
442}
443
444// TODO(b/147376783): SSE 4.2 and AVX-VNNI support is incomplete / placeholder.
445// Optimization is not finished. In particular the dimensions of the kernel
446// blocks can be changed as desired.
447//
448// When removing this comment, update profiling label below.
449void PackFloatSse42(const float* src_ptr, const float* zerobuf, int src_stride,
450 int remaining_src_cols, int src_rows, float* packed_ptr) {
Benoit Jacobd1a14aa2020-01-14 13:28:47 -0500451 profiler::ScopeLabel label("Pack kSse42 float (UNFINISHED)");
Alex Stark6180f1f2020-01-10 13:24:08 -0500452 static constexpr int kPackCols = 8; // Source cols packed together.
453 static constexpr int kPackRows = 8; // Short input is padded.
454 float trailing_buf[(kPackRows - 1) * kPackCols];
455 if (remaining_src_cols < 8) {
456 memset(trailing_buf, 0, sizeof(trailing_buf));
457 }
458 PackFloatSse42Packer(src_ptr, zerobuf, src_stride, remaining_src_cols,
459 src_rows, packed_ptr, trailing_buf);
460
461 const int trailing_rows = src_rows & (kPackRows - 1);
462 if (trailing_rows > 0) {
463 const int non_trailing_rows = src_rows & ~(kPackRows - 1);
464 memcpy(packed_ptr + kPackCols * non_trailing_rows, trailing_buf,
465 kPackCols * trailing_rows * sizeof(float));
466 }
467}
468
469#endif // RUY_PLATFORM(SSE42) && RUY_OPT_ENABLED(RUY_OPT_INTRINSICS)
470
471} // namespace ruy