[InstCombine] Replacing X86-specific rounding intrinsics with generic floor-ceil

This patch replaces calls to X86-specific intrinsics with floor-ceil semantics
with calls to target-independent @llvm.floor.* and @llvm.ceil.* intrinsics. This
doesn't affect the resulting machine code, as those intrinsics are lowered to
the same instructions, but exposes these specific rounding cases to generic
optimizations.

Differential Revision: https://reviews.llvm.org/D48067

llvm-svn: 335039
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 9eb8d5d..9e046c9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -576,6 +576,105 @@
   return ConstantVector::get(Vals);
 }
 
+// Replace X86-specific intrinsics with generic floor-ceil where applicable.
+static Value *simplifyX86round(IntrinsicInst &II,
+                               InstCombiner::BuilderTy &Builder) {
+  ConstantInt *Arg = nullptr;
+  Intrinsic::ID IntrinsicID = II.getIntrinsicID();
+
+  if (IntrinsicID == Intrinsic::x86_sse41_round_ss ||
+      IntrinsicID == Intrinsic::x86_sse41_round_sd)
+    Arg = dyn_cast<ConstantInt>(II.getArgOperand(2));
+  else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+           IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd)
+    Arg = dyn_cast<ConstantInt>(II.getArgOperand(4));
+  else
+    Arg = dyn_cast<ConstantInt>(II.getArgOperand(1));
+  if (!Arg)
+    return nullptr;
+  unsigned RoundControl = Arg->getZExtValue();
+
+  Arg = nullptr;
+  unsigned SAE = 0;
+  if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 ||
+      IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512)
+    Arg = dyn_cast<ConstantInt>(II.getArgOperand(4));
+  else if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+           IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd)
+    Arg = dyn_cast<ConstantInt>(II.getArgOperand(5));
+  else
+    SAE = 4;
+  if (!SAE) {
+    if (!Arg)
+      return nullptr;
+    SAE = Arg->getZExtValue();
+  }
+
+  if (SAE != 4 || (RoundControl != 2 /*ceil*/ && RoundControl != 1 /*floor*/))
+    return nullptr;
+
+  Value *Src, *Dst, *Mask;
+  bool IsScalar = false;
+  if (IntrinsicID == Intrinsic::x86_sse41_round_ss ||
+      IntrinsicID == Intrinsic::x86_sse41_round_sd ||
+      IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+      IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) {
+    IsScalar = true;
+    if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+        IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) {
+      Mask = II.getArgOperand(3);
+      Value *Zero = Constant::getNullValue(Mask->getType());
+      Mask = Builder.CreateAnd(Mask, 1);
+      Mask = Builder.CreateICmp(ICmpInst::ICMP_NE, Mask, Zero);
+      Dst = II.getArgOperand(2);
+    } else
+      Dst = II.getArgOperand(0);
+    Src = Builder.CreateExtractElement(II.getArgOperand(1), (uint64_t)0);
+  } else {
+    Src = II.getArgOperand(0);
+    if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_128 ||
+        IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_256 ||
+        IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ps_512 ||
+        IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_128 ||
+        IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_256 ||
+        IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_pd_512) {
+      Dst = II.getArgOperand(2);
+      Mask = II.getArgOperand(3);
+    } else {
+      Dst = Src;
+      Mask = ConstantInt::getAllOnesValue(
+          Builder.getIntNTy(Src->getType()->getVectorNumElements()));
+    }
+  }
+
+  Intrinsic::ID ID = (RoundControl == 2) ? Intrinsic::ceil : Intrinsic::floor;
+  Value *Res = Builder.CreateIntrinsic(ID, {Src}, &II);
+  if (!IsScalar) {
+    if (auto *C = dyn_cast<Constant>(Mask))
+      if (C->isAllOnesValue())
+        return Res;
+    auto *MaskTy = VectorType::get(
+        Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth());
+    Mask = Builder.CreateBitCast(Mask, MaskTy);
+    unsigned Width = Src->getType()->getVectorNumElements();
+    if (MaskTy->getVectorNumElements() > Width) {
+      uint32_t Indices[4];
+      for (unsigned i = 0; i != Width; ++i)
+        Indices[i] = i;
+      Mask = Builder.CreateShuffleVector(Mask, Mask,
+                                         makeArrayRef(Indices, Width));
+    }
+    return Builder.CreateSelect(Mask, Res, Dst);
+  }
+  if (IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_ss ||
+      IntrinsicID == Intrinsic::x86_avx512_mask_rndscale_sd) {
+    Dst = Builder.CreateExtractElement(Dst, (uint64_t)0);
+    Res = Builder.CreateSelect(Mask, Res, Dst);
+    Dst = II.getArgOperand(0);
+  }
+  return Builder.CreateInsertElement(Dst, Res, (uint64_t)0);
+}
+
 static Value *simplifyX86movmsk(const IntrinsicInst &II) {
   Value *Arg = II.getArgOperand(0);
   Type *ResTy = II.getType();
@@ -2222,6 +2321,22 @@
     break;
   }
 
+  case Intrinsic::x86_sse41_round_ps:
+  case Intrinsic::x86_sse41_round_pd:
+  case Intrinsic::x86_avx_round_ps_256:
+  case Intrinsic::x86_avx_round_pd_256:
+  case Intrinsic::x86_avx512_mask_rndscale_ps_128:
+  case Intrinsic::x86_avx512_mask_rndscale_ps_256:
+  case Intrinsic::x86_avx512_mask_rndscale_ps_512:
+  case Intrinsic::x86_avx512_mask_rndscale_pd_128:
+  case Intrinsic::x86_avx512_mask_rndscale_pd_256:
+  case Intrinsic::x86_avx512_mask_rndscale_pd_512:
+  case Intrinsic::x86_avx512_mask_rndscale_ss:
+  case Intrinsic::x86_avx512_mask_rndscale_sd:
+    if (Value *V = simplifyX86round(*II, Builder))
+      return replaceInstUsesWith(*II, V);
+    break;
+
   case Intrinsic::x86_mmx_pmovmskb:
   case Intrinsic::x86_sse_movmsk_ps:
   case Intrinsic::x86_sse2_movmsk_pd:
@@ -2438,8 +2553,6 @@
   case Intrinsic::x86_sse2_cmp_sd:
   case Intrinsic::x86_sse2_min_sd:
   case Intrinsic::x86_sse2_max_sd:
-  case Intrinsic::x86_sse41_round_ss:
-  case Intrinsic::x86_sse41_round_sd:
   case Intrinsic::x86_xop_vfrcz_ss:
   case Intrinsic::x86_xop_vfrcz_sd: {
    unsigned VWidth = II->getType()->getVectorNumElements();
@@ -2452,6 +2565,19 @@
    }
    break;
   }
+  case Intrinsic::x86_sse41_round_ss:
+  case Intrinsic::x86_sse41_round_sd: {
+    unsigned VWidth = II->getType()->getVectorNumElements();
+    APInt UndefElts(VWidth, 0);
+    APInt AllOnesEltMask(APInt::getAllOnesValue(VWidth));
+    if (Value *V = SimplifyDemandedVectorElts(II, AllOnesEltMask, UndefElts)) {
+      if (V != II)
+        return replaceInstUsesWith(*II, V);
+      return II;
+    } else if (Value *V = simplifyX86round(*II, Builder))
+      return replaceInstUsesWith(*II, V);
+    break;
+  }
 
   // Constant fold ashr( <A x Bi>, Ci ).
   // Constant fold lshr( <A x Bi>, Ci ).