Affine expression analysis and simplification.
Outside of IR/
- simplify a MutableAffineMap by flattening the affine expressions
- add a simplify affine expression pass that uses this analysis
- update the FlatAffineConstraints API (to be used in the next CL)
In IR:
- add isMultipleOf and getKnownGCD for AffineExpr, and make the in-IR
simplication of simplifyMod simpler and more powerful.
- rename the AffineExpr visitor methods to distinguish b/w visiting and
walking, and to simplify API names based on context.
The next CL will use some of these for the loop unrolling/unroll-jam to make
the detection for the need of cleanup loop powerful/non-trivial.
A future CL will finally move this simplification to FlatAffineConstraints to
make it more powerful. For eg., currently, even if a mod expr appearing in a
part of the expression tree can't be simplified, the whole thing won't be
simplified.
PiperOrigin-RevId: 211012256
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 6d28260..ac6e728 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -41,23 +41,41 @@
/// A mutable affine map. Its affine expressions are however unique.
struct MutableAffineMap {
public:
- explicit MutableAffineMap(AffineMap *map);
+ MutableAffineMap(AffineMap *map, MLIRContext *context);
AffineExpr *getResult(unsigned idx) const { return results[idx]; }
unsigned getNumResults() const { return results.size(); }
+ unsigned getNumDims() const { return numDims; }
+ unsigned getNumSymbols() const { return numSymbols; }
+ /// Returns true if the idx'th result expression is a multiple of factor.
+ bool isMultipleOf(unsigned idx, int64_t factor) const;
+
+ /// Simplify the (result) expressions in this map using analysis (used by
+ //-simplify-affine-expr pass).
+ void simplify();
+ /// Get the AffineMap corresponding to this MutableAffineMap. Note that an
+ /// AffineMap * will be uniqued and stored in context, while a mutable one
+ /// isn't.
+ AffineMap *getAffineMap();
private:
+ // Same meaning as AffineMap's fields.
SmallVector<AffineExpr *, 8> results;
SmallVector<AffineExpr *, 8> rangeSizes;
+ unsigned numDims;
+ unsigned numSymbols;
+ /// A pointer to the IR's context to store all newly created AffineExpr's.
+ MLIRContext *context;
};
/// A mutable integer set. Its affine expressions are however unique.
struct MutableIntegerSet {
public:
- explicit MutableIntegerSet(IntegerSet *set);
+ MutableIntegerSet(IntegerSet *set, MLIRContext *context);
/// Create a universal set (no constraints).
- explicit MutableIntegerSet(unsigned numDims, unsigned numSymbols);
+ MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+ MLIRContext *context);
unsigned getNumDims() const { return numDims; }
unsigned getNumSymbols() const { return numSymbols; }
@@ -74,6 +92,8 @@
SmallVector<AffineExpr *, 8> constraints;
SmallVector<bool, 8> eqFlags;
+ /// A pointer to the IR's context to store all newly created AffineExpr's.
+ MLIRContext *context;
};
/// An AffineValueMap is an affine map plus its ML value operands and
@@ -89,9 +109,9 @@
// TODO(bondhugula): Some of these classes could go into separate files.
class AffineValueMap {
public:
- explicit AffineValueMap(const AffineApplyOp &op);
- explicit AffineValueMap(const AffineBound &bound);
- explicit AffineValueMap(AffineMap *map);
+ AffineValueMap(const AffineApplyOp &op, MLIRContext *context);
+ AffineValueMap(const AffineBound &bound, MLIRContext *context);
+ AffineValueMap(AffineMap *map, MLIRContext *context);
~AffineValueMap();
@@ -110,7 +130,7 @@
/// Return true if the idx^th result can be proved to be a multiple of
/// 'factor', false otherwise.
- bool isMultipleOf(unsigned idx, int64_t factor) const;
+ inline bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Return true if the result at 'idx' is a constant, false
/// otherwise.
@@ -128,8 +148,6 @@
SmallVector<MLValue *, 4> operands;
/// The SSA results binding to the results of 'map'.
SmallVector<MLValue *, 4> results;
- /// A pointer to the IR's context to store all newly created AffineExpr's.
- MLIRContext *context;
};
/// An IntegerValueSet is an integer set plus its operands.
@@ -155,8 +173,6 @@
MutableIntegerSet set;
/// The SSA operands binding to the dim's and symbols of 'set'.
SmallVector<MLValue *, 4> operands;
- /// A pointer to the IR's context to store all newly created AffineExpr's.
- MLIRContext *context;
};
/// A flat list of affine equalities and inequalities in the form.
@@ -190,8 +206,7 @@
/// constraints and identifiers..
FlatAffineConstraints(unsigned numReservedInequalities,
unsigned numReservedEqualities, unsigned numReservedIds)
- : numEqualities(0), numInequalities(0),
- numReservedEqualities(numReservedEqualities),
+ : numReservedEqualities(numReservedEqualities),
numReservedInequalities(numReservedInequalities),
numReservedIds(numReservedIds) {
equalities.reserve(numReservedIds * numReservedEqualities);
@@ -208,23 +223,50 @@
/// Create an affine constraint system from an IntegerValueSet.
// TODO(bondhugula)
- FlatAffineConstraints(const IntegerValueSet &set);
+ explicit FlatAffineConstraints(const IntegerValueSet &set);
FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
const IntegerSet &set);
+ FlatAffineConstraints(const MutableAffineMap &map);
+
~FlatAffineConstraints() {}
- inline int64_t atEq(unsigned i, unsigned j) {
- return equalities[i * numIds + j];
+ inline int64_t atEq(unsigned i, unsigned j) const {
+ return equalities[i * (numIds + 1) + j];
}
- inline int64_t atIneq(unsigned i, unsigned j) {
- return inequalities[i * numIds + j];
+ inline int64_t &atEq(unsigned i, unsigned j) {
+ return equalities[i * (numIds + 1) + j];
}
- unsigned getNumEqualities() const { return equalities.size(); }
- unsigned getNumInequalities() const { return inequalities.size(); }
+ inline int64_t atIneq(unsigned i, unsigned j) const {
+ return inequalities[i * (numIds + 1) + j];
+ }
+
+ inline int64_t &atIneq(unsigned i, unsigned j) {
+ return inequalities[i * (numIds + 1) + j];
+ }
+
+ inline unsigned getNumCols() const { return numIds + 1; }
+
+ inline unsigned getNumEqualities() const {
+ return equalities.size() / getNumCols();
+ }
+
+ inline unsigned getNumInequalities() const {
+ return inequalities.size() / getNumCols();
+ }
+
+ ArrayRef<int64_t> getEquality(unsigned idx) {
+ return ArrayRef<int64_t>(&equalities[idx * getNumCols()], getNumCols());
+ }
+
+ ArrayRef<int64_t> getInequality(unsigned idx) {
+ return ArrayRef<int64_t>(&inequalities[idx * getNumCols()], getNumCols());
+ }
+
+ AffineExpr *toAffineExpr(unsigned idx, MLIRContext *context);
void addInequality(ArrayRef<int64_t> inEq);
void addEquality(ArrayRef<int64_t> eq);
@@ -239,11 +281,19 @@
void removeEquality(unsigned pos);
void removeInequality(unsigned pos);
- unsigned getNumConstraints() const { return numEqualities + numInequalities; }
- unsigned getNumIds() const { return numIds; }
- unsigned getNumDimIds() const { return numDims; }
- unsigned getNumSymbolIds() const { return numSymbols; }
- unsigned getNumLocalIds() const { return numDims - numSymbols - numDims; }
+ unsigned getNumConstraints() const {
+ return equalities.size() + inequalities.size();
+ }
+ inline unsigned getNumIds() const { return numIds; }
+ inline unsigned getNumResultDimIds() const { return numResultDims; }
+ inline unsigned getNumDimIds() const { return numDims; }
+ inline unsigned getNumSymbolIds() const { return numSymbols; }
+ inline unsigned getNumLocalIds() const {
+ return numIds - numResultDims - numDims - numSymbols;
+ }
+
+ void print(raw_ostream &os) const;
+ void dump() const;
private:
/// Coefficients of affine equalities (in == 0 form).
@@ -252,12 +302,6 @@
/// Coefficients of affine inequalities (in >= 0 form).
SmallVector<int64_t, 64> inequalities;
- /// Number of equalities in this system.
- unsigned numEqualities;
-
- /// Number of inequalities in this system.
- unsigned numInequalities;
-
// Pre-allocated space.
unsigned numReservedEqualities;
unsigned numReservedInequalities;
@@ -267,6 +311,9 @@
unsigned numIds;
/// Number of identifiers corresponding to real dimensions.
+ unsigned numResultDims;
+
+ /// Number of identifiers corresponding to real dimensions.
unsigned numDims;
/// Number of identifiers corresponding to symbols (unknown but constant for
diff --git a/include/mlir/Analysis/HyperRectangularSet.h b/include/mlir/Analysis/HyperRectangularSet.h
index ba28b3e..8956aeb 100644
--- a/include/mlir/Analysis/HyperRectangularSet.h
+++ b/include/mlir/Analysis/HyperRectangularSet.h
@@ -95,6 +95,7 @@
HyperRectangularSet(unsigned numDims, unsigned numSymbols,
ArrayRef<ArrayRef<AffineExpr *>> lbs,
ArrayRef<ArrayRef<AffineExpr *>> ubs,
+ MLIRContext *context,
IntegerSet *symbolContext = nullptr);
unsigned getNumDims() const { return numDims; }
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index 825783d..6ef62dc 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -70,6 +70,12 @@
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool isPureAffine() const;
+ /// Returns the greatest known common divisor of this affine expression.
+ uint64_t getKnownGcd() const;
+
+ /// Return true if the affine expression is a multiple of 'factor'.
+ bool isMultipleOf(int64_t factor) const;
+
protected:
explicit AffineExpr(Kind kind) : kind(kind) {}
~AffineExpr() {}
diff --git a/include/mlir/IR/AffineExprVisitor.h b/include/mlir/IR/AffineExprVisitor.h
index 7f0bc15..5a2753d 100644
--- a/include/mlir/IR/AffineExprVisitor.h
+++ b/include/mlir/IR/AffineExprVisitor.h
@@ -26,7 +26,7 @@
namespace mlir {
-/// Base class for AffineExpr visitors.
+/// Base class for AffineExpr visitors/walkers.
///
/// AffineExpr visitors are used when you want to perform different actions
/// for different kinds of AffineExprs without having to use lots of casts
@@ -72,54 +72,86 @@
/// just as efficient as having your own switch statement over the statement
/// opcode.
-template <typename SubClass> class AffineExprVisitor {
+template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
public:
- // Function to visit an AffineExpr.
- void visit(AffineExpr *expr) {
+ // Function to walk an AffineExpr (in post order).
+ RetTy walkPostOrder(AffineExpr *expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr->getKind()) {
case AffineExpr::Kind::Add: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryAddOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExpr::Kind::Mul: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryMulOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExpr::Kind::Mod: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryModOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExpr::Kind::FloorDiv: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryFloorDivOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExpr::Kind::CeilDiv: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryCeilDivOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExpr::Kind::Constant:
- return static_cast<SubClass *>(this)->visitAffineConstantExpr(
+ return static_cast<SubClass *>(this)->visitConstantExpr(
cast<AffineConstantExpr>(expr));
case AffineExpr::Kind::DimId:
- return static_cast<SubClass *>(this)->visitAffineDimExpr(
+ return static_cast<SubClass *>(this)->visitDimExpr(
cast<AffineDimExpr>(expr));
case AffineExpr::Kind::SymbolId:
- return static_cast<SubClass *>(this)->visitAffineSymbolExpr(
+ return static_cast<SubClass *>(this)->visitSymbolExpr(
+ cast<AffineSymbolExpr>(expr));
+ }
+ }
+
+ // Function to visit an AffineExpr.
+ RetTy visit(AffineExpr *expr) {
+ static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
+ "Must instantiate with a derived type of AffineExprVisitor");
+ switch (expr->getKind()) {
+ case AffineExpr::Kind::Add: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::Mul: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::Mod: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::FloorDiv: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::CeilDiv: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::Constant:
+ return static_cast<SubClass *>(this)->visitConstantExpr(
+ cast<AffineConstantExpr>(expr));
+ case AffineExpr::Kind::DimId:
+ return static_cast<SubClass *>(this)->visitDimExpr(
+ cast<AffineDimExpr>(expr));
+ case AffineExpr::Kind::SymbolId:
+ return static_cast<SubClass *>(this)->visitSymbolExpr(
cast<AffineSymbolExpr>(expr));
}
}
@@ -135,29 +167,30 @@
// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
void visitAffineBinaryOpExpr(AffineBinaryOpExpr *expr) {}
- void visitAffineBinaryAddOpExpr(AffineBinaryOpExpr *expr) {
+ void visitAddExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryMulOpExpr(AffineBinaryOpExpr *expr) {
+ void visitMulExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryModOpExpr(AffineBinaryOpExpr *expr) {
+ void visitModExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryFloorDivOpExpr(AffineBinaryOpExpr *expr) {
+ void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryCeilDivOpExpr(AffineBinaryOpExpr *expr) {
+ void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineConstantExpr(AffineConstantExpr *expr) {}
+ void visitConstantExpr(AffineConstantExpr *expr) {}
void visitAffineDimExpr(AffineDimExpr *expr) {}
void visitAffineSymbolExpr(AffineSymbolExpr *expr) {}
private:
- void visitAffineBinaryOpOperands(AffineBinaryOpExpr *expr) {
- visit(expr->getLHS());
- visit(expr->getRHS());
+ // Walk the operands - each operand is itself walked in post order.
+ void walkOperandsPostOrder(AffineBinaryOpExpr *expr) {
+ walkPostOrder(expr->getLHS());
+ walkPostOrder(expr->getRHS());
}
};
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index b9320c8..30ece71 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -25,6 +25,7 @@
namespace mlir {
+class FunctionPass;
class MLFunctionPass;
class ModulePass;
@@ -38,6 +39,9 @@
/// line if provided.
MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
+/// Creates an affine expression simplification pass.
+FunctionPass *createSimplifyAffineExprPass();
+
/// Replaces all ML functions in the module with equivalent CFG functions.
/// Function references are appropriately patched to refer to the newly
/// generated CFG functions.