[X86] Constant folding of adds/subs intrinsics

Summary: This adds constant folding of signed add/sub with saturation intrinsics.

Reviewers: craig.topper, spatel, RKSimon, chandlerc, efriedma

Reviewed By: craig.topper

Subscribers: rnk, llvm-commits

Differential Revision: https://reviews.llvm.org/D50499

llvm-svn: 339659
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index f0c4ceb..38d6f29 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -243,6 +243,86 @@
   return nullptr;
 }
 
+static Value *simplifyX86AddsSubs(const IntrinsicInst &II,
+                                  InstCombiner::BuilderTy &Builder) {
+  bool IsAddition = false;
+  bool IsMasked = false;
+
+  switch (II.getIntrinsicID()) {
+  default: llvm_unreachable("Unexpected intrinsic!");
+  case Intrinsic::x86_sse2_padds_b:
+  case Intrinsic::x86_sse2_padds_w:
+  case Intrinsic::x86_avx2_padds_b:
+  case Intrinsic::x86_avx2_padds_w:
+    IsAddition = true; IsMasked = false;
+    break;
+  case Intrinsic::x86_sse2_psubs_b:
+  case Intrinsic::x86_sse2_psubs_w:
+  case Intrinsic::x86_avx2_psubs_b:
+  case Intrinsic::x86_avx2_psubs_w:
+    IsAddition = false; IsMasked = false;
+    break;
+  case Intrinsic::x86_avx512_mask_padds_b_512:
+  case Intrinsic::x86_avx512_mask_padds_w_512:
+    IsAddition = true; IsMasked = true;
+    break;
+  case Intrinsic::x86_avx512_mask_psubs_b_512:
+  case Intrinsic::x86_avx512_mask_psubs_w_512:
+    IsAddition = false; IsMasked = true;
+    break;
+  }
+
+  auto *Arg0 = dyn_cast<Constant>(II.getOperand(0));
+  auto *Arg1 = dyn_cast<Constant>(II.getOperand(1));
+  auto VT = cast<VectorType>(II.getType());
+  auto SVT = VT->getElementType();
+  unsigned NumElems = VT->getNumElements();
+
+  if (!Arg0 || !Arg1 || (IsMasked && !isa<Constant>(II.getOperand(2))))
+    return nullptr;
+
+  SmallVector<Constant *, 64> Result;
+
+  APInt MaxValue = APInt::getSignedMaxValue(SVT->getIntegerBitWidth());
+  APInt MinValue = APInt::getSignedMinValue(SVT->getIntegerBitWidth());
+  for (unsigned i = 0; i < NumElems; ++i) {
+    auto *Elt0 = Arg0->getAggregateElement(i);
+    auto *Elt1 = Arg1->getAggregateElement(i);
+    if (isa<UndefValue>(Elt0) || isa<UndefValue>(Elt1)) {
+      Result.push_back(UndefValue::get(SVT));
+      continue;
+    }
+
+    if (!isa<ConstantInt>(Elt0) || !isa<ConstantInt>(Elt1))
+      return nullptr;
+
+    const APInt &Val0 = cast<ConstantInt>(Elt0)->getValue();
+    const APInt &Val1 = cast<ConstantInt>(Elt1)->getValue();
+    bool Overflow = false;
+    APInt ResultElem = IsAddition ? Val0.sadd_ov(Val1, Overflow)
+                                  : Val0.ssub_ov(Val1, Overflow);
+    if (Overflow)
+      ResultElem = Val0.isNegative() ? MinValue : MaxValue;
+    Result.push_back(Constant::getIntegerValue(SVT, ResultElem));
+  }
+
+  Value *ResultVec = ConstantVector::get(Result);
+
+  if (II.getNumArgOperands() == 4) { // For masked intrinsics.
+    Value *Src = II.getOperand(2);
+    auto Mask = II.getOperand(3);
+    if (auto *C = dyn_cast<Constant>(Mask))
+      if (C->isAllOnesValue())
+        return ResultVec;
+    auto *MaskTy = VectorType::get(
+        Builder.getInt1Ty(), cast<IntegerType>(Mask->getType())->getBitWidth());
+    Mask = Builder.CreateBitCast(Mask, MaskTy);
+    ResultVec = Builder.CreateSelect(Mask, ResultVec, Src);
+  }
+
+  return ResultVec;
+}
+
 static Value *simplifyX86immShift(const IntrinsicInst &II,
                                   InstCombiner::BuilderTy &Builder) {
   bool LogicalShift = false;
@@ -2525,6 +2605,24 @@
     break;
   }
 
+  // Constant fold add/sub with saturation intrinsics.
+  case Intrinsic::x86_sse2_padds_b:
+  case Intrinsic::x86_sse2_padds_w:
+  case Intrinsic::x86_sse2_psubs_b:
+  case Intrinsic::x86_sse2_psubs_w:
+  case Intrinsic::x86_avx2_padds_b:
+  case Intrinsic::x86_avx2_padds_w:
+  case Intrinsic::x86_avx2_psubs_b:
+  case Intrinsic::x86_avx2_psubs_w:
+  case Intrinsic::x86_avx512_mask_padds_b_512:
+  case Intrinsic::x86_avx512_mask_padds_w_512:
+  case Intrinsic::x86_avx512_mask_psubs_b_512:
+  case Intrinsic::x86_avx512_mask_psubs_w_512:
+    if (Value *V = simplifyX86AddsSubs(*II, Builder))
+      return replaceInstUsesWith(*II, V);
+    break;
+    
+
   // Constant fold ashr( <A x Bi>, Ci ).
   // Constant fold lshr( <A x Bi>, Ci ).
   // Constant fold shl( <A x Bi>, Ci ).