blob: 531b6354191c357b2682bcd152b6596b4b261edb [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright (c) Facebook, Inc. and its affiliates.
2// All rights reserved.
3//
4// Copyright 2019 Google LLC
5//
6// This source code is licensed under the BSD-style license found in the
7// LICENSE file in the root directory of this source tree.
8
9#pragma once
10
11#include <gtest/gtest.h>
12
13#include <algorithm>
14#include <cassert>
15#include <cstddef>
16#include <cstdlib>
17#include <functional>
Marat Dukhan5ce30d92020-04-14 03:31:26 -070018#include <limits>
XNNPACK Teamb455b122019-09-27 18:10:33 -070019#include <random>
20#include <vector>
21
22#include <xnnpack.h>
23
24
25class ChannelShuffleOperatorTester {
26 public:
27 inline ChannelShuffleOperatorTester& groups(size_t groups) {
28 assert(groups != 0);
29 this->groups_ = groups;
30 return *this;
31 }
32
33 inline size_t groups() const {
34 return this->groups_;
35 }
36
37 inline ChannelShuffleOperatorTester& group_channels(size_t group_channels) {
38 assert(group_channels != 0);
39 this->group_channels_ = group_channels;
40 return *this;
41 }
42
43 inline size_t group_channels() const {
44 return this->group_channels_;
45 }
46
47 inline size_t channels() const {
48 return groups() * group_channels();
49 }
50
51 inline ChannelShuffleOperatorTester& input_stride(size_t input_stride) {
52 assert(input_stride != 0);
53 this->input_stride_ = input_stride;
54 return *this;
55 }
56
57 inline size_t input_stride() const {
58 if (this->input_stride_ == 0) {
59 return channels();
60 } else {
61 assert(this->input_stride_ >= channels());
62 return this->input_stride_;
63 }
64 }
65
66 inline ChannelShuffleOperatorTester& output_stride(size_t output_stride) {
67 assert(output_stride != 0);
68 this->output_stride_ = output_stride;
69 return *this;
70 }
71
72 inline size_t output_stride() const {
73 if (this->output_stride_ == 0) {
74 return channels();
75 } else {
76 assert(this->output_stride_ >= channels());
77 return this->output_stride_;
78 }
79 }
80
81 inline ChannelShuffleOperatorTester& batch_size(size_t batch_size) {
82 assert(batch_size != 0);
83 this->batch_size_ = batch_size;
84 return *this;
85 }
86
87 inline size_t batch_size() const {
88 return this->batch_size_;
89 }
90
91 inline ChannelShuffleOperatorTester& iterations(size_t iterations) {
92 this->iterations_ = iterations;
93 return *this;
94 }
95
96 inline size_t iterations() const {
97 return this->iterations_;
98 }
99
100 void TestX8() const {
101 std::random_device random_device;
102 auto rng = std::mt19937(random_device());
Marat Dukhan5ce30d92020-04-14 03:31:26 -0700103 auto u8rng = std::bind(std::uniform_int_distribution<uint32_t>(0, std::numeric_limits<uint8_t>::max()), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700104
105 std::vector<uint8_t> input(XNN_EXTRA_BYTES / sizeof(uint8_t) + (batch_size() - 1) * input_stride() + channels());
106 std::vector<uint8_t> output((batch_size() - 1) * output_stride() + channels());
107 for (size_t iteration = 0; iteration < iterations(); iteration++) {
108 std::generate(input.begin(), input.end(), std::ref(u8rng));
109 std::fill(output.begin(), output.end(), 0xA5);
110
111 // Create, setup, run, and destroy Channel Shuffle operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800112 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700113 xnn_operator_t channel_shuffle_op = nullptr;
114
115 ASSERT_EQ(xnn_status_success,
116 xnn_create_channel_shuffle_nc_x8(
117 groups(), group_channels(),
118 input_stride(), output_stride(),
119 0, &channel_shuffle_op));
120 ASSERT_NE(nullptr, channel_shuffle_op);
121
122 // Smart pointer to automatically delete channel_shuffle_op.
123 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
124
125 ASSERT_EQ(xnn_status_success,
126 xnn_setup_channel_shuffle_nc_x8(
127 channel_shuffle_op,
128 batch_size(),
129 input.data(), output.data(),
130 nullptr /* thread pool */));
131
132 ASSERT_EQ(xnn_status_success,
133 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
134
135 // Verify results.
136 for (size_t i = 0; i < batch_size(); i++) {
137 for (size_t g = 0; g < groups(); g++) {
138 for (size_t c = 0; c < group_channels(); c++) {
139 ASSERT_EQ(uint32_t(input[i * input_stride() + g * group_channels() + c]),
140 uint32_t(output[i * output_stride() + c * groups() + g]))
141 << "batch index " << i << ", group " << g << ", channel " << c;
142 }
143 }
144 }
145 }
146 }
147
148 void TestX32() const {
149 std::random_device random_device;
150 auto rng = std::mt19937(random_device());
151 auto f32rng = std::bind(std::uniform_real_distribution<float>(), rng);
152
153 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
154 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
155 for (size_t iteration = 0; iteration < iterations(); iteration++) {
156 std::generate(input.begin(), input.end(), std::ref(f32rng));
157 std::fill(output.begin(), output.end(), std::nanf(""));
158
159 // Create, setup, run, and destroy Channel Shuffle operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800160 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700161 xnn_operator_t channel_shuffle_op = nullptr;
162
163 ASSERT_EQ(xnn_status_success,
164 xnn_create_channel_shuffle_nc_x32(
165 groups(), group_channels(),
166 input_stride(), output_stride(),
167 0, &channel_shuffle_op));
168 ASSERT_NE(nullptr, channel_shuffle_op);
169
170 // Smart pointer to automatically delete channel_shuffle_op.
171 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_channel_shuffle_op(channel_shuffle_op, xnn_delete_operator);
172
173 ASSERT_EQ(xnn_status_success,
174 xnn_setup_channel_shuffle_nc_x32(
175 channel_shuffle_op,
176 batch_size(),
177 input.data(), output.data(),
178 nullptr /* thread pool */));
179
180 ASSERT_EQ(xnn_status_success,
181 xnn_run_operator(channel_shuffle_op, nullptr /* thread pool */));
182
183 // Verify results.
184 for (size_t i = 0; i < batch_size(); i++) {
185 for (size_t g = 0; g < groups(); g++) {
186 for (size_t c = 0; c < group_channels(); c++) {
187 ASSERT_EQ(input[i * input_stride() + g * group_channels() + c],
188 output[i * output_stride() + c * groups() + g])
189 << "batch index " << i << ", group " << g << ", channel " << c;
190 }
191 }
192 }
193 }
194 }
195
196 private:
197 size_t groups_{1};
198 size_t group_channels_{1};
199 size_t batch_size_{1};
200 size_t input_stride_{0};
201 size_t output_stride_{0};
202 size_t iterations_{15};
203};