blob: e40df662c910ef2429157d003c8ef77559fb0675 [file] [log] [blame]
Marat Dukhan3bb3bfc2020-05-19 17:42:46 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19#include <xnnpack/params.h>
20
21
22class FillMicrokernelTester {
23 public:
24 inline FillMicrokernelTester& rows(size_t rows) {
25 assert(rows != 0);
26 this->rows_ = rows;
27 return *this;
28 }
29
30 inline size_t rows() const {
31 return this->rows_;
32 }
33
34 inline FillMicrokernelTester& channels(size_t channels) {
35 assert(channels != 0);
36 this->channels_ = channels;
37 return *this;
38 }
39
40 inline size_t channels() const {
41 return this->channels_;
42 }
43
44 inline FillMicrokernelTester& output_stride(size_t output_stride) {
45 assert(output_stride != 0);
46 this->output_stride_ = output_stride;
47 return *this;
48 }
49
50 inline size_t output_stride() const {
51 if (this->output_stride_ == 0) {
52 return channels();
53 } else {
54 return this->output_stride_;
55 }
56 }
57
58 inline FillMicrokernelTester& iterations(size_t iterations) {
59 this->iterations_ = iterations;
60 return *this;
61 }
62
63 inline size_t iterations() const {
64 return this->iterations_;
65 }
66
67 void Test(xnn_x32_fill_ukernel_function fill) const {
68 ASSERT_GE(output_stride(), channels());
69
70 std::random_device random_device;
71 auto rng = std::mt19937(random_device());
72 auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), rng);
73
74 std::vector<uint32_t> output((rows() - 1) * output_stride() + channels());
75 std::vector<uint32_t> output_copy(output.size());
76 for (size_t iteration = 0; iteration < iterations(); iteration++) {
77 std::generate(output.begin(), output.end(), std::ref(u32rng));
78 std::copy(output.cbegin(), output.cend(), output_copy.begin());
79 const uint32_t fill_value = u32rng();
80
81 // Call optimized micro-kernel.
82 fill(
83 rows(),
84 channels() * sizeof(uint32_t),
85 output.data(),
86 output_stride() * sizeof(uint32_t),
87 &fill_value);
88
89 // Verify results.
90 for (size_t i = 0; i < rows(); i++) {
91 for (size_t c = 0; c < channels(); c++) {
92 ASSERT_EQ(output[i * output_stride() + c], fill_value)
93 << "at row " << i << " / " << rows()
94 << ", channel " << c << " / " << channels()
95 << ", fill value 0x" << std::hex << std::setw(8) << std::setfill('0') << fill_value
96 << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
97 }
98 }
99 for (size_t i = 0; i + 1 < rows(); i++) {
100 for (size_t c = channels(); c < output_stride(); c++) {
101 ASSERT_EQ(output[i * output_stride() + c], output_copy[i * output_stride() + c])
102 << "at row " << i << " / " << rows()
103 << ", channel " << c << " / " << channels()
104 << ", original value 0x" << std::hex << std::setw(8) << std::setfill('0')
105 << output_copy[i * output_stride() + c]
106 << ", output value 0x" << std::hex << std::setw(8) << std::setfill('0') << output[i * output_stride() + c];
107 }
108 }
109 }
110 }
111
112 private:
113 size_t rows_{1};
114 size_t channels_{1};
115 size_t output_stride_{0};
116 size_t iterations_{15};
117};