blob: 5c39b208f5b84b89ea1bc900fed0b6a265eb512b [file] [log] [blame]
Kaizenbf8b01d2017-10-12 14:26:51 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "tests/validation/FixedPoint.h"
25
26#include "tests/Globals.h"
27#include "tests/framework/Asserts.h"
28#include "tests/framework/Macros.h"
29#include "tests/framework/datasets/Datasets.h"
30#include "tests/validation/Validation.h"
31
32namespace arm_compute
33{
34namespace test
35{
36namespace validation
37{
38namespace
39{
40const auto FuncNamesDataset = framework::dataset::make("FunctionNames", { FixedPointOp::ADD,
41 FixedPointOp::SUB,
42 FixedPointOp::MUL,
43 FixedPointOp::EXP,
44 FixedPointOp::LOG,
45 FixedPointOp::INV_SQRT
46 });
47
48template <typename T>
49void load_array_from_numpy(const std::string &file, std::vector<unsigned long> &shape, std::vector<T> &data) // NOLINT
50{
51 try
52 {
53 npy::LoadArrayFromNumpy(file, shape, data);
54 }
55 catch(const std::runtime_error &e)
56 {
57 throw framework::FileNotFound("Could not load npy file: " + file + " (" + e.what() + ")");
58 }
59}
60} // namespace
61
62TEST_SUITE(UNIT)
63TEST_SUITE(FixedPoint)
64
65// *INDENT-OFF*
66// clang-format off
67DATA_TEST_CASE(FixedPointQS8Inputs, framework::DatasetMode::ALL, combine(
68 FuncNamesDataset,
69 framework::dataset::make("FractionalBits", 1, 7)),
70 func_name, frac_bits)
71// clang-format on
72// *INDENT-ON*
73{
74 std::vector<double> data;
75 std::vector<unsigned long> shape; //NOLINT
76
77 std::string func_name_lower = to_string(func_name);
78 std::transform(func_name_lower.begin(), func_name_lower.end(), func_name_lower.begin(), ::tolower);
79
80 const std::string inputs_file = library->path()
81 + "fixed_point/"
82 + func_name_lower
83 + "_Q8."
84 + support::cpp11::to_string(frac_bits)
85 + ".in.npy";
86
87 load_array_from_numpy(inputs_file, shape, data);
88
89 // Values stored as doubles so reinterpret as floats
90 const auto *float_val = reinterpret_cast<float *>(&data[0]);
91 const size_t num_elements = data.size() * sizeof(double) / sizeof(float);
92
93 for(unsigned int i = 0; i < num_elements; ++i)
94 {
95 // Convert to fixed point
96 fixed_point_arithmetic::fixed_point<int8_t> in_val(float_val[i], frac_bits);
97
98 // Check that the value didn't change
99 ARM_COMPUTE_EXPECT(static_cast<float>(in_val) == float_val[i], framework::LogLevel::ERRORS);
100 }
101}
102
Kaizenbf8b01d2017-10-12 14:26:51 +0100103// *INDENT-OFF*
104// clang-format off
105DATA_TEST_CASE(FixedPointQS8Outputs, framework::DatasetMode::ALL, zip(combine(
106 FuncNamesDataset,
107 framework::dataset::make("FractionalBits", 1, 7)),
108 framework::dataset::make("ExpectedFailures", { 0, 0, 0, 0, 0, 0,
109 0, 0, 0, 0, 0, 0,
110 0, 0, 0, 0, 0, 0,
111 7, 8, 13, 2, 0, 0,
112 0, 0, 0, 0, 0, 0,
113 0, 0, 0, 5, 33, 96 })),
114 func_name, frac_bits, expected_failures)
115// clang-format on
116// *INDENT-ON*
117{
118 std::vector<double> in_data;
119 std::vector<unsigned long> in_shape; //NOLINT
120
121 std::vector<double> out_data;
122 std::vector<unsigned long> out_shape; //NOLINT
123
124 std::string func_name_lower = to_string(func_name);
125 std::transform(func_name_lower.begin(), func_name_lower.end(), func_name_lower.begin(), ::tolower);
126
127 const std::string base_file_name = library->path()
128 + "fixed_point/"
129 + func_name_lower
130 + "_Q8."
131 + support::cpp11::to_string(frac_bits);
132
133 const std::string inputs_file = base_file_name + ".in.npy";
134 const std::string reference_file = base_file_name + ".out.npy";
135
136 load_array_from_numpy(inputs_file, in_shape, in_data);
137 load_array_from_numpy(reference_file, out_shape, out_data);
138
139 ARM_COMPUTE_EXPECT(in_shape.front() == out_shape.front(), framework::LogLevel::ERRORS);
140
141 const float step_size = std::pow(2.f, -frac_bits);
142 int64_t num_mismatches = 0;
143
144 // Values stored as doubles so reinterpret as floats
145 const auto *float_val = reinterpret_cast<float *>(&in_data[0]);
146 const auto *ref_val = reinterpret_cast<float *>(&out_data[0]);
147
148 const size_t num_elements = in_data.size() * sizeof(double) / sizeof(float);
149
150 for(unsigned int i = 0; i < num_elements; ++i)
151 {
152 fixed_point_arithmetic::fixed_point<int8_t> in_val(float_val[i], frac_bits);
153 fixed_point_arithmetic::fixed_point<int8_t> out_val(0.f, frac_bits);
154
155 float tolerance = 0.f;
156
157 if(func_name == FixedPointOp::ADD)
158 {
159 out_val = in_val + in_val;
160 }
161 else if(func_name == FixedPointOp::SUB)
162 {
163 out_val = in_val - in_val; //NOLINT
164 }
165 else if(func_name == FixedPointOp::MUL)
166 {
167 tolerance = 1.f * step_size;
168 out_val = in_val * in_val;
169 }
170 else if(func_name == FixedPointOp::EXP)
171 {
172 tolerance = 2.f * step_size;
173 out_val = fixed_point_arithmetic::exp(in_val);
174 }
175 else if(func_name == FixedPointOp::LOG)
176 {
177 tolerance = 4.f * step_size;
178 out_val = fixed_point_arithmetic::log(in_val);
179 }
180 else if(func_name == FixedPointOp::INV_SQRT)
181 {
182 tolerance = 5.f * step_size;
183 out_val = fixed_point_arithmetic::inv_sqrt(in_val);
184 }
185
186 if(std::abs(static_cast<float>(out_val) - ref_val[i]) > tolerance)
187 {
188 ARM_COMPUTE_TEST_INFO("input = " << in_val);
189 ARM_COMPUTE_TEST_INFO("output = " << out_val);
190 ARM_COMPUTE_TEST_INFO("reference = " << ref_val[i]);
191 ARM_COMPUTE_TEST_INFO("tolerance = " << tolerance);
192
193 ARM_COMPUTE_TEST_INFO((std::abs(static_cast<float>(out_val) - ref_val[i]) <= tolerance));
194
195 ++num_mismatches;
196 }
197 }
198
199 ARM_COMPUTE_EXPECT(num_mismatches == expected_failures, framework::LogLevel::ERRORS);
200}
201
202TEST_SUITE_END()
203TEST_SUITE_END()
204} // namespace validation
205} // namespace test
206} // namespace arm_compute