blob: 095051e2a0fbbfae5b9475ffef3d91ad848ad310 [file] [log] [blame]
Alan Kellyfda06cb2021-12-15 03:30:32 -08001#!/usr/bin/env python
2# Copyright 2021 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16import xngen
17import xnncommon
18
19parser = argparse.ArgumentParser(
20 description="Matrix transpose microkernel test generator")
21parser.add_argument(
22 "-s",
23 "--spec",
24 metavar="FILE",
25 required=True,
26 help="Specification (YAML) file")
27parser.add_argument(
28 "-o",
29 "--output",
30 metavar="FILE",
31 required=True,
32 help="Output (C++ source) file")
33parser.set_defaults(defines=list())
34
35
36def split_ukernel_name(name):
Alan Kelly1945f0b2021-12-24 01:26:45 -080037 match = re.match(r"^xnn_(x\d+)_transpose_ukernel__(\d+)x(\d+)_(.+)$", name)
Alan Kellyfda06cb2021-12-15 03:30:32 -080038 if match is None:
39 raise ValueError("Unexpected microkernel name: " + name)
40 tile_height = int(match.group(2))
41 tile_width = int(match.group(3))
42
43 arch, isa = xnncommon.parse_target_name(target_name=match.group(4))
44 return tile_height, tile_width, arch, isa
45
46
47TRANSPOSE_TEST_TEMPLATE = """\
48TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}) {
49 $if ISA_CHECK:
50 ${ISA_CHECK};
51 TransposeMicrokernelTester()
52 .input_stride(${TILE_WIDTH})
53 .output_stride(${TILE_HEIGHT})
54 .block_width(${TILE_WIDTH})
55 .block_height(${TILE_HEIGHT})
56 .iterations(1)
57 .Test(${KERNEL});
58}
59
60TEST(${TEST_NAME}, bh_1_${TILE_HEIGHT * 2}_bw_1_${TILE_WIDTH * 2}) {
61 $if ISA_CHECK:
62 ${ISA_CHECK};
63 for(size_t i = 1; i <= ${TILE_HEIGHT * 2}; ++i){
64 for(size_t j = 1; j <= ${TILE_WIDTH * 2}; ++j){
65 TransposeMicrokernelTester()
66 .input_stride(j)
67 .output_stride(i)
68 .block_width(j)
69 .block_height(i)
70 .iterations(1)
71 .Test(${KERNEL});
72 }
73 }
74}
75
76TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH * 2}) {
77 $if ISA_CHECK:
78 ${ISA_CHECK};
79 TransposeMicrokernelTester()
80 .input_stride(${TILE_WIDTH * 2})
81 .output_stride(${TILE_HEIGHT})
82 .block_width(${TILE_WIDTH * 2})
83 .block_height(${TILE_HEIGHT})
84 .iterations(1)
85 .Test(${KERNEL});
86}
87
88TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH + 1}_${TILE_WIDTH * 2}) {
89 $if ISA_CHECK:
90 ${ISA_CHECK};
91 for(size_t i = ${TILE_WIDTH + 1}; i < ${TILE_WIDTH * 2}; ++i){
92 TransposeMicrokernelTester()
93 .input_stride(i)
94 .output_stride(${TILE_HEIGHT})
95 .block_width(i)
96 .block_height(${TILE_HEIGHT})
97 .iterations(1)
98 .Test(${KERNEL});
99 }
100}
101
102TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH + 1}_${TILE_WIDTH * 2}) {
103 $if ISA_CHECK:
104 ${ISA_CHECK};
105 for(size_t i = ${TILE_WIDTH + 1}; i < ${TILE_WIDTH * 2}; ++i){
106 TransposeMicrokernelTester()
107 .input_stride(i)
108 .output_stride(${TILE_HEIGHT * 2})
109 .block_width(i)
110 .block_height(${TILE_HEIGHT * 2})
111 .iterations(1)
112 .Test(${KERNEL});
113 }
114}
115
116TEST(${TEST_NAME}, bh_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH}) {
117 $if ISA_CHECK:
118 ${ISA_CHECK};
119 TransposeMicrokernelTester()
120 .input_stride(${TILE_WIDTH})
121 .output_stride(${TILE_HEIGHT * 2})
122 .block_width(${TILE_WIDTH})
123 .block_height(${TILE_HEIGHT * 2})
124 .iterations(1)
125 .Test(${KERNEL});
126}
127
128TEST(${TEST_NAME}, bh_${TILE_HEIGHT + 1}_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH}){
129 $if ISA_CHECK:
130 ${ISA_CHECK};
131 for(size_t i = ${TILE_HEIGHT + 1}; i < ${TILE_HEIGHT * 2}; ++i){
132 TransposeMicrokernelTester()
133 .input_stride(${TILE_WIDTH})
134 .output_stride(i)
135 .block_width(${TILE_WIDTH})
136 .block_height(i)
137 .iterations(1)
138 .Test(${KERNEL});
139 }
140}
141
142TEST(${TEST_NAME}, bh_${TILE_HEIGHT + 1}_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH * 2}){
143 $if ISA_CHECK:
144 ${ISA_CHECK};
145 for(size_t i = ${TILE_HEIGHT + 1}; i < ${TILE_HEIGHT * 2}; ++i){
146 TransposeMicrokernelTester()
147 .input_stride(${TILE_WIDTH * 2})
148 .output_stride(i)
149 .block_width(${TILE_WIDTH * 2})
150 .block_height(i)
151 .iterations(1)
152 .Test(${KERNEL});
153 }
154}
155
156TEST(${TEST_NAME}, bh_${TILE_HEIGHT + 1}_${TILE_HEIGHT * 2}_bw_${TILE_WIDTH + 1}_${TILE_WIDTH * 2}) {
157 $if ISA_CHECK:
158 ${ISA_CHECK};
159 for(size_t i = ${TILE_HEIGHT + 1}; i < ${TILE_HEIGHT * 2}; ++i){
160 for(size_t j = ${TILE_WIDTH + 1}; j < ${TILE_WIDTH * 2}; ++j){
161 TransposeMicrokernelTester()
162 .input_stride(j)
163 .output_stride(i)
164 .block_width(j)
165 .block_height(i)
166 .iterations(1)
167 .Test(${KERNEL});
168 }
169 }
170}
171
172TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}_is_${TILE_WIDTH * 2}) {
173 $if ISA_CHECK:
174 ${ISA_CHECK};
175 TransposeMicrokernelTester()
176 .input_stride(${TILE_WIDTH * 2})
177 .output_stride(${TILE_HEIGHT})
178 .block_width(${TILE_WIDTH})
179 .block_height(${TILE_HEIGHT})
180 .iterations(1)
181 .Test(${KERNEL});
182}
183
184TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}_os_${TILE_HEIGHT * 2}) {
185 $if ISA_CHECK:
186 ${ISA_CHECK};
187 TransposeMicrokernelTester()
188 .input_stride(${TILE_WIDTH})
189 .output_stride(${TILE_HEIGHT * 2})
190 .block_width(${TILE_WIDTH})
191 .block_height(${TILE_HEIGHT})
192 .iterations(1)
193 .Test(${KERNEL});
194}
195
196TEST(${TEST_NAME}, bh_${TILE_HEIGHT}_bw_${TILE_WIDTH}_is_${TILE_WIDTH * 2}_os_${TILE_HEIGHT * 2}) {
197 $if ISA_CHECK:
198 ${ISA_CHECK};
199 TransposeMicrokernelTester()
200 .input_stride(${TILE_WIDTH * 2})
201 .output_stride(${TILE_HEIGHT * 2})
202 .block_width(${TILE_WIDTH})
203 .block_height(${TILE_HEIGHT})
204 .iterations(1)
205 .Test(${KERNEL});
206}
207"""
208
209
210def generate_test_cases(ukernel, tile_height, tile_width, isa):
211 """Generates all tests cases for a Vector Convert Operation micro-kernel.
212
213 Args:
214 ukernel: C name of the micro-kernel function.
215 tile_height: Number of vertical elements processed by the ukernel.
216 tile_width: Number of horizontal elements processed by the ukernel.
217 isa: instruction set required to run the micro-kernel. Generated unit test
218 will skip execution if the host processor doesn't support this ISA.
219
220 Returns:
221 Code for the test case.
222 """
223 _, test_name = ukernel.split("_", 1)
224 test_args = [ukernel]
225 return xngen.preprocess(
226 TRANSPOSE_TEST_TEMPLATE, {
227 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
228 "KERNEL": ukernel,
229 "TILE_HEIGHT": tile_height,
230 "TILE_WIDTH": tile_width,
231 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
232 })
233
234
235def main(args):
236 options = parser.parse_args(args)
237
238 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
239 spec_yaml = yaml.safe_load(spec_file)
240 if not isinstance(spec_yaml, list):
241 raise ValueError("expected a list of micro-kernels in the spec")
242
243 tests = """\
244// Copyright 2021 Google LLC
245//
246// This source code is licensed under the BSD-style license found in the
247// LICENSE file in the root directory of this source tree.
248//
249// Auto-generated file. Do not edit!
250// Specification: {specification}
251// Generator: {generator}
252
253
254#include <gtest/gtest.h>
255
256#include <xnnpack/common.h>
257#include <xnnpack/isa-checks.h>
258
259#include <xnnpack/transpose.h>
260#include "transpose-microkernel-tester.h"
261""".format(
262 specification=options.spec, generator=sys.argv[0])
263
264 for ukernel_spec in spec_yaml:
265 name = ukernel_spec["name"]
266 tile_height, tile_width, arch, isa = split_ukernel_name(name)
267
268 # specification can override architecture
269 arch = ukernel_spec.get("arch", arch)
270
271 test_case = generate_test_cases(name, tile_height, tile_width, isa)
272 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
273
274 txt_changed = True
275 if os.path.exists(options.output):
276 with codecs.open(options.output, "r", encoding="utf-8") as output_file:
277 txt_changed = output_file.read() != tests
278
279 if txt_changed:
280 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
281 output_file.write(tests)
282
283
284if __name__ == "__main__":
285 main(sys.argv[1:])