[SCEV] Generalize folding of trunc(x)+n*trunc(y) into folding m*trunc(x)+n*trunc(y)
Summary:
A SCEV such as:
{%v2,+,((-1 * (trunc i64 (-1 * %v1) to i32)) + (-1 * (trunc i64 %v1 to i32)))}<%loop>
can be folded into, simply, {%v2,+,0}. However, the current code in ::getAddExpr()
will not try to apply the simplification m*trunc(x)+n*trunc(y) -> trunc(trunc(m)*x+trunc(n)*y)
because it only keys off having a non-multiplied trunc as the first term in the simplification.
This patch generalizes this code to try to do a more generic fold of these trunc
expressions.
Reviewers: sanjoy
Reviewed By: sanjoy
Subscribers: llvm-commits
Differential Revision: https://reviews.llvm.org/D37888
llvm-svn: 313988
diff --git a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
index f1a5a1a..4bede8b 100644
--- a/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
+++ b/llvm/unittests/Analysis/ScalarEvolutionTest.cpp
@@ -1009,5 +1009,37 @@
auto Result = SE.createAddRecFromPHIWithCasts(cast<SCEVUnknown>(Expr));
}
+TEST_F(ScalarEvolutionsTest, SCEVFoldSumOfTruncs) {
+ // Verify that the following SCEV gets folded to a zero:
+ // (-1 * (trunc i64 (-1 * %0) to i32)) + (-1 * (trunc i64 %0 to i32)
+ Type *ArgTy = Type::getInt64Ty(Context);
+ Type *Int32Ty = Type::getInt32Ty(Context);
+ SmallVector<Type *, 1> Types;
+ Types.push_back(ArgTy);
+ FunctionType *FTy = FunctionType::get(Type::getVoidTy(Context), Types, false);
+ Function *F = cast<Function>(M.getOrInsertFunction("f", FTy));
+ BasicBlock *BB = BasicBlock::Create(Context, "entry", F);
+ ReturnInst::Create(Context, nullptr, BB);
+
+ ScalarEvolution SE = buildSE(*F);
+
+ auto *Arg = &*(F->arg_begin());
+ const auto *ArgSCEV = SE.getSCEV(Arg);
+
+ // Build the SCEV
+ const auto *A0 = SE.getNegativeSCEV(ArgSCEV);
+ const auto *A1 = SE.getTruncateExpr(A0, Int32Ty);
+ const auto *A = SE.getNegativeSCEV(A1);
+
+ const auto *B0 = SE.getTruncateExpr(ArgSCEV, Int32Ty);
+ const auto *B = SE.getNegativeSCEV(B0);
+
+ const auto *Expr = SE.getAddExpr(A, B);
+ dbgs() << "DDN\nExpr: " << *Expr << "\n";
+ // Verify that the SCEV was folded to 0
+ const auto *ZeroConst = SE.getConstant(Int32Ty, 0);
+ EXPECT_EQ(Expr, ZeroConst);
+}
+
} // end anonymous namespace
} // end namespace llvm