blob: 7aa959f396f40e3e46a58050248e79a58b50b434 [file] [log] [blame]
Frank Barchard6c74cd12020-05-20 21:26:47 -07001#!/usr/bin/env python
2# Copyright 2020 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16from primes import next_prime
17import xngen
18import xnncommon
19
20
21parser = argparse.ArgumentParser(description='GAvgPool microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23 help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25 help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
Marat Dukhan85755042022-01-13 01:46:05 -080030 match = re.match(r"^xnn_(qs8|qu8|f16|f32)_[p]?gavgpool(_(minmax))?(_(fp32|rndnu))?_ukernel_((\d+)p)?(\d+)x__(.+)_c(\d+)(_acc(\d+))?$", name)
Frank Barchard6c74cd12020-05-20 21:26:47 -070031 if match is None:
32 raise ValueError("Unexpected microkernel name: " + name)
33
Marat Dukhan85755042022-01-13 01:46:05 -080034 requantization_type = match.group(5)
Marat Dukhan847ff5e2022-01-11 20:31:06 -080035 if match.group(6):
36 primary_tile = int(match.group(7))
37 incremental_tile = int(match.group(8))
Frank Barchard6c74cd12020-05-20 21:26:47 -070038 else:
Marat Dukhan847ff5e2022-01-11 20:31:06 -080039 primary_tile = int(match.group(8))
Frank Barchard6c74cd12020-05-20 21:26:47 -070040 incremental_tile = 0
Marat Dukhan847ff5e2022-01-11 20:31:06 -080041 channel_tile = int(match.group(10))
Frank Barchard6c74cd12020-05-20 21:26:47 -070042
Marat Dukhan847ff5e2022-01-11 20:31:06 -080043 arch, isa = xnncommon.parse_target_name(target_name=match.group(9))
Marat Dukhan85755042022-01-13 01:46:05 -080044 return requantization_type, primary_tile, incremental_tile, channel_tile, arch, isa
Frank Barchard6c74cd12020-05-20 21:26:47 -070045
46
47AVGPOOL_TEST_TEMPLATE = """\
48$if INCREMENTAL_TILE == 0:
49 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile) {
50 $if ISA_CHECK:
51 ${ISA_CHECK};
52 GAvgPoolMicrokernelTester()
53 .rows(${PRIMARY_TILE})
54 .channels(${CHANNEL_TILE})
55 .Test(${", ".join(TEST_ARGS)});
56 }
57
58 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_subtile) {
59 $if ISA_CHECK:
60 ${ISA_CHECK};
61 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
62 GAvgPoolMicrokernelTester()
63 .rows(rows)
64 .channels(${CHANNEL_TILE})
65 .Test(${", ".join(TEST_ARGS)});
66 }
67 }
68
69 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_input_stride) {
70 $if ISA_CHECK:
71 ${ISA_CHECK};
72 GAvgPoolMicrokernelTester()
73 .rows(${PRIMARY_TILE})
74 .channels(${CHANNEL_TILE})
75 .input_stride(${next_prime(CHANNEL_TILE+1)})
76 .Test(${", ".join(TEST_ARGS)});
77 }
78
79 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_qmax) {
80 $if ISA_CHECK:
81 ${ISA_CHECK};
82 GAvgPoolMicrokernelTester()
83 .rows(${PRIMARY_TILE})
84 .channels(${CHANNEL_TILE})
85 .qmax(128)
86 .Test(${", ".join(TEST_ARGS)});
87 }
88
89 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_qmin) {
90 $if ISA_CHECK:
91 ${ISA_CHECK};
92 GAvgPoolMicrokernelTester()
93 .rows(${PRIMARY_TILE})
94 .channels(${CHANNEL_TILE})
95 .qmin(128)
96 .Test(${", ".join(TEST_ARGS)});
97 }
98
99 $if CHANNEL_TILE > 1:
100 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_fulltile) {
101 $if ISA_CHECK:
102 ${ISA_CHECK};
103 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
104 GAvgPoolMicrokernelTester()
105 .rows(${PRIMARY_TILE})
106 .channels(channels)
107 .Test(${", ".join(TEST_ARGS)});
108 }
109 }
110
111 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_subtile) {
112 $if ISA_CHECK:
113 ${ISA_CHECK};
114 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
115 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
116 GAvgPoolMicrokernelTester()
117 .rows(rows)
118 .channels(channels)
119 .Test(${", ".join(TEST_ARGS)});
120 }
121 }
122 }
123
124 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile) {
125 $if ISA_CHECK:
126 ${ISA_CHECK};
127 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
128 GAvgPoolMicrokernelTester()
129 .rows(${PRIMARY_TILE})
130 .channels(channels)
131 .Test(${", ".join(TEST_ARGS)});
132 }
133 }
134
135 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_subtile) {
136 $if ISA_CHECK:
137 ${ISA_CHECK};
138 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
139 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
140 GAvgPoolMicrokernelTester()
141 .rows(rows)
142 .channels(channels)
143 .Test(${", ".join(TEST_ARGS)});
144 }
145 }
146 }
147
148 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile_with_qmax) {
149 $if ISA_CHECK:
150 ${ISA_CHECK};
151 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
152 GAvgPoolMicrokernelTester()
153 .rows(${PRIMARY_TILE})
154 .channels(channels)
155 .qmax(128)
156 .Test(${", ".join(TEST_ARGS)});
157 }
158 }
159
160 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile_with_qmin) {
161 $if ISA_CHECK:
162 ${ISA_CHECK};
163 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
164 GAvgPoolMicrokernelTester()
165 .rows(${PRIMARY_TILE})
166 .channels(channels)
167 .qmin(128)
168 .Test(${", ".join(TEST_ARGS)});
169 }
170 }
171
172 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile) {
173 $if ISA_CHECK:
174 ${ISA_CHECK};
175 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
176 GAvgPoolMicrokernelTester()
177 .rows(${PRIMARY_TILE})
178 .channels(channels)
179 .Test(${", ".join(TEST_ARGS)});
180 }
181 }
182
183 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_subtile) {
184 $if ISA_CHECK:
185 ${ISA_CHECK};
186 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
187 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
188 GAvgPoolMicrokernelTester()
189 .rows(rows)
190 .channels(channels)
191 .Test(${", ".join(TEST_ARGS)});
192 }
193 }
194 }
195
196 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile_with_qmax) {
197 $if ISA_CHECK:
198 ${ISA_CHECK};
199 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
200 GAvgPoolMicrokernelTester()
201 .rows(${PRIMARY_TILE})
202 .channels(channels)
203 .qmax(128)
204 .Test(${", ".join(TEST_ARGS)});
205 }
206 }
207
208 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile_with_qmin) {
209 $if ISA_CHECK:
210 ${ISA_CHECK};
211 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
212 GAvgPoolMicrokernelTester()
213 .rows(${PRIMARY_TILE})
214 .channels(channels)
215 .qmin(128)
216 .Test(${", ".join(TEST_ARGS)});
217 }
218 }
219$else:
220 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile) {
221 $if ISA_CHECK:
222 ${ISA_CHECK};
223 GAvgPoolMicrokernelTester()
224 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
225 .channels(${CHANNEL_TILE})
226 .Test(${", ".join(TEST_ARGS)});
227 }
228
229 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_input_stride) {
230 $if ISA_CHECK:
231 ${ISA_CHECK};
232 GAvgPoolMicrokernelTester()
233 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
234 .channels(${CHANNEL_TILE})
235 .input_stride(${next_prime(CHANNEL_TILE+1)})
236 .Test(${", ".join(TEST_ARGS)});
237 }
238
239 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
240 $if ISA_CHECK:
241 ${ISA_CHECK};
242 GAvgPoolMicrokernelTester()
243 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
244 .channels(${CHANNEL_TILE})
245 .qmax(128)
246 .Test(${", ".join(TEST_ARGS)});
247 }
248
249 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
250 $if ISA_CHECK:
251 ${ISA_CHECK};
252 GAvgPoolMicrokernelTester()
253 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
254 .channels(${CHANNEL_TILE})
255 .qmin(128)
256 .Test(${", ".join(TEST_ARGS)});
257 }
258
259 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_subtile) {
260 $if ISA_CHECK:
261 ${ISA_CHECK};
262 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
263 GAvgPoolMicrokernelTester()
264 .rows(rows)
265 .channels(${CHANNEL_TILE})
266 .Test(${", ".join(TEST_ARGS)});
267 }
268 }
269
270 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_subtile_with_input_stride) {
271 $if ISA_CHECK:
272 ${ISA_CHECK};
273 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
274 GAvgPoolMicrokernelTester()
275 .rows(rows)
276 .channels(${CHANNEL_TILE})
277 .input_stride(${next_prime(CHANNEL_TILE+1)})
278 .Test(${", ".join(TEST_ARGS)});
279 }
280 }
281
282 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_fulltile) {
283 $if ISA_CHECK:
284 ${ISA_CHECK};
285 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
286 GAvgPoolMicrokernelTester()
287 .rows(rows)
288 .channels(${CHANNEL_TILE})
289 .Test(${", ".join(TEST_ARGS)});
290 }
291 }
292
293 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
294 $if ISA_CHECK:
295 ${ISA_CHECK};
296 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
297 GAvgPoolMicrokernelTester()
298 .rows(rows)
299 .channels(${CHANNEL_TILE})
300 .input_stride(${next_prime(CHANNEL_TILE+1)})
301 .Test(${", ".join(TEST_ARGS)});
302 }
303 }
304
305 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_2pass_fulltile) {
306 $if ISA_CHECK:
307 ${ISA_CHECK};
308 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
309 GAvgPoolMicrokernelTester()
310 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
311 .channels(channels)
312 .Test(${", ".join(TEST_ARGS)});
313 }
314 }
315
316 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_2pass_subtile) {
317 $if ISA_CHECK:
318 ${ISA_CHECK};
319 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
320 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
321 GAvgPoolMicrokernelTester()
322 .rows(rows)
323 .channels(channels)
324 .Test(${", ".join(TEST_ARGS)});
325 }
326 }
327 }
328
329 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_fulltile) {
330 $if ISA_CHECK:
331 ${ISA_CHECK};
332 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
333 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
334 GAvgPoolMicrokernelTester()
335 .rows(rows)
336 .channels(channels)
337 .Test(${", ".join(TEST_ARGS)});
338 }
339 }
340 }
341
342 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
343 $if ISA_CHECK:
344 ${ISA_CHECK};
345 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
346 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
347 GAvgPoolMicrokernelTester()
348 .rows(rows)
349 .channels(channels)
350 .input_stride(${next_prime(CHANNEL_TILE*16+1)})
351 .Test(${", ".join(TEST_ARGS)});
352 }
353 }
354 }
355
356 $if CHANNEL_TILE > 1:
357 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile) {
358 $if ISA_CHECK:
359 ${ISA_CHECK};
360 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
361 GAvgPoolMicrokernelTester()
362 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
363 .channels(channels)
364 .Test(${", ".join(TEST_ARGS)});
365 }
366 }
367
368 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
369 $if ISA_CHECK:
370 ${ISA_CHECK};
371 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
372 GAvgPoolMicrokernelTester()
373 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
374 .channels(channels)
375 .qmax(128)
376 .Test(${", ".join(TEST_ARGS)});
377 }
378 }
379
380 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
381 $if ISA_CHECK:
382 ${ISA_CHECK};
383 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
384 GAvgPoolMicrokernelTester()
385 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
386 .channels(channels)
387 .qmin(128)
388 .Test(${", ".join(TEST_ARGS)});
389 }
390 }
391
392 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_subtile) {
393 $if ISA_CHECK:
394 ${ISA_CHECK};
395 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
396 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
397 GAvgPoolMicrokernelTester()
398 .rows(rows)
399 .channels(channels)
400 .Test(${", ".join(TEST_ARGS)});
401 }
402 }
403 }
404
405 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_fulltile) {
406 $if ISA_CHECK:
407 ${ISA_CHECK};
408 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
409 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
410 GAvgPoolMicrokernelTester()
411 .rows(rows)
412 .channels(channels)
413 .Test(${", ".join(TEST_ARGS)});
414 }
415 }
416 }
417
418 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
419 $if ISA_CHECK:
420 ${ISA_CHECK};
421 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
422 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
423 GAvgPoolMicrokernelTester()
424 .rows(rows)
425 .channels(channels)
426 .input_stride(${next_prime(CHANNEL_TILE+1)})
427 .Test(${", ".join(TEST_ARGS)});
428 }
429 }
430 }
431
432 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile) {
433 $if ISA_CHECK:
434 ${ISA_CHECK};
435 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
436 GAvgPoolMicrokernelTester()
437 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
438 .channels(channels)
439 .Test(${", ".join(TEST_ARGS)});
440 }
441 }
442
443 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
444 $if ISA_CHECK:
445 ${ISA_CHECK};
446 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
447 GAvgPoolMicrokernelTester()
448 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
449 .channels(channels)
450 .qmax(128)
451 .Test(${", ".join(TEST_ARGS)});
452 }
453 }
454
455 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
456 $if ISA_CHECK:
457 ${ISA_CHECK};
458 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
459 GAvgPoolMicrokernelTester()
460 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
461 .channels(channels)
462 .qmin(128)
463 .Test(${", ".join(TEST_ARGS)});
464 }
465 }
466
467 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_subtile) {
468 $if ISA_CHECK:
469 ${ISA_CHECK};
470 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
471 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
472 GAvgPoolMicrokernelTester()
473 .rows(rows)
474 .channels(channels)
475 .Test(${", ".join(TEST_ARGS)});
476 }
477 }
478 }
479
480 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_fulltile) {
481 $if ISA_CHECK:
482 ${ISA_CHECK};
483 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
484 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows < ${INCREMENTAL_TILE*5}; rows += ${PRIMARY_TILE+INCREMENTAL_TILE}) {
485 GAvgPoolMicrokernelTester()
486 .rows(rows)
487 .channels(channels)
488 .Test(${", ".join(TEST_ARGS)});
489 }
490 }
491 }
492
493 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
494 $if ISA_CHECK:
495 ${ISA_CHECK};
496 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
497 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows < ${INCREMENTAL_TILE*5}; rows += ${PRIMARY_TILE+INCREMENTAL_TILE}) {
498 GAvgPoolMicrokernelTester()
499 .rows(rows)
500 .channels(channels)
501 .input_stride(${next_prime(CHANNEL_TILE*2+11)})
502 .Test(${", ".join(TEST_ARGS)});
503 }
504 }
505 }
506
507"""
508
509
Marat Dukhan85755042022-01-13 01:46:05 -0800510def generate_test_cases(ukernel, init_fn, requantization_type, primary_tile,
511 incremental_tile, channel_tile, isa):
Frank Barchard6c74cd12020-05-20 21:26:47 -0700512 """Generates all tests cases for a GAVGPOOL micro-kernel.
513
514 Args:
515 ukernel: C name of the micro-kernel function.
Marat Dukhana7d74b12022-01-07 17:51:27 -0800516 init_fn: C name of the function to initialize microkernel parameters.
Marat Dukhan85755042022-01-13 01:46:05 -0800517 requantization_type: Requantization type (FP32/RNDNU).
Frank Barchard6c74cd12020-05-20 21:26:47 -0700518 primary_tile: Number of rows (pixels) processed per one iteration of the
519 primary outer loop of the micro-kernel.
520 incremental_tile: Number of rows (pixels) processed per one iteration of
521 the incremental outer loop of the micro-kernel.
522 channel_tile: Number of channels processed per one iteration of the inner
523 loops of the micro-kernel.
524 isa: instruction set required to run the micro-kernel. Generated unit test
525 will skip execution if the host processor doesn't support this ISA.
526
527 Returns:
528 Code for the test case.
529 """
530 _, test_name = ukernel.split("_", 1)
531 _, datatype, ukernel_type, _ = ukernel.split("_", 3)
Marat Dukhana7d74b12022-01-07 17:51:27 -0800532 test_args = [ukernel, init_fn]
Marat Dukhan85755042022-01-13 01:46:05 -0800533 if requantization_type:
534 test_args.append("xnn_%s_requantize_%s" % \
535 (datatype.lower(), requantization_type.lower()))
Frank Barchard6c74cd12020-05-20 21:26:47 -0700536 return xngen.preprocess(AVGPOOL_TEST_TEMPLATE, {
537 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
538 "TEST_ARGS": test_args,
539 "DATATYPE": datatype,
540 "PRIMARY_TILE": primary_tile,
541 "INCREMENTAL_TILE": incremental_tile,
542 "CHANNEL_TILE": channel_tile,
543 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
544 "next_prime": next_prime,
545 })
546
547
548def main(args):
549 options = parser.parse_args(args)
550
551 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
552 spec_yaml = yaml.safe_load(spec_file)
553 if not isinstance(spec_yaml, list):
554 raise ValueError("expected a list of micro-kernels in the spec")
555
556 tests = """\
557// Copyright (c) Facebook, Inc. and its affiliates.
558// All rights reserved.
559//
560// Copyright 2020 Google LLC
561//
562// This source code is licensed under the BSD-style license found in the
563// LICENSE file in the root directory of this source tree.
564//
565// Auto-generated file. Do not edit!
566// Specification: {specification}
567// Generator: {generator}
568
569
570#include <gtest/gtest.h>
571
572#include <xnnpack/common.h>
573#include <xnnpack/isa-checks.h>
574
575#include <xnnpack/gavgpool.h>
576#include "gavgpool-microkernel-tester.h"
577""".format(specification=options.spec, generator=sys.argv[0])
578
579 for ukernel_spec in spec_yaml:
580 name = ukernel_spec["name"]
Marat Dukhana7d74b12022-01-07 17:51:27 -0800581 init_fn = ukernel_spec.get("init")
Marat Dukhan85755042022-01-13 01:46:05 -0800582 requantization_type, primary_tile, incremental_tile, channel_tile, arch, \
583 isa = split_ukernel_name(name)
Frank Barchard6c74cd12020-05-20 21:26:47 -0700584
585 # specification can override architecture
586 arch = ukernel_spec.get("arch", arch)
587
Marat Dukhan85755042022-01-13 01:46:05 -0800588 test_case = generate_test_cases(name, init_fn, requantization_type,
589 primary_tile, incremental_tile,
590 channel_tile, isa)
Frank Barchard6c74cd12020-05-20 21:26:47 -0700591 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
592
Frank Barchard1f83cf92021-09-07 14:13:03 -0700593 txt_changed = True
594 if os.path.exists(options.output):
595 with codecs.open(options.output, "r", encoding="utf-8") as output_file:
596 txt_changed = output_file.read() != tests
597
598 if txt_changed:
599 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
600 output_file.write(tests)
Frank Barchard6c74cd12020-05-20 21:26:47 -0700601
602
603if __name__ == "__main__":
604 main(sys.argv[1:])