blob: d155c18943a66f5f3ece2bf52b21e93906e5e936 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19
20
21class HardSwishOperatorTester {
22 public:
23 inline HardSwishOperatorTester& channels(size_t channels) {
24 assert(channels != 0);
25 this->channels_ = channels;
26 return *this;
27 }
28
29 inline size_t channels() const {
30 return this->channels_;
31 }
32
33 inline HardSwishOperatorTester& input_stride(size_t input_stride) {
34 assert(input_stride != 0);
35 this->input_stride_ = input_stride;
36 return *this;
37 }
38
39 inline size_t input_stride() const {
40 if (this->input_stride_ == 0) {
41 return this->channels_;
42 } else {
43 assert(this->input_stride_ >= this->channels_);
44 return this->input_stride_;
45 }
46 }
47
48 inline HardSwishOperatorTester& output_stride(size_t output_stride) {
49 assert(output_stride != 0);
50 this->output_stride_ = output_stride;
51 return *this;
52 }
53
54 inline size_t output_stride() const {
55 if (this->output_stride_ == 0) {
56 return this->channels_;
57 } else {
58 assert(this->output_stride_ >= this->channels_);
59 return this->output_stride_;
60 }
61 }
62
63 inline HardSwishOperatorTester& batch_size(size_t batch_size) {
64 assert(batch_size != 0);
65 this->batch_size_ = batch_size;
66 return *this;
67 }
68
69 inline size_t batch_size() const {
70 return this->batch_size_;
71 }
72
73 inline HardSwishOperatorTester& iterations(size_t iterations) {
74 this->iterations_ = iterations;
75 return *this;
76 }
77
78 inline size_t iterations() const {
79 return this->iterations_;
80 }
81
82 void TestF32() const {
83 std::random_device random_device;
84 auto rng = std::mt19937(random_device());
85 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
86
87 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
88 (batch_size() - 1) * input_stride() + channels());
89 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
90 std::vector<float> output_ref(batch_size() * channels());
91 for (size_t iteration = 0; iteration < iterations(); iteration++) {
92 std::generate(input.begin(), input.end(), std::ref(f32rng));
93 std::fill(output.begin(), output.end(), std::nanf(""));
94
95 // Compute reference results.
96 for (size_t i = 0; i < batch_size(); i++) {
97 for (size_t c = 0; c < channels(); c++) {
98 const float x = input[i * input_stride() + c];
99 const float y = x * std::min(std::max(x + 3.0f, 0.0f), 6.0f) / 6.0f;
100 output_ref[i * channels() + c] = y;
101 }
102 }
103
104 // Create, setup, run, and destroy HardSwish operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800105 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700106 xnn_operator_t hardswish_op = nullptr;
107
108 ASSERT_EQ(xnn_status_success,
109 xnn_create_hardswish_nc_f32(
110 channels(), input_stride(), output_stride(),
111 0, &hardswish_op));
112 ASSERT_NE(nullptr, hardswish_op);
113
114 // Smart pointer to automatically delete hardswish_op.
115 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_hardswish_op(hardswish_op, xnn_delete_operator);
116
117 ASSERT_EQ(xnn_status_success,
118 xnn_setup_hardswish_nc_f32(
119 hardswish_op,
120 batch_size(),
121 input.data(), output.data(),
122 nullptr /* thread pool */));
123
124 ASSERT_EQ(xnn_status_success,
125 xnn_run_operator(hardswish_op, nullptr /* thread pool */));
126
127 // Verify results.
128 for (size_t i = 0; i < batch_size(); i++) {
129 for (size_t c = 0; c < channels(); c++) {
130 ASSERT_NEAR(output_ref[i * channels() + c], output[i * output_stride() + c], std::abs(output[i * output_stride() + c]) * 1.0e-6f)
131 << "at position " << i << ", batch size = " << batch_size() << ", channels = " << channels();
132 }
133 }
134 }
135 }
136
137 private:
138 size_t batch_size_{1};
139 size_t channels_{1};
140 size_t input_stride_{0};
141 size_t output_stride_{0};
142 size_t iterations_{15};
143};