blob: b47955fe0ad1d8edefa4291e9ed67ea69752e552 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cmath>
16#include <cstddef>
17#include <cstdlib>
18#include <functional>
19#include <random>
20#include <vector>
21
22#include <fp16.h>
23
24#include <xnnpack.h>
25#include <xnnpack/AlignedAllocator.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070026#include <xnnpack/pack.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070027#include <xnnpack/params-init.h>
28#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070029#include <xnnpack/requantization.h>
30
31
32class GemmMicrokernelTester {
33 public:
34 enum class Variant {
35 Native,
36 Scalar,
37 };
38
39 inline GemmMicrokernelTester& mr(size_t mr) {
40 this->mr_ = mr;
41 return *this;
42 }
43
44 inline size_t mr() const {
45 return this->mr_;
46 }
47
48 inline GemmMicrokernelTester& nr(size_t nr) {
49 this->nr_ = nr;
50 return *this;
51 }
52
53 inline size_t nr() const {
54 return this->nr_;
55 }
56
57
58 inline GemmMicrokernelTester& kr(size_t kr) {
59 this->kr_ = kr;
60 return *this;
61 }
62
63 inline size_t kr() const {
64 return this->kr_;
65 }
66
67 inline GemmMicrokernelTester& sr(size_t sr) {
68 this->sr_ = sr;
69 return *this;
70 }
71
72 inline size_t sr() const {
73 return this->sr_;
74 }
75
76 inline GemmMicrokernelTester& m(size_t m) {
77 this->m_ = m;
78 return *this;
79 }
80
81 inline size_t m() const {
82 return this->m_;
83 }
84
85 inline GemmMicrokernelTester& n(size_t n) {
86 this->n_ = n;
87 return *this;
88 }
89
90 inline size_t n() const {
91 return this->n_;
92 }
93
94 inline GemmMicrokernelTester& k(size_t k) {
95 this->k_ = k;
96 return *this;
97 }
98
99 inline size_t k() const {
100 return this->k_;
101 }
102
103 inline GemmMicrokernelTester& ks(size_t ks) {
104 this->ks_ = ks;
105 return *this;
106 }
107
108 inline size_t ks() const {
109 return this->ks_;
110 }
111
112 inline size_t packed_k() const {
113 return k() % kr() == 0 ? k() : (k() / kr() + 1) * kr();
114 }
115
116 inline size_t packed_n() const {
117 return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
118 }
119
120 inline size_t bias_n() const {
121 return n() % nr() == 0 ? n() : (n() / nr() + 1) * nr();
122 }
123
124 inline GemmMicrokernelTester& a_stride(size_t a_stride) {
125 this->a_stride_ = a_stride;
126 return *this;
127 }
128
129 inline size_t a_stride() const {
130 return this->a_stride_ == 0 ? k() : this->a_stride_;
131 }
132
133 inline GemmMicrokernelTester& cm_stride(size_t cm_stride) {
134 this->cm_stride_ = cm_stride;
135 return *this;
136 }
137
138 inline size_t cm_stride() const {
139 return this->cm_stride_ == 0 ? cn_stride() * ((n() - 1) / nr()) + (n() - 1) % nr() + 1 : this->cm_stride_;
140 }
141
142 inline GemmMicrokernelTester& cn_stride(size_t cn_stride) {
143 this->cn_stride_ = cn_stride;
144 return *this;
145 }
146
147 inline size_t cn_stride() const {
148 return this->cn_stride_ == 0 ? nr() : this->cn_stride_;
149 }
150
151 inline GemmMicrokernelTester& a_zero_point(uint8_t a_zero_point) {
152 this->a_zero_point_ = a_zero_point;
153 return *this;
154 }
155
156 inline uint8_t a_zero_point() const {
157 return this->a_zero_point_;
158 }
159
160 inline GemmMicrokernelTester& b_zero_point(uint8_t b_zero_point) {
161 this->b_zero_point_ = b_zero_point;
162 return *this;
163 }
164
165 inline uint8_t b_zero_point() const {
166 return this->b_zero_point_;
167 }
168
169 inline GemmMicrokernelTester& qmin(uint8_t qmin) {
170 this->qmin_ = qmin;
171 return *this;
172 }
173
174 inline uint8_t qmin() const {
175 return this->qmin_;
176 }
177
178 inline GemmMicrokernelTester& qmax(uint8_t qmax) {
179 this->qmax_ = qmax;
180 return *this;
181 }
182
183 inline uint8_t qmax() const {
184 return this->qmax_;
185 }
186
187 inline GemmMicrokernelTester& a_offset(size_t a_offset) {
188 this->a_offset_ = a_offset;
189 return *this;
190 }
191
192 inline size_t a_offset() const {
193 return this->a_offset_;
194 }
195
196 inline GemmMicrokernelTester& zero_index(size_t zero_index) {
197 this->zero_index_ = zero_index;
198 return *this;
199 }
200
201 inline size_t zero_index() const {
202 return this->zero_index_;
203 }
204
205 inline GemmMicrokernelTester& iterations(size_t iterations) {
206 this->iterations_ = iterations;
207 return *this;
208 }
209
210 inline size_t iterations() const {
211 return this->iterations_;
212 }
213
214 void Test(xnn_q8_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
215 ASSERT_LE(m(), mr());
216
217 std::random_device random_device;
218 auto rng = std::mt19937(random_device());
219 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
220 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
221
222 std::vector<uint8_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
223 std::vector<uint8_t> b(n() * k());
224 std::vector<int32_t> bias(n());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800225 std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(packed_n() * packed_k() + bias_n() * sizeof(uint32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700226 std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
227 std::vector<int32_t> acc(m() * n());
228 std::vector<uint8_t> c_ref(m() * n());
229
230 for (size_t iteration = 0; iteration < iterations(); iteration++) {
231 do {
232 std::generate(a.begin(), a.end(), std::ref(u8rng));
233 } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
234 do {
235 std::generate(b.begin(), b.end(), std::ref(u8rng));
236 } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
237 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
238 std::fill(c.begin(), c.end(), 0xA5);
239
240 std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
241 xnn_pack_q8_gemm_goi_w(1, n(), k(), nr(), kr(),
242 a_zero_point(), b_zero_point(),
243 b.data(), bias.data(), packed_w.data());
244
245 // Compute 32-bit results and output quantization arguments.
246 std::fill(acc.begin(), acc.end(), 0);
247 for (size_t m_index = 0; m_index < m(); m_index++) {
248 for (size_t n_index = 0; n_index < n(); n_index++) {
249 for (size_t k_index = 0; k_index < k(); k_index++) {
250 acc[m_index * n() + n_index] +=
251 (int32_t(a[m_index * a_stride() + k_index]) - int32_t(a_zero_point())) *
252 (int32_t(b[n_index * k() + k_index]) - int32_t(b_zero_point()));
253 }
254 acc[m_index * n() + n_index] += bias[n_index];
255 }
256 }
257
258 const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
259 const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
260 const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
261 const uint8_t c_zero_point = uint8_t(std::max(std::min(
262 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
263 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
264
265 const float requantization_scale = 1.0f / float(c_scale);
266 union xnn_q8_gemm_params quantization_params = { };
267 switch (variant) {
268 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700269 quantization_params = xnn_init_q8_gemm_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700270 a_zero_point(), b_zero_point(),
271 requantization_scale, c_zero_point, qmin(), qmax());
272 break;
273 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700274 quantization_params = xnn_init_scalar_q8_gemm_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700275 a_zero_point(), b_zero_point(),
276 requantization_scale, c_zero_point, qmin(), qmax());
277 break;
278 }
279 const union xnn_q31_requantization_params scalar_requantization_params =
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700280 xnn_init_scalar_requantization_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700281 requantization_scale, c_zero_point, qmin(), qmax());
282
283 gemm(
284 m(), n(), k(),
285 a.data(), a_stride() * sizeof(uint8_t),
286 packed_w.data(),
287 c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
288 &quantization_params);
289
290 for (size_t m_index = 0; m_index < m(); m_index++) {
291 for (size_t n_index = 0; n_index < n(); n_index++) {
292 c_ref[m_index * n() + n_index] = xnn_q31_requantize(acc[m_index * n() + n_index], scalar_requantization_params);
293 }
294 }
295
296 for (size_t i = 0; i < m(); i++) {
297 for (size_t j = 0; j < n(); j++) {
298 ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
299 ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
300 ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
301 << "at " << i << ", " << j << ": reference = " << (uint32_t) c_ref[i * n() + j]
302 << " (accumulator = " << acc[i * n() + j]
303 << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
304 << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
305 << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
306 }
307 }
308 }
309 }
310
311 void Test(xnn_q8_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
312 ASSERT_LE(m(), mr());
313
314 std::random_device random_device;
315 auto rng = std::mt19937(random_device());
316 auto s32rng = std::bind(std::uniform_int_distribution<int32_t>(-10000, 10000), rng);
317 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
318
319 std::vector<uint8_t> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint8_t));
320 std::vector<uint8_t> b(n() * ks() * k());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800321 std::vector<uint8_t, AlignedAllocator<uint8_t, 64>> packed_w(ks() * packed_n() * packed_k() + bias_n() * sizeof(uint32_t) / sizeof(uint8_t));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700322 std::vector<int32_t> bias(n());
323 std::vector<uint8_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
324 std::vector<int32_t> acc(m() * n());
325 std::vector<uint8_t> c_ref(m() * n());
326 std::vector<uint8_t> junk(k() + 8);
327 std::vector<const uint8_t*> im2col(mr() * ks());
328
329 std::fill(junk.begin(), junk.end(), 0xA5);
330
331 for (size_t iteration = 0; iteration < iterations(); iteration++) {
332 do {
333 std::generate(a.begin(), a.end(), std::ref(u8rng));
334 } while (a.size() > 1 && *std::max_element(a.cbegin(), a.cend()) == *std::min_element(a.cbegin(), a.cend()));
335 do {
336 std::generate(b.begin(), b.end(), std::ref(u8rng));
337 } while (b.size() > 1 && *std::max_element(b.cbegin(), b.cend()) == *std::min_element(b.cbegin(), b.cend()));
338 std::generate(bias.begin(), bias.end(), std::ref(s32rng));
339 std::fill(c.begin(), c.end(), 0xA5);
340
341 std::fill(packed_w.begin(), packed_w.end(), b_zero_point());
342 xnn_pack_q8_conv_goki_w(
343 1, n(), ks(), k(), nr(), kr(),
344 a_zero_point(), b_zero_point(),
345 b.data(), bias.data(), packed_w.data());
346
347 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
348 for (size_t m_index = 0; m_index < mr(); m_index++) {
349 im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
350 }
351
352 }
353 std::shuffle(im2col.begin(), im2col.end(), rng);
354 if (zero_index() != SIZE_MAX) {
355 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
356 im2col[ks_index * mr() + zero_index()] = a.data();
357 }
358 }
359 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
360 for (size_t m_index = m(); m_index < mr(); m_index++) {
361 im2col[ks_index * mr() + m_index] = junk.data();
362 }
363 }
364
365 // Compute 32-bit results and output quantization arguments.
366 std::fill(acc.begin(), acc.end(), 0);
367 for (size_t m_index = 0; m_index < m(); m_index++) {
368 for (size_t n_index = 0; n_index < n(); n_index++) {
369 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
370 for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
371 for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
372 if (im2col[ks_index * mr() + m_index] == a.data()) {
373 acc[m_index * n() + n_index] +=
374 (int32_t(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset]) - int32_t(a_zero_point())) *
375 (int32_t(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]) - int32_t(b_zero_point()));
376 } else {
377 acc[m_index * n() + n_index] +=
378 (int32_t(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset + a_offset()]) - int32_t(a_zero_point())) *
379 (int32_t(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]) - int32_t(b_zero_point()));
380 }
381 }
382 }
383 }
384 acc[m_index * n() + n_index] += bias[n_index];
385 }
386 }
387
388 const int32_t accumulated_min = *std::min_element(acc.cbegin(), acc.cend());
389 const int32_t accumulated_max = *std::max_element(acc.cbegin(), acc.cend());
390 const double c_scale = uint32_t(accumulated_max - accumulated_min) >= 256 ? double(uint32_t(accumulated_max - accumulated_min)) / 255.0 : 1.00001;
391 const uint8_t c_zero_point = uint8_t(std::max(std::min(
392 lrint(127.5 - 0.5 * double(accumulated_min + accumulated_max) / c_scale),
393 long(std::numeric_limits<uint8_t>::max())), long(std::numeric_limits<uint8_t>::min())));
394
395 const float requantization_scale = 1.0f / float(c_scale);
396 union xnn_q8_gemm_params quantization_params = { };
397 switch (variant) {
398 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700399 quantization_params = xnn_init_q8_gemm_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700400 a_zero_point(), b_zero_point(),
401 requantization_scale, c_zero_point, qmin(), qmax());
402 break;
403 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700404 quantization_params = xnn_init_scalar_q8_gemm_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700405 a_zero_point(), b_zero_point(),
406 requantization_scale, c_zero_point, qmin(), qmax());
407 break;
408 }
409 const union xnn_q31_requantization_params scalar_requantization_params =
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700410 xnn_init_scalar_requantization_params(
XNNPACK Teamb455b122019-09-27 18:10:33 -0700411 requantization_scale, c_zero_point, qmin(), qmax());
412
413 const uint8_t* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
414
415 igemm(
416 m(), n(), k(), ks() * mr() * sizeof(void*),
417 im2col.data(), packed_w.data(),
418 c.data(), cm_stride() * sizeof(uint8_t), cn_stride() * sizeof(uint8_t),
419 a_offset() * sizeof(uint8_t), zero_pointer,
420 &quantization_params);
421
422 for (size_t m_index = 0; m_index < m(); m_index++) {
423 for (size_t n_index = 0; n_index < n(); n_index++) {
424 c_ref[m_index * n() + n_index] = xnn_q31_requantize(acc[m_index * n() + n_index], scalar_requantization_params);
425 }
426 }
427
428 for (size_t i = 0; i < m(); i++) {
429 for (size_t j = 0; j < n(); j++) {
430 ASSERT_LE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmax()));
431 ASSERT_GE(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(qmin()));
432 ASSERT_EQ(uint32_t(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]), uint32_t(c_ref[i * n() + j]))
433 << "at " << i << ", " << j << ": reference = " << uint32_t(c_ref[i * n() + j])
434 << " (accumulator = " << acc[i * n() + j]
435 << "), optimized = " << (uint32_t) c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x "
436 << nr() << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k()
437 << ", requantization scale = " << requantization_scale << ", output zero point = " << int32_t(c_zero_point);
438 }
439 }
440 }
441 }
442
443 void Test(xnn_f16_gemm_ukernel_function gemm, Variant variant = Variant::Native) const
444 {
445 ASSERT_LE(m(), mr());
446 ASSERT_GE(a_stride(), k());
447 ASSERT_GE(cm_stride(), n());
448
449 std::random_device random_device;
450 auto rng = std::mt19937(random_device());
451 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
452 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
453
454 std::vector<uint16_t> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(uint16_t));
455 std::vector<uint16_t> b(n() * k());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800456 std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> packed_w(packed_n() * packed_k() + bias_n());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700457 std::vector<uint16_t> bias(n());
458 std::vector<uint16_t> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
459 std::vector<float> c_ref(m() * n());
460
461 xnn_f16_output_params output_params;
462 output_params.scale = UINT16_C(0x3C00) /* 1.0 */;
463
464 for (size_t iteration = 0; iteration < iterations(); iteration++) {
465 std::generate(a.begin(), a.end(), std::ref(f16rng));
466 std::generate(b.begin(), b.end(), std::ref(f16rng));
467 std::generate(bias.begin(), bias.end(), std::ref(f16rng));
468 std::fill(c.begin(), c.end(), UINT16_C(0x7E00) /* NaN */);
469 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
470
471 std::fill(packed_w.begin(), packed_w.end(), 0);
472 xnn_pack_f16_gemm_goi_w(1, n(), k(), nr(), kr(), b.data(), bias.data(), packed_w.data());
473
474 for (size_t m_index = 0; m_index < m(); m_index++) {
475 for (size_t n_index = 0; n_index < n(); n_index++) {
476 for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
477 for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
478 ASSERT_LE(n(), packed_n());
479 ASSERT_LT(m_index * n() + n_index, c_ref.size());
480 ASSERT_LT(m_index * k() + k_block_start + k_block_offset, a.size());
481
482 c_ref[m_index * n() + n_index] +=
483 fp16_ieee_to_fp32_value(a[m_index * a_stride() + k_block_start + k_block_offset]) *
484 fp16_ieee_to_fp32_value(b[n_index * k() + k_block_start + k_block_offset]);
485 }
486 }
487 c_ref[m_index * n() + n_index] += fp16_ieee_to_fp32_value(bias[n_index]);
488 }
489 }
490
491 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
492 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
493 const float c_min = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin())));
494 const float c_max = fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax())));
495 output_params.max = fp16_ieee_from_fp32_value(c_max);
496 output_params.min = fp16_ieee_from_fp32_value(c_min);
497
498 for (float& c_value : c_ref) {
499 c_value = std::max(std::min(c_value, c_max), c_min);
500 }
501
502 gemm(m(), n(), k() * sizeof(uint16_t),
503 a.data(), a_stride() * sizeof(uint16_t),
504 packed_w.data(),
505 c.data(), cm_stride() * sizeof(uint16_t), cn_stride() * sizeof(uint16_t),
506 &output_params);
507
508 // Validate micro-kernel outputs.
509 for (size_t i = 0; i < m(); i++) {
510 for (size_t j = 0; j < n(); j++) {
511 ASSERT_NEAR(
512 fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]),
513 c_ref[i * n() + j],
514 std::abs(c_ref[i * n() + j]) * 1.0e-2f)
515 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
516 << ", optimized = " << fp16_ieee_to_fp32_value(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()]) << ", Mr x Nr x Kr = " << mr() << " x " << nr()
517 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
518 }
519 }
520 }
521 }
522
523 void Test(xnn_f32_ppmm_ukernel_function ppmm, Variant variant = Variant::Native) const {
524 ASSERT_LE(m(), mr());
525 ASSERT_GE(cm_stride(), n());
526
527 std::random_device random_device;
528 auto rng = std::mt19937(random_device());
529 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
530
531 std::vector<float> a(packed_k() * mr());
532 std::vector<float> b(n() * k());
533 std::vector<float> bias(n());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800534 std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700535 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
536 std::vector<float> c_ref(m() * n());
537
538 for (size_t iteration = 0; iteration < iterations(); iteration++) {
539 std::generate(a.begin(), a.end(), std::ref(f32rng));
540 std::generate(b.begin(), b.end(), std::ref(f32rng));
541 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
542 std::fill(c.begin(), c.end(), nanf(""));
543 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
544
545 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
546 xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data());
547
548 for (size_t i = m(); i < mr(); i++) {
549 for (size_t l = 0; l < k(); l++) {
550 a[l * mr() + i] = a[l * mr() + m() - 1];
551 }
552 }
553
554 for (size_t i = 0; i < m(); i++) {
555 for (size_t j = 0; j < n(); j++) {
556 for (size_t l = 0; l < k(); l++) {
557 c_ref[i * n() + j] +=
558 a[l * mr() + i] *
559 b[j * k() + l];
560 }
561 c_ref[i * n() + j] += bias[j];
562 }
563 }
564
565 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
566 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
567 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
568 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
569
570 // Prepare output parameters.
571 xnn_f32_output_params output_params = { };
572 switch (variant) {
573 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700574 output_params = xnn_init_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700575 break;
576 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700577 output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700578 break;
579 }
580
581 for (float& c_value : c_ref) {
582 c_value = std::max(std::min(c_value, c_max), c_min);
583 }
584
585 ppmm(m(), n(), k() * sizeof(float),
586 a.data(), packed_w.data(),
587 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
588 &output_params);
589
590 // Validate micro-kernel outputs.
591 for (size_t i = 0; i < m(); i++) {
592 for (size_t j = 0; j < n(); j++) {
593 ASSERT_NEAR(
594 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
595 c_ref[i * n() + j],
596 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
597 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
598 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
599 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
600 }
601 }
602 }
603 }
604
605 void Test(xnn_f32_gemm_ukernel_function gemm, Variant variant = Variant::Native) const {
606 ASSERT_LE(m(), mr());
607 ASSERT_GE(a_stride(), k());
608 ASSERT_GE(cm_stride(), n());
609
610 std::random_device random_device;
611 auto rng = std::mt19937(random_device());
612 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
613
614 std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
615 std::vector<float> b(n() * k());
616 std::vector<float> bias(n());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800617 std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k() + bias_n());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700618 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
619 std::vector<float> c_ref(m() * n());
620
621 for (size_t iteration = 0; iteration < iterations(); iteration++) {
622 std::generate(a.begin(), a.end(), std::ref(f32rng));
623 std::generate(b.begin(), b.end(), std::ref(f32rng));
624 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
625 std::fill(c.begin(), c.end(), nanf(""));
626 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
627
628 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
629 xnn_pack_f32_gemm_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), bias.data(), packed_w.data());
630
631 for (size_t m_index = 0; m_index < m(); m_index++) {
632 for (size_t n_index = 0; n_index < n(); n_index++) {
633 for (size_t k_index = 0; k_index < k(); k_index++) {
634 ASSERT_LE(n(), packed_n());
635 ASSERT_LT(m_index * n() + n_index, c_ref.size());
636 c_ref[m_index * n() + n_index] +=
637 a[m_index * a_stride() + k_index] *
638 b[n_index * k() + k_index];
639 }
640 c_ref[m_index * n() + n_index] += bias[n_index];
641 }
642 }
643
644 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
645 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
646 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
647 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
648
649 // Prepare output parameters.
650 xnn_f32_output_params output_params = { };
651 switch (variant) {
652 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700653 output_params = xnn_init_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700654 break;
655 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700656 output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700657 break;
658 }
659
660 for (size_t m_index = 0; m_index < m(); m_index++) {
661 for (size_t n_index = 0; n_index < n(); n_index++) {
662 c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
663 }
664 }
665
666 gemm(m(), n(), k() * sizeof(float),
667 a.data(), a_stride() * sizeof(float),
668 packed_w.data(),
669 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
670 &output_params);
671
672 // Validate micro-kernel outputs.
673 for (size_t i = 0; i < m(); i++) {
674 for (size_t j = 0; j < n(); j++) {
675 ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
676 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
677 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
678 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
679 ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
680 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
681 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
682 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
683 ASSERT_NEAR(
684 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
685 c_ref[i * n() + j],
686 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
687 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
688 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
689 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
690 }
691 }
692 }
693 }
694
695 void Test(xnn_f32_gemminc_ukernel_function gemminc, Variant variant = Variant::Native) const {
696 ASSERT_LE(m(), mr());
697 ASSERT_GE(a_stride(), k());
698 ASSERT_GE(cm_stride(), n());
699
700 std::random_device random_device;
701 auto rng = std::mt19937(random_device());
702 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
703
704 std::vector<float> a((m() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
705 std::vector<float> b(n() * k());
706 std::vector<float> bias(n());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800707 std::vector<float, AlignedAllocator<float, 64>> packed_w(packed_n() * packed_k()); // no bias_n()
XNNPACK Teamb455b122019-09-27 18:10:33 -0700708 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
709 std::vector<float> c_ref(m() * n());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800710 std::vector<float, AlignedAllocator<float, 64>> acc(mr() * packed_n());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700711
712 for (size_t iteration = 0; iteration < iterations(); iteration++) {
713 std::generate(a.begin(), a.end(), std::ref(f32rng));
714 std::generate(b.begin(), b.end(), std::ref(f32rng));
715 std::fill(c.begin(), c.end(), nanf(""));
716 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
717 std::generate(acc.begin(), acc.end(), std::ref(f32rng));
718
719 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
720 xnn_pack_f32_gemminc_goi_w(1, n(), k(), nr(), kr(), sr(), b.data(), packed_w.data());
721
722 for (size_t m_index = 0; m_index < m(); m_index++) {
723 for (size_t n_index = 0; n_index < n(); n_index++) {
724 for (size_t k_index = 0; k_index < k(); k_index++) {
725 ASSERT_LE(n(), packed_n());
726 ASSERT_LT(m_index * n() + n_index, c_ref.size());
727 c_ref[m_index * n() + n_index] +=
728 a[m_index * a_stride() + k_index] *
729 b[n_index * k() + k_index];
730 }
731 c_ref[m_index * n() + n_index] += acc[n_index / nr() * nr() * mr() + m_index % mr() * nr() + n_index % nr()];
732 }
733 }
734
735 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
736 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
737 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
738 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
739
740 // Prepare output parameters.
741 xnn_f32_output_params output_params = { };
742 switch (variant) {
743 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700744 output_params = xnn_init_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700745 break;
746 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700747 output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700748 break;
749 }
750
751 for (size_t m_index = 0; m_index < m(); m_index++) {
752 for (size_t n_index = 0; n_index < n(); n_index++) {
753 c_ref[m_index * n() + n_index] = std::max(std::min(c_ref[m_index * n() + n_index], c_max), c_min);
754 }
755 }
756
757 gemminc(m(), n(), k() * sizeof(float),
758 a.data(), a_stride() * sizeof(float),
759 packed_w.data(),
760 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
761 acc.data(),
762 &output_params);
763
764 // Validate micro-kernel outputs.
765 for (size_t i = 0; i < m(); i++) {
766 for (size_t j = 0; j < n(); j++) {
767 ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
768 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
769 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
770 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
771 ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
772 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
773 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
774 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
775 ASSERT_NEAR(
776 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
777 c_ref[i * n() + j],
778 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
779 << "at " << i << ", " << j << ": reference = " << c_ref[i * n() + j]
780 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
781 << " x " << kr() << ", M x N x K = " << m() << " x " << n() << " x " << k();
782 }
783 }
784 }
785 }
786
787 void Test(xnn_f32_igemm_ukernel_function igemm, Variant variant = Variant::Native) const {
788 ASSERT_LE(m(), mr());
789
790 std::random_device random_device;
791 auto rng = std::mt19937(random_device());
792 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
793
794 std::vector<float> a((mr() - 1) * a_stride() + k() + XNN_EXTRA_BYTES / sizeof(float));
795 std::vector<float> b(n() * ks() * k());
Marat Dukhan0f349c42019-11-27 11:58:54 -0800796 std::vector<float, AlignedAllocator<float, 64>> packed_w(ks() * packed_k() * packed_n() + bias_n());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700797 std::vector<float> bias(n());
798 std::vector<float> c((mr() - 1) * cm_stride() + ((n() - 1) / nr()) * cn_stride() + (n() - 1) % nr() + 1);
799 std::vector<float> c_ref(m() * n());
800 std::vector<float> junk(k() + XNN_EXTRA_BYTES / sizeof(float));
801 std::vector<const float*> im2col(mr() * ks());
802 std::fill(junk.begin(), junk.end(), nanf(""));
803
804 for (size_t iteration = 0; iteration < iterations(); iteration++) {
805 std::generate(a.begin(), a.end(), std::ref(f32rng));
806 std::generate(b.begin(), b.end(), std::ref(f32rng));
807 std::generate(bias.begin(), bias.end(), std::ref(f32rng));
808 std::fill(c.begin(), c.end(), nanf(""));
809 std::fill(c_ref.begin(), c_ref.end(), 0.0f);
810
811 std::fill(packed_w.begin(), packed_w.end(), 0.0f);
812 xnn_pack_f32_conv_goki_w(
813 1, n(), ks(), k(), nr(), kr(), sr(),
814 b.data(), bias.data(), packed_w.data());
815
816 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
817 for (size_t m_index = 0; m_index < mr(); m_index++) {
818 im2col[ks_index * mr() + m_index] = a.data() + a_stride() * m_index - a_offset();
819 }
820 }
821 std::shuffle(im2col.begin(), im2col.end(), rng);
822 if (zero_index() != SIZE_MAX) {
823 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
824 im2col[ks_index * mr() + zero_index()] = a.data();
825 }
826 }
827 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
828 for (size_t m_index = m(); m_index < mr(); m_index++) {
829 im2col[ks_index * mr() + m_index] = junk.data();
830 }
831 }
832
833 std::fill(c_ref.begin(), c_ref.end(), 0.0);
834 for (size_t m_index = 0; m_index < m(); m_index++) {
835 for (size_t n_index = 0; n_index < n(); n_index++) {
836 for (size_t ks_index = 0; ks_index < ks(); ks_index++) {
837 for (size_t k_block_start = 0; k_block_start < k(); k_block_start += kr()) {
838 for (size_t k_block_offset = 0; k_block_offset < std::min(k() - k_block_start, kr()); k_block_offset++) {
839 ASSERT_LT(ks_index * mr() + m_index, im2col.size());
840 ASSERT_LT(k_block_start + k_block_offset, k());
841 ASSERT_LT(k_block_start + k_block_offset, a_stride());
842 if (im2col[ks_index * mr() + m_index] == a.data()) {
843 c_ref[m_index * n() + n_index] +=
844 double(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset]) *
845 double(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]);
846 } else {
847 c_ref[m_index * n() + n_index] +=
848 double(im2col[ks_index * mr() + m_index][k_block_start + k_block_offset + a_offset()]) *
849 double(b[(n_index * ks() + ks_index) * k() + k_block_start + k_block_offset]);
850 }
851 }
852 }
853 }
854 c_ref[m_index * n() + n_index] += bias[n_index];
855 }
856 }
857
858 const float accumulated_min = *std::min_element(c_ref.cbegin(), c_ref.cend());
859 const float accumulated_max = *std::max_element(c_ref.cbegin(), c_ref.cend());
860 const float c_min = accumulated_min + (accumulated_max - accumulated_min) / 255.0f * float(qmin());
861 const float c_max = accumulated_max - (accumulated_max - accumulated_min) / 255.0f * float(255 - qmax());
862 for (size_t m_index = 0; m_index < m(); m_index++) {
863 for (size_t n_index = 0; n_index < n(); n_index++) {
864 c_ref[m_index * n() + n_index] = std::min(c_ref[m_index * n() + n_index], c_max);
865 c_ref[m_index * n() + n_index] = std::max(c_ref[m_index * n() + n_index], c_min);
866 }
867 }
868
869 // Prepare output parameters.
870 xnn_f32_output_params output_params = { };
871 switch (variant) {
872 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700873 output_params = xnn_init_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700874 break;
875 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700876 output_params = xnn_init_scalar_f32_output_params(c_min, c_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700877 break;
878 }
879
880 const float* zero_pointer = (zero_index() != SIZE_MAX) ? a.data() : NULL;
881
882 igemm(
883 m(), n(), k() * sizeof(float), ks() * mr() * sizeof(void*),
884 im2col.data(), packed_w.data(),
885 c.data(), cm_stride() * sizeof(float), cn_stride() * sizeof(float),
886 a_offset() * sizeof(float), zero_pointer,
887 &output_params);
888
889 for (size_t i = 0; i < m(); i++) {
890 for (size_t j = 0; j < n(); j++) {
891 ASSERT_LE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_max)
892 << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
893 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
894 << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
895 ASSERT_GE(c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()], c_min)
896 << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
897 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
898 << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
899 ASSERT_NEAR(
900 c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()],
901 c_ref[i * n() + j],
902 std::abs(c_ref[i * n() + j]) * 1.0e-6f)
903 << "at " << i << ", " << i << ": reference = " << c_ref[i * n() + j]
904 << ", optimized = " << c[i * cm_stride() + (j / nr()) * cn_stride() + j % nr()] << ", Mr x Nr x Kr = " << mr() << " x " << nr()
905 << " x " << kr() << ", M x N x KC x KS = " << m() << " x " << n() << " x " << k() << " x " << ks();
906 }
907 }
908 }
909 }
910
911 private:
912 size_t mr_{1};
913 size_t nr_{1};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700914 size_t kr_{1};
915 size_t sr_{1};
916 size_t m_{1};
917 size_t n_{1};
918 size_t k_{1};
919 size_t ks_{1};
920 size_t a_stride_{0};
921 size_t cm_stride_{0};
922 size_t cn_stride_{0};
923 uint8_t a_zero_point_{127};
924 uint8_t b_zero_point_{127};
925 uint8_t qmin_{0};
926 uint8_t qmax_{255};
927 size_t a_offset_{0};
928 size_t zero_index_{SIZE_MAX};
929 size_t iterations_{15};
930};