blob: d05d89603c7fb02cc5eedd16b6336a77ea48bc59 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
Marat Dukhan629a33e2019-10-01 10:39:14 -070012#include <cmath>
XNNPACK Teamb455b122019-09-27 18:10:33 -070013#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20#include <xnnpack/AlignedAllocator.h>
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -070021#include <xnnpack/params-init.h>
Frank Barcharde0601b52019-10-25 17:43:34 -070022#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070023
24
25class PReLUMicrokernelTester {
26 public:
27 enum class Variant {
28 Native,
29 Scalar,
30 };
31
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080032 inline PReLUMicrokernelTester& rows(size_t rows) {
33 assert(rows != 0);
34 this->rows_ = rows;
XNNPACK Teamb455b122019-09-27 18:10:33 -070035 return *this;
36 }
37
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080038 inline size_t rows() const {
39 return this->rows_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070040 }
41
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080042 inline PReLUMicrokernelTester& channels(size_t channels) {
43 assert(channels != 0);
44 this->channels_ = channels;
XNNPACK Teamb455b122019-09-27 18:10:33 -070045 return *this;
46 }
47
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080048 inline size_t channels() const {
49 return this->channels_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070050 }
51
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080052 inline PReLUMicrokernelTester& input_stride(size_t input_stride) {
53 assert(input_stride != 0);
54 this->input_stride_ = input_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -070055 return *this;
56 }
57
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080058 inline size_t input_stride() const {
59 if (this->input_stride_ == 0) {
60 return channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -070061 } else {
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080062 assert(this->input_stride_ >= channels());
63 return this->input_stride_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070064 }
65 }
66
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080067 inline PReLUMicrokernelTester& output_stride(size_t output_stride) {
68 assert(output_stride != 0);
69 this->output_stride_ = output_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -070070 return *this;
71 }
72
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080073 inline size_t output_stride() const {
74 if (this->output_stride_ == 0) {
75 return channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -070076 } else {
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080077 assert(this->output_stride_ >= channels());
78 return this->output_stride_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070079 }
80 }
81
82 inline PReLUMicrokernelTester& inplace(bool inplace) {
83 this->inplace_ = inplace;
84 return *this;
85 }
86
87 inline bool inplace() const {
88 return this->inplace_;
89 }
90
91 inline PReLUMicrokernelTester& qmin(uint8_t qmin) {
92 this->qmin_ = qmin;
93 return *this;
94 }
95
96 inline uint8_t qmin() const {
97 return this->qmin_;
98 }
99
100 inline PReLUMicrokernelTester& qmax(uint8_t qmax) {
101 this->qmax_ = qmax;
102 return *this;
103 }
104
105 inline uint8_t qmax() const {
106 return this->qmax_;
107 }
108
109 inline PReLUMicrokernelTester& iterations(size_t iterations) {
110 this->iterations_ = iterations;
111 return *this;
112 }
113
114 inline size_t iterations() const {
115 return this->iterations_;
116 }
117
118 void Test(xnn_f32_prelu_ukernel_function prelu, Variant variant = Variant::Native) const {
119 std::random_device random_device;
120 auto rng = std::mt19937(random_device());
121 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
122 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
123
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800124 std::vector<float> x(channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhan9594db02019-12-05 14:32:37 -0800125 std::vector<float, AlignedAllocator<float, 64>> w(channels() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800126 std::vector<float> y(channels() + (rows() - 1) * output_stride() + XNN_EXTRA_BYTES / sizeof(float));
127 std::vector<float> y_ref(channels());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700128 for (size_t iteration = 0; iteration < iterations(); iteration++) {
129 std::generate(x.begin(), x.end(), std::ref(f32irng));
130 std::generate(w.begin(), w.end(), std::ref(f32wrng));
131 if (inplace()) {
132 std::generate(y.begin(), y.end(), std::ref(f32irng));
133 } else {
134 std::fill(y.begin(), y.end(), nanf(""));
135 }
136 const float* x_data = inplace() ? y.data() : x.data();
137
138 // Compute reference results, without clamping.
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800139 for (size_t i = 0; i < channels(); i++) {
Marat Dukhan629a33e2019-10-01 10:39:14 -0700140 y_ref[i] = std::signbit(x_data[i]) ? x_data[i] * w[i] : x_data[i];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700141 }
142
143 // Compute clamping parameters.
144 const float accumulated_min = *std::min_element(y_ref.cbegin(), y_ref.cend());
145 const float accumulated_max = *std::max_element(y_ref.cbegin(), y_ref.cend());
146 const float accumulated_range = accumulated_max - accumulated_min;
147 const float y_min = accumulated_range == 0.0f ?
148 -std::numeric_limits<float>::infinity() : accumulated_min + accumulated_range / 255.0f * float(qmin());
149 const float y_max = accumulated_range == 0.0f ?
150 +std::numeric_limits<float>::infinity() : accumulated_max - accumulated_range / 255.0f * float(255 - qmax());
151
152 // Prepare output parameters.
153 xnn_f32_output_params output_params = { };
154 switch (variant) {
155 case Variant::Native:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700156 output_params = xnn_init_f32_output_params(y_min, y_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700157 break;
158 case Variant::Scalar:
Marat Dukhaneeaa7bd2019-10-25 17:31:25 -0700159 output_params = xnn_init_scalar_f32_output_params(y_min, y_max);
XNNPACK Teamb455b122019-09-27 18:10:33 -0700160 break;
161 }
162
163 // Clamp reference results.
164 for (float& value : y_ref) {
165 value = std::min(std::max(value, y_min), y_max);
166 }
167
168 // Call optimized micro-kernel.
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800169 prelu(rows(), channels() * sizeof(float),
170 x_data, input_stride() * sizeof(float),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700171 w.data(),
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800172 y.data(), output_stride() * sizeof(float),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700173 &output_params);
174
175 // Verify results.
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800176 for (size_t i = 0; i < channels(); i++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700177 ASSERT_LE(y[i], y_max)
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800178 << "at " << i << ", channels = " << channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700179 ASSERT_GE(y[i], y_min)
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800180 << "at " << i << ", channels = " << channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700181 ASSERT_NEAR(y[i], y_ref[i], 1.0e-6f * std::abs(y_ref[i]))
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800182 << "at " << i << ", channels = " << channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -0700183 }
184 }
185 }
186
187 private:
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800188 size_t rows_{1};
189 size_t channels_{1};
190 size_t input_stride_{0};
191 size_t output_stride_{0};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700192 bool inplace_{false};
193 uint8_t qmin_{0};
194 uint8_t qmax_{255};
195 size_t iterations_{15};
196};