Affine expression analysis and simplification.
Outside of IR/
- simplify a MutableAffineMap by flattening the affine expressions
- add a simplify affine expression pass that uses this analysis
- update the FlatAffineConstraints API (to be used in the next CL)
In IR:
- add isMultipleOf and getKnownGCD for AffineExpr, and make the in-IR
simplication of simplifyMod simpler and more powerful.
- rename the AffineExpr visitor methods to distinguish b/w visiting and
walking, and to simplify API names based on context.
The next CL will use some of these for the loop unrolling/unroll-jam to make
the detection for the need of cleanup loop powerful/non-trivial.
A future CL will finally move this simplification to FlatAffineConstraints to
make it more powerful. For eg., currently, even if a mod expr appearing in a
part of the expression tree can't be simplified, the whole thing won't be
simplified.
PiperOrigin-RevId: 211012256
diff --git a/include/mlir/Analysis/AffineStructures.h b/include/mlir/Analysis/AffineStructures.h
index 6d28260..ac6e728 100644
--- a/include/mlir/Analysis/AffineStructures.h
+++ b/include/mlir/Analysis/AffineStructures.h
@@ -41,23 +41,41 @@
/// A mutable affine map. Its affine expressions are however unique.
struct MutableAffineMap {
public:
- explicit MutableAffineMap(AffineMap *map);
+ MutableAffineMap(AffineMap *map, MLIRContext *context);
AffineExpr *getResult(unsigned idx) const { return results[idx]; }
unsigned getNumResults() const { return results.size(); }
+ unsigned getNumDims() const { return numDims; }
+ unsigned getNumSymbols() const { return numSymbols; }
+ /// Returns true if the idx'th result expression is a multiple of factor.
+ bool isMultipleOf(unsigned idx, int64_t factor) const;
+
+ /// Simplify the (result) expressions in this map using analysis (used by
+ //-simplify-affine-expr pass).
+ void simplify();
+ /// Get the AffineMap corresponding to this MutableAffineMap. Note that an
+ /// AffineMap * will be uniqued and stored in context, while a mutable one
+ /// isn't.
+ AffineMap *getAffineMap();
private:
+ // Same meaning as AffineMap's fields.
SmallVector<AffineExpr *, 8> results;
SmallVector<AffineExpr *, 8> rangeSizes;
+ unsigned numDims;
+ unsigned numSymbols;
+ /// A pointer to the IR's context to store all newly created AffineExpr's.
+ MLIRContext *context;
};
/// A mutable integer set. Its affine expressions are however unique.
struct MutableIntegerSet {
public:
- explicit MutableIntegerSet(IntegerSet *set);
+ MutableIntegerSet(IntegerSet *set, MLIRContext *context);
/// Create a universal set (no constraints).
- explicit MutableIntegerSet(unsigned numDims, unsigned numSymbols);
+ MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+ MLIRContext *context);
unsigned getNumDims() const { return numDims; }
unsigned getNumSymbols() const { return numSymbols; }
@@ -74,6 +92,8 @@
SmallVector<AffineExpr *, 8> constraints;
SmallVector<bool, 8> eqFlags;
+ /// A pointer to the IR's context to store all newly created AffineExpr's.
+ MLIRContext *context;
};
/// An AffineValueMap is an affine map plus its ML value operands and
@@ -89,9 +109,9 @@
// TODO(bondhugula): Some of these classes could go into separate files.
class AffineValueMap {
public:
- explicit AffineValueMap(const AffineApplyOp &op);
- explicit AffineValueMap(const AffineBound &bound);
- explicit AffineValueMap(AffineMap *map);
+ AffineValueMap(const AffineApplyOp &op, MLIRContext *context);
+ AffineValueMap(const AffineBound &bound, MLIRContext *context);
+ AffineValueMap(AffineMap *map, MLIRContext *context);
~AffineValueMap();
@@ -110,7 +130,7 @@
/// Return true if the idx^th result can be proved to be a multiple of
/// 'factor', false otherwise.
- bool isMultipleOf(unsigned idx, int64_t factor) const;
+ inline bool isMultipleOf(unsigned idx, int64_t factor) const;
/// Return true if the result at 'idx' is a constant, false
/// otherwise.
@@ -128,8 +148,6 @@
SmallVector<MLValue *, 4> operands;
/// The SSA results binding to the results of 'map'.
SmallVector<MLValue *, 4> results;
- /// A pointer to the IR's context to store all newly created AffineExpr's.
- MLIRContext *context;
};
/// An IntegerValueSet is an integer set plus its operands.
@@ -155,8 +173,6 @@
MutableIntegerSet set;
/// The SSA operands binding to the dim's and symbols of 'set'.
SmallVector<MLValue *, 4> operands;
- /// A pointer to the IR's context to store all newly created AffineExpr's.
- MLIRContext *context;
};
/// A flat list of affine equalities and inequalities in the form.
@@ -190,8 +206,7 @@
/// constraints and identifiers..
FlatAffineConstraints(unsigned numReservedInequalities,
unsigned numReservedEqualities, unsigned numReservedIds)
- : numEqualities(0), numInequalities(0),
- numReservedEqualities(numReservedEqualities),
+ : numReservedEqualities(numReservedEqualities),
numReservedInequalities(numReservedInequalities),
numReservedIds(numReservedIds) {
equalities.reserve(numReservedIds * numReservedEqualities);
@@ -208,23 +223,50 @@
/// Create an affine constraint system from an IntegerValueSet.
// TODO(bondhugula)
- FlatAffineConstraints(const IntegerValueSet &set);
+ explicit FlatAffineConstraints(const IntegerValueSet &set);
FlatAffineConstraints(ArrayRef<const AffineValueMap *> avmRef,
const IntegerSet &set);
+ FlatAffineConstraints(const MutableAffineMap &map);
+
~FlatAffineConstraints() {}
- inline int64_t atEq(unsigned i, unsigned j) {
- return equalities[i * numIds + j];
+ inline int64_t atEq(unsigned i, unsigned j) const {
+ return equalities[i * (numIds + 1) + j];
}
- inline int64_t atIneq(unsigned i, unsigned j) {
- return inequalities[i * numIds + j];
+ inline int64_t &atEq(unsigned i, unsigned j) {
+ return equalities[i * (numIds + 1) + j];
}
- unsigned getNumEqualities() const { return equalities.size(); }
- unsigned getNumInequalities() const { return inequalities.size(); }
+ inline int64_t atIneq(unsigned i, unsigned j) const {
+ return inequalities[i * (numIds + 1) + j];
+ }
+
+ inline int64_t &atIneq(unsigned i, unsigned j) {
+ return inequalities[i * (numIds + 1) + j];
+ }
+
+ inline unsigned getNumCols() const { return numIds + 1; }
+
+ inline unsigned getNumEqualities() const {
+ return equalities.size() / getNumCols();
+ }
+
+ inline unsigned getNumInequalities() const {
+ return inequalities.size() / getNumCols();
+ }
+
+ ArrayRef<int64_t> getEquality(unsigned idx) {
+ return ArrayRef<int64_t>(&equalities[idx * getNumCols()], getNumCols());
+ }
+
+ ArrayRef<int64_t> getInequality(unsigned idx) {
+ return ArrayRef<int64_t>(&inequalities[idx * getNumCols()], getNumCols());
+ }
+
+ AffineExpr *toAffineExpr(unsigned idx, MLIRContext *context);
void addInequality(ArrayRef<int64_t> inEq);
void addEquality(ArrayRef<int64_t> eq);
@@ -239,11 +281,19 @@
void removeEquality(unsigned pos);
void removeInequality(unsigned pos);
- unsigned getNumConstraints() const { return numEqualities + numInequalities; }
- unsigned getNumIds() const { return numIds; }
- unsigned getNumDimIds() const { return numDims; }
- unsigned getNumSymbolIds() const { return numSymbols; }
- unsigned getNumLocalIds() const { return numDims - numSymbols - numDims; }
+ unsigned getNumConstraints() const {
+ return equalities.size() + inequalities.size();
+ }
+ inline unsigned getNumIds() const { return numIds; }
+ inline unsigned getNumResultDimIds() const { return numResultDims; }
+ inline unsigned getNumDimIds() const { return numDims; }
+ inline unsigned getNumSymbolIds() const { return numSymbols; }
+ inline unsigned getNumLocalIds() const {
+ return numIds - numResultDims - numDims - numSymbols;
+ }
+
+ void print(raw_ostream &os) const;
+ void dump() const;
private:
/// Coefficients of affine equalities (in == 0 form).
@@ -252,12 +302,6 @@
/// Coefficients of affine inequalities (in >= 0 form).
SmallVector<int64_t, 64> inequalities;
- /// Number of equalities in this system.
- unsigned numEqualities;
-
- /// Number of inequalities in this system.
- unsigned numInequalities;
-
// Pre-allocated space.
unsigned numReservedEqualities;
unsigned numReservedInequalities;
@@ -267,6 +311,9 @@
unsigned numIds;
/// Number of identifiers corresponding to real dimensions.
+ unsigned numResultDims;
+
+ /// Number of identifiers corresponding to real dimensions.
unsigned numDims;
/// Number of identifiers corresponding to symbols (unknown but constant for
diff --git a/include/mlir/Analysis/HyperRectangularSet.h b/include/mlir/Analysis/HyperRectangularSet.h
index ba28b3e..8956aeb 100644
--- a/include/mlir/Analysis/HyperRectangularSet.h
+++ b/include/mlir/Analysis/HyperRectangularSet.h
@@ -95,6 +95,7 @@
HyperRectangularSet(unsigned numDims, unsigned numSymbols,
ArrayRef<ArrayRef<AffineExpr *>> lbs,
ArrayRef<ArrayRef<AffineExpr *>> ubs,
+ MLIRContext *context,
IntegerSet *symbolContext = nullptr);
unsigned getNumDims() const { return numDims; }
diff --git a/include/mlir/IR/AffineExpr.h b/include/mlir/IR/AffineExpr.h
index 825783d..6ef62dc 100644
--- a/include/mlir/IR/AffineExpr.h
+++ b/include/mlir/IR/AffineExpr.h
@@ -70,6 +70,12 @@
/// floordiv, ceildiv, and mod is only allowed w.r.t constants.
bool isPureAffine() const;
+ /// Returns the greatest known common divisor of this affine expression.
+ uint64_t getKnownGcd() const;
+
+ /// Return true if the affine expression is a multiple of 'factor'.
+ bool isMultipleOf(int64_t factor) const;
+
protected:
explicit AffineExpr(Kind kind) : kind(kind) {}
~AffineExpr() {}
diff --git a/include/mlir/IR/AffineExprVisitor.h b/include/mlir/IR/AffineExprVisitor.h
index 7f0bc15..5a2753d 100644
--- a/include/mlir/IR/AffineExprVisitor.h
+++ b/include/mlir/IR/AffineExprVisitor.h
@@ -26,7 +26,7 @@
namespace mlir {
-/// Base class for AffineExpr visitors.
+/// Base class for AffineExpr visitors/walkers.
///
/// AffineExpr visitors are used when you want to perform different actions
/// for different kinds of AffineExprs without having to use lots of casts
@@ -72,54 +72,86 @@
/// just as efficient as having your own switch statement over the statement
/// opcode.
-template <typename SubClass> class AffineExprVisitor {
+template <typename SubClass, typename RetTy = void> class AffineExprVisitor {
//===--------------------------------------------------------------------===//
// Interface code - This is the public interface of the AffineExprVisitor
// that you use to visit affine expressions...
public:
- // Function to visit an AffineExpr.
- void visit(AffineExpr *expr) {
+ // Function to walk an AffineExpr (in post order).
+ RetTy walkPostOrder(AffineExpr *expr) {
static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
"Must instantiate with a derived type of AffineExprVisitor");
switch (expr->getKind()) {
case AffineExpr::Kind::Add: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryAddOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
}
case AffineExpr::Kind::Mul: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryMulOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
}
case AffineExpr::Kind::Mod: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryModOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
}
case AffineExpr::Kind::FloorDiv: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryFloorDivOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
}
case AffineExpr::Kind::CeilDiv: {
auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
- visitAffineBinaryOpOperands(binOpExpr);
- return static_cast<SubClass *>(this)->visitAffineBinaryCeilDivOpExpr(
- binOpExpr);
+ walkOperandsPostOrder(binOpExpr);
+ return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
}
case AffineExpr::Kind::Constant:
- return static_cast<SubClass *>(this)->visitAffineConstantExpr(
+ return static_cast<SubClass *>(this)->visitConstantExpr(
cast<AffineConstantExpr>(expr));
case AffineExpr::Kind::DimId:
- return static_cast<SubClass *>(this)->visitAffineDimExpr(
+ return static_cast<SubClass *>(this)->visitDimExpr(
cast<AffineDimExpr>(expr));
case AffineExpr::Kind::SymbolId:
- return static_cast<SubClass *>(this)->visitAffineSymbolExpr(
+ return static_cast<SubClass *>(this)->visitSymbolExpr(
+ cast<AffineSymbolExpr>(expr));
+ }
+ }
+
+ // Function to visit an AffineExpr.
+ RetTy visit(AffineExpr *expr) {
+ static_assert(std::is_base_of<AffineExprVisitor, SubClass>::value,
+ "Must instantiate with a derived type of AffineExprVisitor");
+ switch (expr->getKind()) {
+ case AffineExpr::Kind::Add: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitAddExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::Mul: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitMulExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::Mod: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitModExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::FloorDiv: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitFloorDivExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::CeilDiv: {
+ auto *binOpExpr = cast<AffineBinaryOpExpr>(expr);
+ return static_cast<SubClass *>(this)->visitCeilDivExpr(binOpExpr);
+ }
+ case AffineExpr::Kind::Constant:
+ return static_cast<SubClass *>(this)->visitConstantExpr(
+ cast<AffineConstantExpr>(expr));
+ case AffineExpr::Kind::DimId:
+ return static_cast<SubClass *>(this)->visitDimExpr(
+ cast<AffineDimExpr>(expr));
+ case AffineExpr::Kind::SymbolId:
+ return static_cast<SubClass *>(this)->visitSymbolExpr(
cast<AffineSymbolExpr>(expr));
}
}
@@ -135,29 +167,30 @@
// Default visit methods. Note that the default op-specific binary op visit
// methods call the general visitAffineBinaryOpExpr visit method.
void visitAffineBinaryOpExpr(AffineBinaryOpExpr *expr) {}
- void visitAffineBinaryAddOpExpr(AffineBinaryOpExpr *expr) {
+ void visitAddExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryMulOpExpr(AffineBinaryOpExpr *expr) {
+ void visitMulExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryModOpExpr(AffineBinaryOpExpr *expr) {
+ void visitModExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryFloorDivOpExpr(AffineBinaryOpExpr *expr) {
+ void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineBinaryCeilDivOpExpr(AffineBinaryOpExpr *expr) {
+ void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
static_cast<SubClass *>(this)->visitAffineBinaryOpExpr(expr);
}
- void visitAffineConstantExpr(AffineConstantExpr *expr) {}
+ void visitConstantExpr(AffineConstantExpr *expr) {}
void visitAffineDimExpr(AffineDimExpr *expr) {}
void visitAffineSymbolExpr(AffineSymbolExpr *expr) {}
private:
- void visitAffineBinaryOpOperands(AffineBinaryOpExpr *expr) {
- visit(expr->getLHS());
- visit(expr->getRHS());
+ // Walk the operands - each operand is itself walked in post order.
+ void walkOperandsPostOrder(AffineBinaryOpExpr *expr) {
+ walkPostOrder(expr->getLHS());
+ walkPostOrder(expr->getRHS());
}
};
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index b9320c8..30ece71 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -25,6 +25,7 @@
namespace mlir {
+class FunctionPass;
class MLFunctionPass;
class ModulePass;
@@ -38,6 +39,9 @@
/// line if provided.
MLFunctionPass *createLoopUnrollAndJamPass(int unrollJamFactor = -1);
+/// Creates an affine expression simplification pass.
+FunctionPass *createSimplifyAffineExprPass();
+
/// Replaces all ML functions in the module with equivalent CFG functions.
/// Function references are appropriately patched to refer to the newly
/// generated CFG functions.
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index 782257e..965004b 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -20,47 +20,66 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
+
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardOps.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
-MutableAffineMap::MutableAffineMap(AffineMap *map) {
+MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
+ : numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
+ context(context) {
for (auto *result : map->getResults())
results.push_back(result);
for (auto *rangeSize : map->getRangeSizes())
results.push_back(rangeSize);
}
-MutableIntegerSet::MutableIntegerSet(IntegerSet *set)
- : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()) {
- // TODO(bondhugula)
-}
+bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
+ if (results[idx]->isMultipleOf(factor))
+ return true;
-// Universal set.
-MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols)
- : numDims(numDims), numSymbols(numSymbols) {}
-
-AffineValueMap::AffineValueMap(const AffineApplyOp &op)
- : map(op.getAffineMap()) {
- // TODO: pull operands and results in.
-}
-
-bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
- // Check if the (first result expr) % factor becomes 0.
- if (auto *expr = dyn_cast<AffineConstantExpr>(AffineBinaryOpExpr::get(
- AffineExpr::Kind::Mod, map.getResult(idx),
- AffineConstantExpr::get(factor, context), context)))
- return expr->getValue() == 0;
-
- // TODO(bondhugula): use FlatAffineConstraints to complete this.
+ // TODO(bondhugula): use FlatAffineConstraints to complete this (for a more
+ // powerful analysis).
assert(0 && "isMultipleOf implementation incomplete");
return false;
}
+MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
+ : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()),
+ context(context) {
+ // TODO(bondhugula)
+}
+
+// Universal set.
+MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+ MLIRContext *context)
+ : numDims(numDims), numSymbols(numSymbols), context(context) {}
+
+AffineValueMap::AffineValueMap(const AffineApplyOp &op, MLIRContext *context)
+ : map(op.getAffineMap(), context) {
+ // TODO: pull operands and results in.
+}
+
+inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
+ return map.isMultipleOf(idx, factor);
+}
+
AffineValueMap::~AffineValueMap() {}
+void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
+ assert(eq.size() == getNumCols());
+ unsigned offset = equalities.size();
+ equalities.resize(equalities.size() + eq.size());
+ for (unsigned i = 0, e = eq.size(); i < e; i++) {
+ equalities[offset + i] = eq[i];
+ }
+}
+
} // end namespace mlir
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 1a06515..14e180b 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -110,9 +110,10 @@
HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
ArrayRef<ArrayRef<AffineExpr *>> lbs,
ArrayRef<ArrayRef<AffineExpr *>> ubs,
+ MLIRContext *context,
IntegerSet *symbolContext)
- : context(symbolContext ? MutableIntegerSet(symbolContext)
- : MutableIntegerSet(numDims, numSymbols)) {
+ : context(symbolContext ? MutableIntegerSet(symbolContext, context)
+ : MutableIntegerSet(numDims, numSymbols, context)) {
unsigned d = 0;
for (auto boundList : lbs) {
AffineBoundExprList lb;
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 6bfbaf5..6393b3e 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -100,3 +100,56 @@
}
}
}
+
+uint64_t AffineExpr::getKnownGcd() const {
+ AffineBinaryOpExpr *binExpr = nullptr;
+ switch (kind) {
+ case Kind::SymbolId:
+ LLVM_FALLTHROUGH;
+ case Kind::DimId:
+ return 1;
+ case Kind::Constant:
+ return std::abs(cast<AffineConstantExpr>(this)->getValue());
+ case Kind::Mul:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ return binExpr->getLHS()->getKnownGcd() * binExpr->getRHS()->getKnownGcd();
+ case Kind::Add:
+ LLVM_FALLTHROUGH;
+ case Kind::FloorDiv:
+ case Kind::CeilDiv:
+ case Kind::Mod:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ return llvm::GreatestCommonDivisor64(binExpr->getLHS()->getKnownGcd(),
+ binExpr->getRHS()->getKnownGcd());
+ }
+}
+
+bool AffineExpr::isMultipleOf(int64_t factor) const {
+ AffineBinaryOpExpr *binExpr = nullptr;
+ uint64_t l, u;
+ switch (kind) {
+ case Kind::SymbolId:
+ LLVM_FALLTHROUGH;
+ case Kind::DimId:
+ return factor * factor == 1;
+ case Kind::Constant:
+ return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
+ case Kind::Mul:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ // It's probably not worth optimizing this further (to not traverse the
+ // whole sub-tree under - it that would require a version of isMultipleOf
+ // that on a 'false' return also returns the known GCD).
+ return (l = binExpr->getLHS()->getKnownGcd()) % factor == 0 ||
+ (u = binExpr->getRHS()->getKnownGcd()) % factor == 0 ||
+ (l * u) % factor == 0;
+ case Kind::Add:
+ case Kind::FloorDiv:
+ case Kind::CeilDiv:
+ case Kind::Mod:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ return llvm::GreatestCommonDivisor64(binExpr->getLHS()->getKnownGcd(),
+ binExpr->getRHS()->getKnownGcd()) %
+ factor ==
+ 0;
+ }
+}
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 2b6d680..37a3b45 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -225,17 +225,13 @@
return AffineConstantExpr::get(lhsConst->getValue() % rhsConst->getValue(),
context);
- // Fold modulo of a multiply with a constant that is a multiple of the
- // modulo factor to zero. Eg: (i * 128) mod 64 = 0.
+ // Fold modulo of an expression that is known to be a multiple of a constant
+ // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
+ // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
if (rhsConst) {
- auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
- if (lBin && lBin->getKind() == Kind::Mul) {
- if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
- // rhsConst is known to be positive if a constant.
- if (lrhs->getValue() % rhsConst->getValue() == 0)
- return AffineConstantExpr::get(0, context);
- }
- }
+ // rhsConst is known to be positive if a constant.
+ if (lhs->getKnownGcd() % rhsConst->getValue() == 0)
+ return AffineConstantExpr::get(0, context);
}
return nullptr;
diff --git a/lib/Transforms/SimplifyAffineExpr.cpp b/lib/Transforms/SimplifyAffineExpr.cpp
new file mode 100644
index 0000000..3c72887
--- /dev/null
+++ b/lib/Transforms/SimplifyAffineExpr.cpp
@@ -0,0 +1,252 @@
+//===- SimplifyAffineExpr.cpp - MLIR Affine Structures Class-----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to simplify affine expressions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StmtVisitor.h"
+
+#include "mlir/Transforms/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using llvm::report_fatal_error;
+
+namespace {
+
+/// Simplify all affine expressions appearing in the operation statements of the
+/// MLFunction.
+// TODO(someone): Gradually, extend this to all affine map references found in
+// ML functions and CFG functions.
+struct SimplifyAffineExpr : public FunctionPass {
+ explicit SimplifyAffineExpr() {}
+
+ void runOnMLFunction(MLFunction *f);
+ // Does nothing on CFG functions for now. No reusable walkers/visitors exist
+ // for this yet? TODO(someone).
+ void runOnCFGFunction(CFGFunction *f) {}
+};
+
+// This class is used to flatten a pure affine expression into a sum of products
+// (w.r.t constants) when possible, and in that process accumulating
+// contributions for each dimensional and symbolic identifier together. Note
+// that an affine expression may not always be expressible that way due to the
+// preesnce of modulo, floordiv, and ceildiv expressions. A simplification that
+// this flattening naturally performs is to fold a modulo expression to a zero,
+// if possible. Two examples are below:
+//
+// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
+// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
+//
+// For modulo and floordiv expressions, an additional variable is introduced to
+// rewrite it as a sum of products (w.r.t constants). For example, for the
+// second example above, d0 % 4 is replaced by d0 - 4*q with q being introduced:
+// the expression simplifies to:
+// (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to zero.
+//
+// This is a linear time post order walk for an affine expression that attempts
+// the above simplifications through visit methods, with partial results being
+// stored in 'operandExprStack'. When a parent expr is visited, the flattened
+// expressions corresponding to its two operands would already be on the stack -
+// the parent expr looks at the two flattened expressions and combines the two.
+// It pops off the operand expressions and pushes the combined result (although
+// this is done in-place on its LHS operand expr. When the walk is completed,
+// the flattened form of the top-level expression would be left on the stack.
+//
+class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
+public:
+ std::vector<SmallVector<int64_t, 32>> operandExprStack;
+
+ // The layout of the flattened expressions is dimensions, symbols, locals,
+ // and constant term.
+ unsigned getNumCols() const { return numDims + numSymbols + numLocals + 1; }
+
+ AffineExprFlattener(unsigned numDims, unsigned numSymbols)
+ : numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
+
+ void visitMulExpr(AffineBinaryOpExpr *expr) {
+ assert(expr->isPureAffine());
+ // Get the RHS constant.
+ auto rhsConst = operandExprStack.back()[getNumCols() - 1];
+ operandExprStack.pop_back();
+ // Update the LHS in place instead of pop and push.
+ auto &lhs = operandExprStack.back();
+ for (unsigned i = 0, e = lhs.size(); i < e; i++) {
+ lhs[i] *= rhsConst;
+ }
+ }
+ void visitAddExpr(AffineBinaryOpExpr *expr) {
+ const auto &rhs = operandExprStack.back();
+ auto &lhs = operandExprStack[operandExprStack.size() - 2];
+ assert(lhs.size() == rhs.size());
+ // Update the LHS in place.
+ for (unsigned i = 0; i < rhs.size(); i++) {
+ lhs[i] += rhs[i];
+ }
+ // Pop off the RHS.
+ operandExprStack.pop_back();
+ }
+ void visitModExpr(AffineBinaryOpExpr *expr) {
+ assert(expr->isPureAffine());
+ // This is a pure affine expr; the RHS is a constant.
+ auto rhsConst = operandExprStack.back()[getNumCols() - 1];
+ operandExprStack.pop_back();
+ auto &lhs = operandExprStack.back();
+ assert(rhsConst != 0 && "RHS constant can't be zero");
+ unsigned i;
+ for (i = 0; i < lhs.size(); i++)
+ if (lhs[i] % rhsConst != 0)
+ break;
+ if (i == lhs.size()) {
+ // The modulo expression here simplifies to zero.
+ lhs.assign(lhs.size(), 0);
+ return;
+ }
+ // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
+ // q * expr2) where q is the existential quantifier introduced.
+ addExistentialQuantifier();
+ lhs = operandExprStack.back();
+ lhs[numDims + numSymbols + numLocals - 1] = -rhsConst;
+ }
+ void visitConstantExpr(AffineConstantExpr *expr) {
+ operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+ auto &eq = operandExprStack.back();
+ eq[getNumCols() - 1] = expr->getValue();
+ }
+ void visitDimExpr(AffineDimExpr *expr) {
+ SmallVector<int64_t, 32> eq(getNumCols(), 0);
+ eq[expr->getPosition()] = 1;
+ operandExprStack.push_back(eq);
+ }
+ void visitSymbolExpr(AffineSymbolExpr *expr) {
+ SmallVector<int64_t, 32> eq(getNumCols(), 0);
+ eq[numDims + expr->getPosition()] = 1;
+ operandExprStack.push_back(eq);
+ }
+ void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
+ // TODO(bondhugula): handle ceildiv as well; won't simplify further through
+ // this analysis but will be handled (rest of the expr will simplify).
+ report_fatal_error("ceildiv expr simplification not supported here");
+ }
+ void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
+ // TODO(bondhugula): handle ceildiv as well; won't simplify further through
+ // this analysis but will be handled (rest of the expr will simplify).
+ report_fatal_error("floordiv expr simplification unimplemented");
+ }
+ // Add an existential quantifier (used to flatten a mod or a floordiv expr).
+ void addExistentialQuantifier() {
+ for (auto &subExpr : operandExprStack) {
+ subExpr.insert(subExpr.begin() + numDims + numSymbols + numLocals, 0);
+ }
+ numLocals++;
+ }
+
+ unsigned numDims;
+ unsigned numSymbols;
+ unsigned numLocals;
+};
+
+} // end anonymous namespace
+
+FunctionPass *mlir::createSimplifyAffineExprPass() {
+ return new SimplifyAffineExpr();
+}
+
+AffineMap *MutableAffineMap::getAffineMap() {
+ return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
+}
+
+void SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
+ struct MapSimplifier : public StmtWalker<MapSimplifier> {
+ MLIRContext *context;
+ MapSimplifier(MLIRContext *context) : context(context) {}
+
+ void visitOperationStmt(OperationStmt *opStmt) {
+ for (auto attr : opStmt->getAttrs()) {
+ if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
+ MutableAffineMap mMap(mapAttr->getValue(), context);
+ mMap.simplify();
+ auto *map = mMap.getAffineMap();
+ opStmt->setAttr(attr.first, AffineMapAttr::get(map, context));
+ }
+ }
+ }
+ };
+
+ MapSimplifier v(f->getContext());
+ v.walkPostOrder(f);
+}
+
+/// Get an affine expression from a flat ArrayRef. If there are local variables
+/// (existential quantifiers introduced during the flattening) that appear in
+/// the sum of products expression, we can't readily express it as an affine
+/// expression of dimension and symbol id's; return nullptr in such cases.
+static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+ unsigned numSymbols, MLIRContext *context) {
+ // Check if any local variable has a non-zero coefficient.
+ for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
+ if (eq[j] != 0)
+ return nullptr;
+ }
+
+ AffineExpr *expr = AffineConstantExpr::get(0, context);
+ for (unsigned j = 0; j < numDims + numSymbols; j++) {
+ if (eq[j] != 0) {
+ AffineExpr *id =
+ j < numDims
+ ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
+ : AffineSymbolExpr::get(j - numDims, context);
+ expr = AffineBinaryOpExpr::get(
+ AffineExpr::Kind::Add, expr,
+ AffineBinaryOpExpr::get(AffineExpr::Kind::Mul,
+ AffineConstantExpr::get(eq[j], context), id,
+ context),
+ context);
+ }
+ }
+ unsigned constTerm = eq[eq.size() - 1];
+ if (constTerm != 0)
+ expr = AffineBinaryOpExpr::get(AffineExpr::Kind::Add, expr,
+ AffineConstantExpr::get(constTerm, context),
+ context);
+ return expr;
+}
+
+// Simplify the result affine expressions of this map. The expressions have to
+// be pure for the simplification implemented.
+void MutableAffineMap::simplify() {
+ // Simplify each of the results if possible.
+ for (unsigned i = 0, e = getNumResults(); i < e; i++) {
+ AffineExpr *result = getResult(i);
+ if (!result->isPureAffine())
+ continue;
+
+ AffineExprFlattener flattener(numDims, numSymbols);
+ flattener.walkPostOrder(result);
+ const auto &flattenedExpr = flattener.operandExprStack.back();
+ auto *expr = toAffineExpr(flattenedExpr, numDims, numSymbols, context);
+ if (expr)
+ results[i] = expr;
+ flattener.operandExprStack.pop_back();
+ assert(flattener.operandExprStack.empty());
+ }
+}
diff --git a/test/IR/affine-map.mlir b/test/IR/affine-map.mlir
index 4aacbd3..5d16f13 100644
--- a/test/IR/affine-map.mlir
+++ b/test/IR/affine-map.mlir
@@ -29,8 +29,9 @@
#map3j = (i, j) -> (i + 1, j*1*4 + 2)
#map3k = (i, j) -> (i + 1, j*4*1 + 2)
-// The following reduction should be unique'd out too but the expression
-// simplifier is not powerful enough.
+// The following reduction should be unique'd out too but such expression
+// simplification is not performed for IR parsing, but only through analyses
+// and transforms.
// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, d1 + d1 + d1 + d1 + 2)
#map3l = (i, j) -> ((j - i) + 2*(i - j + 1) + j - 1 + 0, j + j + 1 + j + j + 1)
@@ -155,13 +156,17 @@
// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (0, d1, d0 * 2, 0)
#map46 = (i, j, k) -> (i*0, 1*j, i * 128 floordiv 64, j * 0 floordiv 64)
-// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (d0, d0 * 4, 0, 0)
-#map47 = (i, j, k) -> (i * 64 ceildiv 64, i * 512 ceildiv 128, 4 * j mod 4, 4*j*4 mod 8)
+// CHECK: #map{{[0-9]+}} = (d0, d1, d2) -> (d0, d0 * 4, 0, 0, 0)
+#map47 = (i, j, k) -> (i * 64 ceildiv 64, i * 512 ceildiv 128, 4 * j mod 4, 4*j*4 mod 8, k mod 1)
-// floordiv should resolve similarly to ceildiv and be unique'd out
+// floordiv should resolve similarly to ceildiv and be unique'd out.
// CHECK-NOT: #map48{{[a-z]}}
#map48 = (i, j, k) -> (i * 64 floordiv 64, i * 512 floordiv 128, 4 * j mod 4, 4*j*4 mod 8)
+// Simplifications for mod using known GCD's of the LHS expr.
+// CHECK: #map{{[0-9]+}} = (d0, d1)[s0] -> (0, 0, 0, (d0 * 4 + 3) mod 2)
+#map49 = (i, j)[s0] -> ( (i * 4 + 8) mod 4, 32 * j * s0 * 8 mod 256, (4*i + (j * (s0 * 2))) mod 2, (4*i + 3) mod 2)
+
// CHECK: extfunc @f0(memref<2x4xi8, #map{{[0-9]+}}, 1>)
extfunc @f0(memref<2x4xi8, #map0, 1>)
@@ -323,3 +328,6 @@
// CHECK: extfunc @f48(memref<100x100x100xi8, #map{{[0-9]+}}>)
extfunc @f48(memref<100x100x100xi8, #map48>)
+
+// CHECK: extfunc @f49(memref<100x100xi8, #map{{[0-9]+}}>)
+extfunc @f49(memref<100x100xi8, #map49>)
diff --git a/test/Transforms/simplify.mlir b/test/Transforms/simplify.mlir
new file mode 100644
index 0000000..19d9cb1
--- /dev/null
+++ b/test/Transforms/simplify.mlir
@@ -0,0 +1,20 @@
+// RUN: mlir-opt %s -o - -simplify-affine-expr | FileCheck %s
+
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (0, 0)
+#map0 = (d0, d1) -> ((d0 - d0 mod 4) mod 4, (d0 - d0 mod 128 - 64) mod 64)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (d0 + 1, d1 * 5 + 3)
+#map1 = (d0, d1) -> (d1 - d0 + (d0 - d1 + 1) * 2 + d1 - 1, 1 + 2*d1 + d1 + d1 + d1 + 2)
+// CHECK: #map{{[0-9]+}} = (d0, d1) -> (0, 0, 0)
+#map2 = (d0, d1) -> (((d0 - d0 mod 2) * 2) mod 4, (5*d1 + 8 - (5*d1 + 4) mod 4) mod 4, 0)
+
+mlfunc @test() {
+ for %n0 = 0 to 127 {
+ for %n1 = 0 to 7 {
+ %x = affine_apply #map0(%n0, %n1)
+ %y = affine_apply #map1(%n0, %n1)
+ %z = affine_apply #map2(%n0, %n1)
+ }
+ }
+ return
+}
+
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index e1cba40..5bbf6b7 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -37,6 +37,7 @@
#include "llvm/Support/Regex.h"
#include "llvm/Support/SourceMgr.h"
#include "llvm/Support/ToolOutputFile.h"
+
using namespace mlir;
using namespace llvm;
@@ -55,6 +56,7 @@
ConvertToCFG,
LoopUnroll,
LoopUnrollAndJam,
+ SimplifyAffineExpr,
TFRaiseControlFlow,
};
@@ -65,6 +67,8 @@
clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
clEnumValN(LoopUnrollAndJam, "loop-unroll-jam",
"Unroll and jam loops"),
+ clEnumValN(SimplifyAffineExpr, "simplify-affine-expr",
+ "Simplify affine expressions"),
clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
"Dynamic TensorFlow Switch/Match nodes to a CFG")));
@@ -117,6 +121,9 @@
case LoopUnrollAndJam:
pass = createLoopUnrollAndJamPass();
break;
+ case SimplifyAffineExpr:
+ pass = createSimplifyAffineExprPass();
+ break;
case TFRaiseControlFlow:
pass = createRaiseTFControlFlowPass();
break;