[CUDA] Added __hmma_m16n16k16_* builtins to support mma instructions on sm_70

Differential Revision: https://reviews.llvm.org/D38742

llvm-svn: 315624
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index f25a70a..e8d2c59 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -9731,6 +9731,204 @@
     Builder.CreateStore(Pred, PredOutPtr);
     return Builder.CreateExtractValue(ResultPair, 0);
   }
+  case NVPTX::BI__hmma_m16n16k16_ld_a:
+  case NVPTX::BI__hmma_m16n16k16_ld_b:
+  case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
+  case NVPTX::BI__hmma_m16n16k16_ld_c_f32: {
+    Address Dst = EmitPointerWithAlignment(E->getArg(0));
+    Value *Src = EmitScalarExpr(E->getArg(1));
+    Value *Ldm = EmitScalarExpr(E->getArg(2));
+    llvm::APSInt isColMajorArg;
+    if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext()))
+      return nullptr;
+    bool isColMajor = isColMajorArg.getSExtValue();
+    unsigned IID;
+    unsigned NumResults;
+    switch (BuiltinID) {
+    case NVPTX::BI__hmma_m16n16k16_ld_a:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_load_a_f16_col_stride
+                       : Intrinsic::nvvm_wmma_load_a_f16_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_ld_b:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_load_b_f16_col_stride
+                       : Intrinsic::nvvm_wmma_load_b_f16_row_stride;
+      NumResults = 8;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_ld_c_f16:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f16_col_stride
+                       : Intrinsic::nvvm_wmma_load_c_f16_row_stride;
+      NumResults = 4;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_ld_c_f32:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_load_c_f32_col_stride
+                       : Intrinsic::nvvm_wmma_load_c_f32_row_stride;
+      NumResults = 8;
+      break;
+    default:
+      llvm_unreachable("Unexpected builtin ID.");
+    }
+    Value *Result =
+        Builder.CreateCall(CGM.getIntrinsic(IID),
+                           {Builder.CreatePointerCast(Src, VoidPtrTy), Ldm});
+
+    // Save returned values.
+    for (unsigned i = 0; i < NumResults; ++i) {
+      Builder.CreateAlignedStore(
+          Builder.CreateBitCast(Builder.CreateExtractValue(Result, i),
+                                Dst.getElementType()),
+          Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)),
+          CharUnits::fromQuantity(4));
+    }
+    return Result;
+  }
+
+  case NVPTX::BI__hmma_m16n16k16_st_c_f16:
+  case NVPTX::BI__hmma_m16n16k16_st_c_f32: {
+    Value *Dst = EmitScalarExpr(E->getArg(0));
+    Address Src = EmitPointerWithAlignment(E->getArg(1));
+    Value *Ldm = EmitScalarExpr(E->getArg(2));
+    llvm::APSInt isColMajorArg;
+    if (!E->getArg(3)->isIntegerConstantExpr(isColMajorArg, getContext()))
+      return nullptr;
+    bool isColMajor = isColMajorArg.getSExtValue();
+    unsigned IID;
+    unsigned NumResults = 8;
+    // PTX Instructions (and LLVM instrinsics) are defined for slice _d_, yet
+    // for some reason nvcc builtins use _c_.
+    switch (BuiltinID) {
+    case NVPTX::BI__hmma_m16n16k16_st_c_f16:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f16_col_stride
+                       : Intrinsic::nvvm_wmma_store_d_f16_row_stride;
+      NumResults = 4;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_st_c_f32:
+      IID = isColMajor ? Intrinsic::nvvm_wmma_store_d_f32_col_stride
+                       : Intrinsic::nvvm_wmma_store_d_f32_row_stride;
+      break;
+    default:
+      llvm_unreachable("Unexpected builtin ID.");
+    }
+    Function *Intrinsic = CGM.getIntrinsic(IID);
+    llvm::Type *ParamType = Intrinsic->getFunctionType()->getParamType(1);
+    SmallVector<Value *, 10> Values;
+    Values.push_back(Builder.CreatePointerCast(Dst, VoidPtrTy));
+    for (unsigned i = 0; i < NumResults; ++i) {
+      Value *V = Builder.CreateAlignedLoad(
+          Builder.CreateGEP(Src.getPointer(), llvm::ConstantInt::get(IntTy, i)),
+          CharUnits::fromQuantity(4));
+      Values.push_back(Builder.CreateBitCast(V, ParamType));
+    }
+    Values.push_back(Ldm);
+    Value *Result = Builder.CreateCall(Intrinsic, Values);
+    return Result;
+  }
+
+  // BI__hmma_m16n16k16_mma_<Dtype><CType>(d, a, b, c, layout, satf)
+  //  --> Intrinsic::nvvm_wmma_mma_sync<layout A,B><DType><CType><Satf>
+  case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
+  case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
+  case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
+  case NVPTX::BI__hmma_m16n16k16_mma_f16f32: {
+    Address Dst = EmitPointerWithAlignment(E->getArg(0));
+    Address SrcA = EmitPointerWithAlignment(E->getArg(1));
+    Address SrcB = EmitPointerWithAlignment(E->getArg(2));
+    Address SrcC = EmitPointerWithAlignment(E->getArg(3));
+    llvm::APSInt LayoutArg;
+    if (!E->getArg(4)->isIntegerConstantExpr(LayoutArg, getContext()))
+      return nullptr;
+    int Layout = LayoutArg.getSExtValue();
+    if (Layout < 0 || Layout > 3)
+      return nullptr;
+    llvm::APSInt SatfArg;
+    if (!E->getArg(5)->isIntegerConstantExpr(SatfArg, getContext()))
+      return nullptr;
+    bool Satf = SatfArg.getSExtValue();
+
+    // clang-format off
+#define MMA_VARIANTS(type) {{                                   \
+      Intrinsic::nvvm_wmma_mma_sync_row_row_##type,             \
+      Intrinsic::nvvm_wmma_mma_sync_row_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_mma_sync_row_col_##type,             \
+      Intrinsic::nvvm_wmma_mma_sync_row_col_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_mma_sync_col_row_##type,             \
+      Intrinsic::nvvm_wmma_mma_sync_col_row_##type##_satfinite, \
+      Intrinsic::nvvm_wmma_mma_sync_col_col_##type,             \
+      Intrinsic::nvvm_wmma_mma_sync_col_col_##type##_satfinite  \
+    }}
+    // clang-format on
+
+    auto getMMAIntrinsic = [Layout, Satf](std::array<unsigned, 8> Variants) {
+      unsigned Index = Layout * 2 + Satf;
+      assert(Index < 8);
+      return Variants[Index];
+    };
+    unsigned IID;
+    unsigned NumEltsC;
+    unsigned NumEltsD;
+    switch (BuiltinID) {
+    case NVPTX::BI__hmma_m16n16k16_mma_f16f16:
+      IID = getMMAIntrinsic(MMA_VARIANTS(f16_f16));
+      NumEltsC = 4;
+      NumEltsD = 4;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_mma_f32f16:
+      IID = getMMAIntrinsic(MMA_VARIANTS(f32_f16));
+      NumEltsC = 4;
+      NumEltsD = 8;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_mma_f16f32:
+      IID = getMMAIntrinsic(MMA_VARIANTS(f16_f32));
+      NumEltsC = 8;
+      NumEltsD = 4;
+      break;
+    case NVPTX::BI__hmma_m16n16k16_mma_f32f32:
+      IID = getMMAIntrinsic(MMA_VARIANTS(f32_f32));
+      NumEltsC = 8;
+      NumEltsD = 8;
+      break;
+    default:
+      llvm_unreachable("Unexpected builtin ID.");
+    }
+#undef MMA_VARIANTS
+
+    SmallVector<Value *, 24> Values;
+    Function *Intrinsic = CGM.getIntrinsic(IID);
+    llvm::Type *ABType = Intrinsic->getFunctionType()->getParamType(0);
+    // Load A
+    for (unsigned i = 0; i < 8; ++i) {
+      Value *V = Builder.CreateAlignedLoad(
+          Builder.CreateGEP(SrcA.getPointer(),
+                            llvm::ConstantInt::get(IntTy, i)),
+          CharUnits::fromQuantity(4));
+      Values.push_back(Builder.CreateBitCast(V, ABType));
+    }
+    // Load B
+    for (unsigned i = 0; i < 8; ++i) {
+      Value *V = Builder.CreateAlignedLoad(
+          Builder.CreateGEP(SrcB.getPointer(),
+                            llvm::ConstantInt::get(IntTy, i)),
+          CharUnits::fromQuantity(4));
+      Values.push_back(Builder.CreateBitCast(V, ABType));
+    }
+    // Load C
+    llvm::Type *CType = Intrinsic->getFunctionType()->getParamType(16);
+    for (unsigned i = 0; i < NumEltsC; ++i) {
+      Value *V = Builder.CreateAlignedLoad(
+          Builder.CreateGEP(SrcC.getPointer(),
+                            llvm::ConstantInt::get(IntTy, i)),
+          CharUnits::fromQuantity(4));
+      Values.push_back(Builder.CreateBitCast(V, CType));
+    }
+    Value *Result = Builder.CreateCall(Intrinsic, Values);
+    llvm::Type *DType = Dst.getElementType();
+    for (unsigned i = 0; i < NumEltsD; ++i)
+      Builder.CreateAlignedStore(
+          Builder.CreateBitCast(Builder.CreateExtractValue(Result, i), DType),
+          Builder.CreateGEP(Dst.getPointer(), llvm::ConstantInt::get(IntTy, i)),
+          CharUnits::fromQuantity(4));
+    return Result;
+  }
   default:
     return nullptr;
   }