[LV] Support efficient vectorization of an induction with redundant casts

D30041 extended SCEVPredicateRewriter to improve handling of Phi nodes whose
update chain involves casts; PSCEV can now build an AddRecurrence for some
forms of such phi nodes, under the proper runtime overflow test. This means
that we can identify such phi nodes as an induction, and the loop-vectorizer
can now vectorize such inductions, however inefficiently. The vectorizer
doesn't know that it can ignore the casts, and so it vectorizes them.

This patch records the casts in the InductionDescriptor, so that they could
be marked to be ignored for cost calculation (we use VecValuesToIgnore for
that) and ignored for vectorization/widening/scalarization (i.e. treated as
TriviallyDead).

In addition to marking all these casts to be ignored, we also need to make
sure that each cast is mapped to the right vector value in the vector loop body
(be it a widened, vectorized, or scalarized induction). So whenever an
induction phi is mapped to a vector value (during vectorization/widening/
scalarization), we also map the respective cast instruction (if exists) to that
vector value. (If the phi-update sequence of an induction involves more than one
cast, then the above mapping to vector value is relevant only for the last cast
of the sequence as we allow only the "last cast" to be used outside the
induction update chain itself).

This is the last step in addressing PR30654.

llvm-svn: 320672
diff --git a/llvm/lib/Transforms/Utils/LoopUtils.cpp b/llvm/lib/Transforms/Utils/LoopUtils.cpp
index 7df6f91d..c3fa05a 100644
--- a/llvm/lib/Transforms/Utils/LoopUtils.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUtils.cpp
@@ -678,7 +678,8 @@
 }
 
 InductionDescriptor::InductionDescriptor(Value *Start, InductionKind K,
-                                         const SCEV *Step, BinaryOperator *BOp)
+                                         const SCEV *Step, BinaryOperator *BOp,
+                                         SmallVectorImpl<Instruction *> *Casts)
   : StartValue(Start), IK(K), Step(Step), InductionBinOp(BOp) {
   assert(IK != IK_NoInduction && "Not an induction");
 
@@ -705,6 +706,12 @@
           (InductionBinOp->getOpcode() == Instruction::FAdd ||
            InductionBinOp->getOpcode() == Instruction::FSub))) &&
          "Binary opcode should be specified for FP induction");
+
+  if (Casts) {
+    for (auto &Inst : *Casts) {
+      RedundantCasts.push_back(Inst);
+    }
+  }
 }
 
 int InductionDescriptor::getConsecutiveDirection() const {
@@ -808,7 +815,7 @@
     StartValue = Phi->getIncomingValue(1);
   } else {
     assert(TheLoop->contains(Phi->getIncomingBlock(1)) &&
-           "Unexpected Phi node in the loop"); 
+           "Unexpected Phi node in the loop");
     BEValue = Phi->getIncomingValue(1);
     StartValue = Phi->getIncomingValue(0);
   }
@@ -841,6 +848,110 @@
   return true;
 }
 
