[NVPTX, CUDA] Added support for m8n32k16 and m32n8k16 variants of wmma instructions.

The new instructions were added added for sm_70+ GPUs in CUDA-9.1.

Differential Revision: https://reviews.llvm.org/D45068

llvm-svn: 330296
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index fffc242..6a2f2b0 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -10715,7 +10715,15 @@
   case NVPTX::BI__hmma_m16n16k16_ld_a:
   case NVPTX::BI__hmma_m16n16k16_ld_b:
   case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
-  case NVPTX::BI__hmma_m16n16k16_ld_c_f32: {
+  case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
+  case NVPTX::BI__hmma_m32n8k16_ld_a:
+  case NVPTX::BI__hmma_m32n8k16_ld_b:
+  case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+  case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+  case NVPTX::BI__hmma_m8n32k16_ld_a:
+  case NVPTX::BI__hmma_m8n32k16_ld_b:
+  case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+  case NVPTX::BI__hmma_m8n32k16_ld_c_f32: {
     Address Dst = EmitPointerWithAlignment(E->getArg(0));
     Value *Src = EmitScalarExpr(E->getArg(1));
     Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -10746,6 +10754,46 @@
                        : Intrinsic::nvvm_wmma_m16n16k16_load_c_f32_row_stride;
       NumResults = 8;
       break;
+    case NVPTX::BI__hmma_m32n8k16_ld_a:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m32n8k16_load_a_f16_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_ld_b:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m32n8k16_load_b_f16_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_ld_c_f16:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m32n8k16_load_c_f16_row_stride;
+      NumResults = 4;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_ld_c_f32:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_col_stride
+                       : Intrinsic::nvvm_wmma_m32n8k16_load_c_f32_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_ld_a:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m8n32k16_load_a_f16_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_ld_b:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m8n32k16_load_b_f16_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_ld_c_f16:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m8n32k16_load_c_f16_row_stride;
+      NumResults = 4;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_ld_c_f32:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_col_stride
+                       : Intrinsic::nvvm_wmma_m8n32k16_load_c_f32_row_stride;
+      NumResults = 8;
+      break;
     default:
       llvm_unreachable("Unexpected builtin ID.");
     }
@@ -10764,7 +10812,11 @@
   }
 
   case NVPTX::BI__hmma_m16n16k16_st_c_f16:
-  case NVPTX::BI__hmma_m16n16k16_st_c_f32: {
+  case NVPTX::BI__hmma_m16n16k16_st_c_f32:
+  case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+  case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+  case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+  case NVPTX::BI__hmma_m8n32k16_st_c_f32: {
     Value *Dst = EmitScalarExpr(E->getArg(0));
     Address Src = EmitPointerWithAlignment(E->getArg(1));
     Value *Ldm = EmitScalarExpr(E->getArg(2));
@@ -10786,6 +10838,24 @@
       IID = isColMajor ? Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_col_stride
                        : Intrinsic::nvvm_wmma_m16n16k16_store_d_f32_row_stride;
       break;
+    case NVPTX::BI__hmma_m32n8k16_st_c_f16:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m32n8k16_store_d_f16_row_stride;
+      NumResults = 4;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_st_c_f32:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_col_stride
+                       : Intrinsic::nvvm_wmma_m32n8k16_store_d_f32_row_stride;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_st_c_f16:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_col_stride
+                       : Intrinsic::nvvm_wmma_m8n32k16_store_d_f16_row_stride;
+      NumResults = 4;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_st_c_f32:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_col_stride
+                       : Intrinsic::nvvm_wmma_m8n32k16_store_d_f32_row_stride;
+      break;
     default:
       llvm_unreachable("Unexpected builtin ID.");
     }
@@ -10808,7 +10878,15 @@
   case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
   case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
   case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
-  case NVPTX::BI__hmma_m16n16k16_mma_f16f32: {
+  case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+  case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+  case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+  case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+  case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+  case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+  case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+  case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+  case NVPTX::BI__hmma_m8n32k16_mma_f16f32: {
     Address Dst = EmitPointerWithAlignment(E->getArg(0));
     Address SrcA = EmitPointerWithAlignment(E->getArg(1));
     Address SrcB = EmitPointerWithAlignment(E->getArg(2));
@@ -10825,15 +10903,15 @@
     bool Satf = SatfArg.getSExtValue();
 
     // clang-format off
-#define MMA_VARIANTS(type) {{                                        \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type,             \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_row_row_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type,             \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_row_col_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type,             \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_col_row_##type##_satfinite, \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type,             \
-      Intrinsic::nvvm_wmma_m16n16k16_mma_col_col_##type##_satfinite  \
+#define MMA_VARIANTS(geom, type) {{                                 \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type,             \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type,             \
+      Intrinsic::nvvm_wmma_##geom##_mma_row_col_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type,             \
+      Intrinsic::nvvm_wmma_##geom##_mma_col_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type,             \
+      Intrinsic::nvvm_wmma_##geom##_mma_col_col_##type##_satfinite  \
     }}
     // clang-format on
 
@@ -10847,22 +10925,62 @@
     unsigned NumEltsD;
     switch (BuiltinID) {
     case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
-      IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16));
+      IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f16));
       NumEltsC = 4;
       NumEltsD = 4;
       break;
     case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
-      IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16));
+      IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f16));
       NumEltsC = 4;
       NumEltsD = 8;
       break;
     case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
-      IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32));
+      IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f16_f32));
       NumEltsC = 8;
       NumEltsD = 4;
       break;
     case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
-      IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32));
+      IID = getMMAIntrinsic(MMA_VARIANTS(m16n16k16, f32_f32));
+      NumEltsC = 8;
+      NumEltsD = 8;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_mma_f16f16:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f16));
+      NumEltsC = 4;
+      NumEltsD = 4;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_mma_f32f16:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f16));
+      NumEltsC = 4;
+      NumEltsD = 8;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_mma_f16f32:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f16_f32));
+      NumEltsC = 8;
+      NumEltsD = 4;
+      break;
+    case NVPTX::BI__hmma_m32n8k16_mma_f32f32:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m32n8k16, f32_f32));
+      NumEltsC = 8;
+      NumEltsD = 8;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_mma_f16f16:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f16));
+      NumEltsC = 4;
+      NumEltsD = 4;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_mma_f32f16:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f16));
+      NumEltsC = 4;
+      NumEltsD = 8;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_mma_f16f32:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f16_f32));
+      NumEltsC = 8;
+      NumEltsD = 4;
+      break;
+    case NVPTX::BI__hmma_m8n32k16_mma_f32f32:
+      IID = getMMAIntrinsic(MMA_VARIANTS(m8n32k16, f32_f32));
       NumEltsC = 8;
       NumEltsD = 8;
       break;