blob: 92aea366bb4aebbb768cb2ad579e8b4e04dae450 [file] [log] [blame]
Marat Dukhan346a9e52019-11-15 09:06:30 -08001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16import xngen
17import xnncommon
18
19
20parser = argparse.ArgumentParser(
21 description='Vector unary operation microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23 help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25 help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
Marat Dukhane79acb72021-08-16 19:03:53 -070030 match = re.match(r"^xnn_(s8|u8|f16|f32)_v(abs|clamp|elu|hswish|lrelu|neg|relu|rndd|rndne|rndu|rndz|sigmoid|sqr|sqrt)_(fact_)?ukernel__(.+)_x(\d+)$", name)
Marat Dukhan346a9e52019-11-15 09:06:30 -080031 if match is None:
32 raise ValueError("Unexpected microkernel name: " + name)
33 op_type = {
Marat Dukhan6674d692021-05-05 22:27:00 -070034 "abs": "Abs",
Marat Dukhan60d3f242021-05-13 11:59:02 -070035 "clamp": "Clamp",
Marat Dukhan6674d692021-05-05 22:27:00 -070036 "elu": "ELU",
Marat Dukhan949b6e72021-05-13 11:21:06 -070037 "hswish": "HardSwish",
Marat Dukhan6674d692021-05-05 22:27:00 -070038 "lrelu": "LeakyReLU",
39 "neg": "Negate",
Marat Dukhan949b6e72021-05-13 11:21:06 -070040 "relu": "ReLU",
41 "rndd": "RoundDown",
Marat Dukhan6674d692021-05-05 22:27:00 -070042 "rndne": "RoundToNearestEven",
43 "rndz": "RoundTowardsZero",
44 "rndu": "RoundUp",
Marat Dukhan949b6e72021-05-13 11:21:06 -070045 "sigmoid": "Sigmoid",
Marat Dukhan6674d692021-05-05 22:27:00 -070046 "sqr": "Square",
47 "sqrt": "SquareRoot",
Marat Dukhan346a9e52019-11-15 09:06:30 -080048 }[match.group(2)]
Erich Elsen8fd7b5f2019-11-18 10:50:41 -080049 batch_tile = int(match.group(5))
Marat Dukhan346a9e52019-11-15 09:06:30 -080050
Erich Elsen8fd7b5f2019-11-18 10:50:41 -080051 arch, isa = xnncommon.parse_target_name(target_name=match.group(4))
Marat Dukhan346a9e52019-11-15 09:06:30 -080052 return op_type, batch_tile, arch, isa
53
54
55BINOP_TEST_TEMPLATE = """\
56TEST(${TEST_NAME}, batch_eq_${BATCH_TILE}) {
57 $if ISA_CHECK:
58 ${ISA_CHECK};
Marat Dukhan87ed45c2021-05-13 12:25:22 -070059 VUnaryMicrokernelTester()
Marat Dukhan346a9e52019-11-15 09:06:30 -080060 .batch_size(${BATCH_TILE})
61 .Test(${", ".join(TEST_ARGS)});
62}
63
64$if BATCH_TILE > 1:
65 TEST(${TEST_NAME}, batch_div_${BATCH_TILE}) {
66 $if ISA_CHECK:
67 ${ISA_CHECK};
68 for (size_t batch_size = ${BATCH_TILE*2}; batch_size < ${BATCH_TILE*10}; batch_size += ${BATCH_TILE}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -070069 VUnaryMicrokernelTester()
Marat Dukhan346a9e52019-11-15 09:06:30 -080070 .batch_size(batch_size)
71 .Test(${", ".join(TEST_ARGS)});
72 }
73 }
74
75 TEST(${TEST_NAME}, batch_lt_${BATCH_TILE}) {
76 $if ISA_CHECK:
77 ${ISA_CHECK};
78 for (size_t batch_size = 1; batch_size < ${BATCH_TILE}; batch_size++) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -070079 VUnaryMicrokernelTester()
Marat Dukhan346a9e52019-11-15 09:06:30 -080080 .batch_size(batch_size)
81 .Test(${", ".join(TEST_ARGS)});
82 }
83 }
84
85TEST(${TEST_NAME}, batch_gt_${BATCH_TILE}) {
86 $if ISA_CHECK:
87 ${ISA_CHECK};
88 for (size_t batch_size = ${BATCH_TILE+1}; batch_size < ${10 if BATCH_TILE == 1 else BATCH_TILE*2}; batch_size++) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -070089 VUnaryMicrokernelTester()
Marat Dukhan346a9e52019-11-15 09:06:30 -080090 .batch_size(batch_size)
91 .Test(${", ".join(TEST_ARGS)});
92 }
93}
94
95TEST(${TEST_NAME}, inplace) {
96 $if ISA_CHECK:
97 ${ISA_CHECK};
98 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -070099 VUnaryMicrokernelTester()
Marat Dukhan346a9e52019-11-15 09:06:30 -0800100 .batch_size(batch_size)
101 .inplace(true)
102 .Test(${", ".join(TEST_ARGS)});
103 }
104}
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700105
Marat Dukhan60d3f242021-05-13 11:59:02 -0700106$if OP_TYPE == "Clamp":
107 TEST(${TEST_NAME}, qmin) {
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700108 $if ISA_CHECK:
109 ${ISA_CHECK};
Marat Dukhan60d3f242021-05-13 11:59:02 -0700110 for (uint8_t qmin = 1; qmin < 255; qmin++) {
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700111 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -0700112 VUnaryMicrokernelTester()
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700113 .batch_size(batch_size)
Marat Dukhan60d3f242021-05-13 11:59:02 -0700114 .qmin(qmin)
115 .Test(${", ".join(TEST_ARGS)});
116 }
117 }
118 }
119
120 TEST(${TEST_NAME}, qmax) {
121 $if ISA_CHECK:
122 ${ISA_CHECK};
123 for (uint8_t qmax = 1; qmax < 255; qmax++) {
124 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -0700125 VUnaryMicrokernelTester()
Marat Dukhan60d3f242021-05-13 11:59:02 -0700126 .batch_size(batch_size)
127 .qmax(qmax)
Marat Dukhan8cc7efe2020-06-10 16:24:27 -0700128 .Test(${", ".join(TEST_ARGS)});
129 }
130 }
131 }
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800132
133$if OP_TYPE == "ELU":
134 TEST(${TEST_NAME}, prescale) {
135 $if ISA_CHECK:
136 ${ISA_CHECK};
137 for (float prescale : std::vector<float>({0.1f, 10.0f})) {
138 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -0700139 VUnaryMicrokernelTester()
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800140 .batch_size(batch_size)
141 .prescale(prescale)
142 .Test(${", ".join(TEST_ARGS)});
143 }
144 }
145 }
146
147 TEST(${TEST_NAME}, alpha) {
148 $if ISA_CHECK:
149 ${ISA_CHECK};
150 for (float alpha : std::vector<float>({0.3f, 3.0f})) {
151 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -0700152 VUnaryMicrokernelTester()
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800153 .batch_size(batch_size)
154 .alpha(alpha)
155 .Test(${", ".join(TEST_ARGS)});
156 }
157 }
158 }
159
160 TEST(${TEST_NAME}, beta) {
161 $if ISA_CHECK:
162 ${ISA_CHECK};
163 for (float beta : std::vector<float>({0.3f, 3.0f})) {
164 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -0700165 VUnaryMicrokernelTester()
Marat Dukhaned6baaf2020-12-01 15:07:08 -0800166 .batch_size(batch_size)
167 .beta(beta)
168 .Test(${", ".join(TEST_ARGS)});
169 }
170 }
171 }
Marat Dukhan60d3f242021-05-13 11:59:02 -0700172
173$if OP_TYPE == "LeakyReLU":
174 TEST(${TEST_NAME}, slope) {
175 $if ISA_CHECK:
176 ${ISA_CHECK};
177 for (float slope : std::vector<float>({-0.7f, 0.3f, 1.3f})) {
178 for (size_t batch_size = 1; batch_size <= ${BATCH_TILE*5}; batch_size += ${max(1, BATCH_TILE-1)}) {
Marat Dukhan87ed45c2021-05-13 12:25:22 -0700179 VUnaryMicrokernelTester()
Marat Dukhan60d3f242021-05-13 11:59:02 -0700180 .batch_size(batch_size)
181 .slope(slope)
182 .Test(${", ".join(TEST_ARGS)});
183 }
184 }
185 }
Marat Dukhan346a9e52019-11-15 09:06:30 -0800186"""
187
188
Marat Dukhan1f5b1082021-08-16 17:01:44 -0700189def generate_test_cases(ukernel, op_type, init_fn, batch_tile, isa):
Marat Dukhan346a9e52019-11-15 09:06:30 -0800190 """Generates all tests cases for a Vector Unary Operation micro-kernel.
191
192 Args:
193 ukernel: C name of the micro-kernel function.
194 op_type: Operation type.
Marat Dukhan1f5b1082021-08-16 17:01:44 -0700195 init_fn: C name of the function to initialize microkernel parameters.
Marat Dukhan346a9e52019-11-15 09:06:30 -0800196 batch_tile: Number of batch elements processed per one iteration of the
197 inner loop of the micro-kernel.
198 isa: instruction set required to run the micro-kernel. Generated unit test
199 will skip execution if the host processor doesn't support this ISA.
200
201 Returns:
202 Code for the test case.
203 """
204 _, test_name = ukernel.split("_", 1)
205 _, datatype, _ = ukernel.split("_", 2)
Marat Dukhan0d10cc72021-12-23 19:49:19 -0800206 test_args = [ukernel]
Marat Dukhan0e801372022-01-04 00:10:41 -0800207 if init_fn or op_type.startswith("Round"):
208 if op_type.startswith("Round"):
209 test_args.append("VUnaryMicrokernelTester::OpType::" + op_type)
210 if init_fn is not None:
211 test_args.append(init_fn)
Marat Dukhane5efb162021-12-31 10:26:13 -0800212 elif op_type not in ["Abs", "Negate", "Square", "SquareRoot"]:
Marat Dukhan0d10cc72021-12-23 19:49:19 -0800213 test_args.append("VUnaryMicrokernelTester::OpType::" + op_type)
214 if not isa:
215 test_args.append("VUnaryMicrokernelTester::Variant::Scalar")
Marat Dukhan346a9e52019-11-15 09:06:30 -0800216 return xngen.preprocess(BINOP_TEST_TEMPLATE, {
217 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
218 "TEST_ARGS": test_args,
219 "DATATYPE": datatype,
220 "BATCH_TILE": batch_tile,
221 "OP_TYPE": op_type,
222 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
223 })
224
225
226def main(args):
227 options = parser.parse_args(args)
228
229 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
230 spec_yaml = yaml.safe_load(spec_file)
231 if not isinstance(spec_yaml, list):
232 raise ValueError("expected a list of micro-kernels in the spec")
233
234 tests = """\
235// Copyright 2019 Google LLC
236//
237// This source code is licensed under the BSD-style license found in the
238// LICENSE file in the root directory of this source tree.
239//
240// Auto-generated file. Do not edit!
241// Specification: {specification}
242// Generator: {generator}
243
244
245#include <gtest/gtest.h>
246
247#include <xnnpack/common.h>
248#include <xnnpack/isa-checks.h>
249
Marat Dukhan1e782c42019-11-21 17:02:40 -0800250#include <xnnpack/vunary.h>
251#include "vunary-microkernel-tester.h"
Marat Dukhan346a9e52019-11-15 09:06:30 -0800252""".format(specification=options.spec, generator=sys.argv[0])
253
254 for ukernel_spec in spec_yaml:
255 name = ukernel_spec["name"]
Marat Dukhan1f5b1082021-08-16 17:01:44 -0700256 init_fn = ukernel_spec.get("init")
Marat Dukhan346a9e52019-11-15 09:06:30 -0800257 op_type, batch_tile, arch, isa = split_ukernel_name(name)
258
259 # specification can override architecture
260 arch = ukernel_spec.get("arch", arch)
261
Marat Dukhan1f5b1082021-08-16 17:01:44 -0700262 test_case = generate_test_cases(name, op_type, init_fn, batch_tile, isa)
Marat Dukhan346a9e52019-11-15 09:06:30 -0800263 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
264
Frank Barchard1f83cf92021-09-07 14:13:03 -0700265 txt_changed = True
266 if os.path.exists(options.output):
267 with codecs.open(options.output, "r", encoding="utf-8") as output_file:
268 txt_changed = output_file.read() != tests
269
270 if txt_changed:
271 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
272 output_file.write(tests)
Marat Dukhan346a9e52019-11-15 09:06:30 -0800273
274
275if __name__ == "__main__":
276 main(sys.argv[1:])