[NVPTX] Make tensor shape part of WMMA intrinsic's name.
This is needed for the upcoming implementation of the
new 8x32x16 and 32x8x16 variants of WMMA instructions
introduced in CUDA 9.1.
Differential Revision: https://reviews.llvm.org/D44719
llvm-svn: 328158
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index 996e5e7..d3ea1f2 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10515,23 +10515,23 @@
unsigned NumResults;
switch (BuiltinID) {
case NVPTX::BI__hmma_m16n16k16_ld_a:
- IID = isColMajor ? Intrinsic::nvvm_wmma_load_a_f16_col_stride
- : Intrinsic::nvvm_wmma_load_a_f16_row_stride;
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_col_stride
+ : Intrinsic::nvvm_wmma_m16n16k16_load_a_f16_row_stride;
NumResults = 8;
break;
case NVPTX::BI__hmma_m16n16k16_ld_b:
- IID = isColMajor ? Intrinsic::nvvm_wmma_load_b_f16_col_stride
- : Intrinsic::nvvm_wmma_load_b_f16_row_stride;
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_col_stride
+ : Intrinsic::nvvm_wmma_m16n16k16_load_b_f16_row_stride;
NumResults = 8;
break;
case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f16_col_stride
- : Intrinsic::nvvm_wmma_load_c_f16_row_stride;
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_col_stride
+ : Intrinsic::nvvm_wmma_m16n16k16_load_c_f16_row_stride;
NumResults = 4;
break;
case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f32_col_stride
- : Intrinsic::nvvm_wmma_load_c_f32_row_stride;
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_col_stride
+ : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride;
NumResults = 8;
break;
default:
@@ -10566,13 +10566,13 @@
// for some reason nvcc builtins use _c_.
switch (BuiltinID) {
case NVPTX::BI__hmma_m16n16k16_st_c_f16:
- IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f16_col_stride
- : Intrinsic::nvvm_wmma_store_d_f16_row_stride;
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_col_stride
+ : Intrinsic::nvvm_wmma_m16n16k16_store_d_f16_row_stride;
NumResults = 4;
break;
case NVPTX::BI__hmma_m16n16k16_st_c_f32:
- IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f32_col_stride
- : Intrinsic::nvvm_wmma_store_d_f32_row_stride;
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride
+ : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride;
break;
default:
llvm_unreachable("Unexpected builtin ID.");
@@ -10591,8 +10591,8 @@
return Result;
}
- // BI__hmma_m16n16k16_mma_<Dtype><CType>(d, a, b, c, layout, satf)
- // --> Intrinsic::nvvm_wmma_mma_sync<layout A,B><DType><CType><Satf>
+ // BI__hmma_m16n16k16_mma_<Dtype><CType>(d, a, b, c, layout, satf) -->
+ // Intrinsic::nvvm_wmma_m16n16k16_mma_sync<layout A,B><DType><CType><Satf>
case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
@@ -10613,15 +10613,15 @@
bool Satf = SatfArg.getSExtValue();
// clang-format off
-#define MMA_VARIANTS(type) {{ \
- Intrinsic::nvvm_wmma_mma_sync_row_row_##type, \
- Intrinsic::nvvm_wmma_mma_sync_row_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_mma_sync_row_col_##type, \
- Intrinsic::nvvm_wmma_mma_sync_row_col_##type##_satfinite, \
- Intrinsic::nvvm_wmma_mma_sync_col_row_##type, \
- Intrinsic::nvvm_wmma_mma_sync_col_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_mma_sync_col_col_##type, \
- Intrinsic::nvvm_wmma_mma_sync_col_col_##type##_satfinite \
+#define MMA_VARIANTS(type) {{ \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type, \
+ Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite \
}}
// clang-format on