+/// This function is called when we suspect that the update-chain of a phi node
+/// (whose symbolic SCEV expression sin \p PhiScev) contains redundant casts, 
+/// that can be ignored. (This can happen when the PSCEV rewriter adds a runtime 
+/// predicate P under which the SCEV expression for the phi can be the 
+/// AddRecurrence \p AR; See createAddRecFromPHIWithCast). We want to find the 
+/// cast instructions that are involved in the update-chain of this induction. 
+/// A caller that adds the required runtime predicate can be free to drop these 
+/// cast instructions, and compute the phi using \p AR (instead of some scev 
+/// expression with casts).
+///
+/// For example, without a predicate the scev expression can take the following
+/// form:
+///      (Ext ix (Trunc iy ( Start + i*Step ) to ix) to iy)
+///
+/// It corresponds to the following IR sequence:
+/// %for.body:
+///   %x = phi i64 [ 0, %ph ], [ %add, %for.body ]
+///   %casted_phi = "ExtTrunc i64 %x"
+///   %add = add i64 %casted_phi, %step
+///
+/// where %x is given in \p PN,
+/// PSE.getSCEV(%x) is equal to PSE.getSCEV(%casted_phi) under a predicate,
+/// and the IR sequence that "ExtTrunc i64 %x" represents can take one of
+/// several forms, for example, such as:
+///   ExtTrunc1:    %casted_phi = and  %x, 2^n-1
+/// or:
+///   ExtTrunc2:    %t = shl %x, m
+///                 %casted_phi = ashr %t, m
+///
+/// If we are able to find such sequence, we return the instructions
+/// we found, namely %casted_phi and the instructions on its use-def chain up
+/// to the phi (not including the phi).
+bool getCastsForInductionPHI(
+    PredicatedScalarEvolution &PSE, const SCEVUnknown *PhiScev,
+    const SCEVAddRecExpr *AR, SmallVectorImpl<Instruction *> &CastInsts) {
+
+  assert(CastInsts.empty() && "CastInsts is expected to be empty.");
+  auto *PN = cast<PHINode>(PhiScev->getValue());
+  assert(PSE.getSCEV(PN) == AR && "Unexpected phi node SCEV expression");
+  const Loop *L = AR->getLoop();
+
+  // Find any cast instructions that participate in the def-use chain of 
+  // PhiScev in the loop.
+  // FORNOW/TODO: We currently expect the def-use chain to include only
+  // two-operand instructions, where one of the operands is an invariant.
+  // createAddRecFromPHIWithCasts() currently does not support anything more
+  // involved than that, so we keep the search simple. This can be
+  // extended/generalized as needed.
+
+  auto getDef = [&](const Value *Val) -> Value * {
+    const BinaryOperator *BinOp = dyn_cast<BinaryOperator>(Val);
+    if (!BinOp)
+      return nullptr;
+    Value *Op0 = BinOp->getOperand(0);
+    Value *Op1 = BinOp->getOperand(1);
+    Value *Def = nullptr;
+    if (L->isLoopInvariant(Op0))
+      Def = Op1;
+    else if (L->isLoopInvariant(Op1))
+      Def = Op0;
+    return Def;
+  };
+
+  // Look for the instruction that defines the induction via the
+  // loop backedge.
+  BasicBlock *Latch = L->getLoopLatch();
+  if (!Latch)
+    return false;
+  Value *Val = PN->getIncomingValueForBlock(Latch);
+  if (!Val)
+    return false;
+
+  // Follow the def-use chain until the induction phi is reached.
+  // If on the way we encounter a Value that has the same SCEV Expr as the
+  // phi node, we can consider the instructions we visit from that point
+  // as part of the cast-sequence that can be ignored.
+  bool InCastSequence = false;
+  auto *Inst = dyn_cast<Instruction>(Val);
+  while (Val != PN) {
+    // If we encountered a phi node other than PN, or if we left the loop,
+    // we bail out.
+    if (!Inst || !L->contains(Inst)) {
+      return false;
+    }
+    auto *AddRec = dyn_cast<SCEVAddRecExpr>(PSE.getSCEV(Val));
+    if (AddRec && PSE.areAddRecsEqualWithPreds(AddRec, AR))
+      InCastSequence = true;
+    if (InCastSequence) {
+      // Only the last instruction in the cast sequence is expected to have
+      // uses outside the induction def-use chain.
+      if (!CastInsts.empty())
+        if (!Inst->hasOneUse())
+          return false;
+      CastInsts.push_back(Inst);
+    }
+    Val = getDef(Val);
+    if (!Val)
+      return false;
+    Inst = dyn_cast<Instruction>(Val);
+  }
+
+  return InCastSequence;
+}
+
 bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
                                          PredicatedScalarEvolution &PSE,
                                          InductionDescriptor &D,
@@ -870,13 +981,26 @@
     return false;
   }
 
