[SCEV] Generalize folding of trunc(x)+n*trunc(y) into folding m*trunc(x)+n*trunc(y)
Summary:
A SCEV such as:
{%v2,+,((-1 * (trunc i64 (-1 * %v1) to i32)) + (-1 * (trunc i64 %v1 to i32)))}<%loop>
can be folded into, simply, {%v2,+,0}. However, the current code in ::getAddExpr()
will not try to apply the simplification m*trunc(x)+n*trunc(y) -> trunc(trunc(m)*x+trunc(n)*y)
because it only keys off having a non-multiplied trunc as the first term in the simplification.
This patch generalizes this code to try to do a more generic fold of these trunc
expressions.
Reviewers: sanjoy
Reviewed By: sanjoy
Subscribers: llvm-commits
Differential Revision: https://reviews.llvm.org/D37888
llvm-svn: 313988
diff --git a/llvm/lib/Analysis/ScalarEvolution.cpp b/llvm/lib/Analysis/ScalarEvolution.cpp
index 640d80a..974995b 100644
--- a/llvm/lib/Analysis/ScalarEvolution.cpp
+++ b/llvm/lib/Analysis/ScalarEvolution.cpp
@@ -2338,12 +2338,23 @@
// Check for truncates. If all the operands are truncated from the same
// type, see if factoring out the truncate would permit the result to be
- // folded. eg., trunc(x) + m*trunc(n) --> trunc(x + trunc(m)*n)
+ // folded. eg., n*trunc(x) + m*trunc(y) --> trunc(trunc(m)*x + trunc(n)*y)
// if the contents of the resulting outer trunc fold to something simple.
- for (; Idx < Ops.size() && isa<SCEVTruncateExpr>(Ops[Idx]); ++Idx) {
- const SCEVTruncateExpr *Trunc = cast<SCEVTruncateExpr>(Ops[Idx]);
- Type *DstType = Trunc->getType();
- Type *SrcType = Trunc->getOperand()->getType();
+ auto FindTruncSrcType = [&]() -> Type * {
+ // We're ultimately looking to fold an addrec of truncs and muls of only
+ // constants and truncs, so if we find any other types of SCEV
+ // as operands of the addrec then we bail and return nullptr here.
+ // Otherwise, we return the type of the operand of a trunc that we find.
+ if (auto *T = dyn_cast<SCEVTruncateExpr>(Ops[Idx]))
+ return T->getOperand()->getType();
+ if (const auto *Mul = dyn_cast<SCEVMulExpr>(Ops[Idx])) {
+ const auto *LastOp = Mul->getOperand(Mul->getNumOperands() - 1);
+ if (const auto *T = dyn_cast<SCEVTruncateExpr>(LastOp))
+ return T->getOperand()->getType();
+ }
+ return nullptr;
+ };
+ if (auto *SrcType = FindTruncSrcType()) {
SmallVector<const SCEV *, 8> LargeOps;
bool Ok = true;
// Check all the operands to see if they can be represented in the
@@ -2386,7 +2397,7 @@
const SCEV *Fold = getAddExpr(LargeOps, Flags, Depth + 1);
// If it folds to something simple, use it. Otherwise, don't.
if (isa<SCEVConstant>(Fold) || isa<SCEVUnknown>(Fold))
- return getTruncateExpr(Fold, DstType);
+ return getTruncateExpr(Fold, Ty);
}
}