blob: 5a15af1b667f324f20592bc497a6ae263ee99dc7 [file] [log] [blame]
XNNPACK Teamb455b122019-09-27 18:10:33 -07001#!/usr/bin/env python
2# Copyright 2019 Google LLC
3#
4# This source code is licensed under the BSD-style license found in the
5# LICENSE file in the root directory of this source tree.
6
7import argparse
8import bisect
9import codecs
10import os
11import sys
12import yaml
13
14sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
15from primes import next_prime
16import xngen
Marat Dukhan918a4a62019-10-27 19:49:49 -070017import xnncommon
XNNPACK Teamb455b122019-09-27 18:10:33 -070018
19
20parser = argparse.ArgumentParser(description='XNNPACK generator')
21parser.add_argument("-s", "--spec", metavar="FILE", required=True,
22 help="Spec (YAML) file")
23parser.add_argument("-o", "--output", metavar="FILE", required=True,
24 help='Output (C++ source) file')
25parser.set_defaults(defines=list())
26
27
XNNPACK Teamb455b122019-09-27 18:10:33 -070028def split_ukernel_name(name):
29 common_name, target_name = name.split("__", 1)
30 common_parts = common_name.split("_")
31 param_spec = common_parts[-1].split("x")
32 mr = int(param_spec[0])
Marat Dukhan918a4a62019-10-27 19:49:49 -070033 arch, isa = xnncommon.parse_target_name(target_name)
XNNPACK Teamb455b122019-09-27 18:10:33 -070034 return mr, arch, isa
35
36
37PACK_TEST_CODE = """\
38TEST(${TEST_NAME}, k_eq_${KBLOCK}) {
39 $if ISA_CHECK:
40 ${ISA_CHECK};
41 PackMicrokernelTester()
42 .mr(${MR})
43 .m(${MR})
44 .k(${KBLOCK})
45 .Test(${UKERNEL_NAME});
46}
47
48TEST(${TEST_NAME}, k_eq_${KBLOCK}_subtile) {
49 $if ISA_CHECK:
50 ${ISA_CHECK};
51 for (size_t m = 1; m <= ${MR}; m++) {
52 PackMicrokernelTester()
53 .mr(${MR})
54 .m(m)
55 .k(${KBLOCK})
56 .Test(${UKERNEL_NAME});
57 }
58}
59
60$if KBLOCK != 1:
61 TEST(${TEST_NAME}, k_lt_${KBLOCK}) {
62 $if ISA_CHECK:
63 ${ISA_CHECK};
64 for (size_t k = 1; k < ${KBLOCK}; k++) {
65 PackMicrokernelTester()
66 .mr(${MR})
67 .m(${MR})
68 .k(k)
69 .Test(${UKERNEL_NAME});
70 }
71 }
72
73 TEST(${TEST_NAME}, k_lt_${KBLOCK}_subtile) {
74 $if ISA_CHECK:
75 ${ISA_CHECK};
76 for (size_t k = 1; k < ${KBLOCK}; k++) {
77 for (size_t m = 1; m <= ${MR}; m++) {
78 PackMicrokernelTester()
79 .mr(${MR})
80 .m(m)
81 .k(k)
82 .Test(${UKERNEL_NAME});
83 }
84 }
85 }
86
87TEST(${TEST_NAME}, k_gt_${KBLOCK}) {
88 $if ISA_CHECK:
89 ${ISA_CHECK};
90 for (size_t k = ${KBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
91 PackMicrokernelTester()
92 .mr(${MR})
93 .m(${MR})
94 .k(k)
95 .Test(${UKERNEL_NAME});
96 }
97}
98
99TEST(${TEST_NAME}, k_gt_${KBLOCK}_subtile) {
100 $if ISA_CHECK:
101 ${ISA_CHECK};
102 for (size_t k = ${KBLOCK + 1}; k < ${10 if KBLOCK == 1 else KBLOCK * 2}; k++) {
103 for (size_t m = 1; m <= ${MR}; m++) {
104 PackMicrokernelTester()
105 .mr(${MR})
106 .m(m)
107 .k(k)
108 .Test(${UKERNEL_NAME});
109 }
110 }
111}
112
113$if KBLOCK > 1:
114 TEST(${TEST_NAME}, k_div_${KBLOCK}) {
115 $if ISA_CHECK:
116 ${ISA_CHECK};
117 for (size_t k = ${KBLOCK * 2}; k < ${KBLOCK * 10}; k += ${KBLOCK}) {
118 PackMicrokernelTester()
119 .mr(${MR})
120 .m(${MR})
121 .k(k)
122 .Test(${UKERNEL_NAME});
123 }
124 }
125
126 TEST(${TEST_NAME}, k_div_${KBLOCK}_subtile) {
127 $if ISA_CHECK:
128 ${ISA_CHECK};
129 for (size_t k = ${KBLOCK * 2}; k < ${KBLOCK * 10}; k += ${KBLOCK}) {
130 for (size_t m = 1; m <= ${MR}; m++) {
131 PackMicrokernelTester()
132 .mr(${MR})
133 .m(m)
134 .k(k)
135 .Test(${UKERNEL_NAME});
136 }
137 }
138 }
139
140TEST(${TEST_NAME}, strided_x) {
141 $if ISA_CHECK:
142 ${ISA_CHECK};
143 for (size_t k = 1; k <= ${KBLOCK * 5}; k += ${KBLOCK + 1}) {
144 PackMicrokernelTester()
145 .mr(${MR})
146 .m(${MR})
147 .k(k)
148 .x_stride(${next_prime(KBLOCK * 5 + 1)})
149 .Test(${UKERNEL_NAME});
150 }
151}
152"""
153
154
155def generate_test_cases(ukernel, mr, k_block, isa):
156 """Generates all tests cases for a GEMM micro-kernel.
157
158 Args:
159 ukernel: C name of the micro-kernel function.
160 mr: MR parameter of the PACK micro-kernel.
161 k_block: Number of K values processed per one iteration of the main loop of
162 the micro-kernel.
163 isa: instruction set required to run the micro-kernel. Generated unit test
164 will skip execution if the host processor doesn't support this ISA.
165
166 Returns:
167 Code for the test case.
168 """
169 _, test_name = ukernel.split("_", 1)
170 _, datatype, ukernel_type, _ = ukernel.split("_", 3)
171 test_args = [ukernel]
172 return xngen.preprocess(PACK_TEST_CODE, {
173 "TEST_NAME": test_name.upper().replace("UKERNEL_", ""),
174 "UKERNEL_TYPE": ukernel_type.upper(),
175 "UKERNEL_NAME": ukernel,
176 "DATATYPE": datatype,
177 "MR": mr,
178 "KBLOCK": k_block,
Marat Dukhan918a4a62019-10-27 19:49:49 -0700179 "ISA_CHECK": xnncommon.generate_isa_check_macro(isa),
XNNPACK Teamb455b122019-09-27 18:10:33 -0700180 "next_prime": next_prime,
181 })
182
183
184def main(args):
185 options = parser.parse_args(args)
186
187 with codecs.open(options.spec, "r", encoding="utf-8") as spec_file:
188 spec_yaml = yaml.safe_load(spec_file)
189 if not isinstance(spec_yaml, list):
190 raise ValueError("expected a list of micro-kernels in the spec")
191
192 tests = """\
193// Copyright 2019 Google LLC
194//
195// This source code is licensed under the BSD-style license found in the
196// LICENSE file in the root directory of this source tree.
197//
198// Auto-generated file. Do not edit!
199// Specification: {specification}
200// Generator: {generator}
201
202
Marat Dukhan629a33e2019-10-01 10:39:14 -0700203#include <gtest/gtest.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700204
Marat Dukhan1dadbf72019-10-01 10:46:20 -0700205#include <xnnpack/common.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700206#include <xnnpack/isa-checks.h>
207
Marat Dukhan1dadbf72019-10-01 10:46:20 -0700208#include <xnnpack/packx.h>
XNNPACK Teamb455b122019-09-27 18:10:33 -0700209#include "pack-microkernel-tester.h"
210""".format(specification=options.spec, generator=sys.argv[0])
211
212 for ukernel_spec in spec_yaml:
213 name = ukernel_spec["name"]
214 k_block = int(ukernel_spec["k-block"])
215 mr, arch, isa = split_ukernel_name(name)
216
217 # specification can override architecture
218 arch = ukernel_spec.get("arch", arch)
219
220 test_case = generate_test_cases(name, mr, k_block, isa)
Marat Dukhan918a4a62019-10-27 19:49:49 -0700221 tests += "\n\n" + xnncommon.postprocess_test_case(test_case, arch, isa)
XNNPACK Teamb455b122019-09-27 18:10:33 -0700222
223 with codecs.open(options.output, "w", encoding="utf-8") as output_file:
224 output_file.write(tests)
225
226
227if __name__ == "__main__":
228 main(sys.argv[1:])