[InstCombine] add helper function for icmp+zext/sext; NFC

llvm-svn: 369421
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
index dc6dcc3..b318d0b 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCompares.cpp
@@ -4026,6 +4026,87 @@
   return nullptr;
 }
 
+static Instruction *foldICmpWithZextOrSext(ICmpInst &ICmp) {
+  assert(isa<CastInst>(ICmp.getOperand(0)) && "Expected cast for operand 0");
+  auto *CastOp0 = cast<CastInst>(ICmp.getOperand(0));
+  Value *X;
+  if (!match(CastOp0, m_ZExtOrSExt(m_Value(X))))
+    return nullptr;
+
+  bool IsSignedExt = CastOp0->getOpcode() == Instruction::SExt;
+  bool IsSignedCmp = ICmp.isSigned();
+  if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) {
+    // If the signedness of the two casts doesn't agree (i.e. one is a sext
+    // and the other is a zext), then we can't handle this.
+    if (CastOp0->getOpcode() != CastOp1->getOpcode())
+      return nullptr;
+
+    // Not an extension from the same type?
+    // TODO: Handle this by extending the narrower operand to the type of
+    //       the wider operand.
+    Value *Y = CastOp1->getOperand(0);
+    if (X->getType() != Y->getType())
+      return nullptr;
+
+    // (zext X) == (zext Y) --> X == Y
+    // (sext X) == (sext Y) --> X == Y
+    if (ICmp.isEquality())
+      return new ICmpInst(ICmp.getPredicate(), X, Y);
+
+    // A signed comparison of sign extended values simplifies into a
+    // signed comparison.
+    if (IsSignedCmp && IsSignedExt)
+      return new ICmpInst(ICmp.getPredicate(), X, Y);
+
+    // The other three cases all fold into an unsigned comparison.
+    return new ICmpInst(ICmp.getUnsignedPredicate(), X, Y);
+  }
+
+  // Below here, we are only folding a compare with constant.
+  auto *C = dyn_cast<Constant>(ICmp.getOperand(1));
+  if (!C)
+    return nullptr;
+
+  // Compute the constant that would happen if we truncated to SrcTy then
+  // re-extended to DestTy.
+  Type *SrcTy = CastOp0->getSrcTy();
+  Type *DestTy = CastOp0->getDestTy();
+  Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
+  Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
+
+  // If the re-extended constant didn't change...
+  if (Res2 == C) {
+    if (ICmp.isEquality())
+      return new ICmpInst(ICmp.getPredicate(), X, Res1);
+
+    // A signed comparison of sign extended values simplifies into a
+    // signed comparison.
+    if (IsSignedExt && IsSignedCmp)
+      return new ICmpInst(ICmp.getPredicate(), X, Res1);
+
+    // The other three cases all fold into an unsigned comparison.
+    return new ICmpInst(ICmp.getUnsignedPredicate(), X, Res1);
+  }
+
+  // The re-extended constant changed, partly changed (in the case of a vector),
+  // or could not be determined to be equal (in the case of a constant
+  // expression), so the constant cannot be represented in the shorter type.
+  // All the cases that fold to true or false will have already been handled
+  // by SimplifyICmpInst, so only deal with the tricky case.
+  if (IsSignedCmp || !IsSignedExt || !isa<ConstantInt>(C))
+    return nullptr;
+
+  // Is source op positive?
+  // icmp ult (sext X), C --> icmp sgt X, -1
+  if (ICmp.getPredicate() == ICmpInst::ICMP_ULT)
+    return new ICmpInst(CmpInst::ICMP_SGT, X, Constant::getAllOnesValue(SrcTy));
+
+  // Is source op negative?
+  // icmp ugt (sext X), C --> icmp slt X, 0
+  assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!");
+  return new ICmpInst(CmpInst::ICMP_SLT, X, Constant::getNullValue(SrcTy));
+}
+
 /// Handle icmp (cast x), (cast or constant).
 Instruction *InstCombiner::foldICmpWithCastOp(ICmpInst &ICmp) {
   auto *CastOp0 = dyn_cast<CastInst>(ICmp.getOperand(0));
@@ -4067,82 +4148,7 @@
       return new ICmpInst(ICmp.getPredicate(), Op0Src, NewOp1);
   }
 
