[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.
The new instructions were added added for sm_70+ GPUs in CUDA-9.1.
Differential Revision: https://reviews.llvm.org/D45068
llvm-svn: 330296
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index fffc242..6a2f2b0 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10715,7 +10715,15 @@
case NVPTX::BI__hmma_m16n16k16_ld_a:
case NVPTX::BI__hmma_m16n16k16_ld_b:
case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
- case NVPTX::BI__hmma_m16n16k16_ld_c_f32: {
+ case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
+ case NVPTX::BI__hmma_m32n8k16_ld_a:
+ case NVPTX::BI__hmma_m32n8k16_ld_b:
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+ case NVPTX::BI__hmma_m8n32k16_ld_a:
+ case NVPTX::BI__hmma_m8n32k16_ld_b:
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f32: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Value *Src = EmitScalarExpr(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -10746,6 +10754,46 @@
: Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride;
NumResults = 8;
break;
+ case NVPTX::BI__hmma_m32n8k16_ld_a:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_ld_b:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_a:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_b:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride;
+ NumResults = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_ld_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride;
+ NumResults = 8;
+ break;
default:
llvm_unreachable("Unexpected builtin ID.");
}
@@ -10764,7 +10812,11 @@
}
case NVPTX::BI__hmma_m16n16k16_st_c_f16:
- case NVPTX::BI__hmma_m16n16k16_st_c_f32: {
+ case NVPTX::BI__hmma_m16n16k16_st_c_f32:
+ case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+ case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+ case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+ case NVPTX::BI__hmma_m8n32k16_st_c_f32: {
Value *Dst = EmitScalarExpr(E->getArg(0));
Address Src = EmitPointerWithAlignment(E->getArg(1));
Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -10786,6 +10838,24 @@
IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride
: Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride;
break;
+ case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride
+ : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride;
+ NumResults = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_st_c_f32:
+ IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride
+ : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride;
+ break;
default:
llvm_unreachable("Unexpected builtin ID.");
}
@@ -10808,7 +10878,15 @@
case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
- case NVPTX::BI__hmma_m16n16k16_mma_f16f32: {
+ case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32: {
Address Dst = EmitPointerWithAlignment(E->getArg(0));
Address SrcA = EmitPointerWithAlignment(E->getArg(1));
Address SrcB = EmitPointerWithAlignment(E->getArg(2));
@@ -10825,15 +10903,15 @@
bool Satf = SatfArg.getSExtValue();
// clang-format off
-#define MMA_VARIANTS(type) {{ \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type, \
- Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite \
+#define MMA_VARIANTS(geom, type) {{ \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type, \
+ Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite \
}}
// clang-format on
@@ -10847,22 +10925,62 @@
unsigned NumEltsD;
switch (BuiltinID) {
case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16));
NumEltsC = 4;
NumEltsD = 4;
break;
case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
- IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16));
NumEltsC = 4;
NumEltsD = 8;
break;
case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32));
NumEltsC = 8;
NumEltsD = 4;
break;
case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
- IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32));
+ IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32));
+ NumEltsC = 8;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16));
+ NumEltsC = 4;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16));
+ NumEltsC = 4;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32));
+ NumEltsC = 8;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32));
+ NumEltsC = 8;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16));
+ NumEltsC = 4;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16));
+ NumEltsC = 4;
+ NumEltsD = 8;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32));
+ NumEltsC = 8;
+ NumEltsD = 4;
+ break;
+ case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+ IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32));
NumEltsC = 8;
NumEltsD = 8;
break;