blob: fd56d5af6108e5fa707b7ef642704dfb84585987 [file] [log] [blame]
Marat Dukhan329da642019-11-19 21:44:39 -08001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import codecs
9import math
10import os
11import re
12import sys
13import yaml
14
15sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
16from primes import next_prime
17import xngen
18import xnncommon
19
20
21parser = argparse.ArgumentParser(
22 description='ArgMaxPool microkernel test generator')
23parser.add_argument("-s", "--spec", metavar="FILE", required=True,
24 help="Specification (YAML) file")
25parser.add_argument("-o", "--output", metavar="FILE", required=True,
26 help='Output (C++ source) file')
27parser.set_defaults(defines=list())
28
29
30def split_ukernel_name(name):
31 match = re.match(r"^xnn_(f16|f32)_argmaxpool_ukernel_((\d+)p)?(\d+)x__(.+)_c(\d+)$", name)
32 if match is None:
33 raise ValueError("Unexpected microkernel name: " + name)
34
35 if match.group(2):
36 primary_tile = int(match.group(3))
37 incremental_tile = int(match.group(4))
38 else:
39 primary_tile = int(match.group(4))
40 incremental_tile = 0
41
42 channel_tile = int(match.group(6))
43
44 arch, isa = xnncommon.parse_target_name(target_name=match.group(5))
45 return primary_tile, incremental_tile, channel_tile, arch, isa
46
47
48ARGMAXPOOL_TEST_TEMPLATE = """\
49$if INCREMENTAL_TILE == 0:
50 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_fulltile) {
51 $if ISA_CHECK:
52 ${ISA_CHECK};
53 ArgMaxPoolMicrokernelTester()
54 .pooling_elements(${PRIMARY_TILE})
55 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
56 .channels(${CHANNEL_TILE})
57 .Test(${", ".join(TEST_ARGS)});
58 }
59
60 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
61 $if ISA_CHECK:
62 ${ISA_CHECK};
63 ArgMaxPoolMicrokernelTester()
64 .pooling_elements(${PRIMARY_TILE})
65 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
66 .channels(${CHANNEL_TILE})
67 .input_offset(${next_prime(CHANNEL_TILE+1)})
68 .Test(${", ".join(TEST_ARGS)});
69 }
70
Marat Dukhan329da642019-11-19 21:44:39 -080071 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_subtile) {
72 $if ISA_CHECK:
73 ${ISA_CHECK};
74 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
75 ArgMaxPoolMicrokernelTester()
76 .pooling_elements(pooling_elements)
77 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
78 .channels(${CHANNEL_TILE})
79 .Test(${", ".join(TEST_ARGS)});
80 }
81 }
82
83 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
84 $if ISA_CHECK:
85 ${ISA_CHECK};
86 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
87 ArgMaxPoolMicrokernelTester()
88 .pooling_elements(pooling_elements)
89 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
90 .channels(${CHANNEL_TILE})
91 .input_offset(${next_prime(CHANNEL_TILE+1)})
92 .Test(${", ".join(TEST_ARGS)});
93 }
94 }
95
96 $if CHANNEL_TILE > 1:
97 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_fulltile) {
98 $if ISA_CHECK:
99 ${ISA_CHECK};
100 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
101 ArgMaxPoolMicrokernelTester()
102 .pooling_elements(${PRIMARY_TILE})
103 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
104 .channels(channels)
105 .Test(${", ".join(TEST_ARGS)});
106 }
107 }
108
109 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
110 $if ISA_CHECK:
111 ${ISA_CHECK};
112 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
113 ArgMaxPoolMicrokernelTester()
114 .pooling_elements(${PRIMARY_TILE})
115 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
116 .channels(channels)
117 .input_offset(${next_prime(CHANNEL_TILE*8)})
118 .Test(${", ".join(TEST_ARGS)});
119 }
120 }
121
Marat Dukhan329da642019-11-19 21:44:39 -0800122 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_subtile) {
123 $if ISA_CHECK:
124 ${ISA_CHECK};
125 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
126 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
127 ArgMaxPoolMicrokernelTester()
128 .pooling_elements(pooling_elements)
129 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
130 .channels(channels)
131 .Test(${", ".join(TEST_ARGS)});
132 }
133 }
134 }
135
136 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
137 $if ISA_CHECK:
138 ${ISA_CHECK};
139 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
140 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
141 ArgMaxPoolMicrokernelTester()
142 .pooling_elements(pooling_elements)
143 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
144 .channels(channels)
145 .input_offset(${next_prime(CHANNEL_TILE*8)})
146 .Test(${", ".join(TEST_ARGS)});
147 }
148 }
149 }
150
151 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_fulltile) {
152 $if ISA_CHECK:
153 ${ISA_CHECK};
154 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
155 ArgMaxPoolMicrokernelTester()
156 .pooling_elements(${PRIMARY_TILE})
157 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
158 .channels(channels)
159 .Test(${", ".join(TEST_ARGS)});
160 }
161 }
162
163 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
164 $if ISA_CHECK:
165 ${ISA_CHECK};
166 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
167 ArgMaxPoolMicrokernelTester()
168 .pooling_elements(${PRIMARY_TILE})
169 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
170 .channels(channels)
171 .input_offset(${next_prime(CHANNEL_TILE)})
172 .Test(${", ".join(TEST_ARGS)});
173 }
174 }
175
Marat Dukhan329da642019-11-19 21:44:39 -0800176 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_subtile) {
177 $if ISA_CHECK:
178 ${ISA_CHECK};
179 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
180 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
181 ArgMaxPoolMicrokernelTester()
182 .pooling_elements(pooling_elements)
183 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
184 .channels(channels)
185 .Test(${", ".join(TEST_ARGS)});
186 }
187 }
188 }
189
190 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
191 $if ISA_CHECK:
192 ${ISA_CHECK};
193 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
194 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
195 ArgMaxPoolMicrokernelTester()
196 .pooling_elements(pooling_elements)
197 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
198 .channels(channels)
199 .input_offset(${next_prime(CHANNEL_TILE)})
200 .Test(${", ".join(TEST_ARGS)});
201 }
202 }
203 }
204
205 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_fulltile) {
206 $if ISA_CHECK:
207 ${ISA_CHECK};
208 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
209 ArgMaxPoolMicrokernelTester()
210 .pooling_elements(${PRIMARY_TILE})
211 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
212 .channels(channels)
213 .Test(${", ".join(TEST_ARGS)});
214 }
215 }
216
217 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_fulltile_with_input_offset) {
218 $if ISA_CHECK:
219 ${ISA_CHECK};
220 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
221 ArgMaxPoolMicrokernelTester()
222 .pooling_elements(${PRIMARY_TILE})
223 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
224 .channels(channels)
225 .input_offset(${next_prime(CHANNEL_TILE*2)})
226 .Test(${", ".join(TEST_ARGS)});
227 }
228 }
229
Marat Dukhan329da642019-11-19 21:44:39 -0800230 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_subtile) {
231 $if ISA_CHECK:
232 ${ISA_CHECK};
233 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
234 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
235 ArgMaxPoolMicrokernelTester()
236 .pooling_elements(pooling_elements)
237 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
238 .channels(channels)
239 .Test(${", ".join(TEST_ARGS)});
240 }
241 }
242 }
243
244 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_unipass_subtile_with_input_offset) {
245 $if ISA_CHECK:
246 ${ISA_CHECK};
247 for (size_t pooling_elements = 2; pooling_elements < ${PRIMARY_TILE}; pooling_elements++) {
248 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
249 ArgMaxPoolMicrokernelTester()
250 .pooling_elements(pooling_elements)
251 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
252 .channels(channels)
253 .input_offset(${next_prime(CHANNEL_TILE*2)})
254 .Test(${", ".join(TEST_ARGS)});
255 }
256 }
257 }
258
259$if INCREMENTAL_TILE != 0:
260 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_fulltile) {
261 $if ISA_CHECK:
262 ${ISA_CHECK};
263 ArgMaxPoolMicrokernelTester()
264 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
265 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
266 .channels(${CHANNEL_TILE})
267 .Test(${", ".join(TEST_ARGS)});
268 }
269
270 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
271 $if ISA_CHECK:
272 ${ISA_CHECK};
273 ArgMaxPoolMicrokernelTester()
274 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
275 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
276 .channels(${CHANNEL_TILE})
277 .input_offset(${next_prime(CHANNEL_TILE+1)})
278 .Test(${", ".join(TEST_ARGS)});
279 }
280
Marat Dukhan329da642019-11-19 21:44:39 -0800281 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_subtile) {
282 $if ISA_CHECK:
283 ${ISA_CHECK};
284 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
285 ArgMaxPoolMicrokernelTester()
286 .pooling_elements(pooling_elements)
287 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
288 .channels(${CHANNEL_TILE})
289 .Test(${", ".join(TEST_ARGS)});
290 }
291 }
292
293 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
294 $if ISA_CHECK:
295 ${ISA_CHECK};
296 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
297 ArgMaxPoolMicrokernelTester()
298 .pooling_elements(pooling_elements)
299 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
300 .channels(${CHANNEL_TILE})
301 .input_offset(${next_prime(CHANNEL_TILE+1)})
302 .Test(${", ".join(TEST_ARGS)});
303 }
304 }
305
306 $if CHANNEL_TILE > 1:
307 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_fulltile) {
308 $if ISA_CHECK:
309 ${ISA_CHECK};
310 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
311 ArgMaxPoolMicrokernelTester()
312 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
313 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
314 .channels(channels)
315 .Test(${", ".join(TEST_ARGS)});
316 }
317 }
318
319 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
320 $if ISA_CHECK:
321 ${ISA_CHECK};
322 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
323 ArgMaxPoolMicrokernelTester()
324 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
325 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
326 .channels(channels)
327 .input_offset(${next_prime(CHANNEL_TILE*5)})
328 .Test(${", ".join(TEST_ARGS)});
329 }
330 }
331
Marat Dukhan329da642019-11-19 21:44:39 -0800332 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_subtile) {
333 $if ISA_CHECK:
334 ${ISA_CHECK};
335 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
336 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
337 ArgMaxPoolMicrokernelTester()
338 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
339 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
340 .channels(channels)
341 .Test(${", ".join(TEST_ARGS)});
342 }
343 }
344 }
345
346 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
347 $if ISA_CHECK:
348 ${ISA_CHECK};
349 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
350 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
351 ArgMaxPoolMicrokernelTester()
352 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
353 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
354 .channels(channels)
355 .input_offset(${next_prime(CHANNEL_TILE*8)})
356 .Test(${", ".join(TEST_ARGS)});
357 }
358 }
359 }
360
361 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_fulltile) {
362 $if ISA_CHECK:
363 ${ISA_CHECK};
364 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
365 ArgMaxPoolMicrokernelTester()
366 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
367 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
368 .channels(channels)
369 .Test(${", ".join(TEST_ARGS)});
370 }
371 }
372
373 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
374 $if ISA_CHECK:
375 ${ISA_CHECK};
376 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
377 ArgMaxPoolMicrokernelTester()
378 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
379 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
380 .channels(channels)
381 .input_offset(${next_prime(CHANNEL_TILE)})
382 .Test(${", ".join(TEST_ARGS)});
383 }
384 }
385
Marat Dukhan329da642019-11-19 21:44:39 -0800386 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_subtile) {
387 $if ISA_CHECK:
388 ${ISA_CHECK};
389 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
390 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
391 ArgMaxPoolMicrokernelTester()
392 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
393 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
394 .channels(channels)
395 .Test(${", ".join(TEST_ARGS)});
396 }
397 }
398 }
399
400 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
401 $if ISA_CHECK:
402 ${ISA_CHECK};
403 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
404 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
405 ArgMaxPoolMicrokernelTester()
406 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
407 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
408 .channels(channels)
409 .input_offset(${next_prime(CHANNEL_TILE)})
410 .Test(${", ".join(TEST_ARGS)});
411 }
412 }
413 }
414
415 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_fulltile) {
416 $if ISA_CHECK:
417 ${ISA_CHECK};
418 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
419 ArgMaxPoolMicrokernelTester()
420 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
421 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
422 .channels(channels)
423 .Test(${", ".join(TEST_ARGS)});
424 }
425 }
426
427 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_fulltile_with_input_offset) {
428 $if ISA_CHECK:
429 ${ISA_CHECK};
430 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
431 ArgMaxPoolMicrokernelTester()
432 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
433 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
434 .channels(channels)
435 .input_offset(${next_prime(CHANNEL_TILE*2)})
436 .Test(${", ".join(TEST_ARGS)});
437 }
438 }
439
Marat Dukhan329da642019-11-19 21:44:39 -0800440 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_subtile) {
441 $if ISA_CHECK:
442 ${ISA_CHECK};
443 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
444 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
445 ArgMaxPoolMicrokernelTester()
446 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
447 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
448 .channels(channels)
449 .Test(${", ".join(TEST_ARGS)});
450 }
451 }
452 }
453
454 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_twopass_subtile_with_input_offset) {
455 $if ISA_CHECK:
456 ${ISA_CHECK};
457 for (size_t pooling_elements = ${PRIMARY_TILE+1}; pooling_elements < ${PRIMARY_TILE+INCREMENTAL_TILE}; pooling_elements++) {
458 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
459 ArgMaxPoolMicrokernelTester()
460 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
461 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
462 .channels(channels)
463 .input_offset(${next_prime(CHANNEL_TILE*2)})
464 .Test(${", ".join(TEST_ARGS)});
465 }
466 }
467 }
468
469 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass) {
470 $if ISA_CHECK:
471 ${ISA_CHECK};
472 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
473 ArgMaxPoolMicrokernelTester()
474 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
475 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
476 .channels(${CHANNEL_TILE})
477 .Test(${", ".join(TEST_ARGS)});
478 }
479 }
480
481 TEST(${TEST_NAME}, channels_eq_${CHANNEL_TILE}_multipass_with_input_offset) {
482 $if ISA_CHECK:
483 ${ISA_CHECK};
484 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
485 ArgMaxPoolMicrokernelTester()
486 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
487 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
488 .channels(${CHANNEL_TILE})
489 .input_offset(${next_prime(CHANNEL_TILE+1)})
490 .Test(${", ".join(TEST_ARGS)});
491 }
492 }
493
Marat Dukhan329da642019-11-19 21:44:39 -0800494 $if CHANNEL_TILE > 1:
495 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass) {
496 $if ISA_CHECK:
497 ${ISA_CHECK};
498 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
499 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
500 ArgMaxPoolMicrokernelTester()
501 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
502 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
503 .channels(channels)
504 .Test(${", ".join(TEST_ARGS)});
505 }
506 }
507 }
508
509 TEST(${TEST_NAME}, channels_div_${CHANNEL_TILE}_multipass_with_input_offset) {
510 $if ISA_CHECK:
511 ${ISA_CHECK};
512 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
513 for (size_t channels = ${CHANNEL_TILE*2}; channels < ${CHANNEL_TILE*8}; channels += ${CHANNEL_TILE}) {
514 ArgMaxPoolMicrokernelTester()
515 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
516 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
517 .channels(channels)
518 .input_offset(${next_prime(CHANNEL_TILE*8)})
519 .Test(${", ".join(TEST_ARGS)});
520 }
521 }
522 }
523
Marat Dukhan329da642019-11-19 21:44:39 -0800524 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass) {
525 $if ISA_CHECK:
526 ${ISA_CHECK};
527 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
528 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
529 ArgMaxPoolMicrokernelTester()
530 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
531 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
532 .channels(channels)
533 .Test(${", ".join(TEST_ARGS)});
534 }
535 }
536 }
537
538 TEST(${TEST_NAME}, channels_lt_${CHANNEL_TILE}_multipass_with_input_offset) {
539 $if ISA_CHECK:
540 ${ISA_CHECK};
541 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
542 for (size_t channels = 1; channels < ${CHANNEL_TILE}; channels++) {
543 ArgMaxPoolMicrokernelTester()
544 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
545 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
546 .channels(channels)
547 .input_offset(${CHANNEL_TILE})
548 .Test(${", ".join(TEST_ARGS)});
549 }
550 }
551 }
552
Marat Dukhan329da642019-11-19 21:44:39 -0800553 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass) {
554 $if ISA_CHECK:
555 ${ISA_CHECK};
556 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
557 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
558 ArgMaxPoolMicrokernelTester()
559 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
560 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
561 .channels(channels)
562 .Test(${", ".join(TEST_ARGS)});
563 }
564 }
565 }
566
567 TEST(${TEST_NAME}, channels_gt_${CHANNEL_TILE}_multipass_with_input_offset) {
568 $if ISA_CHECK:
569 ${ISA_CHECK};
570 for (size_t pooling_elements = ${PRIMARY_TILE+INCREMENTAL_TILE+1}; pooling_elements <= ${PRIMARY_TILE+INCREMENTAL_TILE*3}; pooling_elements += 3) {
571 for (size_t channels = ${CHANNEL_TILE+1}; channels < ${10 if CHANNEL_TILE == 1 else CHANNEL_TILE*2}; channels++) {
572 ArgMaxPoolMicrokernelTester()
573 .pooling_elements(${PRIMARY_TILE+INCREMENTAL_TILE})
574 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
575 .channels(channels)
576 .input_offset(${next_prime(CHANNEL_TILE*2)})
577 .Test(${", ".join(TEST_ARGS)});
578 }
579 }
580 }
581
Marat Dukhan329da642019-11-19 21:44:39 -0800582$if INCREMENTAL_TILE == 0:
583 $MIN_POOLING, MAX_POOLING = 2, PRIMARY_TILE
584$else:
585 $MIN_POOLING, MAX_POOLING = PRIMARY_TILE + 1, PRIMARY_TILE + INCREMENTAL_TILE
586
587TEST(${TEST_NAME}, few_output_pixels) {
588 $if ISA_CHECK:
589 ${ISA_CHECK};
590 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
591 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
592 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
593 ArgMaxPoolMicrokernelTester()
594 .output_pixels(output_pixels)
595 .pooling_elements(pooling_elements)
596 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
597 .channels(channels)
598 .Test(${", ".join(TEST_ARGS)});
599 }
600 }
601 }
602}
603
604TEST(${TEST_NAME}, few_output_pixels_with_input_offset) {
605 $if ISA_CHECK:
606 ${ISA_CHECK};
607 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
608 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
609 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
610 ArgMaxPoolMicrokernelTester()
611 .output_pixels(output_pixels)
612 .pooling_elements(pooling_elements)
613 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
614 .channels(channels)
615 .input_offset(${next_prime(CHANNEL_TILE*5+1)})
616 .Test(${", ".join(TEST_ARGS)});
617 }
618 }
619 }
620}
621
Marat Dukhan329da642019-11-19 21:44:39 -0800622TEST(${TEST_NAME}, few_output_pixels_with_output_stride) {
623 $if ISA_CHECK:
624 ${ISA_CHECK};
625 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
626 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
627 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
628 ArgMaxPoolMicrokernelTester()
629 .output_pixels(output_pixels)
630 .pooling_elements(pooling_elements)
631 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
632 .channels(channels)
633 .output_stride(${next_prime(CHANNEL_TILE*5+1)})
634 .Test(${", ".join(TEST_ARGS)});
635 }
636 }
637 }
638}
639
640TEST(${TEST_NAME}, few_output_pixels_with_step) {
641 $if ISA_CHECK:
642 ${ISA_CHECK};
643 for (size_t output_pixels = 2; output_pixels <= 5; output_pixels++) {
644 for (size_t pooling_elements = ${MIN_POOLING}; pooling_elements <= ${MAX_POOLING}; pooling_elements++) {
645 for (size_t channels = 1; channels <= ${CHANNEL_TILE*5}; channels += ${max(1, CHANNEL_TILE-1)}) {
646 for (size_t step = 2; step <= pooling_elements; step++) {
647 ArgMaxPoolMicrokernelTester()
648 .output_pixels(output_pixels)
649 .pooling_elements(pooling_elements)
650 .pooling_tile(${", ".join(map(str, filter(bool, [PRIMARY_TILE, INCREMENTAL_TILE])))})
651 .step(step)
652 .channels(channels)
653 .output_stride(${next_prime(CHANNEL_TILE*5+1)})
654 .Test(${", ".join(TEST_ARGS)});
655 }
656 }
657 }
658 }
659}
660"""
661
662
663def generate_test_cases(ukernel, primary_tile, incremental_tile, channel_tile,
664 isa):
665 """Generates all tests cases for a ARGMAXPOOL micro-kernel.
666
667 Args:
668 ukernel: C name of the micro-kernel function.
669 primary_tile: Number of rows (pixels) processed per one iteration of the
670 primary outer loop of the micro-kernel.
671 incremental_tile: Number of rows (pixels) processed per one iteration of
672 the incremental outer loop of the micro-kernel.
673 channel_tile: Number of channels processed per one iteration of the inner
674 loops of the micro-kernel.
675 isa: instruction set required to run the micro-kernel. Generated unit test
676 will skip execution if the host processor doesn't support this ISA.
677
678 Returns:
679 Code for the test case.
680 """
681 _, test_name = ukernel.split("_", 1)
682 _, datatype, ukernel_type, _ = ukernel.split("_", 3)
683 test_args = [ukernel]
Marat Dukhan3de5dfa2020-12-10 11:19:47 -0800684 if not isa:
Marat Dukhan329da642019-11-19 21:44:39 -0800685 test_args.append("ArgMaxPoolMicrokernelTester::Variant::Scalar")
686 return xngen.preprocess(ARGMAXPOOL_TEST_TEMPLATE, {
687 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
688 "TEST_ARGS": test_args,
689 "DATATYPE": datatype,
690 "PRIMARY_TILE": primary_tile,
691 "INCREMENTAL_TILE": incremental_tile,
692 "CHANNEL_TILE": channel_tile,
693 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
694 "next_prime": next_prime,
695 })
696
697
698def main(args):
699 options = parser.parse_args(args)
700
701 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
702 spec_yaml = yaml.safe_load(spec_file)
703 if not isinstance(spec_yaml, list):
704 raise ValueError("expected a list of micro-kernels in the spec")
705
706 tests = """\
707// Copyright 2019 Google LLC
708//
709// This source code is licensed under the BSD-style license found in the
710// LICENSE file in the root directory of this source tree.
711//
712// Auto-generated file. Do not edit!
713// Specification: {specification}
714// Generator: {generator}
715
716
717#include <gtest/gtest.h>
718
719#include <xnnpack/common.h>
720#include <xnnpack/isa-checks.h>
721
722#include <xnnpack/argmaxpool.h>
723#include "argmaxpool-microkernel-tester.h"
724""".format(specification=options.spec, generator=sys.argv[0])
725
726 for ukernel_spec in spec_yaml:
727 name = ukernel_spec["name"]
728 primary_tile, incremental_tile, channel_tile, arch, isa = \
729 split_ukernel_name(name)
730
731 # specification can override architecture
732 arch = ukernel_spec.get("arch", arch)
733
734 test_case = generate_test_cases(name, primary_tile, incremental_tile,
735 channel_tile, isa)
736 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
737
Frank Barchard1f83cf92021-09-07 14:13:03 -0700738 txt_changed = True
739 if os.path.exists(options.output):
740 with codecs.open(options.output, "r", encoding="utf-8") as output_file:
741 txt_changed = output_file.read() != tests
742
743 if txt_changed:
744 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
745 output_file.write(tests)
Marat Dukhan329da642019-11-19 21:44:39 -0800746
747
748if __name__ == "__main__":
749 main(sys.argv[1:])