[InstCombine] Pass a proper context instruction to all of the calls into InstSimplify

Summary: This matches the behavior we already had for compares and makes us consistent everywhere.

Reviewers: dberlin, hfinkel, spatel

Reviewed By: dberlin

Subscribers: llvm-commits

Differential Revision: https://reviews.llvm.org/D33604

llvm-svn: 305049
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
index ed6386c..287a516 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAddSub.cpp
@@ -1009,8 +1009,9 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(),
-                                 I.hasNoUnsignedWrap(), SQ))
+  if (Value *V =
+          SimplifyAddInst(LHS, RHS, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+                          SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
    // (A*B)+(A*C) -> A*(B+C) etc
@@ -1295,7 +1296,8 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(), SQ))
+  if (Value *V = SimplifyFAddInst(LHS, RHS, I.getFastMathFlags(),
+                                  SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (isa<Constant>(RHS))
@@ -1485,8 +1487,9 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(),
-                                 I.hasNoUnsignedWrap(), SQ))
+  if (Value *V =
+          SimplifySubInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+                          SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // (A*B)-(A*C) -> A*(B-C) etc
@@ -1691,7 +1694,8 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(), SQ))
+  if (Value *V = SimplifyFSubInst(Op0, Op1, I.getFastMathFlags(),
+                                  SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // fsub nsz 0, X ==> fsub nsz -0.0, X
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
index bab28c4..4fe3225 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineAndOrXor.cpp
@@ -1235,7 +1235,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyAndInst(Op0, Op1, SQ))
+  if (Value *V = SimplifyAndInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // See if we can simplify any instructions used by the instruction whose sole
@@ -1963,7 +1963,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyOrInst(Op0, Op1, SQ))
+  if (Value *V = SimplifyOrInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // See if we can simplify any instructions used by the instruction whose sole
@@ -2345,7 +2345,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyXorInst(Op0, Op1, SQ))
+  if (Value *V = SimplifyXorInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *NewXor = foldXorToXor(I))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
index 79f4a47..e9645e9 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineCalls.cpp
@@ -1814,8 +1814,8 @@
 /// lifting.
 Instruction *InstCombiner::visitCallInst(CallInst &CI) {
   auto Args = CI.arg_operands();
-  if (Value *V =
-          SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(), SQ))
+  if (Value *V = SimplifyCall(CI.getCalledValue(), Args.begin(), Args.end(),
+                              SQ.getWithInstruction(&CI)))
     return replaceInstUsesWith(CI, V);
 
   if (isFreeCall(&CI, &TLI))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
index 2fcfe46..365c4ba 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineMulDivRem.cpp
@@ -176,7 +176,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyMulInst(Op0, Op1, SQ))
+  if (Value *V = SimplifyMulInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Value *V = SimplifyUsingDistributiveLaws(I))
@@ -599,7 +599,8 @@
   if (isa<Constant>(Op0))
     std::swap(Op0, Op1);
 
-  if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(), SQ))
+  if (Value *V = SimplifyFMulInst(Op0, Op1, I.getFastMathFlags(),
+                                  SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   bool AllowReassociate = I.hasUnsafeAlgebra();
@@ -1103,7 +1104,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyUDivInst(Op0, Op1, SQ))
+  if (Value *V = SimplifyUDivInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // Handle the integer div common cases
@@ -1176,7 +1177,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifySDivInst(Op0, Op1, SQ))
+  if (Value *V = SimplifySDivInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // Handle the integer div common cases
@@ -1288,7 +1289,8 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(), SQ))
+  if (Value *V = SimplifyFDivInst(Op0, Op1, I.getFastMathFlags(),
+                                  SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (isa<Constant>(Op0))
@@ -1472,7 +1474,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyURemInst(Op0, Op1, SQ))
+  if (Value *V = SimplifyURemInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *common = commonIRemTransforms(I))
@@ -1515,7 +1517,7 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifySRemInst(Op0, Op1, SQ))
+  if (Value *V = SimplifySRemInst(Op0, Op1, SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // Handle the integer rem common cases
@@ -1588,7 +1590,8 @@
   if (Value *V = SimplifyVectorOp(I))
     return replaceInstUsesWith(I, V);
 
-  if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(), SQ))
+  if (Value *V = SimplifyFRemInst(Op0, Op1, I.getFastMathFlags(),
+                                  SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   // Handle cases involving: rem X, (select Cond, Y, Z)
diff --git a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
index 5c9daeb..5dbf1e8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombinePHI.cpp
@@ -880,7 +880,7 @@
 // PHINode simplification
 //
 Instruction *InstCombiner::visitPHINode(PHINode &PN) {
-  if (Value *V = SimplifyInstruction(&PN, SQ))
+  if (Value *V = SimplifyInstruction(&PN, SQ.getWithInstruction(&PN)))
     return replaceInstUsesWith(PN, V);
 
   if (Instruction *Result = FoldPHIArgZextsIntoPHI(PN))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
index 229d59a..b9674d8 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineSelect.cpp
@@ -1121,7 +1121,8 @@
   Value *FalseVal = SI.getFalseValue();
   Type *SelType = SI.getType();
 
-  if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal, SQ))
+  if (Value *V = SimplifySelectInst(CondVal, TrueVal, FalseVal,
+                                    SQ.getWithInstruction(&SI)))
     return replaceInstUsesWith(SI, V);
 
   if (Instruction *I = canonicalizeSelectToShuffle(SI))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
index dc0032b..3f2ddca 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineShifts.cpp
@@ -520,8 +520,9 @@
     return replaceInstUsesWith(I, V);
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-  if (Value *V = SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(),
-                                 I.hasNoUnsignedWrap(), SQ))
+  if (Value *V =
+          SimplifyShlInst(Op0, Op1, I.hasNoSignedWrap(), I.hasNoUnsignedWrap(),
+                          SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *V = commonShiftTransforms(I))
@@ -619,7 +620,8 @@
     return replaceInstUsesWith(I, V);
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-  if (Value *V = SimplifyLShrInst(Op0, Op1, I.isExact(), SQ))
+  if (Value *V =
+          SimplifyLShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *R = commonShiftTransforms(I))
@@ -722,7 +724,8 @@
     return replaceInstUsesWith(I, V);
 
   Value *Op0 = I.getOperand(0), *Op1 = I.getOperand(1);
-  if (Value *V = SimplifyAShrInst(Op0, Op1, I.isExact(), SQ))
+  if (Value *V =
+          SimplifyAShrInst(Op0, Op1, I.isExact(), SQ.getWithInstruction(&I)))
     return replaceInstUsesWith(I, V);
 
   if (Instruction *R = commonShiftTransforms(I))
diff --git a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
index abfa39d..926e466 100644
--- a/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstCombineVectorOps.cpp
@@ -145,7 +145,8 @@
 
 Instruction *InstCombiner::visitExtractElementInst(ExtractElementInst &EI) {
   if (Value *V = SimplifyExtractElementInst(EI.getVectorOperand(),
-                                            EI.getIndexOperand(), SQ))
+                                            EI.getIndexOperand(),
+                                            SQ.getWithInstruction(&EI)))
     return replaceInstUsesWith(EI, V);
 
   // If vector val is constant with all elements the same, replace EI with
@@ -1140,8 +1141,8 @@
   SmallVector<int, 16> Mask = SVI.getShuffleMask();
   Type *Int32Ty = Type::getInt32Ty(SVI.getContext());
 
-  if (auto *V =
-          SimplifyShuffleVectorInst(LHS, RHS, SVI.getMask(), SVI.getType(), SQ))
+  if (auto *V = SimplifyShuffleVectorInst(
+          LHS, RHS, SVI.getMask(), SVI.getType(), SQ.getWithInstruction(&SVI)))
     return replaceInstUsesWith(SVI, V);
 
   bool MadeChange = false;
diff --git a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
index 2678ae0..65e6d2e 100644
--- a/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
+++ b/llvm/lib/Transforms/InstCombine/InstructionCombining.cpp
@@ -256,7 +256,7 @@
         Value *C = I.getOperand(1);
 
         // Does "B op C" simplify?
-        if (Value *V = SimplifyBinOp(Opcode, B, C, SQ)) {
+        if (Value *V = SimplifyBinOp(Opcode, B, C, SQ.getWithInstruction(&I))) {
           // It simplifies to V.  Form "A op V".
           I.setOperand(0, A);
           I.setOperand(1, V);
@@ -285,7 +285,7 @@
         Value *C = Op1->getOperand(1);
 
         // Does "A op B" simplify?
-        if (Value *V = SimplifyBinOp(Opcode, A, B, SQ)) {
+        if (Value *V = SimplifyBinOp(Opcode, A, B, SQ.getWithInstruction(&I))) {
           // It simplifies to V.  Form "V op C".
           I.setOperand(0, V);
           I.setOperand(1, C);
@@ -313,7 +313,7 @@
         Value *C = I.getOperand(1);
 
         // Does "C op A" simplify?
-        if (Value *V = SimplifyBinOp(Opcode, C, A, SQ)) {
+        if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) {
           // It simplifies to V.  Form "V op B".
           I.setOperand(0, V);
           I.setOperand(1, B);
@@ -333,7 +333,7 @@
         Value *C = Op1->getOperand(1);
 
         // Does "C op A" simplify?
-        if (Value *V = SimplifyBinOp(Opcode, C, A, SQ)) {
+        if (Value *V = SimplifyBinOp(Opcode, C, A, SQ.getWithInstruction(&I))) {
           // It simplifies to V.  Form "B op V".
           I.setOperand(0, B);
           I.setOperand(1, V);
@@ -521,7 +521,7 @@
         std::swap(C, D);
       // Consider forming "A op' (B op D)".
       // If "B op D" simplifies then it can be formed with no cost.
-      V = SimplifyBinOp(TopLevelOpcode, B, D, SQ);
+      V = SimplifyBinOp(TopLevelOpcode, B, D, SQ.getWithInstruction(&I));
       // If "B op D" doesn't simplify then only go on if both of the existing
       // operations "A op' B" and "C op' D" will be zapped as no longer used.
       if (!V && LHS->hasOneUse() && RHS->hasOneUse())
@@ -540,7 +540,7 @@
         std::swap(C, D);
       // Consider forming "(A op C) op' B".
       // If "A op C" simplifies then it can be formed with no cost.
-      V = SimplifyBinOp(TopLevelOpcode, A, C, SQ);
+      V = SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I));
 
       // If "A op C" doesn't simplify then only go on if both of the existing
       // operations "A op' B" and "C op' D" will be zapped as no longer used.
@@ -638,8 +638,10 @@
     Instruction::BinaryOps InnerOpcode = Op0->getOpcode(); // op'
 
     // Do "A op C" and "B op C" both simplify?
-    if (Value *L = SimplifyBinOp(TopLevelOpcode, A, C, SQ))
-      if (Value *R = SimplifyBinOp(TopLevelOpcode, B, C, SQ)) {
+    if (Value *L =
+            SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I)))
+      if (Value *R =
+              SimplifyBinOp(TopLevelOpcode, B, C, SQ.getWithInstruction(&I))) {
         // They do! Return "L op' R".
         ++NumExpand;
         C = Builder->CreateBinOp(InnerOpcode, L, R);
@@ -655,8 +657,10 @@
     Instruction::BinaryOps InnerOpcode = Op1->getOpcode(); // op'
 
     // Do "A op B" and "A op C" both simplify?
-    if (Value *L = SimplifyBinOp(TopLevelOpcode, A, B, SQ))
-      if (Value *R = SimplifyBinOp(TopLevelOpcode, A, C, SQ)) {
+    if (Value *L =
+            SimplifyBinOp(TopLevelOpcode, A, B, SQ.getWithInstruction(&I)))
+      if (Value *R =
+              SimplifyBinOp(TopLevelOpcode, A, C, SQ.getWithInstruction(&I))) {
         // They do! Return "L op' R".
         ++NumExpand;
         A = Builder->CreateBinOp(InnerOpcode, L, R);
@@ -671,15 +675,17 @@
     if (auto *SI1 = dyn_cast<SelectInst>(RHS)) {
       if (SI0->getCondition() == SI1->getCondition()) {
         Value *SI = nullptr;
-        if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(),
-                                     SI1->getFalseValue(), SQ))
+        if (Value *V =
+                SimplifyBinOp(TopLevelOpcode, SI0->getFalseValue(),
+                              SI1->getFalseValue(), SQ.getWithInstruction(&I)))
           SI = Builder->CreateSelect(SI0->getCondition(),
                                      Builder->CreateBinOp(TopLevelOpcode,
                                                           SI0->getTrueValue(),
                                                           SI1->getTrueValue()),
                                      V);
-        if (Value *V = SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(),
-                                     SI1->getTrueValue(), SQ))
+        if (Value *V =
+                SimplifyBinOp(TopLevelOpcode, SI0->getTrueValue(),
+                              SI1->getTrueValue(), SQ.getWithInstruction(&I)))
           SI = Builder->CreateSelect(
               SI0->getCondition(), V,
               Builder->CreateBinOp(TopLevelOpcode, SI0->getFalseValue(),
@@ -1399,7 +1405,8 @@
 Instruction *InstCombiner::visitGetElementPtrInst(GetElementPtrInst &GEP) {
   SmallVector<Value*, 8> Ops(GEP.op_begin(), GEP.op_end());
 
-  if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops, SQ))
+  if (Value *V = SimplifyGEPInst(GEP.getSourceElementType(), Ops,
+                                 SQ.getWithInstruction(&GEP)))
     return replaceInstUsesWith(GEP, V);
 
   Value *PtrOp = GEP.getOperand(0);
@@ -1588,7 +1595,8 @@
       if (SO1->getType() != GO1->getType())
         return nullptr;
 
-      Value *Sum = SimplifyAddInst(GO1, SO1, false, false, SQ);
+      Value *Sum =
+          SimplifyAddInst(GO1, SO1, false, false, SQ.getWithInstruction(&GEP));
       // Only do the combine when we are sure the cost after the
       // merge is never more than that before the merge.
       if (Sum == nullptr)
@@ -2283,7 +2291,8 @@
   if (!EV.hasIndices())
     return replaceInstUsesWith(EV, Agg);
 
-  if (Value *V = SimplifyExtractValueInst(Agg, EV.getIndices(), SQ))
+  if (Value *V = SimplifyExtractValueInst(Agg, EV.getIndices(),
+                                          SQ.getWithInstruction(&EV)))
     return replaceInstUsesWith(EV, V);
 
   if (InsertValueInst *IV = dyn_cast<InsertValueInst>(Agg)) {