blob: 4a63259af81889e75f5be092c1795758cfaa651a [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cmath>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19
20
21class PReLUOperatorTester {
22 public:
23 inline PReLUOperatorTester& batch_size(size_t batch_size) {
24 assert(batch_size != 0);
25 this->batch_size_ = batch_size;
26 return *this;
27 }
28
29 inline size_t batch_size() const {
30 return this->batch_size_;
31 }
32
33 inline PReLUOperatorTester& channels(size_t channels) {
34 assert(channels != 0);
35 this->channels_ = channels;
36 return *this;
37 }
38
39 inline size_t channels() const {
40 return this->channels_;
41 }
42
43 inline PReLUOperatorTester& x_stride(size_t x_stride) {
44 assert(x_stride != 0);
45 this->x_stride_ = x_stride;
46 return *this;
47 }
48
49 inline size_t x_stride() const {
50 if (this->x_stride_ == 0) {
51 return this->channels_;
52 } else {
53 assert(this->x_stride_ >= this->channels_);
54 return this->x_stride_;
55 }
56 }
57
58 inline PReLUOperatorTester& y_stride(size_t y_stride) {
59 assert(y_stride != 0);
60 this->y_stride_ = y_stride;
61 return *this;
62 }
63
64 inline size_t y_stride() const {
65 if (this->y_stride_ == 0) {
66 return this->channels_;
67 } else {
68 assert(this->y_stride_ >= this->channels_);
69 return this->y_stride_;
70 }
71 }
72
73 inline PReLUOperatorTester& qmin(uint8_t qmin) {
74 this->qmin_ = qmin;
75 return *this;
76 }
77
78 inline uint8_t qmin() const {
79 return this->qmin_;
80 }
81
82 inline PReLUOperatorTester& qmax(uint8_t qmax) {
83 this->qmax_ = qmax;
84 return *this;
85 }
86
87 inline uint8_t qmax() const {
88 return this->qmax_;
89 }
90
91 inline PReLUOperatorTester& iterations(size_t iterations) {
92 this->iterations_ = iterations;
93 return *this;
94 }
95
96 inline size_t iterations() const {
97 return this->iterations_;
98 }
99
100 void TestF32() const {
101 std::random_device random_device;
102 auto rng = std::mt19937(random_device());
103 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
104 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
105
106 std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
107 std::vector<float> w(channels());
108 std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
109 std::vector<float> y_ref(batch_size() * channels());
110 for (size_t iteration = 0; iteration < iterations(); iteration++) {
111 std::generate(x.begin(), x.end(), std::ref(f32irng));
112 std::generate(w.begin(), w.end(), std::ref(f32wrng));
113 std::fill(y.begin(), y.end(), nanf(""));
114
115 // Compute reference results, without clamping.
116 for (size_t i = 0; i < batch_size(); i++) {
117 for (size_t c = 0; c < channels(); c++) {
Marat Dukhan629a33e2019-10-01 10:39:14 -0700118 y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700119 }
120 }
121
122 // Compute clamping parameters.
123 const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend());
124 const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend());
125 const float accumulated_range = accumulated_max - accumulated_min;
126 const float y_min = accumulated_range == 0.0f ?
127 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
128 const float y_max = accumulated_range == 0.0f ?
129 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
130
131 // Clamp reference results.
132 for (float& value : y_ref) {
133 value = std::min(std::max(value, y_min), y_max);
134 }
135
136 // Create, setup, run, and destroy PReLU operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800137 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700138 xnn_operator_t prelu_op = nullptr;
139
140 ASSERT_EQ(xnn_status_success,
141 xnn_create_prelu_nc_f32(
142 channels(), x_stride(), y_stride(),
143 w.data(),
144 y_min, y_max,
145 0, &prelu_op));
146 ASSERT_NE(nullptr, prelu_op);
147
148 // Smart pointer to automatically delete prelu_op.
149 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
150
151 ASSERT_EQ(xnn_status_success,
152 xnn_setup_prelu_nc_f32(
153 prelu_op,
154 batch_size(),
155 x.data(), y.data(),
156 nullptr /* thread pool */));
157
158 ASSERT_EQ(xnn_status_success,
159 xnn_run_operator(prelu_op, nullptr /* thread pool */));
160
161 // Verify results.
162 for (size_t i = 0; i < batch_size(); i++) {
163 for (size_t c = 0; c < channels(); c++) {
164 ASSERT_LE(y[i * y_stride() + c], y_max)
165 << "i = " << i << ", c = " << c;
166 ASSERT_GE(y[i * y_stride() + c], y_min)
167 << "i = " << i << ", c = " << c;
168 ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c]))
169 << "i = " << i << ", c = " << c;
170 }
171 }
172 }
173 }
174
175 private:
176 size_t batch_size_{1};
177 size_t channels_{1};
178 size_t x_stride_{0};
179 size_t y_stride_{0};
180 uint8_t qmin_{0};
181 uint8_t qmax_{255};
182 size_t iterations_{15};
183};