blob: a673b62de7921c41eda61585c02073cad31bb9c4 [file] [log] [blame]
Marat Dukhanb6bd4bc2020-12-01 17:01:40 -08001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20
21
22class ELUOperatorTester {
23 public:
24 inline ELUOperatorTester& channels(size_t channels) {
25 assert(channels != 0);
26 this->channels_ = channels;
27 return *this;
28 }
29
30 inline size_t channels() const {
31 return this->channels_;
32 }
33
34 inline ELUOperatorTester& input_stride(size_t input_stride) {
35 assert(input_stride != 0);
36 this->input_stride_ = input_stride;
37 return *this;
38 }
39
40 inline size_t input_stride() const {
41 if (this->input_stride_ == 0) {
42 return this->channels_;
43 } else {
44 assert(this->input_stride_ >= this->channels_);
45 return this->input_stride_;
46 }
47 }
48
49 inline ELUOperatorTester& output_stride(size_t output_stride) {
50 assert(output_stride != 0);
51 this->output_stride_ = output_stride;
52 return *this;
53 }
54
55 inline size_t output_stride() const {
56 if (this->output_stride_ == 0) {
57 return this->channels_;
58 } else {
59 assert(this->output_stride_ >= this->channels_);
60 return this->output_stride_;
61 }
62 }
63
64 inline ELUOperatorTester& batch_size(size_t batch_size) {
65 assert(batch_size != 0);
66 this->batch_size_ = batch_size;
67 return *this;
68 }
69
70 inline size_t batch_size() const {
71 return this->batch_size_;
72 }
73
74 inline ELUOperatorTester& alpha(float alpha) {
75 assert(alpha > 0.0f);
76 assert(alpha < 1.0f);
77 this->alpha_ = alpha;
78 return *this;
79 }
80
81 inline float alpha() const {
82 return this->alpha_;
83 }
84
85 inline ELUOperatorTester& iterations(size_t iterations) {
86 this->iterations_ = iterations;
87 return *this;
88 }
89
90 inline size_t iterations() const {
91 return this->iterations_;
92 }
93
94 void TestF32() const {
95 std::random_device random_device;
96 auto rng = std::mt19937(random_device());
97 auto f32rng = std::bind(std::uniform_real_distribution<float>(-20.0f, 20.0f), std::ref(rng));
98
99 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) + (batch_size() - 1) * input_stride() + channels());
100 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
101 std::vector<double> output_ref(batch_size() * channels());
102 for (size_t iteration = 0; iteration < iterations(); iteration++) {
103 std::generate(input.begin(), input.end(), std::ref(f32rng));
104 std::fill(output.begin(), output.end(), std::nanf(""));
105
106 // Compute reference results.
107 for (size_t i = 0; i < batch_size(); i++) {
108 for (size_t c = 0; c < channels(); c++) {
109 const double x = double(input[i * input_stride() + c]);
110 output_ref[i * channels() + c] = std::signbit(x) ? std::expm1(x) * alpha() : x;
111 }
112 }
113
114 // Create, setup, run, and destroy ELU operator.
115 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
116 xnn_operator_t elu_op = nullptr;
117
118 ASSERT_EQ(xnn_status_success,
119 xnn_create_elu_nc_f32(
120 channels(), input_stride(), output_stride(),
121 alpha(),
122 0, &elu_op));
123 ASSERT_NE(nullptr, elu_op);
124
125 // Smart pointer to automatically delete elu_op.
126 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_elu_op(elu_op, xnn_delete_operator);
127
128 ASSERT_EQ(xnn_status_success,
129 xnn_setup_elu_nc_f32(
130 elu_op,
131 batch_size(),
132 input.data(), output.data(),
133 nullptr /* thread pool */));
134
135 ASSERT_EQ(xnn_status_success,
136 xnn_run_operator(elu_op, nullptr /* thread pool */));
137
138 // Verify results.
139 for (size_t i = 0; i < batch_size(); i++) {
140 for (size_t c = 0; c < channels(); c++) {
141 ASSERT_NEAR(output[i * output_stride() + c],
142 output_ref[i * channels() + c],
143 std::abs(output_ref[i * channels() + c]) * 1.0e-5)
144 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels()
145 << ", input " << input[i * input_stride() + c] << ", alpha " << alpha();
146 }
147 }
148 }
149 }
150
151 private:
152 size_t batch_size_{1};
153 size_t channels_{1};
154 size_t input_stride_{0};
155 size_t output_stride_{0};
156 float alpha_{0.5f};
157 size_t iterations_{15};
158};