-  // The code below only handles extension cast instructions, so far.
-  // Enforce this.
-  if (CastOp0->getOpcode() != Instruction::ZExt &&
-      CastOp0->getOpcode() != Instruction::SExt)
-    return nullptr;
-
-  bool isSignedExt = CastOp0->getOpcode() == Instruction::SExt;
-  bool isSignedCmp = ICmp.isSigned();
-
-  if (auto *CastOp1 = dyn_cast<CastInst>(ICmp.getOperand(1))) {
-    // Not an extension from the same type?
-    Value *Op1Src = CastOp1->getOperand(0);
-    if (Op1Src->getType() != Op0Src->getType())
-      return nullptr;
-
-    // If the signedness of the two casts doesn't agree (i.e. one is a sext
-    // and the other is a zext), then we can't handle this.
-    if (CastOp1->getOpcode() != CastOp0->getOpcode())
-      return nullptr;
-
-    // Deal with equality cases early.
-    if (ICmp.isEquality())
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Op1Src);
-
-    // A signed comparison of sign extended values simplifies into a
-    // signed comparison.
-    if (isSignedCmp && isSignedExt)
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Op1Src);
-
-    // The other three cases all fold into an unsigned comparison.
-    return new ICmpInst(ICmp.getUnsignedPredicate(), Op0Src, Op1Src);
-  }
-
-  // If we aren't dealing with a constant on the RHS, exit early.
-  auto *C = dyn_cast<Constant>(ICmp.getOperand(1));
-  if (!C)
-    return nullptr;
-
-  // Compute the constant that would happen if we truncated to SrcTy then
-  // re-extended to DestTy.
-  Constant *Res1 = ConstantExpr::getTrunc(C, SrcTy);
-  Constant *Res2 = ConstantExpr::getCast(CastOp0->getOpcode(), Res1, DestTy);
-
-  // If the re-extended constant didn't change...
-  if (Res2 == C) {
-    // Deal with equality cases early.
-    if (ICmp.isEquality())
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Res1);
-
-    // A signed comparison of sign extended values simplifies into a
-    // signed comparison.
-    if (isSignedExt && isSignedCmp)
-      return new ICmpInst(ICmp.getPredicate(), Op0Src, Res1);
-
-    // The other three cases all fold into an unsigned comparison.
-    return new ICmpInst(ICmp.getUnsignedPredicate(), Op0Src, Res1);
-  }
-
-  // The re-extended constant changed, partly changed (in the case of a vector),
-  // or could not be determined to be equal (in the case of a constant
-  // expression), so the constant cannot be represented in the shorter type.
-  // All the cases that fold to true or false will have already been handled
-  // by SimplifyICmpInst, so only deal with the tricky case.
-  if (isSignedCmp || !isSignedExt || !isa<ConstantInt>(C))
-    return nullptr;
-
-  // Is source op positive?
-  // icmp ult (sext X), C --> icmp sgt X, -1
-  if (ICmp.getPredicate() == ICmpInst::ICMP_ULT)
-    return new ICmpInst(CmpInst::ICMP_SGT, Op0Src,
-                        Constant::getAllOnesValue(SrcTy));
-
-  // Is source op negative?
-  // icmp ugt (sext X), C --> icmp slt X, 0
-  assert(ICmp.getPredicate() == ICmpInst::ICMP_UGT && "ICmp should be folded!");
-  return new ICmpInst(CmpInst::ICMP_SLT, Op0Src, Constant::getNullValue(SrcTy));
+  return foldICmpWithZextOrSext(ICmp);
 }
 
 static bool isNeutralValue(Instruction::BinaryOps BinaryOp, Value *RHS) {