Handle floating point ivs during doInitialization().


git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@59466 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Transforms/Scalar/IndVarSimplify.cpp b/lib/Transforms/Scalar/IndVarSimplify.cpp
index 2af10a6..d580978 100644
--- a/lib/Transforms/Scalar/IndVarSimplify.cpp
+++ b/lib/Transforms/Scalar/IndVarSimplify.cpp
@@ -96,7 +96,8 @@
     void DeleteTriviallyDeadInstructions(SmallPtrSet<Instruction*, 16> &Insts);
 
     void OptimizeCanonicalIVType(Loop *L);
-    void HandleFloatingPointIV(Loop *L);
+    void HandleFloatingPointIV(Loop *L, PHINode *PH, 
+                               SmallPtrSet<Instruction*, 16> &DeadInsts);
   };
 }
 
@@ -433,6 +434,8 @@
     PHINode *PN = cast<PHINode>(I);
     if (isa<PointerType>(PN->getType()))
       EliminatePointerRecurrence(PN, Preheader, DeadInsts);
+    else
+      HandleFloatingPointIV(L, PN, DeadInsts);
   }
 
   if (!DeadInsts.empty())
@@ -468,7 +471,6 @@
   // auxillary induction variables.
   std::vector<std::pair<PHINode*, SCEVHandle> > IndVars;
 
