blob: 44042bab529badfc61119069670c80a8d15ab7af [file] [log] [blame]
Yury Kartynnike7841862020-11-04 18:22:18 -08001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdint>
14#include <cstdlib>
15#include <functional>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20#include <xnnpack/params.h>
21
22
23class DepthToSpaceMicrokernelTester {
24 public:
25 inline DepthToSpaceMicrokernelTester& output_channels(size_t output_channels) {
26 assert(output_channels != 0);
27 this->output_channels_ = output_channels;
28 return *this;
29 }
30
31 inline size_t output_channels() const {
32 return this->output_channels_;
33 }
34
35 inline size_t input_channels() const {
36 return this->output_channels() * this->block_size() * this->block_size();
37 }
38
39 inline DepthToSpaceMicrokernelTester& input_height(size_t input_height) {
40 assert(input_height != 0);
41 this->input_height_ = input_height;
42 return *this;
43 }
44
45 inline size_t input_height() const {
46 return this->input_height_;
47 }
48
49 inline size_t output_height() const {
50 return this->input_height() * this->block_size();
51 }
52
53 inline DepthToSpaceMicrokernelTester& input_width(size_t input_width) {
54 assert(input_width != 0);
55 this->input_width_ = input_width;
56 return *this;
57 }
58
59 inline size_t input_width() const {
60 return this->input_width_;
61 }
62
63 inline size_t output_width() const {
64 return this->input_width() * this->block_size();
65 }
66
67 inline DepthToSpaceMicrokernelTester& block_size(size_t block_size) {
68 assert(block_size != 0);
69 this->block_size_ = block_size;
70 return *this;
71 }
72
73 inline size_t block_size() const {
74 return this->block_size_;
75 }
76
Marat Dukhan77e93a22021-02-26 11:13:55 -080077 inline DepthToSpaceMicrokernelTester& output_channel_stride(size_t output_channel_stride) {
78 assert(output_channel_stride != 0);
79 this->output_channel_stride_ = output_channel_stride;
Yury Kartynnike7841862020-11-04 18:22:18 -080080 return *this;
81 }
82
Marat Dukhan77e93a22021-02-26 11:13:55 -080083 inline size_t output_channel_stride() const {
84 if (this->output_channel_stride_ != 0) {
85 return this->output_channel_stride_;
Yury Kartynnike7841862020-11-04 18:22:18 -080086 } else {
Marat Dukhan77e93a22021-02-26 11:13:55 -080087 return this->output_channels();
Yury Kartynnike7841862020-11-04 18:22:18 -080088 }
89 }
90
91 inline DepthToSpaceMicrokernelTester& iterations(size_t iterations) {
92 this->iterations_ = iterations;
93 return *this;
94 }
95
96 inline size_t iterations() const {
97 return this->iterations_;
98 }
99
Marat Dukhanad71b9a2020-11-20 00:01:51 -0800100 void Test(xnn_x32_depthtospace2d_chw2hwc_ukernel_function depthtospace2d) const {
Yury Kartynnike7841862020-11-04 18:22:18 -0800101 ASSERT_GE(block_size(), 2);
Yury Kartynnike7841862020-11-04 18:22:18 -0800102
103 std::random_device random_device;
104 auto rng = std::mt19937(random_device());
105 auto u32rng = std::bind(std::uniform_int_distribution<uint32_t>(), rng);
106
Marat Dukhan77e93a22021-02-26 11:13:55 -0800107 std::vector<uint32_t> input(input_channels() * input_height() * input_width());
108 std::vector<uint32_t> output((output_height() * output_width() - 1) * output_channel_stride() + output_channels());
Yury Kartynnike7841862020-11-04 18:22:18 -0800109
110 for (size_t iteration = 0; iteration < iterations(); iteration++) {
111 std::generate(input.begin(), input.end(), std::ref(u32rng));
112
113 // Call optimized micro-kernel.
Marat Dukhanad71b9a2020-11-20 00:01:51 -0800114 depthtospace2d(
Yury Kartynnike7841862020-11-04 18:22:18 -0800115 output_channels(),
116 input_height(),
117 input_width(),
118 block_size(),
119 input.data(),
120 output.data(),
Marat Dukhan77e93a22021-02-26 11:13:55 -0800121 output_channel_stride());
Yury Kartynnike7841862020-11-04 18:22:18 -0800122
123 // Verify results.
Marat Dukhan77e93a22021-02-26 11:13:55 -0800124 for (size_t iy = 0; iy < input_height(); iy++) {
125 for (size_t by = 0; by < block_size(); by++) {
126 for (size_t ix = 0; ix < input_width(); ix++) {
127 for (size_t bx = 0; bx < block_size(); bx++) {
128 for (size_t oc = 0; oc < output_channels(); oc++) {
129 const size_t input_index =
130 (((by * block_size() + bx) * output_channels() + oc) * input_height() + iy) * input_width() + ix;
131 const size_t output_index =
132 ((iy * block_size() + by) * output_width() + ix * block_size() + bx) * output_channel_stride() + oc;
133 ASSERT_EQ(output[output_index], input[input_index])
134 << "input x: " << ix << " / " << input_width()
135 << ", input y: " << iy << " / " << input_height()
136 << ", block x: " << bx << " / " << block_size()
137 << ", block y: " << by << " / " << block_size()
138 << ", output channel: " << oc << " / " << output_channels()
139 << ", output stride: " << output_channel_stride();
Yury Kartynnike7841862020-11-04 18:22:18 -0800140 }
141 }
142 }
143 }
144 }
145 }
146 }
147
148 private:
149 size_t output_channels_{1};
150 size_t input_height_{1};
151 size_t input_width_{1};
152 size_t block_size_{2};
Marat Dukhan77e93a22021-02-26 11:13:55 -0800153 size_t output_channel_stride_{0};
Yury Kartynnike7841862020-11-04 18:22:18 -0800154 size_t iterations_{3};
155};