blob: c0d65ea22b194b5c9819b890635c37adbe0ecf14 [file] [log] [blame]
Marat Dukhan5020b962020-06-08 13:30:10 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cstddef>
13#include <cstdlib>
14#include <functional>
15#include <random>
16#include <vector>
17
18#include <xnnpack.h>
19
20
21class NegateOperatorTester {
22 public:
23 inline NegateOperatorTester& channels(size_t channels) {
24 assert(channels != 0);
25 this->channels_ = channels;
26 return *this;
27 }
28
29 inline size_t channels() const {
30 return this->channels_;
31 }
32
33 inline NegateOperatorTester& input_stride(size_t input_stride) {
34 assert(input_stride != 0);
35 this->input_stride_ = input_stride;
36 return *this;
37 }
38
39 inline size_t input_stride() const {
40 if (this->input_stride_ == 0) {
41 return this->channels_;
42 } else {
43 assert(this->input_stride_ >= this->channels_);
44 return this->input_stride_;
45 }
46 }
47
48 inline NegateOperatorTester& output_stride(size_t output_stride) {
49 assert(output_stride != 0);
50 this->output_stride_ = output_stride;
51 return *this;
52 }
53
54 inline size_t output_stride() const {
55 if (this->output_stride_ == 0) {
56 return this->channels_;
57 } else {
58 assert(this->output_stride_ >= this->channels_);
59 return this->output_stride_;
60 }
61 }
62
63 inline NegateOperatorTester& batch_size(size_t batch_size) {
64 assert(batch_size != 0);
65 this->batch_size_ = batch_size;
66 return *this;
67 }
68
69 inline size_t batch_size() const {
70 return this->batch_size_;
71 }
72
73 inline NegateOperatorTester& iterations(size_t iterations) {
74 this->iterations_ = iterations;
75 return *this;
76 }
77
78 inline size_t iterations() const {
79 return this->iterations_;
80 }
81
82 void TestF32() const {
83 std::random_device random_device;
84 auto rng = std::mt19937(random_device());
85 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
86
87 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
88 (batch_size() - 1) * input_stride() + channels());
89 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
90 std::vector<float> output_ref(batch_size() * channels());
91 for (size_t iteration = 0; iteration < iterations(); iteration++) {
92 std::generate(input.begin(), input.end(), std::ref(f32rng));
93 std::fill(output.begin(), output.end(), std::nanf(""));
94
95 // Compute reference results.
96 for (size_t i = 0; i < batch_size(); i++) {
97 for (size_t c = 0; c < channels(); c++) {
98 output_ref[i * channels() + c] = -input[i * input_stride() + c];
99 }
100 }
101
102 // Create, setup, run, and destroy Negate operator.
103 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
104 xnn_operator_t negate_op = nullptr;
105
106 ASSERT_EQ(xnn_status_success,
107 xnn_create_negate_nc_f32(
108 channels(), input_stride(), output_stride(),
109 0, &negate_op));
110 ASSERT_NE(nullptr, negate_op);
111
112 // Smart pointer to automatically delete negate_op.
113 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_negate_op(negate_op, xnn_delete_operator);
114
115 ASSERT_EQ(xnn_status_success,
116 xnn_setup_negate_nc_f32(
117 negate_op,
118 batch_size(),
119 input.data(), output.data(),
120 nullptr /* thread pool */));
121
122 ASSERT_EQ(xnn_status_success,
123 xnn_run_operator(negate_op, nullptr /* thread pool */));
124
125 // Verify results.
126 for (size_t i = 0; i < batch_size(); i++) {
127 for (size_t c = 0; c < channels(); c++) {
128 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
129 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
130 }
131 }
132 }
133 }
134
135 private:
136 size_t batch_size_{1};
137 size_t channels_{1};
138 size_t input_stride_{0};
139 size_t output_stride_{0};
140 size_t iterations_{15};
141};