-  HandleFloatingPointIV(L);
   for (BasicBlock::iterator I = Header->begin(); isa<PHINode>(I); ++I) {
     PHINode *PN = cast<PHINode>(I);
     if (PN->getType()->isInteger()) { // FIXME: when we have fast-math, enable!
@@ -723,149 +725,133 @@
 
 /// HandleFloatingPointIV - If the loop has floating induction variable
 /// then insert corresponding integer induction variable if possible.
-void IndVarSimplify::HandleFloatingPointIV(Loop *L) {
-  BasicBlock *Header = L->getHeader();
-  SmallVector <PHINode *, 4> FPHIs;
-  Instruction *NonPHIInsn = NULL;
+/// For example,
+/// for(double i = 0; i < 10000; ++i)
+///   bar(i)
+/// is converted into
+/// for(int i = 0; i < 10000; ++i)
+///   bar((double)i);
+///
+void IndVarSimplify::HandleFloatingPointIV(Loop *L, PHINode *PH, 
+                                   SmallPtrSet<Instruction*, 16> &DeadInsts) {
 
-  // Collect all floating point IVs first.
-  BasicBlock::iterator I = Header->begin();
-  while(true) {
-    if (!isa<PHINode>(I)) {
-      NonPHIInsn = I;
-      break;
-    }
-    PHINode *PH = cast<PHINode>(I);
-    if (PH->getType()->isFloatingPoint())
-      FPHIs.push_back(PH);
-    ++I;
+  unsigned IncomingEdge = L->contains(PH->getIncomingBlock(0));
+  unsigned BackEdge     = IncomingEdge^1;
+  
+  // Check incoming value.
+  ConstantFP *CZ = dyn_cast<ConstantFP>(PH->getIncomingValue(IncomingEdge));
+  if (!CZ) return;
+  APFloat PHInit = CZ->getValueAPF();
+  if (!PHInit.isPosZero()) return;
+  
+  // Check IV increment.
+  BinaryOperator *Incr = 
+    dyn_cast<BinaryOperator>(PH->getIncomingValue(BackEdge));
+  if (!Incr) return;
+  if (Incr->getOpcode() != Instruction::Add) return;
+  ConstantFP *IncrValue = NULL;
+  unsigned IncrVIndex = 1;
+  if (Incr->getOperand(1) == PH)
+    IncrVIndex = 0;
+  IncrValue = dyn_cast<ConstantFP>(Incr->getOperand(IncrVIndex));
+  if (!IncrValue) return;
+  APFloat IVAPF = IncrValue->getValueAPF();
+  APFloat One = APFloat(IVAPF.getSemantics(), 1);
+  if (!IVAPF.bitwiseIsEqual(One)) return;
+  
+  // Check Incr uses.
+  Value::use_iterator IncrUse = Incr->use_begin();
+  Instruction *U1 = cast<Instruction>(IncrUse++);
+  if (IncrUse == Incr->use_end()) return;
+  Instruction *U2 = cast<Instruction>(IncrUse++);
+  if (IncrUse != Incr->use_end()) return;
+  
+  // Find exit condition.
+  FCmpInst *EC = dyn_cast<FCmpInst>(U1);
+  if (!EC)
+    EC = dyn_cast<FCmpInst>(U2);
+  if (!EC) return;
+
+  if (BranchInst *BI = dyn_cast<BranchInst>(EC->getParent()->getTerminator())) {
+    if (!BI->isConditional()) return;
+    if (BI->getCondition() != EC) return;
   }
-   
-  for (SmallVector<PHINode *, 4>::iterator I = FPHIs.begin(), E = FPHIs.end();
-       I != E; ++I) {
-    PHINode *PH = *I;
-    unsigned IncomingEdge = L->contains(PH->getIncomingBlock(0));
-    unsigned BackEdge     = IncomingEdge^1;
 
-    // Check incoming value.
-    ConstantFP *CZ = dyn_cast<ConstantFP>(PH->getIncomingValue(IncomingEdge));
-    if (!CZ) continue;
-    APFloat PHInit = CZ->getValueAPF();
-    if (!PHInit.isPosZero()) continue;
-
-    // Check IV increment.
-    BinaryOperator *Incr = 
-      dyn_cast<BinaryOperator>(PH->getIncomingValue(BackEdge));
-    if (!Incr) continue;
-    if (Incr->getOpcode() != Instruction::Add) continue;
-    ConstantFP *IncrValue = NULL;
-    unsigned IncrVIndex = 1;
-    if (Incr->getOperand(1) == PH)
-      IncrVIndex = 0;
-    IncrValue = dyn_cast<ConstantFP>(Incr->getOperand(IncrVIndex));
-    if (!IncrValue) continue;
-    APFloat IVAPF = IncrValue->getValueAPF();
-    APFloat One = APFloat(IVAPF.getSemantics(), 1);
-    if (!IVAPF.bitwiseIsEqual(One)) continue;
-
-    // Check Incr uses.
-    Value::use_iterator IncrUse = Incr->use_begin();
-    Instruction *U1 = cast<Instruction>(IncrUse++);
-    if (IncrUse == Incr->use_end()) continue;
-    Instruction *U2 = cast<Instruction>(IncrUse++);
-    if (IncrUse != Incr->use_end()) continue;
-
-    // Find exict condition.
-    FCmpInst *EC = dyn_cast<FCmpInst>(U1);
-    if (!EC)
-      EC = dyn_cast<FCmpInst>(U2);
-    if (!EC) continue;
-    bool skip = false;
-    Instruction *Terminator = EC->getParent()->getTerminator();
-    for(Value::use_iterator ECUI = EC->use_begin(), ECUE = EC->use_end();
-        ECUI != ECUE; ++ECUI) {
-      Instruction *U = cast<Instruction>(ECUI);
-      if (U != Terminator) { 
-        skip = true;
-        break;
-      }
-    }
-    if (skip) continue;
-
-    // Find exit value.
-    ConstantFP *EV = NULL;
-    unsigned EVIndex = 1;
-    if (EC->getOperand(1) == Incr)
-      EVIndex = 0;
-    EV = dyn_cast<ConstantFP>(EC->getOperand(EVIndex));
-    if (!EV) continue;
-    APFloat EVAPF = EV->getValueAPF();
-    if (EVAPF.isNegative()) continue;
-
-    // Find corresponding integer exit value.
-    uint64_t integerVal = Type::Int32Ty->getPrimitiveSizeInBits();
-    bool isExact = false;
-    if (EVAPF.convertToInteger(&integerVal, 32, false, APFloat::rmTowardZero, &isExact)
-        != APFloat::opOK)
-      continue;
-    if (!isExact) continue;
-
-    // Find new predicate for integer comparison.
-    CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE;
-    switch (EC->getPredicate()) {
-    case CmpInst::FCMP_OEQ:
-    case CmpInst::FCMP_UEQ:
-      NewPred = CmpInst::ICMP_EQ;
-      break;
-    case CmpInst::FCMP_OGT:
-    case CmpInst::FCMP_UGT:
-      NewPred = CmpInst::ICMP_UGT;
-      break;
-    case CmpInst::FCMP_OGE:
-    case CmpInst::FCMP_UGE:
-      NewPred = CmpInst::ICMP_UGE;
-      break;
-    case CmpInst::FCMP_OLT:
-    case CmpInst::FCMP_ULT:
-      NewPred = CmpInst::ICMP_ULT;
-      break;
-    case CmpInst::FCMP_OLE:
-    case CmpInst::FCMP_ULE:
-      NewPred = CmpInst::ICMP_ULE;
-      break;
-    default:
-      break;
-    }
-    if (NewPred == CmpInst::BAD_ICMP_PREDICATE) continue;
-
-    // Insert new integer induction variable.
-    SCEVExpander Rewriter(*SE, *LI);
-    PHINode *NewIV = 
-      cast<PHINode>(Rewriter.getOrInsertCanonicalInductionVariable(L,Type::Int32Ty));
-    ConstantInt *NewEV = ConstantInt::get(Type::Int32Ty, integerVal);
-    Value *LHS = (EVIndex == 1 ? NewIV->getIncomingValue(BackEdge) : NewEV);
-    Value *RHS = (EVIndex == 1 ? NewEV : NewIV->getIncomingValue(BackEdge));
-    ICmpInst *NewEC = new ICmpInst(NewPred, LHS, RHS, EC->getNameStart(), 
-                                   EC->getParent()->getTerminator());
-
-    // Delete old, floating point, exit comparision instruction.
-    SE->deleteValueFromRecords(EC);
-    EC->replaceAllUsesWith(NewEC);
-    EC->eraseFromParent();
-
-    // Delete old, floating point, increment instruction.
-    SE->deleteValueFromRecords(Incr);
-    Incr->replaceAllUsesWith(UndefValue::get(Incr->getType()));
-    Incr->eraseFromParent();
-
-    // Replace floating induction variable.
-    UIToFPInst *Conv = new UIToFPInst(NewIV, PH->getType(), "indvar.conv", 
-                                      NonPHIInsn);
-    PH->replaceAllUsesWith(Conv);
-
-    SE->deleteValueFromRecords(PH);
-    PH->removeIncomingValue((unsigned)0);
-    PH->removeIncomingValue((unsigned)0);
+  // Find exit value.
+  ConstantFP *EV = NULL;
+  unsigned EVIndex = 1;
+  if (EC->getOperand(1) == Incr)
+    EVIndex = 0;
+  EV = dyn_cast<ConstantFP>(EC->getOperand(EVIndex));
+  if (!EV) return;
+  APFloat EVAPF = EV->getValueAPF();
+  if (EVAPF.isNegative()) return;
+  
+  // Find corresponding integer exit value.
+  uint64_t intEV = Type::Int32Ty->getPrimitiveSizeInBits();
+  bool isExact = false;
+  if (EVAPF.convertToInteger(&intEV, 32, false, APFloat::rmTowardZero, &isExact)
+      != APFloat::opOK)
+    return;
+  if (!isExact) return;
+  
+  // Find new predicate for integer comparison.
+  CmpInst::Predicate NewPred = CmpInst::BAD_ICMP_PREDICATE;
+  switch (EC->getPredicate()) {
+  case CmpInst::FCMP_OEQ:
+  case CmpInst::FCMP_UEQ:
+    NewPred = CmpInst::ICMP_EQ;
+    break;
+  case CmpInst::FCMP_OGT:
+  case CmpInst::FCMP_UGT:
+    NewPred = CmpInst::ICMP_UGT;
+    break;
+  case CmpInst::FCMP_OGE:
+  case CmpInst::FCMP_UGE:
+    NewPred = CmpInst::ICMP_UGE;
+    break;
+  case CmpInst::FCMP_OLT:
+  case CmpInst::FCMP_ULT:
+    NewPred = CmpInst::ICMP_ULT;
+    break;
+  case CmpInst::FCMP_OLE:
+  case CmpInst::FCMP_ULE:
+    NewPred = CmpInst::ICMP_ULE;
+    break;
+  default:
+    break;
   }
+  if (NewPred == CmpInst::BAD_ICMP_PREDICATE) return;
+  
+  // Insert new integer induction variable.
+  PHINode *NewPHI = PHINode::Create(Type::Int32Ty,
+                                    PH->getName()+".int", PH);
+  NewPHI->addIncoming(Constant::getNullValue(NewPHI->getType()),
+                      PH->getIncomingBlock(IncomingEdge));
+
+  Value *NewAdd = BinaryOperator::CreateAdd(NewPHI, 
+                                            ConstantInt::get(Type::Int32Ty, 1),
+                                            Incr->getName()+".int", Incr);
+  NewPHI->addIncoming(NewAdd, PH->getIncomingBlock(BackEdge));
+
+  ConstantInt *NewEV = ConstantInt::get(Type::Int32Ty, intEV);
+  Value *LHS = (EVIndex == 1 ? NewPHI->getIncomingValue(BackEdge) : NewEV);
+  Value *RHS = (EVIndex == 1 ? NewEV : NewPHI->getIncomingValue(BackEdge));
+  ICmpInst *NewEC = new ICmpInst(NewPred, LHS, RHS, EC->getNameStart(), 
+                                 EC->getParent()->getTerminator());
+  
+  // Delete old, floating point, exit comparision instruction.
+  EC->replaceAllUsesWith(NewEC);
+  DeadInsts.insert(EC);
+  
+  // Delete old, floating point, increment instruction.
+  Incr->replaceAllUsesWith(UndefValue::get(Incr->getType()));
+  DeadInsts.insert(Incr);
+  
+  // Replace floating induction variable.
+  UIToFPInst *Conv = new UIToFPInst(NewPHI, PH->getType(), "indvar.conv", 
+                                    PH->getParent()->getFirstNonPHI());
+  PH->replaceAllUsesWith(Conv);
+  DeadInsts.insert(PH);
 }