blob: ad31c3f46ffb1b88e9e2c8ba979677008741f4c0 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cstddef>
16#include <cstdlib>
17#include <functional>
18#include <random>
19#include <vector>
20
21#include <xnnpack.h>
22
23
24class ChannelShuffleOperatorTester {
25 public:
26 inline ChannelShuffleOperatorTester& groups(size_t groups) {
27 assert(groups != 0);
28 this->groups_ = groups;
29 return *this;
30 }
31
32 inline size_t groups() const {
33 return this->groups_;
34 }
35
36 inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) {
37 assert(group_channels != 0);
38 this->group_channels_ = group_channels;
39 return *this;
40 }
41
42 inline size_t group_channels() const {
43 return this->group_channels_;
44 }
45
46 inline size_t channels() const {
47 return groups() * group_channels();
48 }
49
50 inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) {
51 assert(input_stride != 0);
52 this->input_stride_ = input_stride;
53 return *this;
54 }
55
56 inline size_t input_stride() const {
57 if (this->input_stride_ == 0) {
58 return channels();
59 } else {
60 assert(this->input_stride_ >= channels());
61 return this->input_stride_;
62 }
63 }
64
65 inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) {
66 assert(output_stride != 0);
67 this->output_stride_ = output_stride;
68 return *this;
69 }
70
71 inline size_t output_stride() const {
72 if (this->output_stride_ == 0) {
73 return channels();
74 } else {
75 assert(this->output_stride_ >= channels());
76 return this->output_stride_;
77 }
78 }
79
80 inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) {
81 assert(batch_size != 0);
82 this->batch_size_ = batch_size;
83 return *this;
84 }
85
86 inline size_t batch_size() const {
87 return this->batch_size_;
88 }
89
90 inline ChannelShuffleOperatorTester& iterations(size_t iterations) {
91 this->iterations_ = iterations;
92 return *this;
93 }
94
95 inline size_t iterations() const {
96 return this->iterations_;
97 }
98
99 void TestX8() const {
100 std::random_device random_device;
101 auto rng = std::mt19937(random_device());
102 auto u8rng = std::bind(std::uniform_int_distribution<uint8_t>(), rng);
103
104 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels());
105 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
106 for (size_t iteration = 0; iteration < iterations(); iteration++) {
107 std::generate(input.begin(), input.end(), std::ref(u8rng));
108 std::fill(output.begin(), output.end(), 0xA5);
109
110 // Create, setup, run, and destroy Channel Shuffle operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800111 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700112 xnn_operator_t channel_shuffle_op = nullptr;
113
114 ASSERT_EQ(xnn_status_success,
115 xnn_create_channel_shuffle_nc_x8(
116 groups(), group_channels(),
117 input_stride(), output_stride(),
118 0, &channel_shuffle_op));
119 ASSERT_NE(nullptr, channel_shuffle_op);
120
121 // Smart pointer to automatically delete channel_shuffle_op.
122 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
123
124 ASSERT_EQ(xnn_status_success,
125 xnn_setup_channel_shuffle_nc_x8(
126 channel_shuffle_op,
127 batch_size(),
128 input.data(), output.data(),
129 nullptr /* thread pool */));
130
131 ASSERT_EQ(xnn_status_success,
132 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
133
134 // Verify results.
135 for (size_t i = 0; i < batch_size(); i++) {
136 for (size_t g = 0; g < groups(); g++) {
137 for (size_t c = 0; c < group_channels(); c++) {
138 ASSERT_EQ(uint32_t(input[i * input_stride() + g * group_channels() + c]),
139 uint32_t(output[i * output_stride() + c * groups() + g]))
140 << "batch index " << i << ", group " << g << ", channel " << c;
141 }
142 }
143 }
144 }
145 }
146
147 void TestX32() const {
148 std::random_device random_device;
149 auto rng = std::mt19937(random_device());
150 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
151
152 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
153 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
154 for (size_t iteration = 0; iteration < iterations(); iteration++) {
155 std::generate(input.begin(), input.end(), std::ref(f32rng));
156 std::fill(output.begin(), output.end(), std::nanf(""));
157
158 // Create, setup, run, and destroy Channel Shuffle operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800159 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700160 xnn_operator_t channel_shuffle_op = nullptr;
161
162 ASSERT_EQ(xnn_status_success,
163 xnn_create_channel_shuffle_nc_x32(
164 groups(), group_channels(),
165 input_stride(), output_stride(),
166 0, &channel_shuffle_op));
167 ASSERT_NE(nullptr, channel_shuffle_op);
168
169 // Smart pointer to automatically delete channel_shuffle_op.
170 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
171
172 ASSERT_EQ(xnn_status_success,
173 xnn_setup_channel_shuffle_nc_x32(
174 channel_shuffle_op,
175 batch_size(),
176 input.data(), output.data(),
177 nullptr /* thread pool */));
178
179 ASSERT_EQ(xnn_status_success,
180 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
181
182 // Verify results.
183 for (size_t i = 0; i < batch_size(); i++) {
184 for (size_t g = 0; g < groups(); g++) {
185 for (size_t c = 0; c < group_channels(); c++) {
186 ASSERT_EQ(input[i * input_stride() + g * group_channels() + c],
187 output[i * output_stride() + c * groups() + g])
188 << "batch index " << i << ", group " << g << ", channel " << c;
189 }
190 }
191 }
192 }
193 }
194
195 private:
196 size_t groups_{1};
197 size_t group_channels_{1};
198 size_t batch_size_{1};
199 size_t input_stride_{0};
200 size_t output_stride_{0};
201 size_t iterations_{15};
202};