blob: 251f5a805732815a404452bf9106d940d6896782 [file] [log] [blame]
Kaizenbf8b01d2017-10-12 14:26:51 +01001/*
2 * Copyright (c) 2017 ARM Limited.
3 *
4 * SPDX-License-Identifier: MIT
5 *
6 * Permission is hereby granted, free of charge, to any person obtaining a copy
7 * of this software and associated documentation files (the "Software"), to
8 * deal in the Software without restriction, including without limitation the
9 * rights to use, copy, modify, merge, publish, distribute, sublicense, and/or
10 * sell copies of the Software, and to permit persons to whom the Software is
11 * furnished to do so, subject to the following conditions:
12 *
13 * The above copyright notice and this permission notice shall be included in all
14 * copies or substantial portions of the Software.
15 *
16 * THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
17 * IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
18 * FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
19 * AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
20 * LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
21 * OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
22 * SOFTWARE.
23 */
24#include "tests/validation/FixedPoint.h"
25
26#include "tests/Globals.h"
27#include "tests/framework/Asserts.h"
28#include "tests/framework/Macros.h"
29#include "tests/framework/datasets/Datasets.h"
30#include "tests/validation/Validation.h"
31
32namespace arm_compute
33{
34namespace test
35{
36namespace validation
37{
38namespace
39{
40const auto FuncNamesDataset = framework::dataset::make("FunctionNames", { FixedPointOp::ADD,
41 FixedPointOp::SUB,
42 FixedPointOp::MUL,
43 FixedPointOp::EXP,
44 FixedPointOp::LOG,
45 FixedPointOp::INV_SQRT
46 });
47
48template <typename T>
49void load_array_from_numpy(const std::string &file, std::vector<unsigned long> &shape, std::vector<T> &data) // NOLINT
50{
51 try
52 {
53 npy::LoadArrayFromNumpy(file, shape, data);
54 }
55 catch(const std::runtime_error &e)
56 {
57 throw framework::FileNotFound("Could not load npy file: " + file + " (" + e.what() + ")");
58 }
59}
60} // namespace
61
62TEST_SUITE(UNIT)
63TEST_SUITE(FixedPoint)
64
65// *INDENT-OFF*
66// clang-format off
67DATA_TEST_CASE(FixedPointQS8Inputs, framework::DatasetMode::ALL, combine(
68 FuncNamesDataset,
69 framework::dataset::make("FractionalBits", 1, 7)),
70 func_name, frac_bits)
71// clang-format on
72// *INDENT-ON*
73{
74 std::vector<double> data;
75 std::vector<unsigned long> shape; //NOLINT
76
77 std::string func_name_lower = to_string(func_name);
78 std::transform(func_name_lower.begin(), func_name_lower.end(), func_name_lower.begin(), ::tolower);
79
80 const std::string inputs_file = library->path()
81 + "fixed_point/"
82 + func_name_lower
83 + "_Q8."
84 + support::cpp11::to_string(frac_bits)
85 + ".in.npy";
86
87 load_array_from_numpy(inputs_file, shape, data);
88
89 // Values stored as doubles so reinterpret as floats
90 const auto *float_val = reinterpret_cast<float *>(&data[0]);
91 const size_t num_elements = data.size() * sizeof(double) / sizeof(float);
92
93 for(unsigned int i = 0; i < num_elements; ++i)
94 {
95 // Convert to fixed point
96 fixed_point_arithmetic::fixed_point<int8_t> in_val(float_val[i], frac_bits);
97
98 // Check that the value didn't change
99 ARM_COMPUTE_EXPECT(static_cast<float>(in_val) == float_val[i], framework::LogLevel::ERRORS);
100 }
101}
102
103// The last input argument specifies the expected number of failures for a
104// given combination of (function name, number of fractional bits) as defined
105// by the first two arguments.
106
107// *INDENT-OFF*
108// clang-format off
109DATA_TEST_CASE(FixedPointQS8Outputs, framework::DatasetMode::ALL, zip(combine(
110 FuncNamesDataset,
111 framework::dataset::make("FractionalBits", 1, 7)),
112 framework::dataset::make("ExpectedFailures", { 0, 0, 0, 0, 0, 0,
113 0, 0, 0, 0, 0, 0,
114 0, 0, 0, 0, 0, 0,
115 7, 8, 13, 2, 0, 0,
116 0, 0, 0, 0, 0, 0,
117 0, 0, 0, 5, 33, 96 })),
118 func_name, frac_bits, expected_failures)
119// clang-format on
120// *INDENT-ON*
121{
122 std::vector<double> in_data;
123 std::vector<unsigned long> in_shape; //NOLINT
124
125 std::vector<double> out_data;
126 std::vector<unsigned long> out_shape; //NOLINT
127
128 std::string func_name_lower = to_string(func_name);
129 std::transform(func_name_lower.begin(), func_name_lower.end(), func_name_lower.begin(), ::tolower);
130
131 const std::string base_file_name = library->path()
132 + "fixed_point/"
133 + func_name_lower
134 + "_Q8."
135 + support::cpp11::to_string(frac_bits);
136
137 const std::string inputs_file = base_file_name + ".in.npy";
138 const std::string reference_file = base_file_name + ".out.npy";
139
140 load_array_from_numpy(inputs_file, in_shape, in_data);
141 load_array_from_numpy(reference_file, out_shape, out_data);
142
143 ARM_COMPUTE_EXPECT(in_shape.front() == out_shape.front(), framework::LogLevel::ERRORS);
144
145 const float step_size = std::pow(2.f, -frac_bits);
146 int64_t num_mismatches = 0;
147
148 // Values stored as doubles so reinterpret as floats
149 const auto *float_val = reinterpret_cast<float *>(&in_data[0]);
150 const auto *ref_val = reinterpret_cast<float *>(&out_data[0]);
151
152 const size_t num_elements = in_data.size() * sizeof(double) / sizeof(float);
153
154 for(unsigned int i = 0; i < num_elements; ++i)
155 {
156 fixed_point_arithmetic::fixed_point<int8_t> in_val(float_val[i], frac_bits);
157 fixed_point_arithmetic::fixed_point<int8_t> out_val(0.f, frac_bits);
158
159 float tolerance = 0.f;
160
161 if(func_name == FixedPointOp::ADD)
162 {
163 out_val = in_val + in_val;
164 }
165 else if(func_name == FixedPointOp::SUB)
166 {
167 out_val = in_val - in_val; //NOLINT
168 }
169 else if(func_name == FixedPointOp::MUL)
170 {
171 tolerance = 1.f * step_size;
172 out_val = in_val * in_val;
173 }
174 else if(func_name == FixedPointOp::EXP)
175 {
176 tolerance = 2.f * step_size;
177 out_val = fixed_point_arithmetic::exp(in_val);
178 }
179 else if(func_name == FixedPointOp::LOG)
180 {
181 tolerance = 4.f * step_size;
182 out_val = fixed_point_arithmetic::log(in_val);
183 }
184 else if(func_name == FixedPointOp::INV_SQRT)
185 {
186 tolerance = 5.f * step_size;
187 out_val = fixed_point_arithmetic::inv_sqrt(in_val);
188 }
189
190 if(std::abs(static_cast<float>(out_val) - ref_val[i]) > tolerance)
191 {
192 ARM_COMPUTE_TEST_INFO("input = " << in_val);
193 ARM_COMPUTE_TEST_INFO("output = " << out_val);
194 ARM_COMPUTE_TEST_INFO("reference = " << ref_val[i]);
195 ARM_COMPUTE_TEST_INFO("tolerance = " << tolerance);
196
197 ARM_COMPUTE_TEST_INFO((std::abs(static_cast<float>(out_val) - ref_val[i]) <= tolerance));
198
199 ++num_mismatches;
200 }
201 }
202
203 ARM_COMPUTE_EXPECT(num_mismatches == expected_failures, framework::LogLevel::ERRORS);
204}
205
206TEST_SUITE_END()
207TEST_SUITE_END()
208} // namespace validation
209} // namespace test
210} // namespace arm_compute