blob: 850c1951d75b224b0726a9c507acc64925827611 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import os
11import sys
12import yaml
13
14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15from primes import next_prime
16import xngen
Marat Dukhan918a4a62019-10-27 19:49:49 -070017import xnncommon
XNNPACK Teamb455b122019-09-27 18:10:33 -070018
19
20parser = argparse.ArgumentParser(description='XNNPACK generator')
21parser.add_argument("-s", "--spec", metavar="FILE", required=True,
22 help="Spec (YAML) file")
23parser.add_argument("-o", "--output", metavar="FILE", required=True,
24 help='Output (C++ source) file')
25parser.set_defaults(defines=list())
26
27
XNNPACK Teamb455b122019-09-27 18:10:33 -070028def split_ukernel_name(name):
29 common_name, target_name = name.split("__", 1)
30 common_parts = common_name.split("_")
31 param_spec = common_parts[-1]
32 if "s" in param_spec:
33 param_spec, sr = param_spec.split("s", 1)
34 sr = int(sr)
35 else:
36 sr = 1
37 if "c" in param_spec:
38 param_spec, kr = param_spec.split("c", 1)
39 kr = int(kr)
40 else:
41 kr = 1
42 mr, nr = map(int, param_spec.split("x"))
Marat Dukhan918a4a62019-10-27 19:49:49 -070043 arch, isa = xnncommon.parse_target_name(target_name)
XNNPACK Teamb455b122019-09-27 18:10:33 -070044 return mr, nr, kr, sr, arch, isa
45
46
47GEMM_TEST_CODE = """\
48TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
49 $if ISA_CHECK:
50 ${ISA_CHECK};
51 GemmMicrokernelTester()
52 .mr(${MR})
53 .nr(${NR})
54 .kr(${KR})
55 .sr(${SR})
56 .m(${MR})
57 .n(${NR})
58 .k(${KBLOCK})
59 .Test(${", ".join(TEST_ARGS)});
60}
61
62TEST(${TEST_NAME}, strided_cn) {
63 $if ISA_CHECK:
64 ${ISA_CHECK};
65 GemmMicrokernelTester()
66 .mr(${MR})
67 .nr(${NR})
68 .kr(${KR})
69 .sr(${SR})
70 .m(${MR})
71 .n(${NR})
72 .k(${KBLOCK})
73 .cn_stride(${next_prime(NR + 1)})
74 .Test(${", ".join(TEST_ARGS)});
75}
76
77$if UKERNEL_TYPE != "IGEMM":
78 TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) {
79 $if ISA_CHECK:
80 ${ISA_CHECK};
81 GemmMicrokernelTester()
82 .mr(${MR})
83 .nr(${NR})
84 .kr(${KR})
85 .sr(${SR})
86 .m(${MR})
87 .n(${NR})
88 .k(${KBLOCK})
89 .a_stride(${next_prime(KBLOCK + 1)})
90 .Test(${", ".join(TEST_ARGS)});
91 }
92
93TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
94 $if ISA_CHECK:
95 ${ISA_CHECK};
96 for (uint32_t m = 1; m <= ${MR}; m++) {
97 for (uint32_t n = 1; n <= ${NR}; n++) {
98 GemmMicrokernelTester()
99 .mr(${MR})
100 .nr(${NR})
101 .kr(${KR})
102 .sr(${SR})
103 .m(m)
104 .n(n)
105 .k(${KBLOCK})
106 .iterations(1)
107 .Test(${", ".join(TEST_ARGS)});
108 }
109 }
110}
111
112TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) {
113 $if ISA_CHECK:
114 ${ISA_CHECK};
115 for (uint32_t m = 1; m <= ${MR}; m++) {
116 GemmMicrokernelTester()
117 .mr(${MR})
118 .nr(${NR})
119 .kr(${KR})
120 .sr(${SR})
121 .m(m)
122 .n(${NR})
123 .k(${KBLOCK})
124 .iterations(1)
125 .Test(${", ".join(TEST_ARGS)});
126 }
127}
128
129
130TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) {
131 $if ISA_CHECK:
132 ${ISA_CHECK};
133 for (uint32_t n = 1; n <= ${NR}; n++) {
134 GemmMicrokernelTester()
135 .mr(${MR})
136 .nr(${NR})
137 .kr(${KR})
138 .sr(${SR})
139 .m(${MR})
140 .n(n)
141 .k(${KBLOCK})
142 .iterations(1)
143 .Test(${", ".join(TEST_ARGS)});
144 }
145}
146
147$if IS_PIPELINED:
148 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
149 $if ISA_CHECK:
150 ${ISA_CHECK};
151 GemmMicrokernelTester()
152 .mr(${MR})
153 .nr(${NR})
154 .kr(${KR})
155 .sr(${SR})
156 .m(${MR})
157 .n(${NR})
158 .k(${KBLOCK * 2})
159 .Test(${", ".join(TEST_ARGS)});
160 }
161
162 $if UKERNEL_TYPE != "IGEMM":
163 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) {
164 $if ISA_CHECK:
165 ${ISA_CHECK};
166 GemmMicrokernelTester()
167 .mr(${MR})
168 .nr(${NR})
169 .kr(${KR})
170 .sr(${SR})
171 .m(${MR})
172 .n(${NR})
173 .k(${KBLOCK * 2})
174 .a_stride(${next_prime(KBLOCK * 2 + 1)})
175 .Test(${", ".join(TEST_ARGS)});
176 }
177
178 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
179 $if ISA_CHECK:
180 ${ISA_CHECK};
181 for (uint32_t m = 1; m <= ${MR}; m++) {
182 for (uint32_t n = 1; n <= ${NR}; n++) {
183 GemmMicrokernelTester()
184 .mr(${MR})
185 .nr(${NR})
186 .kr(${KR})
187 .sr(${SR})
188 .m(m)
189 .n(n)
190 .k(${KBLOCK * 2})
191 .iterations(1)
192 .Test(${", ".join(TEST_ARGS)});
193 }
194 }
195 }
196
197$if KBLOCK > 1:
198 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
199 $if ISA_CHECK:
200 ${ISA_CHECK};
201 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
202 GemmMicrokernelTester()
203 .mr(${MR})
204 .nr(${NR})
205 .kr(${KR})
206 .sr(${SR})
207 .m(${MR})
208 .n(${NR})
209 .k(k)
210 .Test(${", ".join(TEST_ARGS)});
211 }
212 }
213
214 $if UKERNEL_TYPE != "IGEMM":
215 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) {
216 $if ISA_CHECK:
217 ${ISA_CHECK};
218 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
219 GemmMicrokernelTester()
220 .mr(${MR})
221 .nr(${NR})
222 .kr(${KR})
223 .sr(${SR})
224 .m(${MR})
225 .n(${NR})
226 .k(k)
227 .a_stride(${next_prime(ADJKBLOCK + 1)})
228 .Test(${", ".join(TEST_ARGS)});
229 }
230 }
231
232 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
233 $if ISA_CHECK:
234 ${ISA_CHECK};
235 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
236 for (uint32_t m = 1; m <= ${MR}; m++) {
237 for (uint32_t n = 1; n <= ${NR}; n++) {
238 GemmMicrokernelTester()
239 .mr(${MR})
240 .nr(${NR})
241 .kr(${KR})
242 .sr(${SR})
243 .m(m)
244 .n(n)
245 .k(k)
246 .iterations(1)
247 .Test(${", ".join(TEST_ARGS)});
248 }
249 }
250 }
251 }
252
253TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
254 $if ISA_CHECK:
255 ${ISA_CHECK};
256 for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
257 GemmMicrokernelTester()
258 .mr(${MR})
259 .nr(${NR})
260 .kr(${KR})
261 .sr(${SR})
262 .m(${MR})
263 .n(${NR})
264 .k(k)
265 .Test(${", ".join(TEST_ARGS)});
266 }
267}
268
269$if UKERNEL_TYPE.startswith("GEMM"):
270 TEST(${TEST_NAME}, k_gt_${KBLOCK}_strided_a) {
271 $if ISA_CHECK:
272 ${ISA_CHECK};
273 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
274 GemmMicrokernelTester()
275 .mr(${MR})
276 .nr(${NR})
277 .kr(${KR})
278 .sr(${SR})
279 .m(${MR})
280 .n(${NR})
281 .k(k)
282 .a_stride(${next_prime(10 if KBLOCK == 1 else KBLOCK * 2 + 1)})
283 .Test(${", ".join(TEST_ARGS)});
284 }
285 }
286
287TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
288 $if ISA_CHECK:
289 ${ISA_CHECK};
290 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
291 for (uint32_t m = 1; m <= ${MR}; m++) {
292 for (uint32_t n = 1; n <= ${NR}; n++) {
293 GemmMicrokernelTester()
294 .mr(${MR})
295 .nr(${NR})
296 .kr(${KR})
297 .sr(${SR})
298 .m(m)
299 .n(n)
300 .k(k)
301 .iterations(1)
302 .Test(${", ".join(TEST_ARGS)});
303 }
304 }
305 }
306}
307
308$if KBLOCK > 1:
309 TEST(${TEST_NAME}, k_div_${KBLOCK}) {
310 $if ISA_CHECK:
311 ${ISA_CHECK};
312 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
313 GemmMicrokernelTester()
314 .mr(${MR})
315 .nr(${NR})
316 .kr(${KR})
317 .sr(${SR})
318 .m(${MR})
319 .n(${NR})
320 .k(k)
321 .Test(${", ".join(TEST_ARGS)});
322 }
323 }
324
325 $if UKERNEL_TYPE.startswith("GEMM"):
326 TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) {
327 $if ISA_CHECK:
328 ${ISA_CHECK};
329 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
330 GemmMicrokernelTester()
331 .mr(${MR})
332 .nr(${NR})
333 .kr(${KR})
334 .sr(${SR})
335 .m(${MR})
336 .n(${NR})
337 .k(k)
338 .a_stride(${next_prime(KBLOCK * 10 + 1)})
339 .Test(${", ".join(TEST_ARGS)});
340 }
341 }
342
343 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
344 $if ISA_CHECK:
345 ${ISA_CHECK};
346 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
347 for (uint32_t m = 1; m <= ${MR}; m++) {
348 for (uint32_t n = 1; n <= ${NR}; n++) {
349 GemmMicrokernelTester()
350 .mr(${MR})
351 .nr(${NR})
352 .kr(${KR})
353 .sr(${SR})
354 .m(m)
355 .n(n)
356 .k(k)
357 .iterations(1)
358 .Test(${", ".join(TEST_ARGS)});
359 }
360 }
361 }
362 }
363
364TEST(${TEST_NAME}, n_gt_${NR}) {
365 $if ISA_CHECK:
366 ${ISA_CHECK};
367 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
368 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
369 GemmMicrokernelTester()
370 .mr(${MR})
371 .nr(${NR})
372 .kr(${KR})
373 .sr(${SR})
374 .m(${MR})
375 .n(${NR})
376 .k(k)
377 .Test(${", ".join(TEST_ARGS)});
378 }
379 }
380}
381
382TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) {
383 $if ISA_CHECK:
384 ${ISA_CHECK};
385 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
386 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
387 GemmMicrokernelTester()
388 .mr(${MR})
389 .nr(${NR})
390 .kr(${KR})
391 .sr(${SR})
392 .m(${MR})
393 .n(${NR})
394 .k(k)
395 .cn_stride(${next_prime(NR + 1)})
396 .Test(${", ".join(TEST_ARGS)});
397 }
398 }
399}
400
401$if UKERNEL_TYPE != "IGEMM":
402 TEST(${TEST_NAME}, n_gt_${NR}_strided_a) {
403 $if ISA_CHECK:
404 ${ISA_CHECK};
405 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
406 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
407 GemmMicrokernelTester()
408 .mr(${MR})
409 .nr(${NR})
410 .kr(${KR})
411 .sr(${SR})
412 .m(${MR})
413 .n(n)
414 .k(k)
415 .a_stride(${next_prime(KBLOCK * 5 + 1)})
416 .Test(${", ".join(TEST_ARGS)});
417 }
418 }
419 }
420
421TEST(${TEST_NAME}, n_gt_${NR}_subtile) {
422 $if ISA_CHECK:
423 ${ISA_CHECK};
424 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
425 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
426 for (uint32_t m = 1; m <= ${MR}; m++) {
427 GemmMicrokernelTester()
428 .mr(${MR})
429 .nr(${NR})
430 .kr(${KR})
431 .sr(${SR})
432 .m(m)
433 .n(n)
434 .k(k)
435 .iterations(1)
436 .Test(${", ".join(TEST_ARGS)});
437 }
438 }
439 }
440}
441
442TEST(${TEST_NAME}, n_div_${NR}) {
443 $if ISA_CHECK:
444 ${ISA_CHECK};
445 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
446 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
447 GemmMicrokernelTester()
448 .mr(${MR})
449 .nr(${NR})
450 .kr(${KR})
451 .sr(${SR})
452 .m(${MR})
453 .n(${NR})
454 .k(k)
455 .Test(${", ".join(TEST_ARGS)});
456 }
457 }
458}
459
460TEST(${TEST_NAME}, n_div_${NR}_strided_cn) {
461 $if ISA_CHECK:
462 ${ISA_CHECK};
463 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
464 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
465 GemmMicrokernelTester()
466 .mr(${MR})
467 .nr(${NR})
468 .kr(${KR})
469 .sr(${SR})
470 .m(${MR})
471 .n(n)
472 .k(k)
473 .cn_stride(${next_prime(NR + 1)})
474 .Test(${", ".join(TEST_ARGS)});
475 }
476 }
477}
478
479$if UKERNEL_TYPE != "IGEMM":
480 TEST(${TEST_NAME}, n_div_${NR}_strided_a) {
481 $if ISA_CHECK:
482 ${ISA_CHECK};
483 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
484 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
485 GemmMicrokernelTester()
486 .mr(${MR})
487 .nr(${NR})
488 .kr(${KR})
489 .sr(${SR})
490 .m(${MR})
491 .n(n)
492 .k(k)
493 .a_stride(${next_prime(KBLOCK * 5 + 1)})
494 .Test(${", ".join(TEST_ARGS)});
495 }
496 }
497 }
498
499TEST(${TEST_NAME}, n_div_${NR}_subtile) {
500 $if ISA_CHECK:
501 ${ISA_CHECK};
502 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
503 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
504 for (uint32_t m = 1; m <= ${MR}; m++) {
505 GemmMicrokernelTester()
506 .mr(${MR})
507 .nr(${NR})
508 .kr(${KR})
509 .sr(${SR})
510 .m(m)
511 .n(n)
512 .k(k)
513 .iterations(1)
514 .Test(${", ".join(TEST_ARGS)});
515 }
516 }
517 }
518}
519
520$if UKERNEL_TYPE.startswith("IGEMM"):
521 TEST(${TEST_NAME}, small_kernel) {
522 $if ISA_CHECK:
523 ${ISA_CHECK};
524 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
525 GemmMicrokernelTester()
526 .mr(${MR})
527 .nr(${NR})
528 .kr(${KR})
529 .sr(${SR})
530 .m(${MR})
531 .n(${NR})
532 .k(k)
533 .ks(3)
534 .Test(${", ".join(TEST_ARGS)});
535 }
536 }
537
538 TEST(${TEST_NAME}, small_kernel_subtile) {
539 $if ISA_CHECK:
540 ${ISA_CHECK};
541 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
542 for (uint32_t m = 1; m <= ${MR}; m++) {
543 for (uint32_t n = 1; n <= ${NR}; n++) {
544 GemmMicrokernelTester()
545 .mr(${MR})
546 .nr(${NR})
547 .kr(${KR})
548 .sr(${SR})
549 .m(m)
550 .n(n)
551 .k(k)
552 .ks(3)
553 .iterations(1)
554 .Test(${", ".join(TEST_ARGS)});
555 }
556 }
557 }
558 }
559
560 TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) {
561 $if ISA_CHECK:
562 ${ISA_CHECK};
563 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
564 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
565 GemmMicrokernelTester()
566 .mr(${MR})
567 .nr(${NR})
568 .kr(${KR})
569 .sr(${SR})
570 .m(${MR})
571 .n(${NR})
572 .k(k)
573 .ks(3)
574 .Test(${", ".join(TEST_ARGS)});
575 }
576 }
577 }
578
579 TEST(${TEST_NAME}, n_div_${NR}_small_kernel) {
580 $if ISA_CHECK:
581 ${ISA_CHECK};
582 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
583 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
584 GemmMicrokernelTester()
585 .mr(${MR})
586 .nr(${NR})
587 .kr(${KR})
588 .sr(${SR})
589 .m(${MR})
590 .n(${NR})
591 .k(k)
592 .ks(3)
593 .Test(${", ".join(TEST_ARGS)});
594 }
595 }
596 }
597
598TEST(${TEST_NAME}, strided_cm_subtile) {
599 $if ISA_CHECK:
600 ${ISA_CHECK};
601 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
602 for (uint32_t m = 1; m <= ${MR}; m++) {
603 for (uint32_t n = 1; n <= ${NR}; n++) {
604 GemmMicrokernelTester()
605 .mr(${MR})
606 .nr(${NR})
607 .kr(${KR})
608 .sr(${SR})
609 .m(m)
610 .n(n)
611 .k(k)
612 .cm_stride(${next_prime(NR + 1)})
613 .iterations(1)
614 .Test(${", ".join(TEST_ARGS)});
615 }
616 }
617 }
618}
619
620$if UKERNEL_TYPE.startswith("IGEMM"):
621 TEST(${TEST_NAME}, a_offset) {
622 $if ISA_CHECK:
623 ${ISA_CHECK};
624 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
625 GemmMicrokernelTester()
626 .mr(${MR})
627 .nr(${NR})
628 .kr(${KR})
629 .sr(${SR})
630 .m(${MR})
631 .n(${NR})
632 .k(k)
633 .ks(3)
634 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
635 .Test(${", ".join(TEST_ARGS)});
636 }
637 }
638
639 TEST(${TEST_NAME}, zero) {
640 $if ISA_CHECK:
641 ${ISA_CHECK};
642 for (uint32_t mz = 0; mz < ${MR}; mz++) {
643 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
644 GemmMicrokernelTester()
645 .mr(${MR})
646 .nr(${NR})
647 .kr(${KR})
648 .sr(${SR})
649 .m(${MR})
650 .n(${NR})
651 .k(k)
652 .ks(3)
653 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
654 .zero_index(mz)
655 .Test(${", ".join(TEST_ARGS)});
656 }
657 }
658 }
659
660TEST(${TEST_NAME}, qmin) {
661 $if ISA_CHECK:
662 ${ISA_CHECK};
663 GemmMicrokernelTester()
664 .mr(${MR})
665 .nr(${NR})
666 .kr(${KR})
667 .sr(${SR})
668 .m(${MR})
669 .n(${NR})
670 .k(${KBLOCK})
671 .qmin(128)
672 .Test(${", ".join(TEST_ARGS)});
673}
674
675TEST(${TEST_NAME}, qmax) {
676 $if ISA_CHECK:
677 ${ISA_CHECK};
678 GemmMicrokernelTester()
679 .mr(${MR})
680 .nr(${NR})
681 .kr(${KR})
682 .sr(${SR})
683 .m(${MR})
684 .n(${NR})
685 .k(${KBLOCK})
686 .qmax(128)
687 .Test(${", ".join(TEST_ARGS)});
688}
689
690TEST(${TEST_NAME}, strided_cm) {
691 $if ISA_CHECK:
692 ${ISA_CHECK};
693 GemmMicrokernelTester()
694 .mr(${MR})
695 .nr(${NR})
696 .kr(${KR})
697 .sr(${SR})
698 .m(${MR})
699 .n(${NR})
700 .k(${KBLOCK})
701 .cm_stride(${next_prime(NR + 1)})
702 .Test(${", ".join(TEST_ARGS)});
703}
704
705$if DATATYPE == "q8":
706 TEST(${TEST_NAME}, no_a_zero_point) {
707 $if ISA_CHECK:
708 ${ISA_CHECK};
709 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
710 GemmMicrokernelTester()
711 .mr(${MR})
712 .nr(${NR})
713 .kr(${KR})
714 .sr(${SR})
715 .m(${MR})
716 .n(${NR})
717 .k(k)
718 .a_zero_point(0)
719 .Test(${", ".join(TEST_ARGS)});
720 }
721 }
722
723 TEST(${TEST_NAME}, no_b_zero_point) {
724 $if ISA_CHECK:
725 ${ISA_CHECK};
726 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
727 GemmMicrokernelTester()
728 .mr(${MR})
729 .nr(${NR})
730 .kr(${KR})
731 .sr(${SR})
732 .m(${MR})
733 .n(${NR})
734 .k(k)
735 .b_zero_point(0)
736 .Test(${", ".join(TEST_ARGS)});
737 }
738 }
739
740 TEST(${TEST_NAME}, no_zero_point) {
741 $if ISA_CHECK:
742 ${ISA_CHECK};
743 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
744 GemmMicrokernelTester()
745 .mr(${MR})
746 .nr(${NR})
747 .kr(${KR})
748 .sr(${SR})
749 .m(${MR})
750 .n(${NR})
751 .k(k)
752 .a_zero_point(0)
753 .b_zero_point(0)
754 .Test(${", ".join(TEST_ARGS)});
755 }
756 }
757"""
758
759
760def generate_test_cases(ukernel, mr, nr, kr, sr,
761 k_block, is_pipelined, isa):
762 """Generates all tests cases for a GEMM micro-kernel.
763
764 Args:
765 ukernel: C name of the micro-kernel function.
766 mr: MR parameter of the GEMM micro-kernel.
767 nr: NR parameter of the GEMM micro-kernel.
768 kr: KR parameter of the GEMM micro-kernel.
769 sr: SR parameter of the GEMM micro-kernel.
770 k_block: Number of K values processed per one iteration of the main loop of
771 the micro-kernel.
772 is_pipelined: Indicates if the micro-kernel is implemented with software
773 pipelining. Additional test cases are generated for software
774 pipelined micro-kernels to separately test prologue + epiloque
775 of the pipelined loop and iteration of the pipelined loop.
776 isa: instruction set required to run the micro-kernel. Generated unit test
777 will skip execution if the host processor doesn't support this ISA.
778
779 Returns:
780 Code for the test case.
781 """
782 _, test_name = ukernel.split("_", 1)
783 _, datatype, ukernel_type, _ = ukernel.split("_", 3)
784 test_args = [ukernel]
785 if not isa or isa == "psimd":
786 test_args.append("GemmMicrokernelTester::Variant::Scalar")
787 return xngen.preprocess(GEMM_TEST_CODE, {
788 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
789 "TEST_ARGS": test_args,
790 "UKERNEL_TYPE": ukernel_type.upper(),
791 "DATATYPE": datatype,
792 "MR": mr,
793 "NR": nr,
794 "KR": kr,
795 "SR": sr,
796 "KBLOCK": k_block,
797 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
798 "IS_PIPELINED": is_pipelined,
Marat Dukhan918a4a62019-10-27 19:49:49 -0700799 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700800 "next_prime": next_prime,
801 })
802
803
804def main(args):
805 options = parser.parse_args(args)
806
807 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
808 spec_yaml = yaml.safe_load(spec_file)
809 if not isinstance(spec_yaml, list):
810 raise ValueError("expected a list of micro-kernels in the spec")
811
812 tests = """\
813// Copyright (c) Facebook, Inc. and its affiliates.
814// All rights reserved.
815//
816// Copyright 2019 Google LLC
817//
818// This source code is licensed under the BSD-style license found in the
819// LICENSE file in the root directory of this source tree.
820//
821// Auto-generated file. Do not edit!
822// Specification: {specification}
823// Generator: {generator}
824
825
Marat Dukhan629a33e2019-10-01 10:39:14 -0700826#include <gtest/gtest.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700827
Marat Dukhan1dadbf72019-10-01 10:46:20 -0700828#include <xnnpack/common.h>
829#include <xnnpack/isa-checks.h>
830
XNNPACK Teamb455b122019-09-27 18:10:33 -0700831#include <xnnpack/gemm.h>
832#include <xnnpack/igemm.h>
833#include <xnnpack/ppmm.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700834#include "gemm-microkernel-tester.h"
835""".format(specification=options.spec, generator=sys.argv[0])
836
837 for ukernel_spec in spec_yaml:
838 name = ukernel_spec["name"]
839 k_block = int(ukernel_spec["k-block"])
840 pipelined = bool(ukernel_spec.get("pipelined", False))
Frank Barchard7e955972019-10-11 10:34:25 -0700841 assembly = bool(ukernel_spec.get("assembly", False))
XNNPACK Teamb455b122019-09-27 18:10:33 -0700842 mr, nr, kr, sr, arch, isa = split_ukernel_name(name)
843
844 # specification can override architecture
845 arch = ukernel_spec.get("arch", arch)
846
847 test_case = generate_test_cases(
848 name, mr, nr, kr, sr, k_block, pipelined, isa)
Marat Dukhan918a4a62019-10-27 19:49:49 -0700849 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa, assembly)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700850
851 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
852 output_file.write(tests)
853
854
855if __name__ == "__main__":
856 main(sys.argv[1:])