Complete AffineExprFlattener based simplification for floordiv/ceildiv.
- handle floordiv/ceildiv in AffineExprFlattener; update the simplification to
work even if mod/floordiv/ceildiv expressions appearing in the tree can't be eliminated.
- refactor the flattening / analysis to move it out of lib/Transforms/
- fix MutableAffineMap::isMultipleOf
- add AffineBinaryOpExpr:getAdd/getMul/... utility methods
PiperOrigin-RevId: 211540536
diff --git a/lib/Transforms/SimplifyAffineExpr.cpp b/lib/Transforms/SimplifyAffineExpr.cpp
index 3c72887..3abc63a 100644
--- a/lib/Transforms/SimplifyAffineExpr.cpp
+++ b/lib/Transforms/SimplifyAffineExpr.cpp
@@ -20,7 +20,6 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
-#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/Attributes.h"
#include "mlir/IR/StmtVisitor.h"
@@ -33,8 +32,8 @@
namespace {
-/// Simplify all affine expressions appearing in the operation statements of the
-/// MLFunction.
+/// Simplifies all affine expressions appearing in the operation statements of
+/// the MLFunction. This is mainly to test the simplifyAffineExpr method.
// TODO(someone): Gradually, extend this to all affine map references found in
// ML functions and CFG functions.
struct SimplifyAffineExpr : public FunctionPass {
@@ -46,125 +45,6 @@
void runOnCFGFunction(CFGFunction *f) {}
};
-// This class is used to flatten a pure affine expression into a sum of products
-// (w.r.t constants) when possible, and in that process accumulating
-// contributions for each dimensional and symbolic identifier together. Note
-// that an affine expression may not always be expressible that way due to the
-// preesnce of modulo, floordiv, and ceildiv expressions. A simplification that
-// this flattening naturally performs is to fold a modulo expression to a zero,
-// if possible. Two examples are below:
-//
-// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
-// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
-//
-// For modulo and floordiv expressions, an additional variable is introduced to
-// rewrite it as a sum of products (w.r.t constants). For example, for the
-// second example above, d0 % 4 is replaced by d0 - 4*q with q being introduced:
-// the expression simplifies to:
-// (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to zero.
-//
-// This is a linear time post order walk for an affine expression that attempts
-// the above simplifications through visit methods, with partial results being
-// stored in 'operandExprStack'. When a parent expr is visited, the flattened
-// expressions corresponding to its two operands would already be on the stack -
-// the parent expr looks at the two flattened expressions and combines the two.
-// It pops off the operand expressions and pushes the combined result (although
-// this is done in-place on its LHS operand expr. When the walk is completed,
-// the flattened form of the top-level expression would be left on the stack.
-//
-class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
-public:
- std::vector<SmallVector<int64_t, 32>> operandExprStack;
-
- // The layout of the flattened expressions is dimensions, symbols, locals,
- // and constant term.
- unsigned getNumCols() const { return numDims + numSymbols + numLocals + 1; }
-
- AffineExprFlattener(unsigned numDims, unsigned numSymbols)
- : numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
-
- void visitMulExpr(AffineBinaryOpExpr *expr) {
- assert(expr->isPureAffine());
- // Get the RHS constant.
- auto rhsConst = operandExprStack.back()[getNumCols() - 1];
- operandExprStack.pop_back();
- // Update the LHS in place instead of pop and push.
- auto &lhs = operandExprStack.back();
- for (unsigned i = 0, e = lhs.size(); i < e; i++) {
- lhs[i] *= rhsConst;
- }
- }
- void visitAddExpr(AffineBinaryOpExpr *expr) {
- const auto &rhs = operandExprStack.back();
- auto &lhs = operandExprStack[operandExprStack.size() - 2];
- assert(lhs.size() == rhs.size());
- // Update the LHS in place.
- for (unsigned i = 0; i < rhs.size(); i++) {
- lhs[i] += rhs[i];
- }
- // Pop off the RHS.
- operandExprStack.pop_back();
- }
- void visitModExpr(AffineBinaryOpExpr *expr) {
- assert(expr->isPureAffine());
- // This is a pure affine expr; the RHS is a constant.
- auto rhsConst = operandExprStack.back()[getNumCols() - 1];
- operandExprStack.pop_back();
- auto &lhs = operandExprStack.back();
- assert(rhsConst != 0 && "RHS constant can't be zero");
- unsigned i;
- for (i = 0; i < lhs.size(); i++)
- if (lhs[i] % rhsConst != 0)
- break;
- if (i == lhs.size()) {
- // The modulo expression here simplifies to zero.
- lhs.assign(lhs.size(), 0);
- return;
- }
- // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
- // q * expr2) where q is the existential quantifier introduced.
- addExistentialQuantifier();
- lhs = operandExprStack.back();
- lhs[numDims + numSymbols + numLocals - 1] = -rhsConst;
- }
- void visitConstantExpr(AffineConstantExpr *expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- eq[getNumCols() - 1] = expr->getValue();
- }
- void visitDimExpr(AffineDimExpr *expr) {
- SmallVector<int64_t, 32> eq(getNumCols(), 0);
- eq[expr->getPosition()] = 1;
- operandExprStack.push_back(eq);
- }
- void visitSymbolExpr(AffineSymbolExpr *expr) {
- SmallVector<int64_t, 32> eq(getNumCols(), 0);
- eq[numDims + expr->getPosition()] = 1;
- operandExprStack.push_back(eq);
- }
- void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
- // TODO(bondhugula): handle ceildiv as well; won't simplify further through
- // this analysis but will be handled (rest of the expr will simplify).
- report_fatal_error("ceildiv expr simplification not supported here");
- }
- void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
- // TODO(bondhugula): handle ceildiv as well; won't simplify further through
- // this analysis but will be handled (rest of the expr will simplify).
- report_fatal_error("floordiv expr simplification unimplemented");
- }
- // Add an existential quantifier (used to flatten a mod or a floordiv expr).
- void addExistentialQuantifier() {
- for (auto &subExpr : operandExprStack) {
- subExpr.insert(subExpr.begin() + numDims + numSymbols + numLocals, 0);
- }
- numLocals++;
- }
-
- unsigned numDims;
- unsigned numSymbols;
- unsigned numLocals;
-};
-
} // end anonymous namespace
FunctionPass *mlir::createSimplifyAffineExprPass() {
@@ -195,58 +75,3 @@
MapSimplifier v(f->getContext());
v.walkPostOrder(f);
}
-
-/// Get an affine expression from a flat ArrayRef. If there are local variables
-/// (existential quantifiers introduced during the flattening) that appear in
-/// the sum of products expression, we can't readily express it as an affine
-/// expression of dimension and symbol id's; return nullptr in such cases.
-static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
- unsigned numSymbols, MLIRContext *context) {
- // Check if any local variable has a non-zero coefficient.
- for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
- if (eq[j] != 0)
- return nullptr;
- }
-
- AffineExpr *expr = AffineConstantExpr::get(0, context);
- for (unsigned j = 0; j < numDims + numSymbols; j++) {
- if (eq[j] != 0) {
- AffineExpr *id =
- j < numDims
- ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
- : AffineSymbolExpr::get(j - numDims, context);
- expr = AffineBinaryOpExpr::get(
- AffineExpr::Kind::Add, expr,
- AffineBinaryOpExpr::get(AffineExpr::Kind::Mul,
- AffineConstantExpr::get(eq[j], context), id,
- context),
- context);
- }
- }
- unsigned constTerm = eq[eq.size() - 1];
- if (constTerm != 0)
- expr = AffineBinaryOpExpr::get(AffineExpr::Kind::Add, expr,
- AffineConstantExpr::get(constTerm, context),
- context);
- return expr;
-}
-
-// Simplify the result affine expressions of this map. The expressions have to
-// be pure for the simplification implemented.
-void MutableAffineMap::simplify() {
- // Simplify each of the results if possible.
- for (unsigned i = 0, e = getNumResults(); i < e; i++) {
- AffineExpr *result = getResult(i);
- if (!result->isPureAffine())
- continue;
-
- AffineExprFlattener flattener(numDims, numSymbols);
- flattener.walkPostOrder(result);
- const auto &flattenedExpr = flattener.operandExprStack.back();
- auto *expr = toAffineExpr(flattenedExpr, numDims, numSymbols, context);
- if (expr)
- results[i] = expr;
- flattener.operandExprStack.pop_back();
- assert(flattener.operandExprStack.empty());
- }
-}