[X86] Add back _mask, _maskz, and _mask3 builtins for some 512-bit fmadd/fmsub/fmaddsub/fmsubadd builtins.

Summary:
We recently switch to using a selects in the intrinsics header files for FMA instructions. But the 512-bit versions support flavors with rounding mode which must be an Integer Constant Expression. This has forced those intrinsics to be implemented as macros. As it stands now the mask and mask3 intrinsics evaluate one of their macro arguments twice. If that argument itself is another intrinsic macro, we can end up over expanding macros. Or if its something we can CSE later it would show up multiple times when it shouldn't.

I tried adding __extension__ around the macro and making it an expression statement and declaring a local variable. But whatever name you choose for the local variable can never be used as the name of an input to the macro in user code. If that happens you would end up with the same name on the LHS and RHS of an assignment after expansion. We might be safe if we use __ in front of the variable names because those names are reserved and user code shouldn't use that, but I wasn't sure I wanted to make that claim.

The other option which I've chosen here, is to add back _mask, _maskz, and _mask3 flavors of the builtin which we will expand in CGBuiltin.cpp to replicate the argument as needed and insert any fneg needed on the third operand to make a subtract. The _maskz isn't truly necessary if we have an unmasked version or if we use the masked version with a -1 mask and wrap a select around it. But I've chosen to make things more uniform.

I separated out the scalar builtin handling to avoid too many things going on in EmitX86FMAExpr. It was different enough due to the extract and insert that the minor duplication of the CreateCall was probably worth it.

Reviewers: tkrupa, RKSimon, spatel, GBuella

Reviewed By: tkrupa

Subscribers: cfe-commits

Differential Revision: https://reviews.llvm.org/D47724

