blob: ea537ef8e82c4df8b13c481fe618b014e28b387e [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import os
11import sys
12import yaml
13
14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15from primes import next_prime
16import xngen
Marat Dukhan918a4a62019-10-27 19:49:49 -070017import xnncommon
XNNPACK Teamb455b122019-09-27 18:10:33 -070018
19
20parser = argparse.ArgumentParser(description='XNNPACK generator')
21parser.add_argument("-s", "--spec", metavar="FILE", required=True,
22 help="Spec (YAML) file")
23parser.add_argument("-o", "--output", metavar="FILE", required=True,
24 help='Output (C++ source) file')
25parser.set_defaults(defines=list())
26
27
XNNPACK Teamb455b122019-09-27 18:10:33 -070028def split_ukernel_name(name):
29 common_name, target_name = name.split("__", 1)
30 common_parts = common_name.split("_")
31 param_spec = common_parts[-1]
32 mr, nr = map(int, param_spec.split("x"))
Marat Dukhan918a4a62019-10-27 19:49:49 -070033 arch, isa = xnncommon.parse_target_name(target_name)
XNNPACK Teamb455b122019-09-27 18:10:33 -070034 return mr, nr, arch, isa
35
36
37TEST_TEMPLATE = """\
38TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
39 $if ISA_CHECK:
40 ${ISA_CHECK};
41 SpMMMicrokernelTester()
42 .mr(${MR})
43 .nr(${NR})
44 .m(${MR})
45 .n(${NR})
46 .k(${KBLOCK})
47 .sparsity(0.0f)
48 .Test(${", ".join(TEST_ARGS)});
49}
50
51$if NR > 1:
52 TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
53 $if ISA_CHECK:
54 ${ISA_CHECK};
55 for (uint32_t n = 1; n <= ${NR}; n++) {
56 SpMMMicrokernelTester()
57 .mr(${MR})
58 .nr(${NR})
59 .m(${MR})
60 .n(n)
61 .k(${KBLOCK})
62 .sparsity(0.0f)
63 .Test(${", ".join(TEST_ARGS)});
64 }
65 }
66
67$if IS_PIPELINED:
68 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}) {
69 $if ISA_CHECK:
70 ${ISA_CHECK};
71 SpMMMicrokernelTester()
72 .mr(${MR})
73 .nr(${NR})
74 .m(${MR})
75 .n(${NR})
76 .k(${KBLOCK * 2})
77 .sparsity(0.0f)
78 .Test(${", ".join(TEST_ARGS)});
79 }
80
81 $if NR > 1:
82 TEST(${TEST_NAME}, k_eq_${KBLOCK * 2}_subtile) {
83 $if ISA_CHECK:
84 ${ISA_CHECK};
85 for (uint32_t n = 1; n <= ${NR}; n++) {
86 SpMMMicrokernelTester()
87 .mr(${MR})
88 .nr(${NR})
89 .m(${MR})
90 .n(n)
91 .k(${KBLOCK * 2})
92 .sparsity(0.0f)
93 .Test(${", ".join(TEST_ARGS)});
94 }
95 }
96
97$if KBLOCK > 1:
98 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}) {
99 $if ISA_CHECK:
100 ${ISA_CHECK};
101 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
102 SpMMMicrokernelTester()
103 .mr(${MR})
104 .nr(${NR})
105 .m(${MR})
106 .n(${NR})
107 .k(k)
108 .sparsity(0.0f)
109 .Test(${", ".join(TEST_ARGS)});
110 }
111 }
112
113 $if NR > 1:
114 TEST(${TEST_NAME}, k_lt_${ADJKBLOCK}_subtile) {
115 $if ISA_CHECK:
116 ${ISA_CHECK};
117 for (size_t k = 1; k < ${ADJKBLOCK}; k++) {
118 for (uint32_t n = 1; n <= ${NR}; n++) {
119 SpMMMicrokernelTester()
120 .mr(${MR})
121 .nr(${NR})
122 .m(${MR})
123 .n(n)
124 .k(k)
125 .sparsity(0.0f)
126 .Test(${", ".join(TEST_ARGS)});
127 }
128 }
129 }
130
131TEST(${TEST_NAME}, k_gt_${ADJKBLOCK}) {
132 $if ISA_CHECK:
133 ${ISA_CHECK};
134 for (size_t k = ${ADJKBLOCK + 1}; k < ${KBLOCK * 10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
135 SpMMMicrokernelTester()
136 .mr(${MR})
137 .nr(${NR})
138 .m(${MR})
139 .n(${NR})
140 .k(k)
141 .sparsity(0.0f)
142 .Test(${", ".join(TEST_ARGS)});
143 }
144}
145
146$if NR > 1:
147 TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
148 $if ISA_CHECK:
149 ${ISA_CHECK};
150 for (size_t k = ${ADJKBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
151 for (uint32_t n = 1; n <= ${NR}; n++) {
152 SpMMMicrokernelTester()
153 .mr(${MR})
154 .nr(${NR})
155 .m(${MR})
156 .n(n)
157 .k(k)
158 .sparsity(0.0f)
159 .Test(${", ".join(TEST_ARGS)});
160 }
161 }
162 }
163
164$if KBLOCK > 1:
165 TEST(${TEST_NAME}, k_div_${KBLOCK}) {
166 $if ISA_CHECK:
167 ${ISA_CHECK};
168 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
169 SpMMMicrokernelTester()
170 .mr(${MR})
171 .nr(${NR})
172 .m(${MR})
173 .n(${NR})
174 .k(k)
175 .sparsity(0.0f)
176 .Test(${", ".join(TEST_ARGS)});
177 }
178 }
179
180 $if NR > 1:
181 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
182 $if ISA_CHECK:
183 ${ISA_CHECK};
184 for (size_t k = ${ADJKBLOCK + KBLOCK}; k <= ${KBLOCK * 10}; k += ${KBLOCK}) {
185 for (uint32_t n = 1; n <= ${NR}; n++) {
186 SpMMMicrokernelTester()
187 .mr(${MR})
188 .nr(${NR})
189 .m(${MR})
190 .n(n)
191 .k(k)
192 .sparsity(0.0f)
193 .Test(${", ".join(TEST_ARGS)});
194 }
195 }
196 }
197
198TEST(${TEST_NAME}, n_gt_${NR}) {
199 $if ISA_CHECK:
200 ${ISA_CHECK};
201 for (uint32_t n = ${NR + 1}; n < ${max(10, NR * 2)}; n++) {
202 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
203 SpMMMicrokernelTester()
204 .mr(${MR})
205 .nr(${NR})
206 .m(${MR})
207 .n(n)
208 .k(k)
209 .sparsity(0.0f)
210 .Test(${", ".join(TEST_ARGS)});
211 }
212 }
213}
214
215$if NR > 1:
216 TEST(${TEST_NAME}, n_div_${NR}) {
217 $if ISA_CHECK:
218 ${ISA_CHECK};
219 for (uint32_t n = ${2 * NR}; n <= ${3 * NR}; n += ${NR}) {
220 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
221 SpMMMicrokernelTester()
222 .mr(${MR})
223 .nr(${NR})
224 .m(${MR})
225 .n(n)
226 .k(k)
227 .Test(${", ".join(TEST_ARGS)});
228 }
229 }
230 }
231
232TEST(${TEST_NAME}, m_lt_${MR}) {
233 $if ISA_CHECK:
234 ${ISA_CHECK};
235 for (uint32_t m = ${1}; m < ${MR}; m++) {
236 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
237 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
238 SpMMMicrokernelTester()
239 .mr(${MR})
240 .nr(${NR})
241 .m(m)
242 .n(n)
243 .k(k)
244 .sparsity(0.0f)
245 .Test(${", ".join(TEST_ARGS)});
246 }
247 }
248 }
249}
250
251TEST(${TEST_NAME}, m_div_${MR}) {
252 $if ISA_CHECK:
253 ${ISA_CHECK};
254 for (uint32_t m = ${MR * 2}; m <= ${MR * 3}; m += ${MR}) {
255 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
256 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
257 SpMMMicrokernelTester()
258 .mr(${MR})
259 .nr(${NR})
260 .m(m)
261 .n(n)
262 .k(k)
263 .sparsity(0.0f)
264 .Test(${", ".join(TEST_ARGS)});
265 }
266 }
267 }
268}
269
270TEST(${TEST_NAME}, m_gt_${MR}) {
271 $if ISA_CHECK:
272 ${ISA_CHECK};
273 for (uint32_t m = ${MR + 1}; m < ${MR * 2}; m++) {
274 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
275 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
276 SpMMMicrokernelTester()
277 .mr(${MR})
278 .nr(${NR})
279 .m(m)
280 .n(n)
281 .k(k)
282 .sparsity(0.0f)
283 .Test(${", ".join(TEST_ARGS)});
284 }
285 }
286 }
287}
288
289TEST(${TEST_NAME}, qmin) {
290 $if ISA_CHECK:
291 ${ISA_CHECK};
292 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
293 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
294 SpMMMicrokernelTester()
295 .mr(${MR})
296 .nr(${NR})
297 .m(${MR * 2})
298 .n(n)
299 .k(k)
300 .sparsity(0.0f)
301 .qmin(128)
302 .Test(${", ".join(TEST_ARGS)});
303 }
304 }
305}
306
307TEST(${TEST_NAME}, qmax) {
308 $if ISA_CHECK:
309 ${ISA_CHECK};
310 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
311 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
312 SpMMMicrokernelTester()
313 .mr(${MR})
314 .nr(${NR})
315 .m(${MR * 2})
316 .n(n)
317 .k(k)
318 .sparsity(0.0f)
319 .qmax(128)
320 .Test(${", ".join(TEST_ARGS)});
321 }
322 }
323}
324
325TEST(${TEST_NAME}, half_sparse) {
326 $if ISA_CHECK:
327 ${ISA_CHECK};
328 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
329 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
330 SpMMMicrokernelTester()
331 .mr(${MR})
332 .nr(${NR})
333 .m(${MR * 2})
334 .n(n)
335 .k(k)
336 .sparsity(0.5f)
337 .Test(${", ".join(TEST_ARGS)});
338 }
339 }
340}
341
342TEST(${TEST_NAME}, zero_weights) {
343 $if ISA_CHECK:
344 ${ISA_CHECK};
345 for (uint32_t n = 1; n < ${max(10, NR * 5)}; n += ${NR + 1}) {
346 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
347 SpMMMicrokernelTester()
348 .mr(${MR})
349 .nr(${NR})
350 .m(${MR * 2})
351 .n(n)
352 .k(k)
353 .sparsity(1.0f)
354 .Test(${", ".join(TEST_ARGS)});
355 }
356 }
357}
358"""
359
360
361def generate_test_cases(ukernel, mr, nr, k_block, is_pipelined, isa):
362 """Generates all tests cases for a GEMM micro-kernel.
363
364 Args:
365 ukernel: C name of the micro-kernel function.
366 mr: MR parameter of the GEMM micro-kernel.
367 nr: NR parameter of the GEMM micro-kernel.
368 k_block: Number of K values processed per one iteration of the main loop of
369 the micro-kernel.
370 is_pipelined: Indicates if the micro-kernel is implemented with software
371 pipelining. Additional test cases are generated for software
372 pipelined micro-kernels to separately test prologue + epiloque
373 of the pipelined loop and iteration of the pipelined loop.
374 isa: instruction set required to run the micro-kernel. Generated unit test
375 will skip execution if the host processor doesn't support this ISA.
376
377 Returns:
378 Code for the test case.
379 """
380 _, test_name = ukernel.split("_", 1)
381 _, datatype, ukernel_type, _ = ukernel.split("_", 3)
382 test_args = [ukernel]
383 if not isa or isa == "psimd":
384 test_args.append("SpMMMicrokernelTester::Variant::Scalar")
385 return xngen.preprocess(TEST_TEMPLATE, {
386 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
387 "TEST_ARGS": test_args,
388 "UKERNEL_TYPE": ukernel_type.upper(),
389 "DATATYPE": datatype,
390 "MR": mr,
391 "NR": nr,
392 "KBLOCK": k_block,
393 "ADJKBLOCK": 2 * k_block if is_pipelined else k_block,
394 "IS_PIPELINED": is_pipelined,
Marat Dukhan918a4a62019-10-27 19:49:49 -0700395 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700396 "next_prime": next_prime,
397 })
398
399
400def main(args):
401 options = parser.parse_args(args)
402
403 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
404 spec_yaml = yaml.safe_load(spec_file)
405 if not isinstance(spec_yaml, list):
406 raise ValueError("expected a list of micro-kernels in the spec")
407
408 tests = """\
409// Copyright 2019 Google LLC
410//
411// This source code is licensed under the BSD-style license found in the
412// LICENSE file in the root directory of this source tree.
413//
414// Auto-generated file. Do not edit!
415// Specification: {specification}
416// Generator: {generator}
417
418
Marat Dukhan629a33e2019-10-01 10:39:14 -0700419#include <gtest/gtest.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700420
Marat Dukhan1dadbf72019-10-01 10:46:20 -0700421#include <xnnpack/common.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700422#include <xnnpack/isa-checks.h>
423
Marat Dukhan1dadbf72019-10-01 10:46:20 -0700424#include <xnnpack/spmm.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700425#include "spmm-microkernel-tester.h"
426""".format(specification=options.spec, generator=sys.argv[0])
427
428 for ukernel_spec in spec_yaml:
429 name = ukernel_spec["name"]
430 k_block = int(ukernel_spec["k-block"])
431 pipelined = bool(ukernel_spec.get("pipelined", False))
432 mr, nr, arch, isa = split_ukernel_name(name)
433
434 # specification can override architecture
435 arch = ukernel_spec.get("arch", arch)
436
437 test_case = generate_test_cases(name, mr, nr, k_block, pipelined, isa)
Marat Dukhan918a4a62019-10-27 19:49:49 -0700438 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700439
440 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
441 output_file.write(tests)
442
443
444if __name__ == "__main__":
445 main(sys.argv[1:])