blob: a4948035e95fc38930deebfe4c1d4e6224ec2b40 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstddef>
17#include <cstdlib>
18#include <functional>
19#include <random>
20#include <vector>
21
22#include <fp16.h>
23
24#include <xnnpack.h>
25#include <xnnpack/AlignedAllocator.h>
26#include <xnnpack/params.h>
27#include <xnnpack/pack.h>
28#include <xnnpack/requantization.h>
29
30
31class GemmMicrokernelTester {
32 public:
33 enum class Variant {
34 Native,
35 Scalar,
36 };
37
38 inline GemmMicrokernelTester& mr(size_t mr) {
39 this->mr_ = mr;
40 return *this;
41 }
42
43 inline size_t mr() const {
44 return this->mr_;
45 }
46
47 inline GemmMicrokernelTester& nr(size_t nr) {
48 this->nr_ = nr;
49 return *this;
50 }
51
52 inline size_t nr() const {
53 return this->nr_;
54 }
55
56
57 inline GemmMicrokernelTester& kr(size_t kr) {
58 this->kr_ = kr;
59 return *this;
60 }
61
62 inline size_t kr() const {
63 return this->kr_;
64 }
65
66 inline GemmMicrokernelTester& sr(size_t sr) {
67 this->sr_ = sr;
68 return *this;
69 }
70
71 inline size_t sr() const {
72 return this->sr_;
73 }
74
75 inline GemmMicrokernelTester& m(size_t m) {
76 this->m_ = m;
77 return *this;
78 }
79
80 inline size_t m() const {
81 return this->m_;
82 }
83
84 inline GemmMicrokernelTester& n(size_t n) {
85 this->n_ = n;
86 return *this;
87 }
88
89 inline size_t n() const {
90 return this->n_;
91 }
92
93 inline GemmMicrokernelTester& k(size_t k) {
94 this->k_ = k;
95 return *this;
96 }
97
98 inline size_t k() const {
99 return this->k_;
100 }
101
102 inline GemmMicrokernelTester& ks(size_t ks) {
103 this->ks_ = ks;
104 return *this;
105 }
106
107 inline size_t ks() const {
108 return this->ks_;
109 }
110
111 inline size_t packed_k() const {
112 return k() % kr() == 0 ? k() : (k() / kr() + 1) * kr();
113 }
114
115 inline size_t packed_n() const {
116 return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
117 }
118
119 inline size_t bias_n() const {
120 return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
121 }
122
123 inline GemmMicrokernelTester& a_stride(size_t a_stride) {
124 this->a_stride_ = a_stride;
125 return *this;
126 }
127
128 inline size_t a_stride() const {
129 return this->a_stride_ == 0 ? k() : this->a_stride_;
130 }
131
132 inline GemmMicrokernelTester& cm_stride(size_t cm_stride) {
133 this->cm_stride_ = cm_stride;
134 return *this;
135 }
136
137 inline size_t cm_stride() const {
138 return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_;
139 }
140
141 inline GemmMicrokernelTester& cn_stride(size_t cn_stride) {
142 this->cn_stride_ = cn_stride;
143 return *this;
144 }
145
146 inline size_t cn_stride() const {
147 return this->cn_stride_ == 0 ? nr() : this->cn_stride_;
148 }
149
150 inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) {
151 this->a_zero_point_ = a_zero_point;
152 return *this;
153 }
154
155 inline uint8_t a_zero_point() const {
156 return this->a_zero_point_;
157 }
158
159 inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) {
160 this->b_zero_point_ = b_zero_point;
161 return *this;
162 }
163
164 inline uint8_t b_zero_point() const {
165 return this->b_zero_point_;
166 }
167
168 inline GemmMicrokernelTester& qmin(uint8_t qmin) {
169 this->qmin_ = qmin;
170 return *this;
171 }
172
173 inline uint8_t qmin() const {
174 return this->qmin_;
175 }
176
177 inline GemmMicrokernelTester& qmax(uint8_t qmax) {
178 this->qmax_ = qmax;
179 return *this;
180 }
181
182 inline uint8_t qmax() const {
183 return this->qmax_;
184 }
185
186 inline GemmMicrokernelTester& a_offset(size_t a_offset) {
187 this->a_offset_ = a_offset;
188 return *this;
189 }
190
191 inline size_t a_offset() const {
192 return this->a_offset_;
193 }
194
195 inline GemmMicrokernelTester& zero_index(size_t zero_index) {
196 this->zero_index_ = zero_index;
197 return *this;
198 }
199
200 inline size_t zero_index() const {
201 return this->zero_index_;
202 }
203
204 inline GemmMicrokernelTester& iterations(size_t iterations) {
205 this->iterations_ = iterations;
206 return *this;
207 }
208
209 inline size_t iterations() const {
210 return this->iterations_;
211 }
212
213 void Test(xnn_q8_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
214 ASSERT_LE(m(), mr());
215
216 std::random_device random_device;
217 auto rng = std::mt19937(random_device());
218 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
219 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
220
221 std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
222 std::vector<uint8_t> b(n() * k());
223 std::vector<int32_t> bias(n());
224 std::vector<uint8_t, AlignedAllocator<uint8_t, 32>> packed_w(packed_n() * packed_k() + bias_n() * sizeof(uint32_t) / sizeof(uint8_t));
225 std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
226 std::vector<int32_t> acc(m() * n());
227 std::vector<uint8_t> c_ref(m() * n());
228
229 for (size_t iteration = 0; iteration < iterations(); iteration++) {
230 do {
231 std::generate(a.begin(), a.end(), std::ref(u8rng));
232 } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
233 do {
234 std::generate(b.begin(), b.end(), std::ref(u8rng));
235 } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
236 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
237 std::fill(c.begin(), c.end(), 0xA5);
238
239 std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
240 xnn_pack_q8_gemm_goi_w(1, n(), k(), nr(), kr(),
241 a_zero_point(), b_zero_point(),
242 b.data(), bias.data(), packed_w.data());
243
244 // Compute 32-bit results and output quantization arguments.
245 std::fill(acc.begin(), acc.end(), 0);
246 for (size_t m_index = 0; m_index < m(); m_index++) {
247 for (size_t n_index = 0; n_index < n(); n_index++) {
248 for (size_t k_index = 0; k_index < k(); k_index++) {
249 acc[m_index * n() + n_index] +=
250 (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
251 (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
252 }
253 acc[m_index * n() + n_index] += bias[n_index];
254 }
255 }
256
257 const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
258 const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
259 const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
260 const uint8_t c_zero_point = uint8_t(std::max(std::min(
261 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
262 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
263
264 const float requantization_scale = 1.0f / float(c_scale);
265 union xnn_q8_gemm_params quantization_params = { };
266 switch (variant) {
267 case Variant::Native:
268 quantization_params = xnn_compute_q8_gemm_params(
269 a_zero_point(), b_zero_point(),
270 requantization_scale, c_zero_point, qmin(), qmax());
271 break;
272 case Variant::Scalar:
273 quantization_params = xnn_compute_scalar_q8_gemm_params(
274 a_zero_point(), b_zero_point(),
275 requantization_scale, c_zero_point, qmin(), qmax());
276 break;
277 }
278 const union xnn_q31_requantization_params scalar_requantization_params =
279 xnn_compute_scalar_requantization_params(
280 requantization_scale, c_zero_point, qmin(), qmax());
281
282 gemm(
283 m(), n(), k(),
284 a.data(), a_stride() * sizeof(uint8_t),
285 packed_w.data(),
286 c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
287 &quantization_params);
288
289 for (size_t m_index = 0; m_index < m(); m_index++) {
290 for (size_t n_index = 0; n_index < n(); n_index++) {
291 c_ref[m_index * n() + n_index] = xnn_q31_requantize(acc[m_index * n() + n_index], scalar_requantization_params);
292 }
293 }
294
295 for (size_t i = 0; i < m(); i++) {
296 for (size_t j = 0; j < n(); j++) {
297 ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
298 ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
299 ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
300 << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
301 << " (accumulator = " << acc[i * n() + j]
302 << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
303 << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
304 << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
305 }
306 }
307 }
308 }
309
310 void Test(xnn_q8_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
311 ASSERT_LE(m(), mr());
312
313 std::random_device random_device;
314 auto rng = std::mt19937(random_device());
315 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
316 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
317
318 std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
319 std::vector<uint8_t> b(n() * ks() * k());
320 std::vector<uint8_t, AlignedAllocator<uint8_t, 32>> packed_w(ks() * packed_n() * packed_k() + bias_n() * sizeof(uint32_t) / sizeof(uint8_t));
321 std::vector<int32_t> bias(n());
322 std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
323 std::vector<int32_t> acc(m() * n());
324 std::vector<uint8_t> c_ref(m() * n());
325 std::vector<uint8_t> junk(k() + 8);
326 std::vector<const uint8_t*> im2col(mr() * ks());
327
328 std::fill(junk.begin(), junk.end(), 0xA5);
329
330 for (size_t iteration = 0; iteration < iterations(); iteration++) {
331 do {
332 std::generate(a.begin(), a.end(), std::ref(u8rng));
333 } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
334 do {
335 std::generate(b.begin(), b.end(), std::ref(u8rng));
336 } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
337 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
338 std::fill(c.begin(), c.end(), 0xA5);
339
340 std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
341 xnn_pack_q8_conv_goki_w(
342 1, n(), ks(), k(), nr(), kr(),
343 a_zero_point(), b_zero_point(),
344 b.data(), bias.data(), packed_w.data());
345
346 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
347 for (size_t m_index = 0; m_index < mr(); m_index++) {
348 im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
349 }
350
351 }
352 std::shuffle(im2col.begin(), im2col.end(), rng);
353 if (zero_index() != SIZE_MAX) {
354 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
355 im2col[ks_index * mr() + zero_index()] = a.data();
356 }
357 }
358 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
359 for (size_t m_index = m(); m_index < mr(); m_index++) {
360 im2col[ks_index * mr() + m_index] = junk.data();
361 }
362 }
363
364 // Compute 32-bit results and output quantization arguments.
365 std::fill(acc.begin(), acc.end(), 0);
366 for (size_t m_index = 0; m_index < m(); m_index++) {
367 for (size_t n_index = 0; n_index < n(); n_index++) {
368 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
369 for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
370 for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
371 if (im2col[ks_index * mr() + m_index] == a.data()) {
372 acc[m_index * n() + n_index] +=
373 (int32_t(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset]) - int32_t(a_zero_point())) *
374 (int32_t(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]) - int32_t(b_zero_point()));
375 } else {
376 acc[m_index * n() + n_index] +=
377 (int32_t(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset + a_offset()]) - int32_t(a_zero_point())) *
378 (int32_t(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]) - int32_t(b_zero_point()));
379 }
380 }
381 }
382 }
383 acc[m_index * n() + n_index] += bias[n_index];
384 }
385 }
386
387 const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
388 const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
389 const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
390 const uint8_t c_zero_point = uint8_t(std::max(std::min(
391 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
392 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
393
394 const float requantization_scale = 1.0f / float(c_scale);
395 union xnn_q8_gemm_params quantization_params = { };
396 switch (variant) {
397 case Variant::Native:
398 quantization_params = xnn_compute_q8_gemm_params(
399 a_zero_point(), b_zero_point(),
400 requantization_scale, c_zero_point, qmin(), qmax());
401 break;
402 case Variant::Scalar:
403 quantization_params = xnn_compute_scalar_q8_gemm_params(
404 a_zero_point(), b_zero_point(),
405 requantization_scale, c_zero_point, qmin(), qmax());
406 break;
407 }
408 const union xnn_q31_requantization_params scalar_requantization_params =
409 xnn_compute_scalar_requantization_params(
410 requantization_scale, c_zero_point, qmin(), qmax());
411
412 const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
413
414 igemm(
415 m(), n(), k(), ks() * mr() * sizeof(void*),
416 im2col.data(), packed_w.data(),
417 c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
418 a_offset() * sizeof(uint8_t), zero_pointer,
419 &quantization_params);
420
421 for (size_t m_index = 0; m_index < m(); m_index++) {
422 for (size_t n_index = 0; n_index < n(); n_index++) {
423 c_ref[m_index * n() + n_index] = xnn_q31_requantize(acc[m_index * n() + n_index], scalar_requantization_params);
424 }
425 }
426
427 for (size_t i = 0; i < m(); i++) {
428 for (size_t j = 0; j < n(); j++) {
429 ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
430 ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
431 ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
432 << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
433 << " (accumulator = " << acc[i * n() + j]
434 << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
435 << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
436 << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
437 }
438 }
439 }
440 }
441
442 void Test(xnn_f16_gemm_ukernel_function gemm, Variant variant = Variant::Native) const
443 {
444 ASSERT_LE(m(), mr());
445 ASSERT_GE(a_stride(), k());
446 ASSERT_GE(cm_stride(), n());
447
448 std::random_device random_device;
449 auto rng = std::mt19937(random_device());
450 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
451 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
452
453 std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
454 std::vector<uint16_t> b(n() * k());
455 std::vector<uint16_t, AlignedAllocator<uint16_t, 32>> packed_w(packed_n() * packed_k() + bias_n());
456 std::vector<uint16_t> bias(n());
457 std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
458 std::vector<float> c_ref(m() * n());
459
460 xnn_f16_output_params output_params;
461 output_params.scale = UINT16_C(0x3C00) /* 1.0 */;
462
463 for (size_t iteration = 0; iteration < iterations(); iteration++) {
464 std::generate(a.begin(), a.end(), std::ref(f16rng));
465 std::generate(b.begin(), b.end(), std::ref(f16rng));
466 std::generate(bias.begin(), bias.end(), std::ref(f16rng));
467 std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
468 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
469
470 std::fill(packed_w.begin(), packed_w.end(), 0);
471 xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), b.data(), bias.data(), packed_w.data());
472
473 for (size_t m_index = 0; m_index < m(); m_index++) {
474 for (size_t n_index = 0; n_index < n(); n_index++) {
475 for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
476 for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
477 ASSERT_LE(n(), packed_n());
478 ASSERT_LT(m_index * n() + n_index, c_ref.size());
479 ASSERT_LT(m_index * k() + k_block_start + k_block_offset, a.size());
480
481 c_ref[m_index * n() + n_index] +=
482 fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_block_start + k_block_offset]) *
483 fp16_ieee_to_fp32_value(b[n_index * k() + k_block_start + k_block_offset]);
484 }
485 }
486 c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
487 }
488 }
489
490 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
491 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
492 const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
493 const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
494 output_params.max = fp16_ieee_from_fp32_value(c_max);
495 output_params.min = fp16_ieee_from_fp32_value(c_min);
496
497 for (float& c_value : c_ref) {
498 c_value = std::max(std::min(c_value, c_max), c_min);
499 }
500
501 gemm(m(), n(), k() * sizeof(uint16_t),
502 a.data(), a_stride() * sizeof(uint16_t),
503 packed_w.data(),
504 c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
505 &output_params);
506
507 // Validate micro-kernel outputs.
508 for (size_t i = 0; i < m(); i++) {
509 for (size_t j = 0; j < n(); j++) {
510 ASSERT_NEAR(
511 fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]),
512 c_ref[i * n() + j],
513 std::abs(c_ref[i * n() + j]) * 1.0e-2f)
514 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
515 << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
516 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
517 }
518 }
519 }
520 }
521
522 void Test(xnn_f32_ppmm_ukernel_function ppmm, Variant variant = Variant::Native) const {
523 ASSERT_LE(m(), mr());
524 ASSERT_GE(cm_stride(), n());
525
526 std::random_device random_device;
527 auto rng = std::mt19937(random_device());
528 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
529
530 std::vector<float> a(packed_k() * mr());
531 std::vector<float> b(n() * k());
532 std::vector<float> bias(n());
533 std::vector<float, AlignedAllocator<float, 32>> packed_w(packed_n() * packed_k() + bias_n());
534 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
535 std::vector<float> c_ref(m() * n());
536
537 for (size_t iteration = 0; iteration < iterations(); iteration++) {
538 std::generate(a.begin(), a.end(), std::ref(f32rng));
539 std::generate(b.begin(), b.end(), std::ref(f32rng));
540 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
541 std::fill(c.begin(), c.end(), nanf(""));
542 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
543
544 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
545 xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data());
546
547 for (size_t i = m(); i < mr(); i++) {
548 for (size_t l = 0; l < k(); l++) {
549 a[l * mr() + i] = a[l * mr() + m() - 1];
550 }
551 }
552
553 for (size_t i = 0; i < m(); i++) {
554 for (size_t j = 0; j < n(); j++) {
555 for (size_t l = 0; l < k(); l++) {
556 c_ref[i * n() + j] +=
557 a[l * mr() + i] *
558 b[j * k() + l];
559 }
560 c_ref[i * n() + j] += bias[j];
561 }
562 }
563
564 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
565 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
566 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
567 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
568
569 // Prepare output parameters.
570 xnn_f32_output_params output_params = { };
571 switch (variant) {
572 case Variant::Native:
573 output_params = xnn_compute_f32_output_params(c_min, c_max);
574 break;
575 case Variant::Scalar:
576 output_params = xnn_compute_scalar_f32_output_params(c_min, c_max);
577 break;
578 }
579
580 for (float& c_value : c_ref) {
581 c_value = std::max(std::min(c_value, c_max), c_min);
582 }
583
584 ppmm(m(), n(), k() * sizeof(float),
585 a.data(), packed_w.data(),
586 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
587 &output_params);
588
589 // Validate micro-kernel outputs.
590 for (size_t i = 0; i < m(); i++) {
591 for (size_t j = 0; j < n(); j++) {
592 ASSERT_NEAR(
593 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
594 c_ref[i * n() + j],
595 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
596 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
597 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
598 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
599 }
600 }
601 }
602 }
603
604 void Test(xnn_f32_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
605 ASSERT_LE(m(), mr());
606 ASSERT_GE(a_stride(), k());
607 ASSERT_GE(cm_stride(), n());
608
609 std::random_device random_device;
610 auto rng = std::mt19937(random_device());
611 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
612
613 std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
614 std::vector<float> b(n() * k());
615 std::vector<float> bias(n());
616 std::vector<float, AlignedAllocator<float, 32>> packed_w(packed_n() * packed_k() + bias_n());
617 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
618 std::vector<float> c_ref(m() * n());
619
620 for (size_t iteration = 0; iteration < iterations(); iteration++) {
621 std::generate(a.begin(), a.end(), std::ref(f32rng));
622 std::generate(b.begin(), b.end(), std::ref(f32rng));
623 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
624 std::fill(c.begin(), c.end(), nanf(""));
625 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
626
627 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
628 xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data());
629
630 for (size_t m_index = 0; m_index < m(); m_index++) {
631 for (size_t n_index = 0; n_index < n(); n_index++) {
632 for (size_t k_index = 0; k_index < k(); k_index++) {
633 ASSERT_LE(n(), packed_n());
634 ASSERT_LT(m_index * n() + n_index, c_ref.size());
635 c_ref[m_index * n() + n_index] +=
636 a[m_index * a_stride() + k_index] *
637 b[n_index * k() + k_index];
638 }
639 c_ref[m_index * n() + n_index] += bias[n_index];
640 }
641 }
642
643 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
644 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
645 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
646 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
647
648 // Prepare output parameters.
649 xnn_f32_output_params output_params = { };
650 switch (variant) {
651 case Variant::Native:
652 output_params = xnn_compute_f32_output_params(c_min, c_max);
653 break;
654 case Variant::Scalar:
655 output_params = xnn_compute_scalar_f32_output_params(c_min, c_max);
656 break;
657 }
658
659 for (size_t m_index = 0; m_index < m(); m_index++) {
660 for (size_t n_index = 0; n_index < n(); n_index++) {
661 c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
662 }
663 }
664
665 gemm(m(), n(), k() * sizeof(float),
666 a.data(), a_stride() * sizeof(float),
667 packed_w.data(),
668 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
669 &output_params);
670
671 // Validate micro-kernel outputs.
672 for (size_t i = 0; i < m(); i++) {
673 for (size_t j = 0; j < n(); j++) {
674 ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
675 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
676 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
677 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
678 ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
679 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
680 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
681 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
682 ASSERT_NEAR(
683 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
684 c_ref[i * n() + j],
685 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
686 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
687 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
688 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
689 }
690 }
691 }
692 }
693
694 void Test(xnn_f32_gemminc_ukernel_function gemminc, Variant variant = Variant::Native) const {
695 ASSERT_LE(m(), mr());
696 ASSERT_GE(a_stride(), k());
697 ASSERT_GE(cm_stride(), n());
698
699 std::random_device random_device;
700 auto rng = std::mt19937(random_device());
701 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
702
703 std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
704 std::vector<float> b(n() * k());
705 std::vector<float> bias(n());
706 std::vector<float, AlignedAllocator<float, 32>> packed_w(packed_n() * packed_k()); // no bias_n()
707 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
708 std::vector<float> c_ref(m() * n());
709 std::vector<float, AlignedAllocator<float, 32>> acc(mr() * packed_n());
710
711 for (size_t iteration = 0; iteration < iterations(); iteration++) {
712 std::generate(a.begin(), a.end(), std::ref(f32rng));
713 std::generate(b.begin(), b.end(), std::ref(f32rng));
714 std::fill(c.begin(), c.end(), nanf(""));
715 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
716 std::generate(acc.begin(), acc.end(), std::ref(f32rng));
717
718 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
719 xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data());
720
721 for (size_t m_index = 0; m_index < m(); m_index++) {
722 for (size_t n_index = 0; n_index < n(); n_index++) {
723 for (size_t k_index = 0; k_index < k(); k_index++) {
724 ASSERT_LE(n(), packed_n());
725 ASSERT_LT(m_index * n() + n_index, c_ref.size());
726 c_ref[m_index * n() + n_index] +=
727 a[m_index * a_stride() + k_index] *
728 b[n_index * k() + k_index];
729 }
730 c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
731 }
732 }
733
734 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
735 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
736 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
737 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
738
739 // Prepare output parameters.
740 xnn_f32_output_params output_params = { };
741 switch (variant) {
742 case Variant::Native:
743 output_params = xnn_compute_f32_output_params(c_min, c_max);
744 break;
745 case Variant::Scalar:
746 output_params = xnn_compute_scalar_f32_output_params(c_min, c_max);
747 break;
748 }
749
750 for (size_t m_index = 0; m_index < m(); m_index++) {
751 for (size_t n_index = 0; n_index < n(); n_index++) {
752 c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
753 }
754 }
755
756 gemminc(m(), n(), k() * sizeof(float),
757 a.data(), a_stride() * sizeof(float),
758 packed_w.data(),
759 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
760 acc.data(),
761 &output_params);
762
763 // Validate micro-kernel outputs.
764 for (size_t i = 0; i < m(); i++) {
765 for (size_t j = 0; j < n(); j++) {
766 ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
767 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
768 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
769 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
770 ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
771 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
772 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
773 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
774 ASSERT_NEAR(
775 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
776 c_ref[i * n() + j],
777 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
778 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
779 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
780 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
781 }
782 }
783 }
784 }
785
786 void Test(xnn_f32_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
787 ASSERT_LE(m(), mr());
788
789 std::random_device random_device;
790 auto rng = std::mt19937(random_device());
791 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
792
793 std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
794 std::vector<float> b(n() * ks() * k());
795 std::vector<float, AlignedAllocator<float, 32>> packed_w(ks() * packed_k() * packed_n() + bias_n());
796 std::vector<float> bias(n());
797 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
798 std::vector<float> c_ref(m() * n());
799 std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
800 std::vector<const float*> im2col(mr() * ks());
801 std::fill(junk.begin(), junk.end(), nanf(""));
802
803 for (size_t iteration = 0; iteration < iterations(); iteration++) {
804 std::generate(a.begin(), a.end(), std::ref(f32rng));
805 std::generate(b.begin(), b.end(), std::ref(f32rng));
806 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
807 std::fill(c.begin(), c.end(), nanf(""));
808 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
809
810 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
811 xnn_pack_f32_conv_goki_w(
812 1, n(), ks(), k(), nr(), kr(), sr(),
813 b.data(), bias.data(), packed_w.data());
814
815 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
816 for (size_t m_index = 0; m_index < mr(); m_index++) {
817 im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
818 }
819 }
820 std::shuffle(im2col.begin(), im2col.end(), rng);
821 if (zero_index() != SIZE_MAX) {
822 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
823 im2col[ks_index * mr() + zero_index()] = a.data();
824 }
825 }
826 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
827 for (size_t m_index = m(); m_index < mr(); m_index++) {
828 im2col[ks_index * mr() + m_index] = junk.data();
829 }
830 }
831
832 std::fill(c_ref.begin(), c_ref.end(), 0.0);
833 for (size_t m_index = 0; m_index < m(); m_index++) {
834 for (size_t n_index = 0; n_index < n(); n_index++) {
835 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
836 for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
837 for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
838 ASSERT_LT(ks_index * mr() + m_index, im2col.size());
839 ASSERT_LT(k_block_start + k_block_offset, k());
840 ASSERT_LT(k_block_start + k_block_offset, a_stride());
841 if (im2col[ks_index * mr() + m_index] == a.data()) {
842 c_ref[m_index * n() + n_index] +=
843 double(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset]) *
844 double(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]);
845 } else {
846 c_ref[m_index * n() + n_index] +=
847 double(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset + a_offset()]) *
848 double(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]);
849 }
850 }
851 }
852 }
853 c_ref[m_index * n() + n_index] += bias[n_index];
854 }
855 }
856
857 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
858 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
859 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
860 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
861 for (size_t m_index = 0; m_index < m(); m_index++) {
862 for (size_t n_index = 0; n_index < n(); n_index++) {
863 c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
864 c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
865 }
866 }
867
868 // Prepare output parameters.
869 xnn_f32_output_params output_params = { };
870 switch (variant) {
871 case Variant::Native:
872 output_params = xnn_compute_f32_output_params(c_min, c_max);
873 break;
874 case Variant::Scalar:
875 output_params = xnn_compute_scalar_f32_output_params(c_min, c_max);
876 break;
877 }
878
879 const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
880
881 igemm(
882 m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
883 im2col.data(), packed_w.data(),
884 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
885 a_offset() * sizeof(float), zero_pointer,
886 &output_params);
887
888 for (size_t i = 0; i < m(); i++) {
889 for (size_t j = 0; j < n(); j++) {
890 ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
891 << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
892 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
893 << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
894 ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
895 << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
896 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
897 << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
898 ASSERT_NEAR(
899 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
900 c_ref[i * n() + j],
901 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
902 << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
903 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
904 << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
905 }
906 }
907 }
908 }
909
910 private:
911 size_t mr_{1};
912 size_t nr_{1};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700913 size_t kr_{1};
914 size_t sr_{1};
915 size_t m_{1};
916 size_t n_{1};
917 size_t k_{1};
918 size_t ks_{1};
919 size_t a_stride_{0};
920 size_t cm_stride_{0};
921 size_t cn_stride_{0};
922 uint8_t a_zero_point_{127};
923 uint8_t b_zero_point_{127};
924 uint8_t qmin_{0};
925 uint8_t qmax_{255};
926 size_t a_offset_{0};
927 size_t zero_index_{SIZE_MAX};
928 size_t iterations_{15};
929};