Extend getConstantTripCount to deal with a larger subset of loop bounds; make loop
unroll/unroll-and-jam more powerful; add additional affine expr builder methods
- use previously added analysis/simplification to infer multiple of unroll
factor trip counts, making loop unroll/unroll-and-jam more general.
- for loop unroll, support bounds that are single result affine map's with the
same set of operands. For unknown loop bounds, loop unroll will now work as
long as trip count can be determined to be a multiple of unroll factor.
- extend getConstantTripCount to deal with single result affine map's with the
same operands. move it to mlir/Analysis/LoopAnalysis.cpp
- add additional builder utility methods for affine expr arithmetic
(difference, mod/floordiv/ceildiv w.r.t postitive constant). simplify code to
use the utility methods.
- move affine analysis routines to AffineAnalysis.cpp/.h from
AffineStructures.cpp/.h.
- Rename LoopUnrollJam to LoopUnrollAndJam to match class name.
- add an additional simplification for simplifyFloorDiv, simplifyCeilDiv
- Rename AffineMap::getNumOperands() getNumInputs: an affine map by itself does
not have operands. Operands are passed to it through affine_apply, from loop
bounds/if condition's, etc., operands are stored in the latter.
This should be sufficiently powerful for now as far as unroll/unroll-and-jam go for TPU
code generation, and can move to other analyses/transformations.
Loop nests like these are now unrolled without any cleanup loop being generated.
for %i = 1 to 100 {
// unroll factor 4: no cleanup loop will be generated.
for %j = (d0) -> (d0) (%i) to (d0) -> (5*d0 + 3) (%i) {
%x = "foo"(%j) : (affineint) -> i32
}
}
for %i = 1 to 100 {
// unroll factor 4: no cleanup loop will be generated.
for %j = (d0) -> (d0) (%i) to (d0) -> (d0 - d mod 4 - 1) (%i) {
%y = "foo"(%j) : (affineint) -> i32
}
}
for %i = 1 to 100 {
for %j = (d0) -> (d0) (%i) to (d0) -> (d0 + 128) (%i) {
%x = "foo"() : () -> i32
}
}
TODO(bondhugula): extend this to LoopUnrollAndJam as well in the next CL (with minor
changes).
PiperOrigin-RevId: 212661212
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index bff76a8..6965137 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -20,291 +20,13 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
-
-#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/Analysis/AffineAnalysis.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/StandardOps.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
-/// Constructs an affine expression from a flat ArrayRef. If there are local
-/// identifiers (neither dimensional nor symbolic) that appear in the sum of
-/// products expression, 'localExprs' is expected to have the AffineExpr for it,
-/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
-/// [dims, symbols, locals, constant term].
-static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
- unsigned numSymbols,
- ArrayRef<AffineExpr *> localExprs,
- MLIRContext *context) {
- // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
- assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
- "unexpected number of local expressions");
-
- AffineExpr *expr = AffineConstantExpr::get(0, context);
- // Dimensions and symbols.
- for (unsigned j = 0; j < numDims + numSymbols; j++) {
- if (eq[j] != 0) {
- AffineExpr *id =
- j < numDims
- ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
- : AffineSymbolExpr::get(j - numDims, context);
- auto *term = AffineBinaryOpExpr::getMul(
- AffineConstantExpr::get(eq[j], context), id, context);
- expr = AffineBinaryOpExpr::getAdd(expr, term, context);
- }
- }
-
- // Local identifiers.
- for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
- if (eq[j] != 0) {
- auto *term = AffineBinaryOpExpr::getMul(
- AffineConstantExpr::get(eq[j], context),
- localExprs[j - numDims - numSymbols], context);
- expr = AffineBinaryOpExpr::getAdd(expr, term, context);
- }
- }
-
- // Constant term.
- unsigned constTerm = eq[eq.size() - 1];
- if (constTerm != 0)
- expr = AffineBinaryOpExpr::getAdd(
- expr, AffineConstantExpr::get(constTerm, context), context);
- return expr;
-}
-
-namespace {
-
-// This class is used to flatten a pure affine expression (AffineExpr *, which
-// is in a tree form) into a sum of products (w.r.t constants) when possible,
-// and in that process simplifying the expression. The simplification performed
-// includes the accumulation of contributions for each dimensional and symbolic
-// identifier together, the simplification of floordiv/ceildiv/mod exprssions
-// and other simplifications that in turn happen as a result. A simplification
-// that this flattening naturally performs is of simplifying the numerator and
-// denominator of floordiv/ceildiv, and folding a modulo expression to a zero,
-// if possible. Three examples are below:
-//
-// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
-// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
-// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
-//
-// For a modulo, floordiv, or a ceildiv expression, an additional identifier
-// (called a local identifier) is introduced to rewrite it as a sum of products
-// (w.r.t constants). For example, for the second example above, d0 % 4 is
-// replaced by d0 - 4*q with q being introduced: the expression then simplifies
-// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
-// zero. Note that an affine expression may not always be expressible in a sum
-// of products form due to the presence of modulo/floordiv/ceildiv expressions
-// that may not be eliminated after simplification; in such cases, the final
-// expression can be reconstructed by replacing the local identifier with its
-// explicit form stored in localExprs (note that the explicit form itself would
-// have been simplified and not necessarily the original form).
-//
-// This is a linear time post order walk for an affine expression that attempts
-// the above simplifications through visit methods, with partial results being
-// stored in 'operandExprStack'. When a parent expr is visited, the flattened
-// expressions corresponding to its two operands would already be on the stack -
-// the parent expr looks at the two flattened expressions and combines the two.
-// It pops off the operand expressions and pushes the combined result (although
-// this is done in-place on its LHS operand expr. When the walk is completed,
-// the flattened form of the top-level expression would be left on the stack.
-//
-class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
-public:
- // Flattend expression layout: [dims, symbols, locals, constant]
- // Stack that holds the LHS and RHS operands while visiting a binary op expr.
- // In future, consider adding a prepass to determine how big the SmallVector's
- // will be, and linearize this to std::vector<int64_t> to prevent
- // SmallVector moves on re-allocation.
- std::vector<SmallVector<int64_t, 32>> operandExprStack;
-
- inline unsigned getNumCols() const {
- return numDims + numSymbols + numLocals + 1;
- }
-
- unsigned numDims;
- unsigned numSymbols;
- // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv
- // expressions that could not be simplified.
- unsigned numLocals;
- // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
- // which new identifiers were introduced; if the latter do not get canceled
- // out, these expressions are needed to reconstruct the AffineExpr * / tree
- // form. Note that these expressions themselves would have been simplified
- // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
- // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
- // would be the local expression stored for q.
- SmallVector<AffineExpr *, 4> localExprs;
- MLIRContext *context;
-
- AffineExprFlattener(unsigned numDims, unsigned numSymbols,
- MLIRContext *context)
- : numDims(numDims), numSymbols(numSymbols), numLocals(0),
- context(context) {
- operandExprStack.reserve(8);
- }
-
- void visitMulExpr(AffineBinaryOpExpr *expr) {
- assert(operandExprStack.size() >= 2);
- // This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS()));
- // Get the RHS constant.
- auto rhsConst = operandExprStack.back()[getConstantIndex()];
- operandExprStack.pop_back();
- // Update the LHS in place instead of pop and push.
- auto &lhs = operandExprStack.back();
- for (unsigned i = 0, e = lhs.size(); i < e; i++) {
- lhs[i] *= rhsConst;
- }
- }
-
- void visitAddExpr(AffineBinaryOpExpr *expr) {
- assert(operandExprStack.size() >= 2);
- const auto &rhs = operandExprStack.back();
- auto &lhs = operandExprStack[operandExprStack.size() - 2];
- assert(lhs.size() == rhs.size());
- // Update the LHS in place.
- for (unsigned i = 0; i < rhs.size(); i++) {
- lhs[i] += rhs[i];
- }
- // Pop off the RHS.
- operandExprStack.pop_back();
- }
-
- void visitModExpr(AffineBinaryOpExpr *expr) {
- assert(operandExprStack.size() >= 2);
- // This is a pure affine expr; the RHS will be a constant.
- assert(isa<AffineConstantExpr>(expr->getRHS()));
- auto rhsConst = operandExprStack.back()[getConstantIndex()];
- operandExprStack.pop_back();
- auto &lhs = operandExprStack.back();
- // TODO(bondhugula): handle modulo by zero case when this issue is fixed
- // at the other places in the IR.
- assert(rhsConst != 0 && "RHS constant can't be zero");
-
- // Check if the LHS expression is a multiple of modulo factor.
- unsigned i;
- for (i = 0; i < lhs.size(); i++)
- if (lhs[i] % rhsConst != 0)
- break;
- // If yes, modulo expression here simplifies to zero.
- if (i == lhs.size()) {
- lhs.assign(lhs.size(), 0);
- return;
- }
-
- // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
- // q * expr2) where q is the existential quantifier introduced.
- addLocalId(AffineBinaryOpExpr::get(
- AffineExpr::Kind::FloorDiv,
- toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
- AffineConstantExpr::get(rhsConst, context), context));
- lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
- }
- void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
- visitDivExpr(expr, /*isCeil=*/true);
- }
- void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
- visitDivExpr(expr, /*isCeil=*/false);
- }
- void visitDimExpr(AffineDimExpr *expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- eq[getDimStartIndex() + expr->getPosition()] = 1;
- }
- void visitSymbolExpr(AffineSymbolExpr *expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- eq[getSymbolStartIndex() + expr->getPosition()] = 1;
- }
- void visitConstantExpr(AffineConstantExpr *expr) {
- operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
- auto &eq = operandExprStack.back();
- eq[getConstantIndex()] = expr->getValue();
- }
-
-private:
- void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
- assert(operandExprStack.size() >= 2);
- assert(isa<AffineConstantExpr>(expr->getRHS()));
- // This is a pure affine expr; the RHS is a positive constant.
- auto rhsConst = operandExprStack.back()[getConstantIndex()];
- // TODO(bondhugula): handle division by zero at the same time the issue is
- // fixed at other places.
- assert(rhsConst != 0 && "RHS constant can't be zero");
- operandExprStack.pop_back();
- auto &lhs = operandExprStack.back();
-
- // Simplify the floordiv, ceildiv if possible by canceling out the greatest
- // common divisors of the numerator and denominator.
- uint64_t gcd = std::abs(rhsConst);
- for (unsigned i = 0; i < lhs.size(); i++)
- gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
- // Simplify the numerator and the denominator.
- if (gcd != 1) {
- for (unsigned i = 0; i < lhs.size(); i++)
- lhs[i] = lhs[i] / gcd;
- }
- int64_t denominator = rhsConst / gcd;
- // If the denominator becomes 1, the updated LHS is the result. (The
- // denominator can't be negative since rhsConst is positive).
- if (denominator == 1)
- return;
-
- // If the denominator cannot be simplified to one, we will have to retain
- // the ceil/floor expr (simplified up until here). Add an existential
- // quantifier to express its result, i.e., expr1 div expr2 is replaced
- // by a new identifier, q.
- auto divKind =
- isCeil ? AffineExpr::Kind::CeilDiv : AffineExpr::Kind::FloorDiv;
- addLocalId(AffineBinaryOpExpr::get(
- divKind, toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
- AffineConstantExpr::get(denominator, context), context));
- lhs.assign(lhs.size(), 0);
- lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
- }
-
- // Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
- // expr). localExpr is the simplified tree expression (AffineExpr *)
- // corresponding to the quantifier.
- void addLocalId(AffineExpr *localExpr) {
- for (auto &subExpr : operandExprStack) {
- subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
- }
- localExprs.push_back(localExpr);
- numLocals++;
- }
-
- inline unsigned getConstantIndex() const { return getNumCols() - 1; }
- inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
- inline unsigned getSymbolStartIndex() const { return numDims; }
- inline unsigned getDimStartIndex() const { return 0; }
-};
-
-} // end anonymous namespace
-
-AffineExpr *mlir::simplifyAffineExpr(AffineExpr *expr, unsigned numDims,
- unsigned numSymbols,
- MLIRContext *context) {
- // TODO(bondhugula): only pure affine for now. The simplification here can be
- // extended to semi-affine maps as well.
- if (!expr->isPureAffine())
- return nullptr;
-
- AffineExprFlattener flattener(numDims, numSymbols, context);
- flattener.walkPostOrder(expr);
- ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
- auto *simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
- flattener.localExprs, context);
- flattener.operandExprStack.pop_back();
- assert(flattener.operandExprStack.empty());
- if (simplifiedExpr == expr)
- return nullptr;
- return simplifiedExpr;
-}
-
MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
: numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
context(context) {