blob: fe62153d0f169bea77fd0cc0d97447a5cf394ea6 [file] [log] [blame]
Marat Dukhan64e52512020-06-09 13:41:16 -07001// Copyright 2020 Google LLC
2//
3// This source code is licensed under the BSD-style license found in the
4// LICENSE file in the root directory of this source tree.
5
6#pragma once
7
8#include <gtest/gtest.h>
9
10#include <algorithm>
11#include <cassert>
12#include <cmath>
13#include <cstddef>
14#include <cstdlib>
15#include <functional>
16#include <random>
17#include <vector>
18
19#include <xnnpack.h>
20
21
22class TruncationOperatorTester {
23 public:
24 inline TruncationOperatorTester& channels(size_t channels) {
25 assert(channels != 0);
26 this->channels_ = channels;
27 return *this;
28 }
29
30 inline size_t channels() const {
31 return this->channels_;
32 }
33
34 inline TruncationOperatorTester& input_stride(size_t input_stride) {
35 assert(input_stride != 0);
36 this->input_stride_ = input_stride;
37 return *this;
38 }
39
40 inline size_t input_stride() const {
41 if (this->input_stride_ == 0) {
42 return this->channels_;
43 } else {
44 assert(this->input_stride_ >= this->channels_);
45 return this->input_stride_;
46 }
47 }
48
49 inline TruncationOperatorTester& output_stride(size_t output_stride) {
50 assert(output_stride != 0);
51 this->output_stride_ = output_stride;
52 return *this;
53 }
54
55 inline size_t output_stride() const {
56 if (this->output_stride_ == 0) {
57 return this->channels_;
58 } else {
59 assert(this->output_stride_ >= this->channels_);
60 return this->output_stride_;
61 }
62 }
63
64 inline TruncationOperatorTester& batch_size(size_t batch_size) {
65 assert(batch_size != 0);
66 this->batch_size_ = batch_size;
67 return *this;
68 }
69
70 inline size_t batch_size() const {
71 return this->batch_size_;
72 }
73
74 inline TruncationOperatorTester& iterations(size_t iterations) {
75 this->iterations_ = iterations;
76 return *this;
77 }
78
79 inline size_t iterations() const {
80 return this->iterations_;
81 }
82
83 void TestF32() const {
84 std::random_device random_device;
85 auto rng = std::mt19937(random_device());
86 auto f32rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f), rng);
87
88 std::vector<float> input(XNN_EXTRA_BYTES / sizeof(float) +
89 (batch_size() - 1) * input_stride() + channels());
90 std::vector<float> output((batch_size() - 1) * output_stride() + channels());
91 std::vector<float> output_ref(batch_size() * channels());
92 for (size_t iteration = 0; iteration < iterations(); iteration++) {
93 std::generate(input.begin(), input.end(), std::ref(f32rng));
94 std::fill(output.begin(), output.end(), std::nanf(""));
95
96 // Compute reference results.
97 for (size_t i = 0; i < batch_size(); i++) {
98 for (size_t c = 0; c < channels(); c++) {
99 output_ref[i * channels() + c] = std::trunc(input[i * input_stride() + c]);
100 }
101 }
102
103 // Create, setup, run, and destroy Truncation operator.
104 ASSERT_EQ(xnn_status_success, xnn_initialize(nullptr /* allocator */));
105 xnn_operator_t truncation_op = nullptr;
106
107 ASSERT_EQ(xnn_status_success,
108 xnn_create_truncation_nc_f32(
109 channels(), input_stride(), output_stride(),
110 0, &truncation_op));
111 ASSERT_NE(nullptr, truncation_op);
112
113 // Smart pointer to automatically delete truncation_op.
114 std::unique_ptr<xnn_operator, decltype(&xnn_delete_operator)> auto_truncation_op(truncation_op, xnn_delete_operator);
115
116 ASSERT_EQ(xnn_status_success,
117 xnn_setup_truncation_nc_f32(
118 truncation_op,
119 batch_size(),
120 input.data(), output.data(),
121 nullptr /* thread pool */));
122
123 ASSERT_EQ(xnn_status_success,
124 xnn_run_operator(truncation_op, nullptr /* thread pool */));
125
126 // Verify results.
127 for (size_t i = 0; i < batch_size(); i++) {
128 for (size_t c = 0; c < channels(); c++) {
129 ASSERT_EQ(output_ref[i * channels() + c], output[i * output_stride() + c])
130 << "at batch " << i << " / " << batch_size() << ", channel " << c << " / " << channels();
131 }
132 }
133 }
134 }
135
136 private:
137 size_t batch_size_{1};
138 size_t channels_{1};
139 size_t input_stride_{0};
140 size_t output_stride_{0};
141 size_t iterations_{15};
142};