blob: 293ca2667228d279268d2d8a5b82439463550c8d [file] [log] [blame]
Marat Dukhan5020b962020-06-08 13:30:10 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19
20
21class SquareOperatorTester {
22 public:
23 inline SquareOperatorTester& channels(size_t channels) {
24 assert(channels != 0);
25 this->channels_ = channels;
26 return *this;
27 }
28
29 inline size_t channels() const {
30 return this->channels_;
31 }
32
33 inline SquareOperatorTester& input_stride(size_t input_stride) {
34 assert(input_stride != 0);
35 this->input_stride_ = input_stride;
36 return *this;
37 }
38
39 inline size_t input_stride() const {
40 if (this->input_stride_ == 0) {
41 return this->channels_;
42 } else {
43 assert(this->input_stride_ >= this->channels_);
44 return this->input_stride_;
45 }
46 }
47
48 inline SquareOperatorTester& output_stride(size_t output_stride) {
49 assert(output_stride != 0);
50 this->output_stride_ = output_stride;
51 return *this;
52 }
53
54 inline size_t output_stride() const {
55 if (this->output_stride_ == 0) {
56 return this->channels_;
57 } else {
58 assert(this->output_stride_ >= this->channels_);
59 return this->output_stride_;
60 }
61 }
62
63 inline SquareOperatorTester& batch_size(size_t batch_size) {
64 assert(batch_size != 0);
65 this->batch_size_ = batch_size;
66 return *this;
67 }
68
69 inline size_t batch_size() const {
70 return this->batch_size_;
71 }
72
73 inline SquareOperatorTester& iterations(size_t iterations) {
74 this->iterations_ = iterations;
75 return *this;
76 }
77
78 inline size_t iterations() const {
79 return this->iterations_;
80 }
81
82 void TestF32() const {
83 std::random_device random_device;
84 auto rng = std::mt19937(random_device());
85 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
86
87 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
88 (batch_size() - 1) * input_stride() + channels());
89 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
90 std::vector<float> output_ref(batch_size() * channels());
91 for (size_t iteration = 0; iteration < iterations(); iteration++) {
92 std::generate(input.begin(), input.end(), std::ref(f32rng));
93 std::fill(output.begin(), output.end(), std::nanf(""));
94
95 // Compute reference results.
96 for (size_t i = 0; i < batch_size(); i++) {
97 for (size_t c = 0; c < channels(); c++) {
98 const float value = input[i * input_stride() + c];
99 output_ref[i * channels() + c] = value * value;
100 }
101 }
102
103 // Create, setup, run, and destroy Square operator.
104 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
105 xnn_operator_t square_op = nullptr;
106
107 ASSERT_EQ(xnn_status_success,
108 xnn_create_square_nc_f32(
109 channels(), input_stride(), output_stride(),
110 0, &square_op));
111 ASSERT_NE(nullptr, square_op);
112
113 // Smart pointer to automatically delete square_op.
114 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_square_op(square_op, xnn_delete_operator);
115
116 ASSERT_EQ(xnn_status_success,
117 xnn_setup_square_nc_f32(
118 square_op,
119 batch_size(),
120 input.data(), output.data(),
121 nullptr /* thread pool */));
122
123 ASSERT_EQ(xnn_status_success,
124 xnn_run_operator(square_op, nullptr /* thread pool */));
125
126 // Verify results.
127 for (size_t i = 0; i < batch_size(); i++) {
128 for (size_t c = 0; c < channels(); c++) {
129 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
130 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
131 }
132 }
133 }
134 }
135
136 private:
137 size_t batch_size_{1};
138 size_t channels_{1};
139 size_t input_stride_{0};
140 size_t output_stride_{0};
141 size_t iterations_{15};
142};