blob: 599dff2b2f9e81d7f43b422d1c7e53c20a2b75af [file] [log] [blame]
Frank Barchard6c74cd12020-05-20 21:26:47 -07001#!/usr/bin/env python
2# Copyright 2020 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16from primes import next_prime
17import xngen
18import xnncommon
19
20
21parser = argparse.ArgumentParser(description='GAvgPool microkernel test generator')
22parser.add_argument("-s", "--spec", metavar="FILE", required=True,
23 help="Specification (YAML) file")
24parser.add_argument("-o", "--output", metavar="FILE", required=True,
25 help='Output (C++ source) file')
26parser.set_defaults(defines=list())
27
28
29def split_ukernel_name(name):
Marat Dukhan4ed53f42020-08-06 01:12:55 -070030 match = re.match(r"^xnn_(qs8|qu8|f16|f32)_[p]?gavgpool(_(minmax))?_ukernel_((\d+)p)?(\d+)x__(.+)_c(\d+)(_acc(\d+))?$", name)
Frank Barchard6c74cd12020-05-20 21:26:47 -070031 if match is None:
32 raise ValueError("Unexpected microkernel name: " + name)
33
34 if match.group(4):
35 primary_tile = int(match.group(5))
36 incremental_tile = int(match.group(6))
37 else:
38 primary_tile = int(match.group(6))
39 incremental_tile = 0
40 channel_tile = int(match.group(8))
41
42 arch, isa = xnncommon.parse_target_name(target_name=match.group(7))
43 return primary_tile, incremental_tile, channel_tile, arch, isa
44
45
46AVGPOOL_TEST_TEMPLATE = """\
47$if INCREMENTAL_TILE == 0:
48 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile) {
49 $if ISA_CHECK:
50 ${ISA_CHECK};
51 GAvgPoolMicrokernelTester()
52 .rows(${PRIMARY_TILE})
53 .channels(${CHANNEL_TILE})
54 .Test(${", ".join(TEST_ARGS)});
55 }
56
57 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_subtile) {
58 $if ISA_CHECK:
59 ${ISA_CHECK};
60 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
61 GAvgPoolMicrokernelTester()
62 .rows(rows)
63 .channels(${CHANNEL_TILE})
64 .Test(${", ".join(TEST_ARGS)});
65 }
66 }
67
68 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_input_stride) {
69 $if ISA_CHECK:
70 ${ISA_CHECK};
71 GAvgPoolMicrokernelTester()
72 .rows(${PRIMARY_TILE})
73 .channels(${CHANNEL_TILE})
74 .input_stride(${next_prime(CHANNEL_TILE+1)})
75 .Test(${", ".join(TEST_ARGS)});
76 }
77
78 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_qmax) {
79 $if ISA_CHECK:
80 ${ISA_CHECK};
81 GAvgPoolMicrokernelTester()
82 .rows(${PRIMARY_TILE})
83 .channels(${CHANNEL_TILE})
84 .qmax(128)
85 .Test(${", ".join(TEST_ARGS)});
86 }
87
88 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_fulltile_with_qmin) {
89 $if ISA_CHECK:
90 ${ISA_CHECK};
91 GAvgPoolMicrokernelTester()
92 .rows(${PRIMARY_TILE})
93 .channels(${CHANNEL_TILE})
94 .qmin(128)
95 .Test(${", ".join(TEST_ARGS)});
96 }
97
98 $if CHANNEL_TILE > 1:
99 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_fulltile) {
100 $if ISA_CHECK:
101 ${ISA_CHECK};
102 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
103 GAvgPoolMicrokernelTester()
104 .rows(${PRIMARY_TILE})
105 .channels(channels)
106 .Test(${", ".join(TEST_ARGS)});
107 }
108 }
109
110 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_subtile) {
111 $if ISA_CHECK:
112 ${ISA_CHECK};
113 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
114 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
115 GAvgPoolMicrokernelTester()
116 .rows(rows)
117 .channels(channels)
118 .Test(${", ".join(TEST_ARGS)});
119 }
120 }
121 }
122
123 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile) {
124 $if ISA_CHECK:
125 ${ISA_CHECK};
126 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
127 GAvgPoolMicrokernelTester()
128 .rows(${PRIMARY_TILE})
129 .channels(channels)
130 .Test(${", ".join(TEST_ARGS)});
131 }
132 }
133
134 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_subtile) {
135 $if ISA_CHECK:
136 ${ISA_CHECK};
137 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
138 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
139 GAvgPoolMicrokernelTester()
140 .rows(rows)
141 .channels(channels)
142 .Test(${", ".join(TEST_ARGS)});
143 }
144 }
145 }
146
147 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile_with_qmax) {
148 $if ISA_CHECK:
149 ${ISA_CHECK};
150 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
151 GAvgPoolMicrokernelTester()
152 .rows(${PRIMARY_TILE})
153 .channels(channels)
154 .qmax(128)
155 .Test(${", ".join(TEST_ARGS)});
156 }
157 }
158
159 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_fulltile_with_qmin) {
160 $if ISA_CHECK:
161 ${ISA_CHECK};
162 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
163 GAvgPoolMicrokernelTester()
164 .rows(${PRIMARY_TILE})
165 .channels(channels)
166 .qmin(128)
167 .Test(${", ".join(TEST_ARGS)});
168 }
169 }
170
171 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile) {
172 $if ISA_CHECK:
173 ${ISA_CHECK};
174 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
175 GAvgPoolMicrokernelTester()
176 .rows(${PRIMARY_TILE})
177 .channels(channels)
178 .Test(${", ".join(TEST_ARGS)});
179 }
180 }
181
182 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_subtile) {
183 $if ISA_CHECK:
184 ${ISA_CHECK};
185 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
186 for (size_t rows = 1; rows < ${PRIMARY_TILE}; rows++) {
187 GAvgPoolMicrokernelTester()
188 .rows(rows)
189 .channels(channels)
190 .Test(${", ".join(TEST_ARGS)});
191 }
192 }
193 }
194
195 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile_with_qmax) {
196 $if ISA_CHECK:
197 ${ISA_CHECK};
198 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
199 GAvgPoolMicrokernelTester()
200 .rows(${PRIMARY_TILE})
201 .channels(channels)
202 .qmax(128)
203 .Test(${", ".join(TEST_ARGS)});
204 }
205 }
206
207 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_fulltile_with_qmin) {
208 $if ISA_CHECK:
209 ${ISA_CHECK};
210 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
211 GAvgPoolMicrokernelTester()
212 .rows(${PRIMARY_TILE})
213 .channels(channels)
214 .qmin(128)
215 .Test(${", ".join(TEST_ARGS)});
216 }
217 }
218$else:
219 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile) {
220 $if ISA_CHECK:
221 ${ISA_CHECK};
222 GAvgPoolMicrokernelTester()
223 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
224 .channels(${CHANNEL_TILE})
225 .Test(${", ".join(TEST_ARGS)});
226 }
227
228 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_input_stride) {
229 $if ISA_CHECK:
230 ${ISA_CHECK};
231 GAvgPoolMicrokernelTester()
232 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
233 .channels(${CHANNEL_TILE})
234 .input_stride(${next_prime(CHANNEL_TILE+1)})
235 .Test(${", ".join(TEST_ARGS)});
236 }
237
238 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
239 $if ISA_CHECK:
240 ${ISA_CHECK};
241 GAvgPoolMicrokernelTester()
242 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
243 .channels(${CHANNEL_TILE})
244 .qmax(128)
245 .Test(${", ".join(TEST_ARGS)});
246 }
247
248 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
249 $if ISA_CHECK:
250 ${ISA_CHECK};
251 GAvgPoolMicrokernelTester()
252 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
253 .channels(${CHANNEL_TILE})
254 .qmin(128)
255 .Test(${", ".join(TEST_ARGS)});
256 }
257
258 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_subtile) {
259 $if ISA_CHECK:
260 ${ISA_CHECK};
261 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
262 GAvgPoolMicrokernelTester()
263 .rows(rows)
264 .channels(${CHANNEL_TILE})
265 .Test(${", ".join(TEST_ARGS)});
266 }
267 }
268
269 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_2pass_subtile_with_input_stride) {
270 $if ISA_CHECK:
271 ${ISA_CHECK};
272 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
273 GAvgPoolMicrokernelTester()
274 .rows(rows)
275 .channels(${CHANNEL_TILE})
276 .input_stride(${next_prime(CHANNEL_TILE+1)})
277 .Test(${", ".join(TEST_ARGS)});
278 }
279 }
280
281 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_fulltile) {
282 $if ISA_CHECK:
283 ${ISA_CHECK};
284 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
285 GAvgPoolMicrokernelTester()
286 .rows(rows)
287 .channels(${CHANNEL_TILE})
288 .Test(${", ".join(TEST_ARGS)});
289 }
290 }
291
292 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
293 $if ISA_CHECK:
294 ${ISA_CHECK};
295 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
296 GAvgPoolMicrokernelTester()
297 .rows(rows)
298 .channels(${CHANNEL_TILE})
299 .input_stride(${next_prime(CHANNEL_TILE+1)})
300 .Test(${", ".join(TEST_ARGS)});
301 }
302 }
303
304 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_2pass_fulltile) {
305 $if ISA_CHECK:
306 ${ISA_CHECK};
307 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
308 GAvgPoolMicrokernelTester()
309 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
310 .channels(channels)
311 .Test(${", ".join(TEST_ARGS)});
312 }
313 }
314
315 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_2pass_subtile) {
316 $if ISA_CHECK:
317 ${ISA_CHECK};
318 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
319 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
320 GAvgPoolMicrokernelTester()
321 .rows(rows)
322 .channels(channels)
323 .Test(${", ".join(TEST_ARGS)});
324 }
325 }
326 }
327
328 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_fulltile) {
329 $if ISA_CHECK:
330 ${ISA_CHECK};
331 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
332 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
333 GAvgPoolMicrokernelTester()
334 .rows(rows)
335 .channels(channels)
336 .Test(${", ".join(TEST_ARGS)});
337 }
338 }
339 }
340
341 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
342 $if ISA_CHECK:
343 ${ISA_CHECK};
344 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
345 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
346 GAvgPoolMicrokernelTester()
347 .rows(rows)
348 .channels(channels)
349 .input_stride(${next_prime(CHANNEL_TILE*16+1)})
350 .Test(${", ".join(TEST_ARGS)});
351 }
352 }
353 }
354
355 $if CHANNEL_TILE > 1:
356 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile) {
357 $if ISA_CHECK:
358 ${ISA_CHECK};
359 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
360 GAvgPoolMicrokernelTester()
361 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
362 .channels(channels)
363 .Test(${", ".join(TEST_ARGS)});
364 }
365 }
366
367 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
368 $if ISA_CHECK:
369 ${ISA_CHECK};
370 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
371 GAvgPoolMicrokernelTester()
372 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
373 .channels(channels)
374 .qmax(128)
375 .Test(${", ".join(TEST_ARGS)});
376 }
377 }
378
379 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
380 $if ISA_CHECK:
381 ${ISA_CHECK};
382 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
383 GAvgPoolMicrokernelTester()
384 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
385 .channels(channels)
386 .qmin(128)
387 .Test(${", ".join(TEST_ARGS)});
388 }
389 }
390
391 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_2pass_subtile) {
392 $if ISA_CHECK:
393 ${ISA_CHECK};
394 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
395 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
396 GAvgPoolMicrokernelTester()
397 .rows(rows)
398 .channels(channels)
399 .Test(${", ".join(TEST_ARGS)});
400 }
401 }
402 }
403
404 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_fulltile) {
405 $if ISA_CHECK:
406 ${ISA_CHECK};
407 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
408 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
409 GAvgPoolMicrokernelTester()
410 .rows(rows)
411 .channels(channels)
412 .Test(${", ".join(TEST_ARGS)});
413 }
414 }
415 }
416
417 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
418 $if ISA_CHECK:
419 ${ISA_CHECK};
420 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
421 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows <= ${INCREMENTAL_TILE*5}; rows += ${INCREMENTAL_TILE}) {
422 GAvgPoolMicrokernelTester()
423 .rows(rows)
424 .channels(channels)
425 .input_stride(${next_prime(CHANNEL_TILE+1)})
426 .Test(${", ".join(TEST_ARGS)});
427 }
428 }
429 }
430
431 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile) {
432 $if ISA_CHECK:
433 ${ISA_CHECK};
434 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
435 GAvgPoolMicrokernelTester()
436 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
437 .channels(channels)
438 .Test(${", ".join(TEST_ARGS)});
439 }
440 }
441
442 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile_with_qmax) {
443 $if ISA_CHECK:
444 ${ISA_CHECK};
445 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
446 GAvgPoolMicrokernelTester()
447 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
448 .channels(channels)
449 .qmax(128)
450 .Test(${", ".join(TEST_ARGS)});
451 }
452 }
453
454 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_fulltile_with_qmin) {
455 $if ISA_CHECK:
456 ${ISA_CHECK};
457 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
458 GAvgPoolMicrokernelTester()
459 .rows(${PRIMARY_TILE+INCREMENTAL_TILE})
460 .channels(channels)
461 .qmin(128)
462 .Test(${", ".join(TEST_ARGS)});
463 }
464 }
465
466 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_2pass_subtile) {
467 $if ISA_CHECK:
468 ${ISA_CHECK};
469 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
470 for (size_t rows = ${PRIMARY_TILE+1}; rows < ${PRIMARY_TILE+INCREMENTAL_TILE}; rows++) {
471 GAvgPoolMicrokernelTester()
472 .rows(rows)
473 .channels(channels)
474 .Test(${", ".join(TEST_ARGS)});
475 }
476 }
477 }
478
479 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_fulltile) {
480 $if ISA_CHECK:
481 ${ISA_CHECK};
482 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
483 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows < ${INCREMENTAL_TILE*5}; rows += ${PRIMARY_TILE+INCREMENTAL_TILE}) {
484 GAvgPoolMicrokernelTester()
485 .rows(rows)
486 .channels(channels)
487 .Test(${", ".join(TEST_ARGS)});
488 }
489 }
490 }
491
492 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_fulltile_with_input_stride) {
493 $if ISA_CHECK:
494 ${ISA_CHECK};
495 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
496 for (size_t rows = ${PRIMARY_TILE+INCREMENTAL_TILE}; rows < ${INCREMENTAL_TILE*5}; rows += ${PRIMARY_TILE+INCREMENTAL_TILE}) {
497 GAvgPoolMicrokernelTester()
498 .rows(rows)
499 .channels(channels)
500 .input_stride(${next_prime(CHANNEL_TILE*2+11)})
501 .Test(${", ".join(TEST_ARGS)});
502 }
503 }
504 }
505
506"""
507
508
509def generate_test_cases(ukernel, primary_tile, incremental_tile, channel_tile,
510 isa):
511 """Generates all tests cases for a GAVGPOOL micro-kernel.
512
513 Args:
514 ukernel: C name of the micro-kernel function.
515 primary_tile: Number of rows (pixels) processed per one iteration of the
516 primary outer loop of the micro-kernel.
517 incremental_tile: Number of rows (pixels) processed per one iteration of
518 the incremental outer loop of the micro-kernel.
519 channel_tile: Number of channels processed per one iteration of the inner
520 loops of the micro-kernel.
521 isa: instruction set required to run the micro-kernel. Generated unit test
522 will skip execution if the host processor doesn't support this ISA.
523
524 Returns:
525 Code for the test case.
526 """
527 _, test_name = ukernel.split("_", 1)
528 _, datatype, ukernel_type, _ = ukernel.split("_", 3)
529 test_args = [ukernel]
Marat Dukhan3de5dfa2020-12-10 11:19:47 -0800530 if not isa:
Frank Barchard6c74cd12020-05-20 21:26:47 -0700531 test_args.append("GAvgPoolMicrokernelTester::Variant::Scalar")
532 return xngen.preprocess(AVGPOOL_TEST_TEMPLATE, {
533 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
534 "TEST_ARGS": test_args,
535 "DATATYPE": datatype,
536 "PRIMARY_TILE": primary_tile,
537 "INCREMENTAL_TILE": incremental_tile,
538 "CHANNEL_TILE": channel_tile,
539 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
540 "next_prime": next_prime,
541 })
542
543
544def main(args):
545 options = parser.parse_args(args)
546
547 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
548 spec_yaml = yaml.safe_load(spec_file)
549 if not isinstance(spec_yaml, list):
550 raise ValueError("expected a list of micro-kernels in the spec")
551
552 tests = """\
553// Copyright (c) Facebook, Inc. and its affiliates.
554// All rights reserved.
555//
556// Copyright 2020 Google LLC
557//
558// This source code is licensed under the BSD-style license found in the
559// LICENSE file in the root directory of this source tree.
560//
561// Auto-generated file. Do not edit!
562// Specification: {specification}
563// Generator: {generator}
564
565
566#include <gtest/gtest.h>
567
568#include <xnnpack/common.h>
569#include <xnnpack/isa-checks.h>
570
571#include <xnnpack/gavgpool.h>
572#include "gavgpool-microkernel-tester.h"
573""".format(specification=options.spec, generator=sys.argv[0])
574
575 for ukernel_spec in spec_yaml:
576 name = ukernel_spec["name"]
577 primary_tile, incremental_tile, channel_tile, arch, isa = \
578 split_ukernel_name(name)
579
580 # specification can override architecture
581 arch = ukernel_spec.get("arch", arch)
582
583 test_case = generate_test_cases(name, primary_tile, incremental_tile,
584 channel_tile, isa)
585 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
586
587 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
588 output_file.write(tests)
589
590
591if __name__ == "__main__":
592 main(sys.argv[1:])