Artem Belevich | 5fe85a0 | 2019-04-25 22:28:09 +0000 | [diff] [blame^] | 1 | # This script generates all variants of wmma builtins, verifies that clang calls |
| 2 | # correct LLVM instrinsics, and checks that availability of specific builtins is |
| 3 | # constrained by the correct PTX version and the target GPU variant. |
| 4 | |
| 5 | # Dummy test run to avoid lit warnings. |
| 6 | # RUN: echo "This is not a real test. It's a generator for builtins-nvpts-mma.cu" >/dev/null |
| 7 | |
| 8 | from __future__ import print_function |
| 9 | |
| 10 | import argparse |
| 11 | from collections import defaultdict |
| 12 | from itertools import product |
| 13 | from string import Template |
| 14 | |
| 15 | class MMAFrag: |
| 16 | def __init__(self, geom, frag, ptx_elt_type): |
| 17 | self.geom = geom |
| 18 | self.frag = frag |
| 19 | self.ptx_type = ptx_elt_type; |
| 20 | |
| 21 | def __repr__(self): |
| 22 | return "%s:%s:%s" % (self.geom, self.frag, self.ptx_type) |
| 23 | |
| 24 | class MMAOp: |
| 25 | def __init__(self, a, b, c, d): |
| 26 | self.a = a |
| 27 | self.b = b |
| 28 | self.c = c |
| 29 | self.d = d |
| 30 | |
| 31 | def __repr__(self): |
| 32 | return ("{A:%s, B:%s, C:%s, D:%s}" % (self.a, self.b, self.c, self.d )) |
| 33 | |
| 34 | def make_mma_ops(geoms, types_a, types_b, types_c, types_d): |
| 35 | ops = [] |
| 36 | for geom, type_a, type_c in product( geoms, types_a, types_c): |
| 37 | for type_b, type_d in product(types_b if types_b else [type_a], |
| 38 | types_d if types_d else [type_c]): |
| 39 | ops.append(MMAOp(MMAFrag(geom, "a", type_a), |
| 40 | MMAFrag(geom, "b", type_b), |
| 41 | MMAFrag(geom, "c", type_c), |
| 42 | MMAFrag(geom, "d", type_d))) |
| 43 | return ops |
| 44 | |
| 45 | def make_ldst_ops(geoms, frags, types): |
| 46 | return [MMAFrag(geom, frag, ptx_type) for (geom, frag, ptx_type) |
| 47 | in product(geoms, frags, types)] |
| 48 | |
| 49 | def get_mma_ops(): |
| 50 | return (make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], |
| 51 | ["f16"], [], ["f16", "f32"], ["f16", "f32"]) + |
| 52 | make_mma_ops(["m16n16k16", "m32n8k16", "m8n32k16"], |
| 53 | ["s8", "u8"], [], ["s32"], []) + |
| 54 | make_mma_ops(["m8n8k32"], |
| 55 | ["s4", "u4"], [], ["s32"], []) + |
| 56 | make_mma_ops(["m8n8k128"], |
| 57 | ["b1"], [], ["s32"], [])) |
| 58 | def get_ldst_ops(): |
| 59 | return (make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], |
| 60 | ["a", "b"], ["f16", "u8", "s8"]) + |
| 61 | make_ldst_ops(["m16n16k16", "m32n8k16", "m8n32k16"], |
| 62 | ["c", "d"], ["f16", "f32", "s32"]) + |
| 63 | make_ldst_ops(["m8n8k32"], ["a", "b"], ["s4","u4"]) + |
| 64 | make_ldst_ops(["m8n8k128"], ["a", "b"], ["b1"]) + |
| 65 | make_ldst_ops(["m8n8k32", "m8n8k128"], ["c", "d"], ["s32"])) |
| 66 | |
| 67 | def is_geom_supported(geom): |
| 68 | # geometries for FP and ints. |
| 69 | if geom in ["m8n32k16", "m32n8k16"]: |
| 70 | return ptx_version >= 61 |
| 71 | # geometries for sub-ints. |
| 72 | if geom in ["m8n8k32", "m8n8k128"]: |
| 73 | return ptx_version >= 63 and gpu_arch >= 75 |
| 74 | if geom == "m16n16k16": |
| 75 | return ptx_version >= 60 |
| 76 | assert(False) # Unexpected geometry. |
| 77 | |
| 78 | def is_type_supported(ptx_type): |
| 79 | if ptx_type in ["s8", "u8", "s32"]: |
| 80 | return ptx_version >= 63 and gpu_arch >= 72 |
| 81 | if ptx_type in ["s4", "u4", "b1"]: |
| 82 | return ptx_version >= 63 and gpu_arch >= 75 |
| 83 | return ptx_version >= 60 and gpu_arch >= 70 |
| 84 | |
| 85 | def is_mma_variant_supported(op, layout_a, layout_b, satf): |
| 86 | if not (is_type_supported(op.a.ptx_type) |
| 87 | and is_geom_supported(op.a.geom)): |
| 88 | return False |
| 89 | # sub-integer require row/col layout, and no satf. |
| 90 | if op.a.ptx_type in ["s4", "u4", "b1"]: |
| 91 | if op.a.ptx_type == "b1" and satf: |
| 92 | return False |
| 93 | return layout_a == "row" and layout_b == "col" |
| 94 | return True |
| 95 | |
| 96 | def is_ldst_variant_supported(frag, layout): |
| 97 | if not (is_type_supported(frag.ptx_type) |
| 98 | and is_geom_supported(frag.geom)): |
| 99 | return False |
| 100 | if frag.ptx_type in ["s4", "u4", "b1"]: |
| 101 | # sub-integer require sm_75 and ptx63, row/col layout for a/b. |
| 102 | return ((frag.frag == "a" and layout == "row") |
| 103 | or (frag.frag == "b" and layout == "col") |
| 104 | or frag.frag in ["c", "d"]) |
| 105 | return True |
| 106 | |
| 107 | def get_builtin_prefix(frag): |
| 108 | prefix = None |
| 109 | if frag.geom in ["m16n16k16", "m32n8k16", "m8n32k16"]: |
| 110 | if frag.ptx_type in ["f16", "f32"]: |
| 111 | prefix = "__hmma" |
| 112 | else: |
| 113 | prefix = "__imma" |
| 114 | elif frag.geom == "m8n8k32": |
| 115 | prefix = "__imma" # sub-integers |
| 116 | elif frag.geom == "m8n8k128": |
| 117 | prefix = "__bmma" |
| 118 | assert prefix |
| 119 | return prefix |
| 120 | |
| 121 | def get_ldst_builtin_name(frag): |
| 122 | prefix = get_builtin_prefix(frag) |
| 123 | |
| 124 | if prefix == "__hmma": |
| 125 | suffix = "" if frag.frag in ["a","b"] else frag.ptx_type |
| 126 | elif prefix in ["__imma", "__bmma"]: |
| 127 | suffix = "" if frag.frag in ["c"] else frag.ptx_type |
| 128 | if suffix == "s32": |
| 129 | suffix = "i32" |
| 130 | if frag.frag == "d": |
| 131 | ifrag = "c" |
| 132 | op = "st" |
| 133 | else: |
| 134 | ifrag = frag.frag |
| 135 | op = "ld" |
| 136 | |
| 137 | name = "%s_%s_%s_%s%s" % (prefix, frag.geom, op, ifrag, |
| 138 | "_" + suffix if suffix else "") |
| 139 | return name |
| 140 | |
| 141 | def get_mma_builtin_name(op): |
| 142 | prefix = get_builtin_prefix(op.a) |
| 143 | |
| 144 | if prefix == "__hmma": |
| 145 | suffix = op.d.ptx_type + op.c.ptx_type |
| 146 | else: |
| 147 | suffix = op.a.ptx_type |
| 148 | |
| 149 | name = "%s_%s_mma%s_%s" % (prefix, op.a.geom, |
| 150 | "_xor_popc" if op.a.ptx_type == "b1" else "", |
| 151 | suffix) |
| 152 | return name |
| 153 | |
| 154 | |
| 155 | def get_required_sm(frag): |
| 156 | if frag.ptx_type in ["u4", "s4", "b1"]: |
| 157 | return 75 |
| 158 | if frag.ptx_type in ["s8", "u8"]: |
| 159 | return 72 |
| 160 | if frag.ptx_type == "s32": |
| 161 | if frag.geom in ["m8n8k32", "m8n8k128"]: # s4/u4/b1 |
| 162 | return 75 |
| 163 | else: # s8/u8 |
| 164 | return 72 |
| 165 | if frag.ptx_type in ["f16", "f32"]: |
| 166 | return 70 |
| 167 | assert(False) |
| 168 | |
| 169 | def get_required_ptx(frag): |
| 170 | if frag.ptx_type in ["f16", "f32"]: |
| 171 | return 60 if frag.geom == "m16n16k16" else 61 |
| 172 | return 63 |
| 173 | |
| 174 | def gen_wmma_ldst_tests(results): |
| 175 | load_template = """ |
| 176 | // CHECK${check_suffix}: call {{.*}} @${intrinsic} |
| 177 | // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}} |
| 178 | ${builtin}(${dst}, ${src}, ldm, ${blayout}); |
| 179 | """.rstrip() |
| 180 | intrinsic_template = "llvm.nvvm.wmma.${geom}.${op}.${frag}.${ilayout}.stride.${itype}" |
| 181 | |
| 182 | for frag, layout in sorted(product(get_ldst_ops(), ["row","col"]), key=str): |
| 183 | |
| 184 | if not is_ldst_variant_supported(frag, layout): |
| 185 | continue |
| 186 | |
| 187 | is_fp = frag.ptx_type == "f32" |
| 188 | min_sm = get_required_sm(frag) |
| 189 | min_ptx = get_required_ptx(frag) |
| 190 | params = { |
| 191 | "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), |
| 192 | "builtin" : get_ldst_builtin_name(frag), |
| 193 | "min_ptx" : min_ptx, |
| 194 | "min_sm" : min_sm, |
| 195 | "dst": "fdst" if is_fp else "dst", |
| 196 | "src": "fsrc" if is_fp else "src", |
| 197 | "blayout" : 0 if layout == "row" else 1, |
| 198 | "intrinsic" : Template(intrinsic_template).substitute({ |
| 199 | "frag" : frag.frag, |
| 200 | "geom" : frag.geom, |
| 201 | "ilayout" : layout, |
| 202 | "itype" : frag.ptx_type, |
| 203 | "op" : "store" if frag.frag == "d" else "load", |
| 204 | }) |
| 205 | } |
| 206 | results[(min_ptx,min_sm)] += Template(load_template).substitute(params) |
| 207 | |
| 208 | return results |
| 209 | |
| 210 | def mma_signature(op): |
| 211 | if op.a.ptx_type in ["s8", "u8", "s4", "u4", "b1"]: |
| 212 | # int and sub-int ops are identified by input type. |
| 213 | return op.a.ptx_type |
| 214 | else: |
| 215 | # the rest are FP ops identified by accumulator & result type. |
| 216 | return "%s.%s" % (op.d.ptx_type, op.c.ptx_type) |
| 217 | |
| 218 | # Get numeric value for rowcol parameter of the builtin |
| 219 | # AFAICT it uses the encoding accepted by NVVM intrinsics: |
| 220 | # https://docs.nvidia.com/cuda/nvvm-ir-spec/index.html#nvvm-intrin-warp-level-matrix-mma |
| 221 | def get_ilayout(a, b): |
| 222 | return { |
| 223 | "row.row" : 0, |
| 224 | "row.col" : 1, |
| 225 | "col.row" : 2, |
| 226 | "col.col" : 3 |
| 227 | }[a + "." + b] |
| 228 | |
| 229 | def gen_wmma_mma_tests(results): |
| 230 | mma_template = """ |
| 231 | // CHECK${check_suffix}: call {{.*}} @${intrinsic} |
| 232 | // expected-error-re@+1 {{'${builtin}' needs target feature sm_${min_sm}{{.*}},ptx${min_ptx}{{.*}}}} |
| 233 | ${builtin}(${dst}, ${asrc}, ${asrc}, ${csrc}, ${ilayout}${maybe_isatf}); |
| 234 | """.rstrip() |
| 235 | intrinsic_template = "llvm.nvvm.wmma.${geom}.mma.${alayout}.${blayout}.${intrinsic_signature}${satf}" |
| 236 | |
| 237 | for op, alayout, blayout, satf in sorted(product( get_mma_ops(), |
| 238 | ["row","col"], |
| 239 | ["row","col"], |
| 240 | [".satfinite", ""]), |
| 241 | key=str): |
| 242 | |
| 243 | if not is_mma_variant_supported(op, alayout, blayout, satf): |
| 244 | continue |
| 245 | |
| 246 | a_is_fp = op.a.ptx_type == "f32" |
| 247 | c_is_fp = op.c.ptx_type == "f32" |
| 248 | d_is_fp = op.d.ptx_type == "f32" |
| 249 | min_sm = get_required_sm(op.a) |
| 250 | min_ptx = get_required_ptx(op.a) |
| 251 | if op.a.ptx_type == "b1": # .b1 MMA has no satf argument. |
| 252 | isatf_arg = "" |
| 253 | else: |
| 254 | isatf_arg = ", 1" if satf else ", 0" |
| 255 | params = { |
| 256 | "check_suffix" : "_PTX%d_SM%d" % (min_ptx, min_sm), |
| 257 | "builtin" : get_mma_builtin_name(op), |
| 258 | "min_ptx" : min_ptx, |
| 259 | "min_sm" : min_sm, |
| 260 | "dst": "fdst" if d_is_fp else "dst", |
| 261 | "asrc": "fsrc" if a_is_fp else "src", |
| 262 | "csrc": "fsrc" if c_is_fp else "src", |
| 263 | "ilayout" : get_ilayout(alayout, blayout), |
| 264 | "maybe_isatf" : isatf_arg, |
| 265 | "intrinsic" : Template(intrinsic_template).substitute({ |
| 266 | "geom" : op.a.geom, |
| 267 | "alayout" : alayout, |
| 268 | "blayout" : blayout, |
| 269 | "intrinsic_signature" : mma_signature(op), |
| 270 | "satf" : satf, |
| 271 | }) |
| 272 | } |
| 273 | results[(min_ptx, min_sm)] += Template(mma_template).substitute(params) |
| 274 | |
| 275 | return results |
| 276 | |
| 277 | def gen_tests(): |
| 278 | results = gen_wmma_ldst_tests(defaultdict(str)) |
| 279 | results = gen_wmma_mma_tests(results) |
| 280 | |
| 281 | run_template = r""" |
| 282 | // |
| 283 | // *** DO NOT EDIT *** |
| 284 | // |
| 285 | // This test has been automatically generated by |
| 286 | // builtins-nvtx-mma.py --ptx=${ptx} --gpu-arch=${sm} |
| 287 | // |
| 288 | // Make sure we can handle all builtins available on sm_${sm} with PTX${ptx} |
| 289 | // ${run}: %clang_cc1 -triple nvptx64-unknown-unknown -target-cpu sm_${sm} \ |
| 290 | // ${run}: -fcuda-is-device -target-feature +ptx${ptx} \ |
| 291 | // ${run}: -DPTX=${ptx} -DSM=${sm} \ |
| 292 | // ${run}: -S -emit-llvm -o - -x cuda %s \ |
| 293 | // ${run}: | FileCheck -check-prefixes=${check_labels} %s |
| 294 | // Verify that all builtins have correct constraints. |
| 295 | // ${run}: %clang_cc1 -triple nvptx-unknown-unknown \ |
| 296 | // ${run}: -target-cpu sm_60 -target-feature +ptx42 \ |
| 297 | // ${run}: -DPTX=${ptx} -DSM=${sm} -fcuda-is-device -S -o /dev/null -x cuda \ |
| 298 | // ${run}: -verify %s |
| 299 | """ |
| 300 | def supported_variants(ptx, sm, results): |
| 301 | return [(ptx_, sm_) for ptx_, sm_ in results if ptx_ <= ptx and sm_ <= sm] |
| 302 | |
| 303 | print(Template(run_template).substitute({ |
| 304 | "run" : "RUN", # To avoid lit misinterpreting the template |
| 305 | "ptx" : ptx_version, |
| 306 | "sm" : gpu_arch, |
| 307 | "check_labels" : ",".join(["CHECK_PTX%d_SM%d" % (ptx_, sm_) |
| 308 | for ptx_, sm_ |
| 309 | in supported_variants(ptx_version, gpu_arch, |
| 310 | results)]) |
| 311 | })) |
| 312 | |
| 313 | print(""" |
| 314 | #if !defined(CUDA_VERSION) |
| 315 | #define __device__ __attribute__((device)) |
| 316 | #define __global__ __attribute__((global)) |
| 317 | #define __shared__ __attribute__((shared)) |
| 318 | #define __constant__ __attribute__((constant)) |
| 319 | |
| 320 | typedef unsigned long long uint64_t; |
| 321 | #endif |
| 322 | |
| 323 | // CHECK-LABEL: test_wmma_buitins |
| 324 | __device__ void test_wmma_buitins(int *src, int *dst, |
| 325 | float *fsrc, float *fdst, int ldm) { |
| 326 | """); |
| 327 | |
| 328 | for (ptx, sm), tests in sorted(results.items()): |
| 329 | print() |
| 330 | print("#if (PTX >= %d) && (SM >= %d)" % (ptx, sm)) |
| 331 | print(tests) |
| 332 | print("#endif // (PTX >= %d) && (SM >= %d) "% (ptx, sm)) |
| 333 | |
| 334 | print("}") |
| 335 | |
| 336 | parser = argparse.ArgumentParser() |
| 337 | parser.add_argument("--ptx", type=int, default=60) |
| 338 | parser.add_argument("--gpu-arch", type=int, default=70) |
| 339 | args = parser.parse_args() |
| 340 | ptx_version = args.ptx |
| 341 | gpu_arch = args.gpu_arch |
| 342 | |
| 343 | gen_tests() |