blob: 47466b6a67de86e6bd8061deef47e940dfd8e268 [file] [log] [blame]
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19#include <xnnpack/params.h>
20
21
22class FillMicrokernelTester {
23 public:
24 inline FillMicrokernelTester& rows(size_t rows) {
25 assert(rows != 0);
26 this->rows_ = rows;
27 return *this;
28 }
29
30 inline size_t rows() const {
31 return this->rows_;
32 }
33
34 inline FillMicrokernelTester& channels(size_t channels) {
35 assert(channels != 0);
36 this->channels_ = channels;
37 return *this;
38 }
39
40 inline size_t channels() const {
41 return this->channels_;
42 }
43
44 inline FillMicrokernelTester& output_stride(size_t output_stride) {
45 assert(output_stride != 0);
46 this->output_stride_ = output_stride;
47 return *this;
48 }
49
50 inline size_t output_stride() const {
51 if (this->output_stride_ == 0) {
52 return channels();
53 } else {
54 return this->output_stride_;
55 }
56 }
57
58 inline FillMicrokernelTester& iterations(size_t iterations) {
59 this->iterations_ = iterations;
60 return *this;
61 }
62
63 inline size_t iterations() const {
64 return this->iterations_;
65 }
66
Marat Dukhan933051b2021-08-07 16:26:15 -070067 void Test(xnn_fill_ukernel_function fill) const {
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070068 ASSERT_GE(output_stride(), channels());
69
70 std::random_device random_device;
71 auto rng = std::mt19937(random_device());
Marat Dukhan933051b2021-08-07 16:26:15 -070072 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070073
Marat Dukhan933051b2021-08-07 16:26:15 -070074 std::vector<uint8_t> output((rows() - 1) * output_stride() + channels());
75 std::vector<uint8_t> output_copy(output.size());
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070076 for (size_t iteration = 0; iteration < iterations(); iteration++) {
Marat Dukhan933051b2021-08-07 16:26:15 -070077 std::generate(output.begin(), output.end(), std::ref(u8rng));
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070078 std::copy(output.cbegin(), output.cend(), output_copy.begin());
Marat Dukhan933051b2021-08-07 16:26:15 -070079 std::array<uint8_t, 4> fill_pattern;
80 std::generate(fill_pattern.begin(), fill_pattern.end(), std::ref(u8rng));
81 uint32_t fill_value = 0;
82 memcpy(&fill_value, fill_pattern.data(), sizeof(fill_value));
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070083
84 // Call optimized micro-kernel.
85 fill(
86 rows(),
Marat Dukhan933051b2021-08-07 16:26:15 -070087 channels() * sizeof(uint8_t),
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070088 output.data(),
Marat Dukhan933051b2021-08-07 16:26:15 -070089 output_stride() * sizeof(uint8_t),
90 fill_value);
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070091
92 // Verify results.
93 for (size_t i = 0; i < rows(); i++) {
94 for (size_t c = 0; c < channels(); c++) {
Marat Dukhan933051b2021-08-07 16:26:15 -070095 ASSERT_EQ(uint32_t(output[i * output_stride() + c]), uint32_t(fill_pattern[c % fill_pattern.size()]))
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -070096 << "at row " << i << " / " << rows()
97 << ", channel " << c << " / " << channels()
98 << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
99 << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
100 }
101 }
102 for (size_t i = 0; i + 1 < rows(); i++) {
103 for (size_t c = channels(); c < output_stride(); c++) {
Marat Dukhan933051b2021-08-07 16:26:15 -0700104 ASSERT_EQ(uint32_t(output[i * output_stride() + c]), uint32_t(output_copy[i * output_stride() + c]))
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -0700105 << "at row " << i << " / " << rows()
106 << ", channel " << c << " / " << channels()
107 << ", original value 0x" << std::hex << std::setw(8) << std::setfill('0')
108 << output_copy[i * output_stride() + c]
109 << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
110 }
111 }
112 }
113 }
114
115 private:
116 size_t rows_{1};
117 size_t channels_{1};
118 size_t output_stride_{0};
119 size_t iterations_{15};
120};