blob: 1d35df35a18b9a2a23710a778b0e0ea028f5c371 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
Frank Barcharda96948e2020-09-11 15:34:18 -070018#include <fp16.h>
19
XNNPACK Teamb455b122019-09-27 18:10:33 -070020#include <xnnpack.h>
21
22
23class HardSwishOperatorTester {
24 public:
25 inline HardSwishOperatorTester& channels(size_t channels) {
26 assert(channels != 0);
27 this->channels_ = channels;
28 return *this;
29 }
30
31 inline size_t channels() const {
32 return this->channels_;
33 }
34
35 inline HardSwishOperatorTester& input_stride(size_t input_stride) {
36 assert(input_stride != 0);
37 this->input_stride_ = input_stride;
38 return *this;
39 }
40
41 inline size_t input_stride() const {
42 if (this->input_stride_ == 0) {
43 return this->channels_;
44 } else {
45 assert(this->input_stride_ >= this->channels_);
46 return this->input_stride_;
47 }
48 }
49
50 inline HardSwishOperatorTester& output_stride(size_t output_stride) {
51 assert(output_stride != 0);
52 this->output_stride_ = output_stride;
53 return *this;
54 }
55
56 inline size_t output_stride() const {
57 if (this->output_stride_ == 0) {
58 return this->channels_;
59 } else {
60 assert(this->output_stride_ >= this->channels_);
61 return this->output_stride_;
62 }
63 }
64
65 inline HardSwishOperatorTester& batch_size(size_t batch_size) {
66 assert(batch_size != 0);
67 this->batch_size_ = batch_size;
68 return *this;
69 }
70
71 inline size_t batch_size() const {
72 return this->batch_size_;
73 }
74
75 inline HardSwishOperatorTester& iterations(size_t iterations) {
76 this->iterations_ = iterations;
77 return *this;
78 }
79
80 inline size_t iterations() const {
81 return this->iterations_;
82 }
83
Frank Barcharda96948e2020-09-11 15:34:18 -070084 void TestF16() const {
85 std::random_device random_device;
86 auto rng = std::mt19937(random_device());
Frank Barchardaed32782020-09-30 03:04:35 -070087 auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), rng);
Frank Barcharda96948e2020-09-11 15:34:18 -070088 auto f16rng = std::bind(fp16_ieee_from_fp32_value, f32rng);
89
90 std::vector<uint16_t> input(XNN_EXTRA_BYTES / sizeof(uint16_t) +
91 (batch_size() - 1) * input_stride() + channels());
92 std::vector<uint16_t> output((batch_size() - 1) * output_stride() + channels());
93 std::vector<float> output_ref(batch_size() * channels());
94 for (size_t iteration = 0; iteration < iterations(); iteration++) {
95 std::generate(input.begin(), input.end(), std::ref(f16rng));
96 std::fill(output.begin(), output.end(), UINT16_C(0x7E00) /* NaN */);
97
98 // Compute reference results.
99 for (size_t i = 0; i < batch_size(); i++) {
100 for (size_t c = 0; c < channels(); c++) {
101 const float x = fp16_ieee_to_fp32_value(input[i * input_stride() + c]);
102 const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
103 output_ref[i * channels() + c] = y;
104 }
105 }
106
107 // Create, setup, run, and destroy HardSwish operator.
108 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
109 xnn_operator_t hardswish_op = nullptr;
110 xnn_status status = xnn_create_hardswish_nc_f16(
111 channels(), input_stride(), output_stride(),
112 0, &hardswish_op);
113 if (status == xnn_status_unsupported_hardware) {
114 GTEST_SKIP();
115 }
116 ASSERT_NE(nullptr, hardswish_op);
117
118 // Smart pointer to automatically delete hardswish_op.
119 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator);
120
121 ASSERT_EQ(xnn_status_success,
122 xnn_setup_hardswish_nc_f16(
123 hardswish_op,
124 batch_size(),
125 input.data(), output.data(),
126 nullptr /* thread pool */));
127
128 ASSERT_EQ(xnn_status_success,
129 xnn_run_operator(hardswish_op, nullptr /* thread pool */));
130
131 // Verify results.
132 for (size_t i = 0; i < batch_size(); i++) {
133 for (size_t c = 0; c < channels(); c++) {
Frank Barchard7d2c1f22020-09-14 16:43:53 -0700134 ASSERT_NEAR(fp16_ieee_to_fp32_value(output[i * output_stride() + c]), output_ref[i * channels() + c], std::max(1.0e-3f, std::abs(output_ref[i * channels() + c]) * 1.0e-2f))
Frank Barcharda96948e2020-09-11 15:34:18 -0700135 << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
136 }
137 }
138 }
139 }
140
XNNPACK Teamb455b122019-09-27 18:10:33 -0700141 void TestF32() const {
142 std::random_device random_device;
143 auto rng = std::mt19937(random_device());
Frank Barchardaed32782020-09-30 03:04:35 -0700144 auto f32rng = std::bind(std::uniform_real_distribution<float>(-4.0f, 4.0f), rng);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700145
146 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
147 (batch_size() - 1) * input_stride() + channels());
148 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
149 std::vector<float> output_ref(batch_size() * channels());
150 for (size_t iteration = 0; iteration < iterations(); iteration++) {
151 std::generate(input.begin(), input.end(), std::ref(f32rng));
152 std::fill(output.begin(), output.end(), std::nanf(""));
153
154 // Compute reference results.
155 for (size_t i = 0; i < batch_size(); i++) {
156 for (size_t c = 0; c < channels(); c++) {
157 const float x = input[i * input_stride() + c];
158 const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
159 output_ref[i * channels() + c] = y;
160 }
161 }
162
163 // Create, setup, run, and destroy HardSwish operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800164 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700165 xnn_operator_t hardswish_op = nullptr;
166
167 ASSERT_EQ(xnn_status_success,
168 xnn_create_hardswish_nc_f32(
169 channels(), input_stride(), output_stride(),
170 0, &hardswish_op));
171 ASSERT_NE(nullptr, hardswish_op);
172
173 // Smart pointer to automatically delete hardswish_op.
174 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator);
175
176 ASSERT_EQ(xnn_status_success,
177 xnn_setup_hardswish_nc_f32(
178 hardswish_op,
179 batch_size(),
180 input.data(), output.data(),
181 nullptr /* thread pool */));
182
183 ASSERT_EQ(xnn_status_success,
184 xnn_run_operator(hardswish_op, nullptr /* thread pool */));
185
186 // Verify results.
187 for (size_t i = 0; i < batch_size(); i++) {
188 for (size_t c = 0; c < channels(); c++) {
Frank Barchardaed32782020-09-30 03:04:35 -0700189 ASSERT_NEAR(output_ref[i * channels() + c], output[i * output_stride() + c], std::max(1.0e-7f, std::abs(output[i * output_stride() + c]) * 1.0e-6f))
XNNPACK Teamb455b122019-09-27 18:10:33 -0700190 << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
191 }
192 }
193 }
194 }
195
196 private:
197 size_t batch_size_{1};
198 size_t channels_{1};
199 size_t input_stride_{0};
200 size_t output_stride_{0};
201 size_t iterations_{15};
202};