blob: 7c2635362a056c7f28b479af67559b97759152b5 [file] [log] [blame]
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
Marat Dukhan0461f2d2021-08-08 12:36:29 -070010#include <array>
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070011#include <algorithm>
12#include <cassert>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20#include <xnnpack/params.h>
21
22
23class FillMicrokernelTester {
24 public:
25 inline FillMicrokernelTester& rows(size_t rows) {
26 assert(rows != 0);
27 this->rows_ = rows;
28 return *this;
29 }
30
31 inline size_t rows() const {
32 return this->rows_;
33 }
34
35 inline FillMicrokernelTester& channels(size_t channels) {
36 assert(channels != 0);
37 this->channels_ = channels;
38 return *this;
39 }
40
41 inline size_t channels() const {
42 return this->channels_;
43 }
44
45 inline FillMicrokernelTester& output_stride(size_t output_stride) {
46 assert(output_stride != 0);
47 this->output_stride_ = output_stride;
48 return *this;
49 }
50
51 inline size_t output_stride() const {
52 if (this->output_stride_ == 0) {
53 return channels();
54 } else {
55 return this->output_stride_;
56 }
57 }
58
59 inline FillMicrokernelTester& iterations(size_t iterations) {
60 this->iterations_ = iterations;
61 return *this;
62 }
63
64 inline size_t iterations() const {
65 return this->iterations_;
66 }
67
Marat Dukhan933051b2021-08-07 16:26:15 -070068 void Test(xnn_fill_ukernel_function fill) const {
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070069 ASSERT_GE(output_stride(), channels());
70
71 std::random_device random_device;
72 auto rng = std::mt19937(random_device());
Marat Dukhan933051b2021-08-07 16:26:15 -070073 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070074
Marat Dukhan933051b2021-08-07 16:26:15 -070075 std::vector<uint8_t> output((rows() - 1) * output_stride() + channels());
76 std::vector<uint8_t> output_copy(output.size());
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070077 for (size_t iteration = 0; iteration < iterations(); iteration++) {
Marat Dukhan933051b2021-08-07 16:26:15 -070078 std::generate(output.begin(), output.end(), std::ref(u8rng));
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070079 std::copy(output.cbegin(), output.cend(), output_copy.begin());
Marat Dukhan933051b2021-08-07 16:26:15 -070080 std::array<uint8_t, 4> fill_pattern;
81 std::generate(fill_pattern.begin(), fill_pattern.end(), std::ref(u8rng));
82 uint32_t fill_value = 0;
83 memcpy(&fill_value, fill_pattern.data(), sizeof(fill_value));
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070084
85 // Call optimized micro-kernel.
86 fill(
87 rows(),
Marat Dukhan933051b2021-08-07 16:26:15 -070088 channels() * sizeof(uint8_t),
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070089 output.data(),
Marat Dukhan933051b2021-08-07 16:26:15 -070090 output_stride() * sizeof(uint8_t),
91 fill_value);
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070092
93 // Verify results.
94 for (size_t i = 0; i < rows(); i++) {
95 for (size_t c = 0; c < channels(); c++) {
Marat Dukhan933051b2021-08-07 16:26:15 -070096 ASSERT_EQ(uint32_t(output[i * output_stride() + c]), uint32_t(fill_pattern[c % fill_pattern.size()]))
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070097 << "at row " << i << " / " << rows()
98 << ", channel " << c << " / " << channels()
99 << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
100 << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
101 }
102 }
103 for (size_t i = 0; i + 1 < rows(); i++) {
104 for (size_t c = channels(); c < output_stride(); c++) {
Marat Dukhan933051b2021-08-07 16:26:15 -0700105 ASSERT_EQ(uint32_t(output[i * output_stride() + c]), uint32_t(output_copy[i * output_stride() + c]))
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -0700106 << "at row " << i << " / " << rows()
107 << ", channel " << c << " / " << channels()
108 << ", original value 0x" << std::hex << std::setw(8) << std::setfill('0')
109 << output_copy[i * output_stride() + c]
110 << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
111 }
112 }
113 }
114 }
115
116 private:
117 size_t rows_{1};
118 size_t channels_{1};
119 size_t output_stride_{0};
120 size_t iterations_{15};
121};