+  // Record any Cast instructions that participate in the induction update
+  const auto *SymbolicPhi = dyn_cast<SCEVUnknown>(PhiScev);
+  // If we started from an UnknownSCEV, and managed to build an addRecurrence
+  // only after enabling Assume with PSCEV, this means we may have encountered
+  // cast instructions that required adding a runtime check in order to
+  // guarantee the correctness of the AddRecurence respresentation of the
+  // induction.
+  if (PhiScev != AR && SymbolicPhi) {
+    SmallVector<Instruction *, 2> Casts;
+    if (getCastsForInductionPHI(PSE, SymbolicPhi, AR, Casts))
+      return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR, &Casts);
+  }
+
   return isInductionPHI(Phi, TheLoop, PSE.getSE(), D, AR);
 }
 
-bool InductionDescriptor::isInductionPHI(PHINode *Phi, const Loop *TheLoop,
-                                         ScalarEvolution *SE,
-                                         InductionDescriptor &D,
-                                         const SCEV *Expr) {
+bool InductionDescriptor::isInductionPHI(
+    PHINode *Phi, const Loop *TheLoop, ScalarEvolution *SE,
+    InductionDescriptor &D, const SCEV *Expr,
+    SmallVectorImpl<Instruction *> *CastsToIgnore) {
   Type *PhiTy = Phi->getType();
   // We only handle integer and pointer inductions variables.
   if (!PhiTy->isIntegerTy() && !PhiTy->isPointerTy())
@@ -895,7 +1019,7 @@
     // FIXME: We should treat this as a uniform. Unfortunately, we
     // don't currently know how to handled uniform PHIs.
     DEBUG(dbgs() << "LV: PHI is a recurrence with respect to an outer loop.\n");
-    return false;    
+    return false;
   }
 
   Value *StartValue =
@@ -908,7 +1032,8 @@
     return false;
 
   if (PhiTy->isIntegerTy()) {
-    D = InductionDescriptor(StartValue, IK_IntInduction, Step);
+    D = InductionDescriptor(StartValue, IK_IntInduction, Step, /*BOp=*/ nullptr,
+                            CastsToIgnore);
     return true;
   }
 
diff --git a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
index 440641b..5deb279 100644
--- a/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
+++ b/llvm/lib/Transforms/Vectorize/LoopVectorize.cpp
@@ -599,6 +599,20 @@
   /// Returns true if we should generate a scalar version of \p IV.
   bool needsScalarInduction(Instruction *IV) const;
 
+  /// If there is a cast involved in the induction variable \p ID, which should 
+  /// be ignored in the vectorized loop body, this function records the 
+  /// VectorLoopValue of the respective Phi also as the VectorLoopValue of the 
+  /// cast. We had already proved that the casted Phi is equal to the uncasted 
+  /// Phi in the vectorized loop (under a runtime guard), and therefore 
+  /// there is no need to vectorize the cast - the same value can be used in the 
+  /// vector loop for both the Phi and the cast. 
+  /// If \p VectorLoopValue is a scalarized value, \p Lane is also specified,
+  /// Otherwise, \p VectorLoopValue is a widened/vectorized value.
+  void recordVectorLoopValueForInductionCast (const InductionDescriptor &ID,
+                                              Value *VectorLoopValue, 
+                                              unsigned Part, 
+                                              unsigned Lane = UINT_MAX);
+
   /// Generate a shuffle sequence that will reverse the vector Vec.
   virtual Value *reverseVector(Value *Vec);
 
@@ -1570,7 +1584,17 @@
   /// Returns the widest induction type.
   Type *getWidestInductionType() { return WidestIndTy; }
 
-  /// Returns True if V is an induction variable in this loop.
+  /// Returns True if V is a Phi node of an induction variable in this loop.
+  bool isInductionPhi(const Value *V);
+
+  /// Returns True if V is a cast that is part of an induction def-use chain,
+  /// and had been proven to be redundant under a runtime guard (in other
+  /// words, the cast has the same SCEV expression as the induction phi).
+  bool isCastedInductionVariable(const Value *V);
+
+  /// Returns True if V can be considered as an induction variable in this 
+  /// loop. V can be the induction phi, or some redundant cast in the def-use
+  /// chain of the inducion phi.
   bool isInductionVariable(const Value *V);
 
   /// Returns True if PN is a reduction variable in this loop.
@@ -1781,6 +1805,12 @@
   /// variables can be pointers.
   InductionList Inductions;
 
+  /// Holds all the casts that participate in the update chain of the induction 
+  /// variables, and that have been proven to be redundant (possibly under a 
+  /// runtime guard). These casts can be ignored when creating the vectorized 
+  /// loop body.
+  SmallPtrSet<Instruction *, 4> InductionCastsToIgnore;
+
   /// Holds the phi nodes that are first-order recurrences.
   RecurrenceSet FirstOrderRecurrences;
 
@@ -2014,7 +2044,7 @@
       return false;
 
     // If the truncated value is not an induction variable, return false.
-    return Legal->isInductionVariable(Op);
+    return Legal->isInductionPhi(Op);
   }
 
   /// Collects the instructions to scalarize for each predicated instruction in
