blob: 47090a622100c07f3105c8f8796b452917126b69 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001// Copyright 2019 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cmath>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19
20
21class PReLUOperatorTester {
22 public:
23 inline PReLUOperatorTester& batch_size(size_t batch_size) {
24 assert(batch_size != 0);
25 this->batch_size_ = batch_size;
26 return *this;
27 }
28
29 inline size_t batch_size() const {
30 return this->batch_size_;
31 }
32
33 inline PReLUOperatorTester& channels(size_t channels) {
34 assert(channels != 0);
35 this->channels_ = channels;
36 return *this;
37 }
38
39 inline size_t channels() const {
40 return this->channels_;
41 }
42
43 inline PReLUOperatorTester& x_stride(size_t x_stride) {
44 assert(x_stride != 0);
45 this->x_stride_ = x_stride;
46 return *this;
47 }
48
49 inline size_t x_stride() const {
50 if (this->x_stride_ == 0) {
51 return this->channels_;
52 } else {
53 assert(this->x_stride_ >= this->channels_);
54 return this->x_stride_;
55 }
56 }
57
58 inline PReLUOperatorTester& y_stride(size_t y_stride) {
59 assert(y_stride != 0);
60 this->y_stride_ = y_stride;
61 return *this;
62 }
63
64 inline size_t y_stride() const {
65 if (this->y_stride_ == 0) {
66 return this->channels_;
67 } else {
68 assert(this->y_stride_ >= this->channels_);
69 return this->y_stride_;
70 }
71 }
72
XNNPACK Teamb455b122019-09-27 18:10:33 -070073 inline PReLUOperatorTester& iterations(size_t iterations) {
74 this->iterations_ = iterations;
75 return *this;
76 }
77
78 inline size_t iterations() const {
79 return this->iterations_;
80 }
81
82 void TestF32() const {
83 std::random_device random_device;
84 auto rng = std::mt19937(random_device());
85 auto f32irng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
86 auto f32wrng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.75f), rng);
87
88 std::vector<float> x((batch_size() - 1) * x_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
89 std::vector<float> w(channels());
90 std::vector<float> y((batch_size() - 1) * y_stride() + channels() + XNN_EXTRA_BYTES / sizeof(float));
91 std::vector<float> y_ref(batch_size() * channels());
92 for (size_t iteration = 0; iteration < iterations(); iteration++) {
93 std::generate(x.begin(), x.end(), std::ref(f32irng));
94 std::generate(w.begin(), w.end(), std::ref(f32wrng));
95 std::fill(y.begin(), y.end(), nanf(""));
96
97 // Compute reference results, without clamping.
98 for (size_t i = 0; i < batch_size(); i++) {
99 for (size_t c = 0; c < channels(); c++) {
Marat Dukhan629a33e2019-10-01 10:39:14 -0700100 y_ref[i * channels() + c] = std::signbit(x[i * x_stride() + c]) ? x[i * x_stride() + c] * w[c] : x[i * x_stride() + c];
XNNPACK Teamb455b122019-09-27 18:10:33 -0700101 }
102 }
103
XNNPACK Teamb455b122019-09-27 18:10:33 -0700104 // Create, setup, run, and destroy PReLU operator.
Marat Dukhan04f03be2019-11-19 12:36:47 -0800105 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
XNNPACK Teamb455b122019-09-27 18:10:33 -0700106 xnn_operator_t prelu_op = nullptr;
107
108 ASSERT_EQ(xnn_status_success,
109 xnn_create_prelu_nc_f32(
110 channels(), x_stride(), y_stride(),
111 w.data(),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700112 0, &prelu_op));
113 ASSERT_NE(nullptr, prelu_op);
114
115 // Smart pointer to automatically delete prelu_op.
116 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_prelu_op(prelu_op, xnn_delete_operator);
117
118 ASSERT_EQ(xnn_status_success,
119 xnn_setup_prelu_nc_f32(
120 prelu_op,
121 batch_size(),
122 x.data(), y.data(),
123 nullptr /* thread pool */));
124
125 ASSERT_EQ(xnn_status_success,
126 xnn_run_operator(prelu_op, nullptr /* thread pool */));
127
128 // Verify results.
129 for (size_t i = 0; i < batch_size(); i++) {
130 for (size_t c = 0; c < channels(); c++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700131 ASSERT_NEAR(y[i * y_stride() + c], y_ref[i * channels() + c], 1.0e-6f * std::abs(y_ref[i * channels() + c]))
132 << "i = " << i << ", c = " << c;
133 }
134 }
135 }
136 }
137
138 private:
139 size_t batch_size_{1};
140 size_t channels_{1};
141 size_t x_stride_{0};
142 size_t y_stride_{0};
XNNPACK Teamb455b122019-09-27 18:10:33 -0700143 size_t iterations_{15};
144};