blob: bc400013230c94cf4e0396b22883596a732321ed [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -080010import collections
XNNPACK Teamb455b122019-09-27 18:10:33 -070011import os
12import sys
13import yaml
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -080014import zlib
XNNPACK Teamb455b122019-09-27 18:10:33 -070015
16sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
17from primes import next_prime
18import xngen
Marat Dukhan918a4a62019-10-27 19:49:49 -070019import xnncommon
XNNPACK Teamb455b122019-09-27 18:10:33 -070020
Zhi An Ng74ddd272022-01-04 09:16:56 -080021parser = argparse.ArgumentParser(description="XNNPACK generator")
22parser.add_argument(
23 "-s", "--spec", metavar="FILE", required=True, help="Spec (YAML) file")
24parser.add_argument(
25 "-o",
26 "--output",
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -080027 action="append",
Zhi An Ng74ddd272022-01-04 09:16:56 -080028 metavar="FILE",
29 required=True,
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -080030 help="Output (C++ source) file(s)")
XNNPACK Teamb455b122019-09-27 18:10:33 -070031parser.set_defaults(defines=list())
32
33
XNNPACK Teamb455b122019-09-27 18:10:33 -070034def split_ukernel_name(name):
35 common_name, target_name = name.split("__", 1)
36 common_parts = common_name.split("_")
Marat Dukhand5694df2021-05-20 17:10:40 -070037 xw = "gemm_xw_" in common_name
XNNPACK Teamb455b122019-09-27 18:10:33 -070038 param_spec = common_parts[-1]
Frank Barchard8e9a66f2021-10-25 17:22:40 -070039 if "s" in param_spec:
40 param_spec, sr = param_spec.split("s", 1)
41 sr = int(sr)
42 else:
43 sr = 1
Frank Barchardc7a032d2021-11-10 12:37:49 -080044 if "c" in param_spec:
45 param_spec, kr = param_spec.split("c", 1)
46 kr = int(kr)
47 else:
48 kr = 1
XNNPACK Teamb455b122019-09-27 18:10:33 -070049 mr, nr = map(int, param_spec.split("x"))
Marat Dukhan918a4a62019-10-27 19:49:49 -070050 arch, isa = xnncommon.parse_target_name(target_name)
Marat Dukhana5d12612021-05-25 01:12:26 -070051
52 requantization = common_parts[-3]
Marat Dukhan89991902021-12-06 00:54:36 -080053 if requantization not in ["fp32", "rndnu"]:
Marat Dukhana5d12612021-05-25 01:12:26 -070054 requantization = None
55
56 return mr, nr, kr, sr, xw, requantization, arch, isa
XNNPACK Teamb455b122019-09-27 18:10:33 -070057
58
59GEMM_TEST_CODE = """\
60TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
61 $if ISA_CHECK:
62 ${ISA_CHECK};
63 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -070064 $if EXTENDED_WEIGHTS:
65 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -070066 .mr(${MR})
67 .nr(${NR})
68 .kr(${KR})
69 .sr(${SR})
70 .m(${MR})
71 .n(${NR})
72 .k(${KBLOCK})
73 .Test(${", ".join(TEST_ARGS)});
74}
75
76TEST(${TEST_NAME}, strided_cn) {
77 $if ISA_CHECK:
78 ${ISA_CHECK};
79 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -070080 $if EXTENDED_WEIGHTS:
81 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -070082 .mr(${MR})
83 .nr(${NR})
84 .kr(${KR})
85 .sr(${SR})
86 .m(${MR})
87 .n(${NR})
88 .k(${KBLOCK})
89 .cn_stride(${next_prime(NR + 1)})
90 .Test(${", ".join(TEST_ARGS)});
91}
92
93$if UKERNEL_TYPE != "IGEMM":
94 TEST(${TEST_NAME}, k_eq_${KBLOCK}_strided_a) {
95 $if ISA_CHECK:
96 ${ISA_CHECK};
97 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -070098 $if EXTENDED_WEIGHTS:
99 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700100 .mr(${MR})
101 .nr(${NR})
102 .kr(${KR})
103 .sr(${SR})
104 .m(${MR})
105 .n(${NR})
106 .k(${KBLOCK})
107 .a_stride(${next_prime(KBLOCK + 1)})
108 .Test(${", ".join(TEST_ARGS)});
109 }
110
111TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
112 $if ISA_CHECK:
113 ${ISA_CHECK};
Zhi An Ng83844ae2022-01-14 09:52:25 -0800114 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800115 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700116 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700117 $if EXTENDED_WEIGHTS:
118 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700119 .mr(${MR})
120 .nr(${NR})
121 .kr(${KR})
122 .sr(${SR})
123 .m(m)
124 .n(n)
125 .k(${KBLOCK})
126 .iterations(1)
127 .Test(${", ".join(TEST_ARGS)});
128 }
129 }
130}
131
132TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_m) {
133 $if ISA_CHECK:
134 ${ISA_CHECK};
135 for (uint32_t m = 1; m <= ${MR}; m++) {
136 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700137 $if EXTENDED_WEIGHTS:
138 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700139 .mr(${MR})
140 .nr(${NR})
141 .kr(${KR})
142 .sr(${SR})
143 .m(m)
144 .n(${NR})
145 .k(${KBLOCK})
146 .iterations(1)
147 .Test(${", ".join(TEST_ARGS)});
148 }
149}
150
151
152TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile_n) {
153 $if ISA_CHECK:
154 ${ISA_CHECK};
155 for (uint32_t n = 1; n <= ${NR}; n++) {
156 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700157 $if EXTENDED_WEIGHTS:
158 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700159 .mr(${MR})
160 .nr(${NR})
161 .kr(${KR})
162 .sr(${SR})
163 .m(${MR})
164 .n(n)
165 .k(${KBLOCK})
166 .iterations(1)
167 .Test(${", ".join(TEST_ARGS)});
168 }
169}
170
171$if IS_PIPELINED:
172 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
173 $if ISA_CHECK:
174 ${ISA_CHECK};
175 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700176 $if EXTENDED_WEIGHTS:
177 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700178 .mr(${MR})
179 .nr(${NR})
180 .kr(${KR})
181 .sr(${SR})
182 .m(${MR})
183 .n(${NR})
184 .k(${KBLOCK * 2})
185 .Test(${", ".join(TEST_ARGS)});
186 }
187
188 $if UKERNEL_TYPE != "IGEMM":
189 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_strided_a) {
190 $if ISA_CHECK:
191 ${ISA_CHECK};
192 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700193 $if EXTENDED_WEIGHTS:
194 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700195 .mr(${MR})
196 .nr(${NR})
197 .kr(${KR})
198 .sr(${SR})
199 .m(${MR})
200 .n(${NR})
201 .k(${KBLOCK * 2})
202 .a_stride(${next_prime(KBLOCK * 2 + 1)})
203 .Test(${", ".join(TEST_ARGS)});
204 }
205
206 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
207 $if ISA_CHECK:
208 ${ISA_CHECK};
Zhi An Ng83844ae2022-01-14 09:52:25 -0800209 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800210 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700211 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700212 $if EXTENDED_WEIGHTS:
213 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700214 .mr(${MR})
215 .nr(${NR})
216 .kr(${KR})
217 .sr(${SR})
218 .m(m)
219 .n(n)
220 .k(${KBLOCK * 2})
221 .iterations(1)
222 .Test(${", ".join(TEST_ARGS)});
223 }
224 }
225 }
226
227$if KBLOCK > 1:
228 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
229 $if ISA_CHECK:
230 ${ISA_CHECK};
231 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
232 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700233 $if EXTENDED_WEIGHTS:
234 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700235 .mr(${MR})
236 .nr(${NR})
237 .kr(${KR})
238 .sr(${SR})
239 .m(${MR})
240 .n(${NR})
241 .k(k)
242 .Test(${", ".join(TEST_ARGS)});
243 }
244 }
245
246 $if UKERNEL_TYPE != "IGEMM":
247 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_strided_a) {
248 $if ISA_CHECK:
249 ${ISA_CHECK};
250 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
251 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700252 $if EXTENDED_WEIGHTS:
253 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700254 .mr(${MR})
255 .nr(${NR})
256 .kr(${KR})
257 .sr(${SR})
258 .m(${MR})
259 .n(${NR})
260 .k(k)
261 .a_stride(${next_prime(ADJKBLOCK + 1)})
262 .Test(${", ".join(TEST_ARGS)});
263 }
264 }
265
266 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
267 $if ISA_CHECK:
268 ${ISA_CHECK};
269 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800270 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800271 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700272 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700273 $if EXTENDED_WEIGHTS:
274 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700275 .mr(${MR})
276 .nr(${NR})
277 .kr(${KR})
278 .sr(${SR})
279 .m(m)
280 .n(n)
281 .k(k)
282 .iterations(1)
283 .Test(${", ".join(TEST_ARGS)});
284 }
285 }
286 }
287 }
288
289TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
290 $if ISA_CHECK:
291 ${ISA_CHECK};
Zhi An Ngc80ffb02021-12-22 13:06:25 -0800292 for (size_t k = ${ADJKBLOCK + 1}; k < ${ADJKBLOCK * 10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700293 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700294 $if EXTENDED_WEIGHTS:
295 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700296 .mr(${MR})
297 .nr(${NR})
298 .kr(${KR})
299 .sr(${SR})
300 .m(${MR})
301 .n(${NR})
302 .k(k)
303 .Test(${", ".join(TEST_ARGS)});
304 }
305}
306
307$if UKERNEL_TYPE.startswith("GEMM"):
Zhi An Ngc80ffb02021-12-22 13:06:25 -0800308 TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_strided_a) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700309 $if ISA_CHECK:
310 ${ISA_CHECK};
Zhi An Ngc80ffb02021-12-22 13:06:25 -0800311 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700312 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700313 $if EXTENDED_WEIGHTS:
314 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700315 .mr(${MR})
316 .nr(${NR})
317 .kr(${KR})
318 .sr(${SR})
319 .m(${MR})
320 .n(${NR})
321 .k(k)
Zhi An Ngc80ffb02021-12-22 13:06:25 -0800322 .a_stride(${next_prime(10 if ADJKBLOCK == 1 else ADJKBLOCK * 2 + 1)})
XNNPACK Teamb455b122019-09-27 18:10:33 -0700323 .Test(${", ".join(TEST_ARGS)});
324 }
325 }
326
Zhi An Ngc80ffb02021-12-22 13:06:25 -0800327TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}_subtile) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700328 $if ISA_CHECK:
329 ${ISA_CHECK};
Zhi An Ngc80ffb02021-12-22 13:06:25 -0800330 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if ADJKBLOCK == 1 else ADJKBLOCK * 2}; k++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800331 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800332 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700333 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700334 $if EXTENDED_WEIGHTS:
335 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700336 .mr(${MR})
337 .nr(${NR})
338 .kr(${KR})
339 .sr(${SR})
340 .m(m)
341 .n(n)
342 .k(k)
343 .iterations(1)
344 .Test(${", ".join(TEST_ARGS)});
345 }
346 }
347 }
348}
349
350$if KBLOCK > 1:
351 TEST(${TEST_NAME}, k_div_${KBLOCK}) {
352 $if ISA_CHECK:
353 ${ISA_CHECK};
354 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
355 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700356 $if EXTENDED_WEIGHTS:
357 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700358 .mr(${MR})
359 .nr(${NR})
360 .kr(${KR})
361 .sr(${SR})
362 .m(${MR})
363 .n(${NR})
364 .k(k)
365 .Test(${", ".join(TEST_ARGS)});
366 }
367 }
368
369 $if UKERNEL_TYPE.startswith("GEMM"):
370 TEST(${TEST_NAME}, k_div_${KBLOCK}_strided_a) {
371 $if ISA_CHECK:
372 ${ISA_CHECK};
373 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
374 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700375 $if EXTENDED_WEIGHTS:
376 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700377 .mr(${MR})
378 .nr(${NR})
379 .kr(${KR})
380 .sr(${SR})
381 .m(${MR})
382 .n(${NR})
383 .k(k)
384 .a_stride(${next_prime(KBLOCK * 10 + 1)})
385 .Test(${", ".join(TEST_ARGS)});
386 }
387 }
388
389 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
390 $if ISA_CHECK:
391 ${ISA_CHECK};
392 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800393 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800394 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700395 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700396 $if EXTENDED_WEIGHTS:
397 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700398 .mr(${MR})
399 .nr(${NR})
400 .kr(${KR})
401 .sr(${SR})
402 .m(m)
403 .n(n)
404 .k(k)
405 .iterations(1)
406 .Test(${", ".join(TEST_ARGS)});
407 }
408 }
409 }
410 }
411
412TEST(${TEST_NAME}, n_gt_${NR}) {
413 $if ISA_CHECK:
414 ${ISA_CHECK};
415 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
416 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
417 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700418 $if EXTENDED_WEIGHTS:
419 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700420 .mr(${MR})
421 .nr(${NR})
422 .kr(${KR})
423 .sr(${SR})
424 .m(${MR})
Zhi An Ngaf9ff852022-01-13 10:48:37 -0800425 .n(n)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700426 .k(k)
427 .Test(${", ".join(TEST_ARGS)});
428 }
429 }
430}
431
432TEST(${TEST_NAME}, n_gt_${NR}_strided_cn) {
433 $if ISA_CHECK:
434 ${ISA_CHECK};
435 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
436 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
437 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700438 $if EXTENDED_WEIGHTS:
439 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700440 .mr(${MR})
441 .nr(${NR})
442 .kr(${KR})
443 .sr(${SR})
444 .m(${MR})
Zhi An Ngaf9ff852022-01-13 10:48:37 -0800445 .n(n)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700446 .k(k)
447 .cn_stride(${next_prime(NR + 1)})
448 .Test(${", ".join(TEST_ARGS)});
449 }
450 }
451}
452
453$if UKERNEL_TYPE != "IGEMM":
454 TEST(${TEST_NAME}, n_gt_${NR}_strided_a) {
455 $if ISA_CHECK:
456 ${ISA_CHECK};
457 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
458 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
459 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700460 $if EXTENDED_WEIGHTS:
461 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700462 .mr(${MR})
463 .nr(${NR})
464 .kr(${KR})
465 .sr(${SR})
466 .m(${MR})
467 .n(n)
468 .k(k)
469 .a_stride(${next_prime(KBLOCK * 5 + 1)})
470 .Test(${", ".join(TEST_ARGS)});
471 }
472 }
473 }
474
475TEST(${TEST_NAME}, n_gt_${NR}_subtile) {
476 $if ISA_CHECK:
477 ${ISA_CHECK};
478 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
479 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
480 for (uint32_t m = 1; m <= ${MR}; m++) {
481 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700482 $if EXTENDED_WEIGHTS:
483 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700484 .mr(${MR})
485 .nr(${NR})
486 .kr(${KR})
487 .sr(${SR})
488 .m(m)
489 .n(n)
490 .k(k)
491 .iterations(1)
492 .Test(${", ".join(TEST_ARGS)});
493 }
494 }
495 }
496}
497
498TEST(${TEST_NAME}, n_div_${NR}) {
499 $if ISA_CHECK:
500 ${ISA_CHECK};
501 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
502 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
503 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700504 $if EXTENDED_WEIGHTS:
505 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700506 .mr(${MR})
507 .nr(${NR})
508 .kr(${KR})
509 .sr(${SR})
510 .m(${MR})
Zhi An Ngaf9ff852022-01-13 10:48:37 -0800511 .n(n)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700512 .k(k)
513 .Test(${", ".join(TEST_ARGS)});
514 }
515 }
516}
517
518TEST(${TEST_NAME}, n_div_${NR}_strided_cn) {
519 $if ISA_CHECK:
520 ${ISA_CHECK};
521 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
522 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
523 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700524 $if EXTENDED_WEIGHTS:
525 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700526 .mr(${MR})
527 .nr(${NR})
528 .kr(${KR})
529 .sr(${SR})
530 .m(${MR})
531 .n(n)
532 .k(k)
533 .cn_stride(${next_prime(NR + 1)})
534 .Test(${", ".join(TEST_ARGS)});
535 }
536 }
537}
538
539$if UKERNEL_TYPE != "IGEMM":
540 TEST(${TEST_NAME}, n_div_${NR}_strided_a) {
541 $if ISA_CHECK:
542 ${ISA_CHECK};
543 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
544 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
545 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700546 $if EXTENDED_WEIGHTS:
547 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700548 .mr(${MR})
549 .nr(${NR})
550 .kr(${KR})
551 .sr(${SR})
552 .m(${MR})
553 .n(n)
554 .k(k)
555 .a_stride(${next_prime(KBLOCK * 5 + 1)})
556 .Test(${", ".join(TEST_ARGS)});
557 }
558 }
559 }
560
561TEST(${TEST_NAME}, n_div_${NR}_subtile) {
562 $if ISA_CHECK:
563 ${ISA_CHECK};
564 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
565 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
566 for (uint32_t m = 1; m <= ${MR}; m++) {
567 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700568 $if EXTENDED_WEIGHTS:
569 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700570 .mr(${MR})
571 .nr(${NR})
572 .kr(${KR})
573 .sr(${SR})
574 .m(m)
575 .n(n)
576 .k(k)
577 .iterations(1)
578 .Test(${", ".join(TEST_ARGS)});
579 }
580 }
581 }
582}
583
584$if UKERNEL_TYPE.startswith("IGEMM"):
585 TEST(${TEST_NAME}, small_kernel) {
586 $if ISA_CHECK:
587 ${ISA_CHECK};
588 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
589 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700590 $if EXTENDED_WEIGHTS:
591 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700592 .mr(${MR})
593 .nr(${NR})
594 .kr(${KR})
595 .sr(${SR})
596 .m(${MR})
597 .n(${NR})
598 .k(k)
599 .ks(3)
600 .Test(${", ".join(TEST_ARGS)});
601 }
602 }
603
604 TEST(${TEST_NAME}, small_kernel_subtile) {
605 $if ISA_CHECK:
606 ${ISA_CHECK};
607 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800608 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800609 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700610 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700611 $if EXTENDED_WEIGHTS:
612 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700613 .mr(${MR})
614 .nr(${NR})
615 .kr(${KR})
616 .sr(${SR})
617 .m(m)
618 .n(n)
619 .k(k)
620 .ks(3)
621 .iterations(1)
622 .Test(${", ".join(TEST_ARGS)});
623 }
624 }
625 }
626 }
627
628 TEST(${TEST_NAME}, n_gt_${NR}_small_kernel) {
629 $if ISA_CHECK:
630 ${ISA_CHECK};
631 for (uint32_t n = ${NR + 1}; n < ${NR * 2}; n++) {
632 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
633 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700634 $if EXTENDED_WEIGHTS:
635 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700636 .mr(${MR})
637 .nr(${NR})
638 .kr(${KR})
639 .sr(${SR})
640 .m(${MR})
Zhi An Ngaf9ff852022-01-13 10:48:37 -0800641 .n(n)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700642 .k(k)
643 .ks(3)
644 .Test(${", ".join(TEST_ARGS)});
645 }
646 }
647 }
648
649 TEST(${TEST_NAME}, n_div_${NR}_small_kernel) {
650 $if ISA_CHECK:
651 ${ISA_CHECK};
652 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
653 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
654 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700655 $if EXTENDED_WEIGHTS:
656 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700657 .mr(${MR})
658 .nr(${NR})
659 .kr(${KR})
660 .sr(${SR})
661 .m(${MR})
Zhi An Ngaf9ff852022-01-13 10:48:37 -0800662 .n(n)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700663 .k(k)
664 .ks(3)
665 .Test(${", ".join(TEST_ARGS)});
666 }
667 }
668 }
669
670TEST(${TEST_NAME}, strided_cm_subtile) {
671 $if ISA_CHECK:
672 ${ISA_CHECK};
673 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800674 for (uint32_t n = 1; n <= ${NR}; n++) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800675 for (uint32_t m = 1; m <= ${MR}; m++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700676 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700677 $if EXTENDED_WEIGHTS:
678 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700679 .mr(${MR})
680 .nr(${NR})
681 .kr(${KR})
682 .sr(${SR})
683 .m(m)
684 .n(n)
685 .k(k)
686 .cm_stride(${next_prime(NR + 1)})
687 .iterations(1)
688 .Test(${", ".join(TEST_ARGS)});
689 }
690 }
691 }
692}
693
694$if UKERNEL_TYPE.startswith("IGEMM"):
695 TEST(${TEST_NAME}, a_offset) {
696 $if ISA_CHECK:
697 ${ISA_CHECK};
698 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
699 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700700 $if EXTENDED_WEIGHTS:
701 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700702 .mr(${MR})
703 .nr(${NR})
704 .kr(${KR})
705 .sr(${SR})
706 .m(${MR})
707 .n(${NR})
708 .k(k)
709 .ks(3)
710 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
711 .Test(${", ".join(TEST_ARGS)});
712 }
713 }
714
715 TEST(${TEST_NAME}, zero) {
716 $if ISA_CHECK:
717 ${ISA_CHECK};
Zhi An Ng83844ae2022-01-14 09:52:25 -0800718 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
Zhi An Ng83844ae2022-01-14 09:52:25 -0800719 for (uint32_t mz = 0; mz < ${MR}; mz++) {
XNNPACK Teamb455b122019-09-27 18:10:33 -0700720 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700721 $if EXTENDED_WEIGHTS:
722 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700723 .mr(${MR})
724 .nr(${NR})
725 .kr(${KR})
726 .sr(${SR})
727 .m(${MR})
728 .n(${NR})
729 .k(k)
730 .ks(3)
731 .a_offset(${next_prime(MR * KBLOCK * 5 + 1)})
732 .zero_index(mz)
733 .Test(${", ".join(TEST_ARGS)});
734 }
735 }
736 }
737
Marat Dukhan163a7e62020-04-09 04:19:26 -0700738$if ACTIVATION == "MINMAX":
739 TEST(${TEST_NAME}, qmin) {
740 $if ISA_CHECK:
741 ${ISA_CHECK};
742 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700743 $if EXTENDED_WEIGHTS:
744 .extended_weights(true)
Marat Dukhan163a7e62020-04-09 04:19:26 -0700745 .mr(${MR})
746 .nr(${NR})
747 .kr(${KR})
748 .sr(${SR})
749 .m(${MR})
750 .n(${NR})
751 .k(${KBLOCK})
752 .qmin(128)
753 .Test(${", ".join(TEST_ARGS)});
754 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700755
Marat Dukhan163a7e62020-04-09 04:19:26 -0700756 TEST(${TEST_NAME}, qmax) {
757 $if ISA_CHECK:
758 ${ISA_CHECK};
759 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700760 $if EXTENDED_WEIGHTS:
761 .extended_weights(true)
Marat Dukhan163a7e62020-04-09 04:19:26 -0700762 .mr(${MR})
763 .nr(${NR})
764 .kr(${KR})
765 .sr(${SR})
766 .m(${MR})
767 .n(${NR})
768 .k(${KBLOCK})
769 .qmax(128)
770 .Test(${", ".join(TEST_ARGS)});
771 }
XNNPACK Teamb455b122019-09-27 18:10:33 -0700772
773TEST(${TEST_NAME}, strided_cm) {
774 $if ISA_CHECK:
775 ${ISA_CHECK};
776 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700777 $if EXTENDED_WEIGHTS:
778 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700779 .mr(${MR})
780 .nr(${NR})
781 .kr(${KR})
782 .sr(${SR})
783 .m(${MR})
784 .n(${NR})
785 .k(${KBLOCK})
786 .cm_stride(${next_prime(NR + 1)})
787 .Test(${", ".join(TEST_ARGS)});
788}
789
Marat Dukhan08b7a972020-07-14 18:17:29 -0700790$if DATATYPE == "qu8":
XNNPACK Teamb455b122019-09-27 18:10:33 -0700791 TEST(${TEST_NAME}, no_a_zero_point) {
792 $if ISA_CHECK:
793 ${ISA_CHECK};
794 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
795 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700796 $if EXTENDED_WEIGHTS:
797 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700798 .mr(${MR})
799 .nr(${NR})
800 .kr(${KR})
801 .sr(${SR})
802 .m(${MR})
803 .n(${NR})
804 .k(k)
805 .a_zero_point(0)
806 .Test(${", ".join(TEST_ARGS)});
807 }
808 }
809
810 TEST(${TEST_NAME}, no_b_zero_point) {
811 $if ISA_CHECK:
812 ${ISA_CHECK};
813 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
814 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700815 $if EXTENDED_WEIGHTS:
816 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700817 .mr(${MR})
818 .nr(${NR})
819 .kr(${KR})
820 .sr(${SR})
821 .m(${MR})
822 .n(${NR})
823 .k(k)
824 .b_zero_point(0)
825 .Test(${", ".join(TEST_ARGS)});
826 }
827 }
828
829 TEST(${TEST_NAME}, no_zero_point) {
830 $if ISA_CHECK:
831 ${ISA_CHECK};
832 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
833 GemmMicrokernelTester()
Marat Dukhand5694df2021-05-20 17:10:40 -0700834 $if EXTENDED_WEIGHTS:
835 .extended_weights(true)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700836 .mr(${MR})
837 .nr(${NR})
838 .kr(${KR})
839 .sr(${SR})
840 .m(${MR})
841 .n(${NR})
842 .k(k)
843 .a_zero_point(0)
844 .b_zero_point(0)
845 .Test(${", ".join(TEST_ARGS)});
846 }
847 }
848"""
849
850
Zhi An Ng74ddd272022-01-04 09:16:56 -0800851def generate_test_cases(ukernel, mr, nr, kr, sr, xw, k_block, init_fn,
852 requantization, is_pipelined, isa, jit):
XNNPACK Teamb455b122019-09-27 18:10:33 -0700853 """Generates all tests cases for a GEMM micro-kernel.
854
855 Args:
856 ukernel: C name of the micro-kernel function.
857 mr: MR parameter of the GEMM micro-kernel.
858 nr: NR parameter of the GEMM micro-kernel.
859 kr: KR parameter of the GEMM micro-kernel.
860 sr: SR parameter of the GEMM micro-kernel.
Marat Dukhand5694df2021-05-20 17:10:40 -0700861 xw: boolean indicator for microkernel with extended weights.
XNNPACK Teamb455b122019-09-27 18:10:33 -0700862 k_block: Number of K values processed per one iteration of the main loop of
Zhi An Ng74ddd272022-01-04 09:16:56 -0800863 the micro-kernel.
Marat Dukhand5694df2021-05-20 17:10:40 -0700864 init_fn: C name of the function to initialize microkernel parameters.
Marat Dukhana5d12612021-05-25 01:12:26 -0700865 requantization: name of the requantization scheme used by the microkernel.
XNNPACK Teamb455b122019-09-27 18:10:33 -0700866 is_pipelined: Indicates if the micro-kernel is implemented with software
Zhi An Ng74ddd272022-01-04 09:16:56 -0800867 pipelining. Additional test cases are generated for software pipelined
868 micro-kernels to separately test prologue + epiloque of the pipelined loop
869 and iteration of the pipelined loop.
XNNPACK Teamb455b122019-09-27 18:10:33 -0700870 isa: instruction set required to run the micro-kernel. Generated unit test
Zhi An Ng74ddd272022-01-04 09:16:56 -0800871 will skip execution if the host processor doesn't support this ISA.
Zhi An Ngb43b47a2021-12-23 16:27:22 -0800872 jit: if we are generating test code for JIT codegen.
XNNPACK Teamb455b122019-09-27 18:10:33 -0700873
874 Returns:
875 Code for the test case.
876 """
Zhi An Ngb43b47a2021-12-23 16:27:22 -0800877 _, ukernel_name = ukernel.split("_", 1)
878
879 if jit:
880 _, _, datatype, ukernel_type, _ = ukernel.split("_", 4)
881 activation = None
882 else:
883 _, datatype, ukernel_type, activation, _ = ukernel.split("_", 4)
884
Marat Dukhan163a7e62020-04-09 04:19:26 -0700885 if activation == "ukernel":
886 activation = "linear"
XNNPACK Teamb455b122019-09-27 18:10:33 -0700887 test_args = [ukernel]
Marat Dukhand5694df2021-05-20 17:10:40 -0700888 if init_fn:
889 test_args.append(init_fn)
Marat Dukhana5d12612021-05-25 01:12:26 -0700890 if requantization:
Marat Dukhanc2e8f662021-07-01 17:06:34 -0700891 requantization_datatype = {"qc8": "qs8"}.get(datatype, datatype)
Marat Dukhan50323b82022-01-11 00:12:01 -0800892 test_args.append("xnn_%s_requantize_%s" % \
893 (requantization_datatype, requantization))
Zhi An Ngb43b47a2021-12-23 16:27:22 -0800894
Zhi An Ngb43b47a2021-12-23 16:27:22 -0800895 if jit:
896 if "minmax" in init_fn:
897 activation = "minmax"
Zhi An Ngb43b47a2021-12-23 16:27:22 -0800898
Zhi An Ng74ddd272022-01-04 09:16:56 -0800899 return xngen.preprocess(
900 GEMM_TEST_CODE, {
901 "TEST_NAME": ukernel_name.upper().replace("UKERNEL_", ""),
902 "TEST_ARGS": test_args,
903 "UKERNEL_TYPE": ukernel_type.upper(),
904 "DATATYPE": datatype,
905 "ACTIVATION": activation.upper(),
906 "MR": mr,
907 "NR": nr,
908 "KR": kr,
909 "SR": sr,
910 "EXTENDED_WEIGHTS": xw,
911 "KBLOCK": k_block,
912 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
913 "IS_PIPELINED": is_pipelined,
914 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
915 "next_prime": next_prime,
Zhi An Ng74ddd272022-01-04 09:16:56 -0800916 })
XNNPACK Teamb455b122019-09-27 18:10:33 -0700917
918
919def main(args):
920 options = parser.parse_args(args)
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -0800921 num_output_files = len(options.output)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700922
923 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
924 spec_yaml = yaml.safe_load(spec_file)
925 if not isinstance(spec_yaml, list):
926 raise ValueError("expected a list of micro-kernels in the spec")
927
928 tests = """\
929// Copyright (c) Facebook, Inc. and its affiliates.
930// All rights reserved.
931//
932// Copyright 2019 Google LLC
933//
934// This source code is licensed under the BSD-style license found in the
935// LICENSE file in the root directory of this source tree.
936//
937// Auto-generated file. Do not edit!
938// Specification: {specification}
939// Generator: {generator}
940
941
Marat Dukhan629a33e2019-10-01 10:39:14 -0700942#include <gtest/gtest.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700943
Zhi An Ngb43b47a2021-12-23 16:27:22 -0800944#include <xnnpack/allocator.h>
Marat Dukhan1dadbf72019-10-01 10:46:20 -0700945#include <xnnpack/common.h>
946#include <xnnpack/isa-checks.h>
947
XNNPACK Teamb455b122019-09-27 18:10:33 -0700948#include <xnnpack/gemm.h>
949#include <xnnpack/igemm.h>
950#include <xnnpack/ppmm.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700951#include "gemm-microkernel-tester.h"
Zhi An Ng74ddd272022-01-04 09:16:56 -0800952""".format(
953 specification=options.spec, generator=sys.argv[0])
XNNPACK Teamb455b122019-09-27 18:10:33 -0700954
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -0800955 outputs = collections.defaultdict(lambda: tests)
956
XNNPACK Teamb455b122019-09-27 18:10:33 -0700957 for ukernel_spec in spec_yaml:
958 name = ukernel_spec["name"]
959 k_block = int(ukernel_spec["k-block"])
Marat Dukhand5694df2021-05-20 17:10:40 -0700960 init_fn = ukernel_spec.get("init")
XNNPACK Teamb455b122019-09-27 18:10:33 -0700961 pipelined = bool(ukernel_spec.get("pipelined", False))
Frank Barchard7e955972019-10-11 10:34:25 -0700962 assembly = bool(ukernel_spec.get("assembly", False))
Zhi An Ng74ddd272022-01-04 09:16:56 -0800963 jit = name.startswith("xnn_generate")
Marat Dukhana5d12612021-05-25 01:12:26 -0700964 mr, nr, kr, sr, xw, requantization, arch, isa = split_ukernel_name(name)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700965
966 # specification can override architecture
967 arch = ukernel_spec.get("arch", arch)
968
Zhi An Ng74ddd272022-01-04 09:16:56 -0800969 test_case = generate_test_cases(name, mr, nr, kr, sr, xw, k_block,
970 init_fn, requantization, pipelined, isa,
971 jit)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700972
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -0800973 # Hash the name of each microkernel and figure out which output file to
974 # write it to.
975 output_index = zlib.crc32(bytes(name, 'utf-8')) % num_output_files
976 outputs[options.output[output_index]] += "\n\n" + xnncommon.postprocess_test_case(
977 test_case, arch, isa, assembly, jit)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700978
Zhi An Ng4c1fd6f2022-01-10 19:35:06 -0800979 for output_name in options.output:
980 txt_changed = True
981 if os.path.exists(output_name):
982 with codecs.open(output_name, "r", encoding="utf-8") as output_file:
983 txt_changed = output_file.read() != outputs[output_name]
984
985 if txt_changed:
986 with codecs.open(output_name, "w", encoding="utf-8") as output_file:
987 output_file.write(outputs[output_name])
XNNPACK Teamb455b122019-09-27 18:10:33 -0700988
Zhi An Ng74ddd272022-01-04 09:16:56 -0800989
XNNPACK Teamb455b122019-09-27 18:10:33 -0700990if __name__ == "__main__":
991 main(sys.argv[1:])