@@ -2600,6 +2630,7 @@
   Instruction *LastInduction = VecInd;
   for (unsigned Part = 0; Part < UF; ++Part) {
     VectorLoopValueMap.setVectorValue(EntryVal, Part, LastInduction);
+    recordVectorLoopValueForInductionCast(II, LastInduction, Part);
     if (isa<TruncInst>(EntryVal))
       addMetadata(LastInduction, EntryVal);
     LastInduction = cast<Instruction>(addFastMathFlag(
@@ -2633,6 +2664,22 @@
   return llvm::any_of(IV->users(), isScalarInst);
 }
 
+void InnerLoopVectorizer::recordVectorLoopValueForInductionCast(
+    const InductionDescriptor &ID, Value *VectorLoopVal, unsigned Part,
+    unsigned Lane) {
+  const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts();
+  if (Casts.empty())
+    return;
+  // Only the first Cast instruction in the Casts vector is of interest.
+  // The rest of the Casts (if exist) have no uses outside the
+  // induction update chain itself.
+  Instruction *CastInst = *Casts.begin();
+  if (Lane < UINT_MAX)
+    VectorLoopValueMap.setScalarValue(CastInst, {Part, Lane}, VectorLoopVal);
+  else
+    VectorLoopValueMap.setVectorValue(CastInst, Part, VectorLoopVal);
+}
+
 void InnerLoopVectorizer::widenIntOrFpInduction(PHINode *IV, TruncInst *Trunc) {
   assert((IV->getType()->isIntegerTy() || IV != OldInduction) &&
          "Primary induction variable must have an integer type");
@@ -2713,6 +2760,7 @@
       Value *EntryPart =
           getStepVector(Broadcasted, VF * Part, Step, ID.getInductionOpcode());
       VectorLoopValueMap.setVectorValue(EntryVal, Part, EntryPart);
+      recordVectorLoopValueForInductionCast(ID, EntryPart, Part);
       if (Trunc)
         addMetadata(EntryPart, Trunc);
     }
@@ -2820,6 +2868,7 @@
       auto *Mul = addFastMathFlag(Builder.CreateBinOp(MulOp, StartIdx, Step));
       auto *Add = addFastMathFlag(Builder.CreateBinOp(AddOp, ScalarIV, Mul));
       VectorLoopValueMap.setScalarValue(EntryVal, {Part, Lane}, Add);
+      recordVectorLoopValueForInductionCast(ID, Add, Part, Lane);
     }
   }
 }
@@ -5169,6 +5218,15 @@
     PHINode *Phi, const InductionDescriptor &ID,
     SmallPtrSetImpl<Value *> &AllowedExit) {
   Inductions[Phi] = ID;
+
+  // In case this induction also comes with casts that we know we can ignore
+  // in the vectorized loop body, record them here. All casts could be recorded
+  // here for ignoring, but suffices to record only the first (as it is the
+  // only one that may bw used outside the cast sequence).
+  const SmallVectorImpl<Instruction *> &Casts = ID.getCastInsts();
+  if (!Casts.empty())
+    InductionCastsToIgnore.insert(*Casts.begin());
+
   Type *PhiTy = Phi->getType();
   const DataLayout &DL = Phi->getModule()->getDataLayout();
 
@@ -5798,7 +5856,7 @@
   return true;
 }
 
-bool LoopVectorizationLegality::isInductionVariable(const Value *V) {
+bool LoopVectorizationLegality::isInductionPhi(const Value *V) {
   Value *In0 = const_cast<Value *>(V);
   PHINode *PN = dyn_cast_or_null<PHINode>(In0);
   if (!PN)
@@ -5807,6 +5865,15 @@
   return Inductions.count(PN);
 }
 
+bool LoopVectorizationLegality::isCastedInductionVariable(const Value *V) {
+  auto *Inst = dyn_cast<Instruction>(V);
+  return (Inst && InductionCastsToIgnore.count(Inst));
+}
+
+bool LoopVectorizationLegality::isInductionVariable(const Value *V) {
+  return isInductionPhi(V) || isCastedInductionVariable(V);
+}
+
 bool LoopVectorizationLegality::isFirstOrderRecurrence(const PHINode *Phi) {
   return FirstOrderRecurrences.count(Phi);
 }
@@ -6917,14 +6984,16 @@
 static const SCEV *getAddressAccessSCEV(
               Value *Ptr,
               LoopVectorizationLegality *Legal,
-              ScalarEvolution *SE,
+              PredicatedScalarEvolution &PSE,
               const Loop *TheLoop) {
+
   auto *Gep = dyn_cast<GetElementPtrInst>(Ptr);
   if (!Gep)
     return nullptr;
 
   // We are looking for a gep with all loop invariant indices except for one
   // which should be an induction variable.
+  auto SE = PSE.getSE();
   unsigned NumOperands = Gep->getNumOperands();
   for (unsigned i = 1; i < NumOperands; ++i) {
     Value *Opd = Gep->getOperand(i);
@@ -6934,7 +7003,7 @@
   }
 
   // Now we know we have a GEP ptr, %inv, %ind, %inv. return the Ptr SCEV.
-  return SE->getSCEV(Ptr);
+  return PSE.getSCEV(Ptr);
 }
 
 static bool isStrideMul(Instruction *I, LoopVectorizationLegality *Legal) {
@@ -6954,7 +7023,7 @@
 
   // Figure out whether the access is strided and get the stride value
   // if it's known in compile time
-  const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, SE, TheLoop);
+  const SCEV *PtrSCEV = getAddressAccessSCEV(Ptr, Legal, PSE, TheLoop);
 
   // Get the cost of the scalar memory instruction and address computation.
   unsigned Cost = VF * TTI.getAddressComputationCost(PtrTy, SE, PtrSCEV);
@@ -7508,6 +7577,13 @@
     SmallPtrSetImpl<Instruction *> &Casts = RedDes.getCastInsts();
     VecValuesToIgnore.insert(Casts.begin(), Casts.end());
   }
+  // Ignore type-casting instructions we identified during induction
+  // detection.
+  for (auto &Induction : *Legal->getInductionVars()) {
+    InductionDescriptor &IndDes = Induction.second;
+    const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
+    VecValuesToIgnore.insert(Casts.begin(), Casts.end());
+  }
 }
 
 LoopVectorizationCostModel::VectorizationFactor
@@ -7613,6 +7689,18 @@
           return U == Ind || DeadInstructions.count(cast<Instruction>(U));
         }))
       DeadInstructions.insert(IndUpdate);
+
+    // We record as "Dead" also the type-casting instructions we had identified 
+    // during induction analysis. We don't need any handling for them in the
+    // vectorized loop because we have proven that, under a proper runtime 
+    // test guarding the vectorized loop, the value of the phi, and the casted 
+    // value of the phi, are the same. The last instruction in this casting chain
+    // will get its scalar/vector/widened def from the scalar/vector/widened def 
+    // of the respective phi node. Any other casts in the induction def-use chain
+    // have no other uses outside the phi update chain, and will be ignored.
+    InductionDescriptor &IndDes = Induction.second;
+    const SmallVectorImpl<Instruction *> &Casts = IndDes.getCastInsts();
+    DeadInstructions.insert(Casts.begin(), Casts.end());
   }
 }