Re-implement the main strength-reduction portion of LoopStrengthReduction.
This new version is much more aggressive about doing "full" reduction in
cases where it reduces register pressure, and also more aggressive about
rewriting induction variables to count down (or up) to zero when doing so
reduces register pressure.
It currently uses fairly simplistic algorithms for finding reuse
opportunities, but it introduces a new framework allows it to combine
multiple strategies at once to form hybrid solutions, instead of doing
all full-reduction or all base+index.
git-svn-id: https://llvm.org/svn/llvm-project/llvm/trunk@94061 91177308-0d34-0410-b5e6-96231b3b80d8
diff --git a/lib/Analysis/ScalarEvolution.cpp b/lib/Analysis/ScalarEvolution.cpp
index 7389007..2f44913 100644
--- a/lib/Analysis/ScalarEvolution.cpp
+++ b/lib/Analysis/ScalarEvolution.cpp
@@ -1089,6 +1089,15 @@
if (!isa<SCEVSignExtendExpr>(SExt))
return SExt;
+ // Force the cast to be folded into the operands of an addrec.
+ if (const SCEVAddRecExpr *AR = dyn_cast<SCEVAddRecExpr>(Op)) {
+ SmallVector<const SCEV *, 4> Ops;
+ for (SCEVAddRecExpr::op_iterator I = AR->op_begin(), E = AR->op_end();
+ I != E; ++I)
+ Ops.push_back(getAnyExtendExpr(*I, Ty));
+ return getAddRecExpr(Ops, AR->getLoop());
+ }
+
// If the expression is obviously signed, use the sext cast value.
if (isa<SCEVSMaxExpr>(Op))
return SExt;
@@ -1204,6 +1213,17 @@
"SCEVAddExpr operand types don't match!");
#endif
+ // If HasNSW is true and all the operands are non-negative, infer HasNUW.
+ if (!HasNUW && HasNSW) {
+ bool All = true;
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ if (!isKnownNonNegative(Ops[i])) {
+ All = false;
+ break;
+ }
+ if (All) HasNUW = true;
+ }
+
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, LI);
@@ -1521,10 +1541,13 @@
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
ID.AddPointer(Ops[i]);
void *IP = 0;
- if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
- SCEVAddExpr *S = SCEVAllocator.Allocate<SCEVAddExpr>();
- new (S) SCEVAddExpr(ID, Ops);
- UniqueSCEVs.InsertNode(S, IP);
+ SCEVAddExpr *S =
+ static_cast<SCEVAddExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ S = SCEVAllocator.Allocate<SCEVAddExpr>();
+ new (S) SCEVAddExpr(ID, Ops);
+ UniqueSCEVs.InsertNode(S, IP);
+ }
if (HasNUW) S->setHasNoUnsignedWrap(true);
if (HasNSW) S->setHasNoSignedWrap(true);
return S;
@@ -1535,6 +1558,7 @@
const SCEV *ScalarEvolution::getMulExpr(SmallVectorImpl<const SCEV *> &Ops,
bool HasNUW, bool HasNSW) {
assert(!Ops.empty() && "Cannot get empty mul!");
+ if (Ops.size() == 1) return Ops[0];
#ifndef NDEBUG
for (unsigned i = 1, e = Ops.size(); i != e; ++i)
assert(getEffectiveSCEVType(Ops[i]->getType()) ==
@@ -1542,6 +1566,17 @@
"SCEVMulExpr operand types don't match!");
#endif
+ // If HasNSW is true and all the operands are non-negative, infer HasNUW.
+ if (!HasNUW && HasNSW) {
+ bool All = true;
+ for (unsigned i = 0, e = Ops.size(); i != e; ++i)
+ if (!isKnownNonNegative(Ops[i])) {
+ All = false;
+ break;
+ }
+ if (All) HasNUW = true;
+ }
+
// Sort by complexity, this groups all similar expression types together.
GroupByComplexity(Ops, LI);
@@ -1576,6 +1611,22 @@
} else if (cast<SCEVConstant>(Ops[0])->getValue()->isZero()) {
// If we have a multiply of zero, it will always be zero.
return Ops[0];
+ } else if (Ops[0]->isAllOnesValue()) {
+ // If we have a mul by -1 of an add, try distributing the -1 among the
+ // add operands.
+ if (Ops.size() == 2)
+ if (const SCEVAddExpr *Add = dyn_cast<SCEVAddExpr>(Ops[1])) {
+ SmallVector<const SCEV *, 4> NewOps;
+ bool AnyFolded = false;
+ for (SCEVAddRecExpr::op_iterator I = Add->op_begin(), E = Add->op_end();
+ I != E; ++I) {
+ const SCEV *Mul = getMulExpr(Ops[0], *I);
+ if (!isa<SCEVMulExpr>(Mul)) AnyFolded = true;
+ NewOps.push_back(Mul);
+ }
+ if (AnyFolded)
+ return getAddExpr(NewOps);
+ }
}
}
@@ -1642,7 +1693,9 @@
// It's tempting to propagate the NSW flag here, but nsw multiplication
// is not associative so this isn't necessarily safe.
- const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop());
+ const SCEV *NewRec = getAddRecExpr(NewOps, AddRec->getLoop(),
+ HasNUW && AddRec->hasNoUnsignedWrap(),
+ /*HasNSW=*/false);
// If all of the other operands were loop invariant, we are done.
if (Ops.size() == 1) return NewRec;
@@ -1696,10 +1749,13 @@
for (unsigned i = 0, e = Ops.size(); i != e; ++i)
ID.AddPointer(Ops[i]);
void *IP = 0;
- if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
- SCEVMulExpr *S = SCEVAllocator.Allocate<SCEVMulExpr>();
- new (S) SCEVMulExpr(ID, Ops);
- UniqueSCEVs.InsertNode(S, IP);
+ SCEVMulExpr *S =
+ static_cast<SCEVMulExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ S = SCEVAllocator.Allocate<SCEVMulExpr>();
+ new (S) SCEVMulExpr(ID, Ops);
+ UniqueSCEVs.InsertNode(S, IP);
+ }
if (HasNUW) S->setHasNoUnsignedWrap(true);
if (HasNSW) S->setHasNoSignedWrap(true);
return S;
@@ -1842,10 +1898,24 @@
return getAddRecExpr(Operands, L, HasNUW, HasNSW); // {X,+,0} --> X
}
+ // If HasNSW is true and all the operands are non-negative, infer HasNUW.
+ if (!HasNUW && HasNSW) {
+ bool All = true;
+ for (unsigned i = 0, e = Operands.size(); i != e; ++i)
+ if (!isKnownNonNegative(Operands[i])) {
+ All = false;
+ break;
+ }
+ if (All) HasNUW = true;
+ }
+
// Canonicalize nested AddRecs in by nesting them in order of loop depth.
if (const SCEVAddRecExpr *NestedAR = dyn_cast<SCEVAddRecExpr>(Operands[0])) {
const Loop *NestedLoop = NestedAR->getLoop();
- if (L->getLoopDepth() < NestedLoop->getLoopDepth()) {
+ if (L->contains(NestedLoop->getHeader()) ?
+ (L->getLoopDepth() < NestedLoop->getLoopDepth()) :
+ (!NestedLoop->contains(L->getHeader()) &&
+ DT->dominates(L->getHeader(), NestedLoop->getHeader()))) {
SmallVector<const SCEV *, 4> NestedOperands(NestedAR->op_begin(),
NestedAR->op_end());
Operands[0] = NestedAR->getStart();
@@ -1884,10 +1954,13 @@
ID.AddPointer(Operands[i]);
ID.AddPointer(L);
void *IP = 0;
- if (const SCEV *S = UniqueSCEVs.FindNodeOrInsertPos(ID, IP)) return S;
- SCEVAddRecExpr *S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
- new (S) SCEVAddRecExpr(ID, Operands, L);
- UniqueSCEVs.InsertNode(S, IP);
+ SCEVAddRecExpr *S =
+ static_cast<SCEVAddRecExpr *>(UniqueSCEVs.FindNodeOrInsertPos(ID, IP));
+ if (!S) {
+ S = SCEVAllocator.Allocate<SCEVAddRecExpr>();
+ new (S) SCEVAddRecExpr(ID, Operands, L);
+ UniqueSCEVs.InsertNode(S, IP);
+ }
if (HasNUW) S->setHasNoUnsignedWrap(true);
if (HasNSW) S->setHasNoSignedWrap(true);
return S;
@@ -2525,31 +2598,28 @@
if (Accum->isLoopInvariant(L) ||
(isa<SCEVAddRecExpr>(Accum) &&
cast<SCEVAddRecExpr>(Accum)->getLoop() == L)) {
+ bool HasNUW = false;
+ bool HasNSW = false;
+
+ // If the increment doesn't overflow, then neither the addrec nor
+ // the post-increment will overflow.
+ if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV)) {
+ if (OBO->hasNoUnsignedWrap())
+ HasNUW = true;
+ if (OBO->hasNoSignedWrap())
+ HasNSW = true;
+ }
+
const SCEV *StartVal =
getSCEV(PN->getIncomingValue(IncomingEdge));
- const SCEVAddRecExpr *PHISCEV =
- cast<SCEVAddRecExpr>(getAddRecExpr(StartVal, Accum, L));
+ const SCEV *PHISCEV =
+ getAddRecExpr(StartVal, Accum, L, HasNUW, HasNSW);
- // If the increment doesn't overflow, then neither the addrec nor the
- // post-increment will overflow.
- if (const AddOperator *OBO = dyn_cast<AddOperator>(BEValueV))
- if (OBO->getOperand(0) == PN &&
- getSCEV(OBO->getOperand(1)) ==
- PHISCEV->getStepRecurrence(*this)) {
- const SCEVAddRecExpr *PostInc = PHISCEV->getPostIncExpr(*this);
- if (OBO->hasNoUnsignedWrap()) {
- const_cast<SCEVAddRecExpr *>(PHISCEV)
- ->setHasNoUnsignedWrap(true);
- const_cast<SCEVAddRecExpr *>(PostInc)
- ->setHasNoUnsignedWrap(true);
- }
- if (OBO->hasNoSignedWrap()) {
- const_cast<SCEVAddRecExpr *>(PHISCEV)
- ->setHasNoSignedWrap(true);
- const_cast<SCEVAddRecExpr *>(PostInc)
- ->setHasNoSignedWrap(true);
- }
- }
+ // Since the no-wrap flags are on the increment, they apply to the
+ // post-incremented value as well.
+ if (Accum->isLoopInvariant(L))
+ (void)getAddRecExpr(getAddExpr(StartVal, Accum),
+ Accum, L, HasNUW, HasNSW);
// Okay, for the entire analysis of this edge we assumed the PHI
// to be symbolic. We now need to go back and purge all of the
@@ -2781,26 +2851,29 @@
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
const SCEV *T = getBackedgeTakenCount(AddRec->getLoop());
const SCEVConstant *Trip = dyn_cast<SCEVConstant>(T);
- if (!Trip) return FullSet;
+ ConstantRange ConservativeResult = FullSet;
+
+ // If there's no unsigned wrap, the value will never be less than its
+ // initial value.
+ if (AddRec->hasNoUnsignedWrap())
+ if (const SCEVConstant *C = dyn_cast<SCEVConstant>(AddRec->getStart()))
+ ConservativeResult =
+ ConstantRange(C->getValue()->getValue(),
+ APInt(getTypeSizeInBits(C->getType()), 0));
// TODO: non-affine addrec
- if (AddRec->isAffine()) {
+ if (Trip && AddRec->isAffine()) {
const Type *Ty = AddRec->getType();
const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) {
MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
const SCEV *Start = AddRec->getStart();
- const SCEV *Step = AddRec->getStepRecurrence(*this);
const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
// Check for overflow.
- // TODO: This is very conservative.
- if (!(Step->isOne() &&
- isKnownPredicate(ICmpInst::ICMP_ULT, Start, End)) &&
- !(Step->isAllOnesValue() &&
- isKnownPredicate(ICmpInst::ICMP_UGT, Start, End)))
- return FullSet;
+ if (!AddRec->hasNoUnsignedWrap())
+ return ConservativeResult;
ConstantRange StartRange = getUnsignedRange(Start);
ConstantRange EndRange = getUnsignedRange(End);
@@ -2809,10 +2882,12 @@
APInt Max = APIntOps::umax(StartRange.getUnsignedMax(),
EndRange.getUnsignedMax());
if (Min.isMinValue() && Max.isMaxValue())
- return FullSet;
+ return ConservativeResult;
return ConstantRange(Min, Max+1);
}
}
+
+ return ConservativeResult;
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
@@ -2891,26 +2966,39 @@
if (const SCEVAddRecExpr *AddRec = dyn_cast<SCEVAddRecExpr>(S)) {
const SCEV *T = getBackedgeTakenCount(AddRec->getLoop());
const SCEVConstant *Trip = dyn_cast<SCEVConstant>(T);
- if (!Trip) return FullSet;
+ ConstantRange ConservativeResult = FullSet;
+
+ // If there's no signed wrap, and all the operands have the same sign or
+ // zero, the value won't ever change sign.
+ if (AddRec->hasNoSignedWrap()) {
+ bool AllNonNeg = true;
+ bool AllNonPos = true;
+ for (unsigned i = 0, e = AddRec->getNumOperands(); i != e; ++i) {
+ if (!isKnownNonNegative(AddRec->getOperand(i))) AllNonNeg = false;
+ if (!isKnownNonPositive(AddRec->getOperand(i))) AllNonPos = false;
+ }
+ unsigned BitWidth = getTypeSizeInBits(AddRec->getType());
+ if (AllNonNeg)
+ ConservativeResult = ConstantRange(APInt(BitWidth, 0),
+ APInt::getSignedMinValue(BitWidth));
+ else if (AllNonPos)
+ ConservativeResult = ConstantRange(APInt::getSignedMinValue(BitWidth),
+ APInt(BitWidth, 1));
+ }
// TODO: non-affine addrec
- if (AddRec->isAffine()) {
+ if (Trip && AddRec->isAffine()) {
const Type *Ty = AddRec->getType();
const SCEV *MaxBECount = getMaxBackedgeTakenCount(AddRec->getLoop());
if (getTypeSizeInBits(MaxBECount->getType()) <= getTypeSizeInBits(Ty)) {
MaxBECount = getNoopOrZeroExtend(MaxBECount, Ty);
const SCEV *Start = AddRec->getStart();
- const SCEV *Step = AddRec->getStepRecurrence(*this);
const SCEV *End = AddRec->evaluateAtIteration(MaxBECount, *this);
// Check for overflow.
- // TODO: This is very conservative.
- if (!(Step->isOne() &&
- isKnownPredicate(ICmpInst::ICMP_SLT, Start, End)) &&
- !(Step->isAllOnesValue() &&
- isKnownPredicate(ICmpInst::ICMP_SGT, Start, End)))
- return FullSet;
+ if (!AddRec->hasNoSignedWrap())
+ return ConservativeResult;
ConstantRange StartRange = getSignedRange(Start);
ConstantRange EndRange = getSignedRange(End);
@@ -2919,15 +3007,19 @@
APInt Max = APIntOps::smax(StartRange.getSignedMax(),
EndRange.getSignedMax());
if (Min.isMinSignedValue() && Max.isMaxSignedValue())
- return FullSet;
+ return ConservativeResult;
return ConstantRange(Min, Max+1);
}
}
+
+ return ConservativeResult;
}
if (const SCEVUnknown *U = dyn_cast<SCEVUnknown>(S)) {
// For a SCEVUnknown, ask ValueTracking.
unsigned BitWidth = getTypeSizeInBits(U->getType());
+ if (!U->getValue()->getType()->isInteger() && !TD)
+ return FullSet;
unsigned NS = ComputeNumSignBits(U->getValue(), TD);
if (NS == 1)
return FullSet;