[SCEV] Do not visit nodes twice in containsConstantSomewhere
This patch reworks the function that searches constants in Add and Mul SCEV expression
chains so that now it does not visit a node more than once, and also renames this function
for better correspondence between its implementation and semantics.
Differential Revision: https://reviews.llvm.org/D35931
llvm-svn: 309367
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 4d7b59c..ecd124a 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -2676,20 +2676,23 @@
/// Determine if any of the operands in this SCEV are a constant or if
/// any of the add or multiply expressions in this SCEV contain a constant.
-static bool containsConstantSomewhere(const SCEV *StartExpr) {
- SmallVector<const SCEV *, 4> Ops;
- Ops.push_back(StartExpr);
- while (!Ops.empty()) {
- const SCEV *CurrentExpr = Ops.pop_back_val();
- if (isa<SCEVConstant>(*CurrentExpr))
- return true;
+static bool containsConstantInAddMulChain(const SCEV *StartExpr) {
+ struct FindConstantInAddMulChain {
+ bool FoundConstant = false;
- if (isa<SCEVAddExpr>(*CurrentExpr) || isa<SCEVMulExpr>(*CurrentExpr)) {
- const auto *CurrentNAry = cast<SCEVNAryExpr>(CurrentExpr);
- Ops.append(CurrentNAry->op_begin(), CurrentNAry->op_end());
+ bool follow(const SCEV *S) {
+ FoundConstant |= isa<SCEVConstant>(S);
+ return isa<SCEVAddExpr>(S) || isa<SCEVMulExpr>(S);
}
- }
- return false;
+ bool isDone() const {
+ return FoundConstant;
+ }
+ };
+
+ FindConstantInAddMulChain F;
+ SCEVTraversal<FindConstantInAddMulChain> ST(F);
+ ST.visitAll(StartExpr);
+ return F.FoundConstant;
}
/// Get a canonical multiply expression, or something simpler if possible.
@@ -2726,7 +2729,11 @@
// If any of Add's ops are Adds or Muls with a constant,
// apply this transformation as well.
if (Add->getNumOperands() == 2)
- if (containsConstantSomewhere(Add))
+ // TODO: There are some cases where this transformation is not
+ // profitable, for example:
+ // Add = (C0 + X) * Y + Z.
+ // Maybe the scope of this transformation should be narrowed down.
+ if (containsConstantInAddMulChain(Add))
return getAddExpr(getMulExpr(LHSC, Add->getOperand(0),
SCEV::FlagAnyWrap, Depth + 1),
getMulExpr(LHSC, Add->getOperand(1),