Separate more constant factors of parameters
So far we separated constant factors from multiplications, however,
only when they are at the outermost level of a parameter SCEV. Now,
we also separate constant factors from the parameter SCEV if the
outermost expression is a SCEVAddRecExpr. With the changes to the
SCEVAffinator we can now improve the extractConstantFactor(...)
function at will without worrying about any other code part. Thus,
if needed we can implement a more comprehensive
extractConstantFactor(...) function that will traverse the SCEV
instead of looking only at the outermost level.
Four test cases were affected. One did not change much and the other
three were simplified.
llvm-svn: 260859
diff --git a/polly/lib/Support/SCEVValidator.cpp b/polly/lib/Support/SCEVValidator.cpp
index e22d3df..c32b2ee 100644
--- a/polly/lib/Support/SCEVValidator.cpp
+++ b/polly/lib/Support/SCEVValidator.cpp
@@ -640,19 +640,35 @@
return Result.getParameters();
}
-std::pair<const SCEV *, const SCEV *>
+std::pair<const SCEVConstant *, const SCEV *>
extractConstantFactor(const SCEV *S, ScalarEvolution &SE) {
- const SCEV *LeftOver = SE.getConstant(S->getType(), 1);
- const SCEV *ConstPart = SE.getConstant(S->getType(), 1);
+ auto *LeftOver = SE.getConstant(S->getType(), 1);
+ auto *ConstPart = cast<SCEVConstant>(SE.getConstant(S->getType(), 1));
- const SCEVMulExpr *M = dyn_cast<SCEVMulExpr>(S);
- if (!M)
+ if (auto *Constant = dyn_cast<SCEVConstant>(S))
+ return std::make_pair(Constant, LeftOver);
+
+ auto *AddRec = dyn_cast<SCEVAddRecExpr>(S);
+ if (AddRec) {
+ auto *StartExpr = AddRec->getStart();
+ if (StartExpr->isZero()) {
+ auto StepPair = extractConstantFactor(AddRec->getStepRecurrence(SE), SE);
+ auto *LeftOverAddRec =
+ SE.getAddRecExpr(StartExpr, StepPair.second, AddRec->getLoop(),
+ AddRec->getNoWrapFlags());
+ return std::make_pair(StepPair.first, LeftOverAddRec);
+ }
+ return std::make_pair(ConstPart, S);
+ }
+
+ auto *Mul = dyn_cast<SCEVMulExpr>(S);
+ if (!Mul)
return std::make_pair(ConstPart, S);
- for (const SCEV *Op : M->operands())
+ for (auto *Op : Mul->operands())
if (isa<SCEVConstant>(Op))
- ConstPart = SE.getMulExpr(ConstPart, Op);
+ ConstPart = cast<SCEVConstant>(SE.getMulExpr(ConstPart, Op));
else
LeftOver = SE.getMulExpr(LeftOver, Op);