llvm-svn: 334159
diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp
index a09fa7a..a086b4d 100644
--- a/clang/lib/CodeGen/CGBuiltin.cpp
+++ b/clang/lib/CodeGen/CGBuiltin.cpp
@@ -8555,79 +8555,110 @@
 
 // Lowers X86 FMA intrinsics to IR.
 static Value *EmitX86FMAExpr(CodeGenFunction &CGF, ArrayRef<Value *> Ops,
-                             unsigned BuiltinID) {
+                             unsigned BuiltinID, bool IsAddSub) {
 
-  bool IsAddSub = false;
-  bool IsScalar = false;
-
-  // 4 operands always means rounding mode without a mask here.
-  bool IsRound = Ops.size() == 4;
-
-  Intrinsic::ID ID;
+  bool Subtract = false;
+  Intrinsic::ID IID = Intrinsic::not_intrinsic;
   switch (BuiltinID) {
   default: break;
-  case clang::X86::BI__builtin_ia32_vfmaddss3: IsScalar = true; break;
-  case clang::X86::BI__builtin_ia32_vfmaddsd3: IsScalar = true; break;
-  case clang::X86::BI__builtin_ia32_vfmaddps512:
-    ID = llvm::Intrinsic::x86_avx512_vfmadd_ps_512; break;
-  case clang::X86::BI__builtin_ia32_vfmaddpd512:
-    ID = llvm::Intrinsic::x86_avx512_vfmadd_pd_512; break;
-  case clang::X86::BI__builtin_ia32_vfmaddsubps: IsAddSub = true; break;
-  case clang::X86::BI__builtin_ia32_vfmaddsubpd: IsAddSub = true; break;
-  case clang::X86::BI__builtin_ia32_vfmaddsubps256: IsAddSub = true; break;
-  case clang::X86::BI__builtin_ia32_vfmaddsubpd256: IsAddSub = true; break;
-  case clang::X86::BI__builtin_ia32_vfmaddsubps512: {
-    ID = llvm::Intrinsic::x86_avx512_vfmaddsub_ps_512;
-    IsAddSub = true;
+  case clang::X86::BI__builtin_ia32_vfmsubps512_mask3:
+    Subtract = true;
+    LLVM_FALLTHROUGH;
+  case clang::X86::BI__builtin_ia32_vfmaddps512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddps512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddps512_mask3:
+    IID = llvm::Intrinsic::x86_avx512_vfmadd_ps_512; break;
+  case clang::X86::BI__builtin_ia32_vfmsubpd512_mask3:
+    Subtract = true;
+    LLVM_FALLTHROUGH;
+  case clang::X86::BI__builtin_ia32_vfmaddpd512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddpd512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddpd512_mask3:
+    IID = llvm::Intrinsic::x86_avx512_vfmadd_pd_512; break;
+  case clang::X86::BI__builtin_ia32_vfmsubaddps512_mask3:
+    Subtract = true;
+    LLVM_FALLTHROUGH;
+  case clang::X86::BI__builtin_ia32_vfmaddsubps512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddsubps512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddsubps512_mask3:
+    IID = llvm::Intrinsic::x86_avx512_vfmaddsub_ps_512;
     break;
-  }
-  case clang::X86::BI__builtin_ia32_vfmaddsubpd512: {
-    ID = llvm::Intrinsic::x86_avx512_vfmaddsub_pd_512;
-    IsAddSub = true;
+  case clang::X86::BI__builtin_ia32_vfmsubaddpd512_mask3:
+    Subtract = true;
+    LLVM_FALLTHROUGH;
+  case clang::X86::BI__builtin_ia32_vfmaddsubpd512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddsubpd512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddsubpd512_mask3:
+    IID = llvm::Intrinsic::x86_avx512_vfmaddsub_pd_512;
     break;
   }
-  }
-
-  // Only handle in case of _MM_FROUND_CUR_DIRECTION/4 (no rounding).
-  if (IsRound) {
-    Function *Intr = CGF.CGM.getIntrinsic(ID);
-    if (cast<llvm::ConstantInt>(Ops[3])->getZExtValue() != (uint64_t)4)
-      return CGF.Builder.CreateCall(Intr, Ops);
-  }
 
   Value *A = Ops[0];
   Value *B = Ops[1];
   Value *C = Ops[2];
 
-  if (IsScalar) {
-    A = CGF.Builder.CreateExtractElement(A, (uint64_t)0);
-    B = CGF.Builder.CreateExtractElement(B, (uint64_t)0);
-    C = CGF.Builder.CreateExtractElement(C, (uint64_t)0);
-  }
+  if (Subtract)
+    C = CGF.Builder.CreateFNeg(C);
 
-  llvm::Type *Ty = A->getType();
-  Function *FMA = CGF.CGM.getIntrinsic(Intrinsic::fma, Ty);
-  Value *Res = CGF.Builder.CreateCall(FMA, {A, B, C} );
+  Value *Res;
 
-  if (IsScalar)
-    return CGF.Builder.CreateInsertElement(Ops[0], Res, (uint64_t)0);
+  // Only handle in case of _MM_FROUND_CUR_DIRECTION/4 (no rounding).
+  if (IID != Intrinsic::not_intrinsic &&
+      cast<llvm::ConstantInt>(Ops.back())->getZExtValue() != (uint64_t)4) {
+    Function *Intr = CGF.CGM.getIntrinsic(IID);
+    Res = CGF.Builder.CreateCall(Intr, {A, B, C, Ops.back() });
+  } else {
+    llvm::Type *Ty = A->getType();
+    Function *FMA = CGF.CGM.getIntrinsic(Intrinsic::fma, Ty);
+    Res = CGF.Builder.CreateCall(FMA, {A, B, C} );
 
-  if (IsAddSub) {
-    // Negate even elts in C using a mask.
-    unsigned NumElts = Ty->getVectorNumElements();
-    SmallVector<Constant *, 16> NMask;
-    Constant *Zero = ConstantInt::get(CGF.Builder.getInt1Ty(), 0);
-    Constant *One = ConstantInt::get(CGF.Builder.getInt1Ty(), 1);
-    for (unsigned i = 0; i < NumElts; ++i) {
-      NMask.push_back(i % 2 == 0 ? One : Zero);
+    if (IsAddSub) {
+      // Negate even elts in C using a mask.
+      unsigned NumElts = Ty->getVectorNumElements();
+      SmallVector<Constant *, 16> NMask;
+      Constant *Zero = ConstantInt::get(CGF.Builder.getInt1Ty(), 0);
+      Constant *One = ConstantInt::get(CGF.Builder.getInt1Ty(), 1);
+      for (unsigned i = 0; i < NumElts; ++i) {
+        NMask.push_back(i % 2 == 0 ? One : Zero);
+      }
+      Value *NegMask = ConstantVector::get(NMask);
+
+      Value *NegC = CGF.Builder.CreateFNeg(C);
+      Value *FMSub = CGF.Builder.CreateCall(FMA, {A, B, NegC} );
+      Res = CGF.Builder.CreateSelect(NegMask, FMSub, Res);
     }
-    Value *NegMask = ConstantVector::get(NMask);
-
-    Value *NegC = CGF.Builder.CreateFNeg(C);
-    Value *FMSub = CGF.Builder.CreateCall(FMA, {A, B, NegC} );
-    Res = CGF.Builder.CreateSelect(NegMask, FMSub, Res);
   }
 
+  // Handle any required masking.
+  Value *MaskFalseVal = nullptr;
+  switch (BuiltinID) {
+  case clang::X86::BI__builtin_ia32_vfmaddps512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddpd512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddsubps512_mask:
+  case clang::X86::BI__builtin_ia32_vfmaddsubpd512_mask:
+    MaskFalseVal = Ops[0];
+    break;
+  case clang::X86::BI__builtin_ia32_vfmaddps512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddpd512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddsubps512_maskz:
+  case clang::X86::BI__builtin_ia32_vfmaddsubpd512_maskz:
+    MaskFalseVal = Constant::getNullValue(Ops[0]->getType());
+    break;
+  case clang::X86::BI__builtin_ia32_vfmsubps512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmaddps512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmsubpd512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmaddpd512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmsubaddps512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmaddsubps512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmsubaddpd512_mask3:
+  case clang::X86::BI__builtin_ia32_vfmaddsubpd512_mask3:
+    MaskFalseVal = Ops[2];
+    break;
+  }
+
+  if (MaskFalseVal)
+    return EmitX86Select(CGF, Ops[3], Res, MaskFalseVal);
+
   return Res;
 }
 
@@ -9046,20 +9077,40 @@
     return EmitX86ConvertToMask(*this, Ops[0]);
 
   case X86::BI__builtin_ia32_vfmaddss3:
-  case X86::BI__builtin_ia32_vfmaddsd3:
+  case X86::BI__builtin_ia32_vfmaddsd3: {
+    Value *A = Builder.CreateExtractElement(Ops[0], (uint64_t)0);
+    Value *B = Builder.CreateExtractElement(Ops[1], (uint64_t)0);
+    Value *C = Builder.CreateExtractElement(Ops[2], (uint64_t)0);
+    Function *FMA = CGM.getIntrinsic(Intrinsic::fma, A->getType());
+    Value *Res = Builder.CreateCall(FMA, {A, B, C} );
+    return Builder.CreateInsertElement(Ops[0], Res, (uint64_t)0);
+  }
   case X86::BI__builtin_ia32_vfmaddps:
   case X86::BI__builtin_ia32_vfmaddpd:
   case X86::BI__builtin_ia32_vfmaddps256:
   case X86::BI__builtin_ia32_vfmaddpd256:
-  case X86::BI__builtin_ia32_vfmaddps512:
-  case X86::BI__builtin_ia32_vfmaddpd512:
+  case X86::BI__builtin_ia32_vfmaddps512_mask:
+  case X86::BI__builtin_ia32_vfmaddps512_maskz:
+  case X86::BI__builtin_ia32_vfmaddps512_mask3:
+  case X86::BI__builtin_ia32_vfmsubps512_mask3:
+  case X86::BI__builtin_ia32_vfmaddpd512_mask:
+  case X86::BI__builtin_ia32_vfmaddpd512_maskz:
+  case X86::BI__builtin_ia32_vfmaddpd512_mask3:
+  case X86::BI__builtin_ia32_vfmsubpd512_mask3:
+    return EmitX86FMAExpr(*this, Ops, BuiltinID, /*IsAddSub*/false);
   case X86::BI__builtin_ia32_vfmaddsubps:
   case X86::BI__builtin_ia32_vfmaddsubpd:
   case X86::BI__builtin_ia32_vfmaddsubps256:
   case X86::BI__builtin_ia32_vfmaddsubpd256:
-  case X86::BI__builtin_ia32_vfmaddsubps512:
-  case X86::BI__builtin_ia32_vfmaddsubpd512:
-    return EmitX86FMAExpr(*this, Ops, BuiltinID);
+  case X86::BI__builtin_ia32_vfmaddsubps512_mask:
+  case X86::BI__builtin_ia32_vfmaddsubps512_maskz:
+  case X86::BI__builtin_ia32_vfmaddsubps512_mask3:
+  case X86::BI__builtin_ia32_vfmsubaddps512_mask3:
+  case X86::BI__builtin_ia32_vfmaddsubpd512_mask:
+  case X86::BI__builtin_ia32_vfmaddsubpd512_maskz:
+  case X86::BI__builtin_ia32_vfmaddsubpd512_mask3:
+  case X86::BI__builtin_ia32_vfmsubaddpd512_mask3:
+    return EmitX86FMAExpr(*this, Ops, BuiltinID, /*IsAddSub*/true);
 
   case X86::BI__builtin_ia32_movdqa32store128_mask:
   case X86::BI__builtin_ia32_movdqa64store128_mask: