More Chris-inspired JumpThreading fixes: use ConstantExpr to correctly constant-fold undef, and be more careful with its return value.
This actually exposed an infinite recursion bug in ComputeValueKnownInPredecessors which theoretically already existed (in JumpThreading's
handling of and/or of i1's), but never manifested before.  This patch adds a tracking set to prevent this case.


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@112589 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/JumpThreading.cpp b/lib/Transforms/Scalar/JumpThreading.cpp
index 7a4c685..afc1661 100644
--- a/lib/Transforms/Scalar/JumpThreading.cpp
+++ b/lib/Transforms/Scalar/JumpThreading.cpp
@@ -16,7 +16,6 @@
 #include "llvm/IntrinsicInst.h"
 #include "llvm/LLVMContext.h"
 #include "llvm/Pass.h"
-#include "llvm/Analysis/ConstantFolding.h"
 #include "llvm/Analysis/InstructionSimplify.h"
 #include "llvm/Analysis/LazyValueInfo.h"
 #include "llvm/Analysis/Loads.h"
@@ -25,6 +24,7 @@
 #include "llvm/Transforms/Utils/SSAUpdater.h"
 #include "llvm/Target/TargetData.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/ADT/DenseSet.h"
 #include "llvm/ADT/Statistic.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/ADT/SmallPtrSet.h"
@@ -78,6 +78,7 @@
 #else
     SmallSet<AssertingVH<BasicBlock>, 16> LoopHeaders;
 #endif
+    DenseSet<std::pair<Value*, BasicBlock*> > RecursionSet;
   public:
     static char ID; // Pass identification
     JumpThreading() : FunctionPass(ID) {}
@@ -270,12 +271,17 @@
 ///
 bool JumpThreading::
 ComputeValueKnownInPredecessors(Value *V, BasicBlock *BB,PredValueInfo &Result){
+  if (!RecursionSet.insert(std::make_pair(V, BB)).second)
+    return false;
+  
   // If V is a constantint, then it is known in all predecessors.
   if (isa<ConstantInt>(V) || isa<UndefValue>(V)) {
     ConstantInt *CI = dyn_cast<ConstantInt>(V);
     
     for (pred_iterator PI = pred_begin(BB), E = pred_end(BB); PI != E; ++PI)
       Result.push_back(std::make_pair(CI, *PI));
+    
+    RecursionSet.erase(std::make_pair(V, BB));
     return true;
   }
   
@@ -310,9 +316,11 @@
         Result.push_back(std::make_pair(dyn_cast<ConstantInt>(PredCst), P));
       }
       
+      RecursionSet.erase(std::make_pair(V, BB));
       return !Result.empty();
     }
     
+    RecursionSet.erase(std::make_pair(V, BB));
     return false;
   }
   
@@ -328,10 +336,15 @@
                                               PN->getIncomingBlock(i), BB);
         // LVI returns null is no value could be determined.
         if (!CI) continue;
-        ConstantInt *CInt = dyn_cast<ConstantInt>(CI);
-        Result.push_back(std::make_pair(CInt, PN->getIncomingBlock(i)));
+        if (ConstantInt *CInt = dyn_cast<ConstantInt>(CI))
+          Result.push_back(std::make_pair(CInt, PN->getIncomingBlock(i)));
+        else if (isa<UndefValue>(CI))
+           Result.push_back(std::make_pair((ConstantInt*)0,
+                                           PN->getIncomingBlock(i)));
       }
     }
+    
+    RecursionSet.erase(std::make_pair(V, BB));
     return !Result.empty();
   }
   
@@ -346,8 +359,10 @@
       ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals);
       ComputeValueKnownInPredecessors(I->getOperand(1), BB, RHSVals);
       
-      if (LHSVals.empty() && RHSVals.empty())
+      if (LHSVals.empty() && RHSVals.empty()) {
+        RecursionSet.erase(std::make_pair(V, BB));
         return false;
+      }
       
       ConstantInt *InterestingVal;
       if (I->getOpcode() == Instruction::Or)
@@ -374,6 +389,8 @@
             Result.back().first = InterestingVal;
           }
         }
+      
+      RecursionSet.erase(std::make_pair(V, BB));
       return !Result.empty();
     }
     
@@ -382,38 +399,48 @@
         isa<ConstantInt>(I->getOperand(1)) &&
         cast<ConstantInt>(I->getOperand(1))->isOne()) {
       ComputeValueKnownInPredecessors(I->getOperand(0), BB, Result);
-      if (Result.empty())
+      if (Result.empty()) {
+        RecursionSet.erase(std::make_pair(V, BB));
         return false;
+      }
 
       // Invert the known values.
       for (unsigned i = 0, e = Result.size(); i != e; ++i)
         if (Result[i].first)
           Result[i].first =
             cast<ConstantInt>(ConstantExpr::getNot(Result[i].first));
+      
+      RecursionSet.erase(std::make_pair(V, BB));
       return true;
     }
   
   // Try to simplify some other binary operator values.
   } else if (BinaryOperator *BO = dyn_cast<BinaryOperator>(I)) {
-    // AND or OR of a value with itself is that value.
     ConstantInt *CI = dyn_cast<ConstantInt>(BO->getOperand(1));
-    if (CI && (BO->getOpcode() == Instruction::And ||
-        BO->getOpcode() == Instruction::Or)) {
+    if (CI) {
       SmallVector<std::pair<ConstantInt*, BasicBlock*>, 8> LHSVals;
       ComputeValueKnownInPredecessors(BO->getOperand(0), BB, LHSVals);
-      for (unsigned i = 0, e = LHSVals.size(); i != e; ++i)
+    
+      // Try to use constant folding to simplify the binary operator.
+      for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) {
+        Constant *Folded = 0;
         if (LHSVals[i].first == 0) {
-          ConstantInt *Zero =
-            cast<ConstantInt>(ConstantInt::get(BO->getType(), 0));
-          Result.push_back(std::make_pair(Zero, LHSVals[i].second));
-        } else if (Constant *Folded = ConstantExpr::get(BO->getOpcode(),
-                                                   LHSVals[i].first, CI)) {
-          Result.push_back(std::make_pair(cast<ConstantInt>(Folded), 
-                                          LHSVals[i].second));
+          Folded = ConstantExpr::get(BO->getOpcode(),
+                                     UndefValue::get(BO->getType()),
+                                     CI);
+        } else {
+          Folded = ConstantExpr::get(BO->getOpcode(), LHSVals[i].first, CI);
         }
-      
-      return !Result.empty();
+        
+        if (ConstantInt *FoldedCInt = dyn_cast<ConstantInt>(Folded))
+          Result.push_back(std::make_pair(FoldedCInt, LHSVals[i].second));
+        else if (isa<UndefValue>(Folded))
+          Result.push_back(std::make_pair((ConstantInt*)0, LHSVals[i].second));
+      }
     }
+      
+    RecursionSet.erase(std::make_pair(V, BB));
+    return !Result.empty();
   }
   
   // Handle compare with phi operand, where the PHI is defined in this block.
@@ -446,6 +473,7 @@
           Result.push_back(std::make_pair(CI, PredBB));
       }
       
+      RecursionSet.erase(std::make_pair(V, BB));
       return !Result.empty();
     }
     
@@ -472,25 +500,32 @@
           Result.push_back(std::make_pair(cast<ConstantInt>(ResC), P));
         }
 
+        RecursionSet.erase(std::make_pair(V, BB));
         return !Result.empty();
       }
       
-      // Try to find a constant value for the LHS of an equality comparison,
+      // Try to find a constant value for the LHS of a comparison,
       // and evaluate it statically if we can.
       if (Constant *CmpConst = dyn_cast<Constant>(Cmp->getOperand(1))) {
         SmallVector<std::pair<ConstantInt*, BasicBlock*>, 8> LHSVals;
         ComputeValueKnownInPredecessors(I->getOperand(0), BB, LHSVals);
         
         for (unsigned i = 0, e = LHSVals.size(); i != e; ++i) {
+          Constant * Folded = 0;
           if (LHSVals[i].first == 0)
-            Result.push_back(std::make_pair((ConstantInt*)0,
-                                            LHSVals[i].second));
-          else if (Constant *Folded = ConstantExpr::getCompare(
-                               Cmp->getPredicate(), LHSVals[i].first, CmpConst))
-            Result.push_back(std::make_pair(cast<ConstantInt>(Folded),
-                                            LHSVals[i].second));
+            Folded = ConstantExpr::getCompare(Cmp->getPredicate(),
+                                UndefValue::get(CmpConst->getType()), CmpConst);
+          else
+            Folded = ConstantExpr::getCompare(Cmp->getPredicate(),   
+                                              LHSVals[i].first, CmpConst);
+          
+          if (ConstantInt *FoldedCInt = dyn_cast<ConstantInt>(Folded))
+            Result.push_back(std::make_pair(FoldedCInt, LHSVals[i].second));
+          else if (isa<UndefValue>(Folded))
+            Result.push_back(std::make_pair((ConstantInt*)0,LHSVals[i].second));
         }
         
+        RecursionSet.erase(std::make_pair(V, BB));
         return !Result.empty();
       }
     }
@@ -505,9 +540,11 @@
         Result.push_back(std::make_pair(CInt, *PI));
     }
     
+    RecursionSet.erase(std::make_pair(V, BB));
     return !Result.empty();
   }
   
+  RecursionSet.erase(std::make_pair(V, BB));
   return false;
 }
 
@@ -1126,8 +1163,9 @@
     return false;
   
   SmallVector<std::pair<ConstantInt*, BasicBlock*>, 8> PredValues;
-  if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues))
+  if (!ComputeValueKnownInPredecessors(Cond, BB, PredValues)) {
     return false;
+  }
   assert(!PredValues.empty() &&
          "ComputeValueKnownInPredecessors returned true with no values");