blob: 1b395fc4f33b172b67b0249e1d2164fd1c3872bb [file] [log] [blame]
Artem Belevich5fe85a02019-04-25 22:28:09 +00001# This script generates all variants of wmma builtins, verifies that clang calls
2# correct LLVM instrinsics, and checks that availability of specific builtins is
3# constrained by the correct PTX version and the target GPU variant.
4
5# Dummy test run to avoid lit warnings.
6# RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null
7
8from __future__ import print_function
9
10import argparse
11from collections import defaultdict
12from itertools import product
13from string import Template
14
15class MMAFrag:
16 def __init__(self, geom, frag, ptx_elt_type):
17 self.geom = geom
18 self.frag = frag
19 self.ptx_type = ptx_elt_type;
20
21 def __repr__(self):
22 return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type)
23
24class MMAOp:
25 def __init__(self, a, b, c, d):
26 self.a = a
27 self.b = b
28 self.c = c
29 self.d = d
30
31 def __repr__(self):
32 return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d ))
33
34def make_mma_ops(geoms, types_a, types_b, types_c, types_d):
35 ops = []
36 for geom, type_a, type_c in product( geoms, types_a, types_c):
37 for type_b, type_d in product(types_b if types_b else [type_a],
38 types_d if types_d else [type_c]):
39 ops.append(MMAOp(MMAFrag(geom, "a", type_a),
40 MMAFrag(geom, "b", type_b),
41 MMAFrag(geom, "c", type_c),
42 MMAFrag(geom, "d", type_d)))
43 return ops
44
45def make_ldst_ops(geoms, frags, types):
46 return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type)
47 in product(geoms, frags, types)]
48
49def get_mma_ops():
50 return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
51 ["f16"], [], ["f16", "f32"], ["f16", "f32"]) +
52 make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
53 ["s8", "u8"], [], ["s32"], []) +
54 make_mma_ops(["m8n8k32"],
55 ["s4", "u4"], [], ["s32"], []) +
56 make_mma_ops(["m8n8k128"],
57 ["b1"], [], ["s32"], []))
58def get_ldst_ops():
59 return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
60 ["a", "b"], ["f16", "u8", "s8"]) +
61 make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"],
62 ["c", "d"], ["f16", "f32", "s32"]) +
63 make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) +
64 make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) +
65 make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"]))
66
67def is_geom_supported(geom):
68 # geometries for FP and ints.
69 if geom in ["m8n32k16", "m32n8k16"]:
70 return ptx_version >= 61
71 # geometries for sub-ints.
72 if geom in ["m8n8k32", "m8n8k128"]:
73 return ptx_version >= 63 and gpu_arch >= 75
74 if geom == "m16n16k16":
75 return ptx_version >= 60
76 assert(False) # Unexpected geometry.
77
78def is_type_supported(ptx_type):
79 if ptx_type in ["s8", "u8", "s32"]:
80 return ptx_version >= 63 and gpu_arch >= 72
81 if ptx_type in ["s4", "u4", "b1"]:
82 return ptx_version >= 63 and gpu_arch >= 75
83 return ptx_version >= 60 and gpu_arch >= 70
84
85def is_mma_variant_supported(op, layout_a, layout_b, satf):
86 if not (is_type_supported(op.a.ptx_type)
87 and is_geom_supported(op.a.geom)):
88 return False
89 # sub-integer require row/col layout, and no satf.
90 if op.a.ptx_type in ["s4", "u4", "b1"]:
91 if op.a.ptx_type == "b1" and satf:
92 return False
93 return layout_a == "row" and layout_b == "col"
94 return True
95
96def is_ldst_variant_supported(frag, layout):
97 if not (is_type_supported(frag.ptx_type)
98 and is_geom_supported(frag.geom)):
99 return False
100 if frag.ptx_type in ["s4", "u4", "b1"]:
101 # sub-integer require sm_75 and ptx63, row/col layout for a/b.
102 return ((frag.frag == "a" and layout == "row")
103 or (frag.frag == "b" and layout == "col")
104 or frag.frag in ["c", "d"])
105 return True
106
107def get_builtin_prefix(frag):
108 prefix = None
109 if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]:
110 if frag.ptx_type in ["f16", "f32"]:
111 prefix = "__hmma"
112 else:
113 prefix = "__imma"
114 elif frag.geom == "m8n8k32":
115 prefix = "__imma" # sub-integers
116 elif frag.geom == "m8n8k128":
117 prefix = "__bmma"
118 assert prefix
119 return prefix
120
121def get_ldst_builtin_name(frag):
122 prefix = get_builtin_prefix(frag)
123
124 if prefix == "__hmma":
125 suffix = "" if frag.frag in ["a","b"] else frag.ptx_type
126 elif prefix in ["__imma", "__bmma"]:
127 suffix = "" if frag.frag in ["c"] else frag.ptx_type
128 if suffix == "s32":
129 suffix = "i32"
130 if frag.frag == "d":
131 ifrag = "c"
132 op = "st"
133 else:
134 ifrag = frag.frag
135 op = "ld"
136
137 name = "%s_%s_%s_%s%s" % (prefix, frag.geom, op, ifrag,
138 "_" + suffix if suffix else "")
139 return name
140
141def get_mma_builtin_name(op):
142 prefix = get_builtin_prefix(op.a)
143
144 if prefix == "__hmma":
145 suffix = op.d.ptx_type + op.c.ptx_type
146 else:
147 suffix = op.a.ptx_type
148
149 name = "%s_%s_mma%s_%s" % (prefix, op.a.geom,
150 "_xor_popc" if op.a.ptx_type == "b1" else "",
151 suffix)
152 return name
153
154
155def get_required_sm(frag):
156 if frag.ptx_type in ["u4", "s4", "b1"]:
157 return 75
158 if frag.ptx_type in ["s8", "u8"]:
159 return 72
160 if frag.ptx_type == "s32":
161 if frag.geom in ["m8n8k32", "m8n8k128"]: # s4/u4/b1
162 return 75
163 else: # s8/u8
164 return 72
165 if frag.ptx_type in ["f16", "f32"]:
166 return 70
167 assert(False)
168
169def get_required_ptx(frag):
170 if frag.ptx_type in ["f16", "f32"]:
171 return 60 if frag.geom == "m16n16k16" else 61
172 return 63
173
174def gen_wmma_ldst_tests(results):
175 load_template = """
176 // CHECK${check_suffix}: call {{.*}} @${intrinsic}
177 // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}}
178 ${builtin}(${dst}, ${src}, ldm, ${blayout});
179""".rstrip()
180 intrinsic_template = "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}"
181
182 for frag, layout in sorted(product(get_ldst_ops(), ["row","col"]), key=str):
183
184 if not is_ldst_variant_supported(frag, layout):
185 continue
186
187 is_fp = frag.ptx_type == "f32"
188 min_sm = get_required_sm(frag)
189 min_ptx = get_required_ptx(frag)
190 params = {
191 "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm),
192 "builtin" : get_ldst_builtin_name(frag),
193 "min_ptx" : min_ptx,
194 "min_sm" : min_sm,
195 "dst": "fdst" if is_fp else "dst",
196 "src": "fsrc" if is_fp else "src",
197 "blayout" : 0 if layout == "row" else 1,
198 "intrinsic" : Template(intrinsic_template).substitute({
199 "frag" : frag.frag,
200 "geom" : frag.geom,
201 "ilayout" : layout,
202 "itype" : frag.ptx_type,
203 "op" : "store" if frag.frag == "d" else "load",
204 })
205 }
206 results[(min_ptx,min_sm)] += Template(load_template).substitute(params)
207
208 return results
209
210def mma_signature(op):
211 if op.a.ptx_type in ["s8", "u8", "s4", "u4", "b1"]:
212 # int and sub-int ops are identified by input type.
213 return op.a.ptx_type
214 else:
215 # the rest are FP ops identified by accumulator & result type.
216 return "%s.%s" % (op.d.ptx_type, op.c.ptx_type)
217
218# Get numeric value for rowcol parameter of the builtin
219# AFAICT it uses the encoding accepted by NVVM intrinsics:
220# https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma
221def get_ilayout(a, b):
222 return {
223 "row.row" : 0,
224 "row.col" : 1,
225 "col.row" : 2,
226 "col.col" : 3
227 }[a + "." + b]
228
229def gen_wmma_mma_tests(results):
230 mma_template = """
231 // CHECK${check_suffix}: call {{.*}} @${intrinsic}
232 // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}}
233 ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_isatf});
234""".rstrip()
235 intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}"
236
237 for op, alayout, blayout, satf in sorted(product( get_mma_ops(),
238 ["row","col"],
239 ["row","col"],
240 [".satfinite", ""]),
241 key=str):
242
243 if not is_mma_variant_supported(op, alayout, blayout, satf):
244 continue
245
246 a_is_fp = op.a.ptx_type == "f32"
247 c_is_fp = op.c.ptx_type == "f32"
248 d_is_fp = op.d.ptx_type == "f32"
249 min_sm = get_required_sm(op.a)
250 min_ptx = get_required_ptx(op.a)
251 if op.a.ptx_type == "b1": # .b1 MMA has no satf argument.
252 isatf_arg = ""
253 else:
254 isatf_arg = ", 1" if satf else ", 0"
255 params = {
256 "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm),
257 "builtin" : get_mma_builtin_name(op),
258 "min_ptx" : min_ptx,
259 "min_sm" : min_sm,
260 "dst": "fdst" if d_is_fp else "dst",
261 "asrc": "fsrc" if a_is_fp else "src",
262 "csrc": "fsrc" if c_is_fp else "src",
263 "ilayout" : get_ilayout(alayout, blayout),
264 "maybe_isatf" : isatf_arg,
265 "intrinsic" : Template(intrinsic_template).substitute({
266 "geom" : op.a.geom,
267 "alayout" : alayout,
268 "blayout" : blayout,
269 "intrinsic_signature" : mma_signature(op),
270 "satf" : satf,
271 })
272 }
273 results[(min_ptx, min_sm)] += Template(mma_template).substitute(params)
274
275 return results
276
277def gen_tests():
278 results = gen_wmma_ldst_tests(defaultdict(str))
279 results = gen_wmma_mma_tests(results)
280
281 run_template = r"""
282//
283// *** DO NOT EDIT ***
284//
285// This test has been automatically generated by
286// builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm}
287//
288// Make sure we can handle all builtins available on sm_${sm} with PTX${ptx}
289// ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \
290// ${run}: -fcuda-is-device -target-feature +ptx${ptx} \
291// ${run}: -DPTX=${ptx} -DSM=${sm} \
292// ${run}: -S -emit-llvm -o - -x cuda %s \
293// ${run}: | FileCheck -check-prefixes=${check_labels} %s
294// Verify that all builtins have correct constraints.
295// ${run}: %clang_cc1 -triple nvptx-unknown-unknown \
296// ${run}: -target-cpu sm_60 -target-feature +ptx42 \
297// ${run}: -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \
298// ${run}: -verify %s
299"""
300 def supported_variants(ptx, sm, results):
301 return [(ptx_, sm_) for ptx_, sm_ in results if ptx_ <= ptx and sm_ <= sm]
302
303 print(Template(run_template).substitute({
304 "run" : "RUN", # To avoid lit misinterpreting the template
305 "ptx" : ptx_version,
306 "sm" : gpu_arch,
307 "check_labels" : ",".join(["CHECK_PTX%d_SM%d" % (ptx_, sm_)
308 for ptx_, sm_
309 in supported_variants(ptx_version, gpu_arch,
310 results)])
311 }))
312
313 print("""
314#if !defined(CUDA_VERSION)
315#define __device__ __attribute__((device))
316#define __global__ __attribute__((global))
317#define __shared__ __attribute__((shared))
318#define __constant__ __attribute__((constant))
319
320typedef unsigned long long uint64_t;
321#endif
322
323// CHECK-LABEL: test_wmma_buitins
324__device__ void test_wmma_buitins(int *src, int *dst,
325 float *fsrc, float *fdst, int ldm) {
326""");
327
328 for (ptx, sm), tests in sorted(results.items()):
329 print()
330 print("#if (PTX >= %d) && (SM >= %d)" % (ptx, sm))
331 print(tests)
332 print("#endif // (PTX >= %d) && (SM >= %d) "% (ptx, sm))
333
334 print("}")
335
336parser = argparse.ArgumentParser()
337parser.add_argument("--ptx", type=int, default=60)
338parser.add_argument("--gpu-arch", type=int, default=70)
339args = parser.parse_args()
340ptx_version = args.ptx
341gpu_arch = args.gpu_arch
342
343gen_tests()