blob: f405764fa84e845c22f51726fa4daef380876804 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
Marat Dukhan629a33e2019-10-01 10:39:14 -070012#include <cmath>
XNNPACK Teamb455b122019-09-27 18:10:33 -070013#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <random>
17#include <vector>
18
Frank Barchardb1966592020-05-12 13:47:06 -070019#include <fp16.h>
20
XNNPACK Teamb455b122019-09-27 18:10:33 -070021#include <xnnpack.h>
22#include <xnnpack/AlignedAllocator.h>
Frank Barcharde0601b52019-10-25 17:43:34 -070023#include <xnnpack/params.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -070024
25
26class PReLUMicrokernelTester {
27 public:
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080028 inline PReLUMicrokernelTester& rows(size_t rows) {
29 assert(rows != 0);
30 this->rows_ = rows;
XNNPACK Teamb455b122019-09-27 18:10:33 -070031 return *this;
32 }
33
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080034 inline size_t rows() const {
35 return this->rows_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070036 }
37
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080038 inline PReLUMicrokernelTester& channels(size_t channels) {
39 assert(channels != 0);
40 this->channels_ = channels;
XNNPACK Teamb455b122019-09-27 18:10:33 -070041 return *this;
42 }
43
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080044 inline size_t channels() const {
45 return this->channels_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070046 }
47
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080048 inline PReLUMicrokernelTester& input_stride(size_t input_stride) {
49 assert(input_stride != 0);
50 this->input_stride_ = input_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -070051 return *this;
52 }
53
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080054 inline size_t input_stride() const {
55 if (this->input_stride_ == 0) {
56 return channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -070057 } else {
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080058 assert(this->input_stride_ >= channels());
59 return this->input_stride_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070060 }
61 }
62
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080063 inline PReLUMicrokernelTester& output_stride(size_t output_stride) {
64 assert(output_stride != 0);
65 this->output_stride_ = output_stride;
XNNPACK Teamb455b122019-09-27 18:10:33 -070066 return *this;
67 }
68
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080069 inline size_t output_stride() const {
70 if (this->output_stride_ == 0) {
71 return channels();
XNNPACK Teamb455b122019-09-27 18:10:33 -070072 } else {
Marat Dukhan69c3f2c2019-11-06 12:30:01 -080073 assert(this->output_stride_ >= channels());
74 return this->output_stride_;
XNNPACK Teamb455b122019-09-27 18:10:33 -070075 }
76 }
77
78 inline PReLUMicrokernelTester& inplace(bool inplace) {
79 this->inplace_ = inplace;
80 return *this;
81 }
82
83 inline bool inplace() const {
84 return this->inplace_;
85 }
86
XNNPACK Teamb455b122019-09-27 18:10:33 -070087 inline PReLUMicrokernelTester& iterations(size_t iterations) {
88 this->iterations_ = iterations;
89 return *this;
90 }
91
92 inline size_t iterations() const {
93 return this->iterations_;
94 }
95
Frank Barchardb1966592020-05-12 13:47:06 -070096 void Test(xnn_f16_prelu_ukernel_function prelu) const {
97 std::random_device random_device;
98 auto rng = std::mt19937(random_device());
99 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
100 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
101 auto f16irng = std::bind(fp16_ieee_from_fp32_value, f32irng);
102 auto f16wrng = std::bind(fp16_ieee_from_fp32_value, f32wrng);
103
104 std::vector<uint16_t> x(channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(uint16_t));
105 std::vector<uint16_t, AlignedAllocator<uint16_t, 64>> w(channels() + XNN_EXTRA_BYTES / sizeof(uint16_t));
106 std::vector<uint16_t> y(channels() + (rows() - 1) * output_stride() + XNN_EXTRA_BYTES / sizeof(uint16_t));
107 std::vector<float> y_ref(channels() * rows());
108 for (size_t iteration = 0; iteration < iterations(); iteration++) {
109 std::generate(x.begin(), x.end(), std::ref(f16irng));
110 std::generate(w.begin(), w.end(), std::ref(f16wrng));
111 if (inplace()) {
112 std::generate(y.begin(), y.end(), std::ref(f16irng));
113 } else {
114 std::fill(y.begin(), y.end(), UINT16_C(0x7E00) /* NaN */);
115 }
116 const uint16_t* x_data = inplace() ? y.data() : x.data();
117
118 // Compute reference results, without clamping.
119 for (size_t n = 0; n < rows(); n++) {
120 for (size_t c = 0; c < channels(); c++) {
121 const float x_value = fp16_ieee_to_fp32_value(x_data[n * input_stride() + c]);
Marat Dukhanf870d042020-06-10 09:29:52 -0700122 y_ref[n * channels() + c] = std::signbit(x_value) ?
123 fp16_ieee_to_fp32_value(fp16_ieee_from_fp32_value(x_value * fp16_ieee_to_fp32_value(w[c]))) : x_value;
Frank Barchardb1966592020-05-12 13:47:06 -0700124 }
125 }
126
127 // Call optimized micro-kernel.
128 prelu(rows(), channels() * sizeof(uint16_t),
129 x_data, input_stride() * sizeof(uint16_t),
130 w.data(),
131 y.data(), output_stride() * sizeof(uint16_t));
132
133 // Verify results.
134 for (size_t n = 0; n < rows(); n++) {
135 for (size_t c = 0; c < channels(); c++) {
Marat Dukhanf870d042020-06-10 09:29:52 -0700136 ASSERT_EQ(fp16_ieee_to_fp32_value(y[n * output_stride() + c]), y_ref[n * channels() + c])
Frank Barchardb1966592020-05-12 13:47:06 -0700137 << "at row " << n << " / " << rows()
138 << ", channel " << c << " / " << channels();
139 }
140 }
141 }
142 }
143
Marat Dukhanc8230a42020-02-24 00:00:35 -0800144 void Test(xnn_f32_prelu_ukernel_function prelu) const {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700145 std::random_device random_device;
146 auto rng = std::mt19937(random_device());
147 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
148 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
149
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800150 std::vector<float> x(channels() + (rows() - 1) * input_stride() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhan9594db02019-12-05 14:32:37 -0800151 std::vector<float, AlignedAllocator<float, 64>> w(channels() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800152 std::vector<float> y(channels() + (rows() - 1) * output_stride() + XNN_EXTRA_BYTES / sizeof(float));
Marat Dukhand9e92eb2020-03-11 04:10:57 -0700153 std::vector<float> y_ref(channels() * rows());
XNNPACK Teamb455b122019-09-27 18:10:33 -0700154 for (size_t iteration = 0; iteration < iterations(); iteration++) {
155 std::generate(x.begin(), x.end(), std::ref(f32irng));
156 std::generate(w.begin(), w.end(), std::ref(f32wrng));
157 if (inplace()) {
158 std::generate(y.begin(), y.end(), std::ref(f32irng));
159 } else {
160 std::fill(y.begin(), y.end(), nanf(""));
161 }
162 const float* x_data = inplace() ? y.data() : x.data();
163
164 // Compute reference results, without clamping.
Marat Dukhand9e92eb2020-03-11 04:10:57 -0700165 for (size_t n = 0; n < rows(); n++) {
166 for (size_t c = 0; c < channels(); c++) {
167 const float x_value = x_data[n * input_stride() + c];
168 y_ref[n * channels() + c] = std::signbit(x_value) ? x_value * w[c] : x_value;
169 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700170 }
171
XNNPACK Teamb455b122019-09-27 18:10:33 -0700172 // Call optimized micro-kernel.
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800173 prelu(rows(), channels() * sizeof(float),
174 x_data, input_stride() * sizeof(float),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700175 w.data(),
Marat Dukhanc8230a42020-02-24 00:00:35 -0800176 y.data(), output_stride() * sizeof(float));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700177
178 // Verify results.
Marat Dukhand9e92eb2020-03-11 04:10:57 -0700179 for (size_t n = 0; n < rows(); n++) {
180 for (size_t c = 0; c < channels(); c++) {
Marat Dukhanf870d042020-06-10 09:29:52 -0700181 ASSERT_EQ(y[n * output_stride() + c], y_ref[n * channels() + c])
Marat Dukhand9e92eb2020-03-11 04:10:57 -0700182 << "at row " << n << " / " << rows()
183 << ", channel " << c << " / " << channels();
184 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700185 }
186 }
187 }
188
189 private:
Marat Dukhan69c3f2c2019-11-06 12:30:01 -0800190 size_t rows_{1};
191 size_t channels_{1};
192 size_t input_stride_{0};
193 size_t output_stride_{0};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700194 bool inplace_{false};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700195 size_t iterations_{15};
196};