Affine expression analysis and simplification.
Outside of IR/
- simplify a MutableAffineMap by flattening the affine expressions
- add a simplify affine expression pass that uses this analysis
- update the FlatAffineConstraints API (to be used in the next CL)
In IR:
- add isMultipleOf and getKnownGCD for AffineExpr, and make the in-IR
simplication of simplifyMod simpler and more powerful.
- rename the AffineExpr visitor methods to distinguish b/w visiting and
walking, and to simplify API names based on context.
The next CL will use some of these for the loop unrolling/unroll-jam to make
the detection for the need of cleanup loop powerful/non-trivial.
A future CL will finally move this simplification to FlatAffineConstraints to
make it more powerful. For eg., currently, even if a mod expr appearing in a
part of the expression tree can't be simplified, the whole thing won't be
simplified.
PiperOrigin-RevId: 211012256
diff --git a/lib/Analysis/AffineStructures.cpp b/lib/Analysis/AffineStructures.cpp
index 782257e..965004b 100644
--- a/lib/Analysis/AffineStructures.cpp
+++ b/lib/Analysis/AffineStructures.cpp
@@ -20,47 +20,66 @@
//===----------------------------------------------------------------------===//
#include "mlir/Analysis/AffineStructures.h"
+
#include "mlir/IR/AffineExpr.h"
+#include "mlir/IR/AffineExprVisitor.h"
#include "mlir/IR/AffineMap.h"
#include "mlir/IR/IntegerSet.h"
#include "mlir/IR/MLIRContext.h"
#include "mlir/IR/StandardOps.h"
+#include "llvm/ADT/ArrayRef.h"
+#include "llvm/Support/raw_ostream.h"
namespace mlir {
-MutableAffineMap::MutableAffineMap(AffineMap *map) {
+MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
+ : numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
+ context(context) {
for (auto *result : map->getResults())
results.push_back(result);
for (auto *rangeSize : map->getRangeSizes())
results.push_back(rangeSize);
}
-MutableIntegerSet::MutableIntegerSet(IntegerSet *set)
- : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()) {
- // TODO(bondhugula)
-}
+bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
+ if (results[idx]->isMultipleOf(factor))
+ return true;
-// Universal set.
-MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols)
- : numDims(numDims), numSymbols(numSymbols) {}
-
-AffineValueMap::AffineValueMap(const AffineApplyOp &op)
- : map(op.getAffineMap()) {
- // TODO: pull operands and results in.
-}
-
-bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
- // Check if the (first result expr) % factor becomes 0.
- if (auto *expr = dyn_cast<AffineConstantExpr>(AffineBinaryOpExpr::get(
- AffineExpr::Kind::Mod, map.getResult(idx),
- AffineConstantExpr::get(factor, context), context)))
- return expr->getValue() == 0;
-
- // TODO(bondhugula): use FlatAffineConstraints to complete this.
+ // TODO(bondhugula): use FlatAffineConstraints to complete this (for a more
+ // powerful analysis).
assert(0 && "isMultipleOf implementation incomplete");
return false;
}
+MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
+ : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()),
+ context(context) {
+ // TODO(bondhugula)
+}
+
+// Universal set.
+MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
+ MLIRContext *context)
+ : numDims(numDims), numSymbols(numSymbols), context(context) {}
+
+AffineValueMap::AffineValueMap(const AffineApplyOp &op, MLIRContext *context)
+ : map(op.getAffineMap(), context) {
+ // TODO: pull operands and results in.
+}
+
+inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
+ return map.isMultipleOf(idx, factor);
+}
+
AffineValueMap::~AffineValueMap() {}
+void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
+ assert(eq.size() == getNumCols());
+ unsigned offset = equalities.size();
+ equalities.resize(equalities.size() + eq.size());
+ for (unsigned i = 0, e = eq.size(); i < e; i++) {
+ equalities[offset + i] = eq[i];
+ }
+}
+
} // end namespace mlir
diff --git a/lib/Analysis/HyperRectangularSet.cpp b/lib/Analysis/HyperRectangularSet.cpp
index 1a06515..14e180b 100644
--- a/lib/Analysis/HyperRectangularSet.cpp
+++ b/lib/Analysis/HyperRectangularSet.cpp
@@ -110,9 +110,10 @@
HyperRectangularSet::HyperRectangularSet(unsigned numDims, unsigned numSymbols,
ArrayRef<ArrayRef<AffineExpr *>> lbs,
ArrayRef<ArrayRef<AffineExpr *>> ubs,
+ MLIRContext *context,
IntegerSet *symbolContext)
- : context(symbolContext ? MutableIntegerSet(symbolContext)
- : MutableIntegerSet(numDims, numSymbols)) {
+ : context(symbolContext ? MutableIntegerSet(symbolContext, context)
+ : MutableIntegerSet(numDims, numSymbols, context)) {
unsigned d = 0;
for (auto boundList : lbs) {
AffineBoundExprList lb;
diff --git a/lib/IR/AffineExpr.cpp b/lib/IR/AffineExpr.cpp
index 6bfbaf5..6393b3e 100644
--- a/lib/IR/AffineExpr.cpp
+++ b/lib/IR/AffineExpr.cpp
@@ -100,3 +100,56 @@
}
}
}
+
+uint64_t AffineExpr::getKnownGcd() const {
+ AffineBinaryOpExpr *binExpr = nullptr;
+ switch (kind) {
+ case Kind::SymbolId:
+ LLVM_FALLTHROUGH;
+ case Kind::DimId:
+ return 1;
+ case Kind::Constant:
+ return std::abs(cast<AffineConstantExpr>(this)->getValue());
+ case Kind::Mul:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ return binExpr->getLHS()->getKnownGcd() * binExpr->getRHS()->getKnownGcd();
+ case Kind::Add:
+ LLVM_FALLTHROUGH;
+ case Kind::FloorDiv:
+ case Kind::CeilDiv:
+ case Kind::Mod:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ return llvm::GreatestCommonDivisor64(binExpr->getLHS()->getKnownGcd(),
+ binExpr->getRHS()->getKnownGcd());
+ }
+}
+
+bool AffineExpr::isMultipleOf(int64_t factor) const {
+ AffineBinaryOpExpr *binExpr = nullptr;
+ uint64_t l, u;
+ switch (kind) {
+ case Kind::SymbolId:
+ LLVM_FALLTHROUGH;
+ case Kind::DimId:
+ return factor * factor == 1;
+ case Kind::Constant:
+ return cast<AffineConstantExpr>(this)->getValue() % factor == 0;
+ case Kind::Mul:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ // It's probably not worth optimizing this further (to not traverse the
+ // whole sub-tree under - it that would require a version of isMultipleOf
+ // that on a 'false' return also returns the known GCD).
+ return (l = binExpr->getLHS()->getKnownGcd()) % factor == 0 ||
+ (u = binExpr->getRHS()->getKnownGcd()) % factor == 0 ||
+ (l * u) % factor == 0;
+ case Kind::Add:
+ case Kind::FloorDiv:
+ case Kind::CeilDiv:
+ case Kind::Mod:
+ binExpr = cast<AffineBinaryOpExpr>(const_cast<AffineExpr *>(this));
+ return llvm::GreatestCommonDivisor64(binExpr->getLHS()->getKnownGcd(),
+ binExpr->getRHS()->getKnownGcd()) %
+ factor ==
+ 0;
+ }
+}
diff --git a/lib/IR/AffineMap.cpp b/lib/IR/AffineMap.cpp
index 2b6d680..37a3b45 100644
--- a/lib/IR/AffineMap.cpp
+++ b/lib/IR/AffineMap.cpp
@@ -225,17 +225,13 @@
return AffineConstantExpr::get(lhsConst->getValue() % rhsConst->getValue(),
context);
- // Fold modulo of a multiply with a constant that is a multiple of the
- // modulo factor to zero. Eg: (i * 128) mod 64 = 0.
+ // Fold modulo of an expression that is known to be a multiple of a constant
+ // to zero if that constant is a multiple of the modulo factor. Eg: (i * 128)
+ // mod 64 is folded to 0, and less trivially, (i*(j*4*(k*32))) mod 128 = 0.
if (rhsConst) {
- auto *lBin = dyn_cast<AffineBinaryOpExpr>(lhs);
- if (lBin && lBin->getKind() == Kind::Mul) {
- if (auto *lrhs = dyn_cast<AffineConstantExpr>(lBin->getRHS())) {
- // rhsConst is known to be positive if a constant.
- if (lrhs->getValue() % rhsConst->getValue() == 0)
- return AffineConstantExpr::get(0, context);
- }
- }
+ // rhsConst is known to be positive if a constant.
+ if (lhs->getKnownGcd() % rhsConst->getValue() == 0)
+ return AffineConstantExpr::get(0, context);
}
return nullptr;
diff --git a/lib/Transforms/SimplifyAffineExpr.cpp b/lib/Transforms/SimplifyAffineExpr.cpp
new file mode 100644
index 0000000..3c72887
--- /dev/null
+++ b/lib/Transforms/SimplifyAffineExpr.cpp
@@ -0,0 +1,252 @@
+//===- SimplifyAffineExpr.cpp - MLIR Affine Structures Class-----*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements a pass to simplify affine expressions.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/AffineStructures.h"
+#include "mlir/IR/AffineExprVisitor.h"
+#include "mlir/IR/AffineMap.h"
+#include "mlir/IR/Attributes.h"
+#include "mlir/IR/StmtVisitor.h"
+
+#include "mlir/Transforms/Pass.h"
+#include "mlir/Transforms/Passes.h"
+
+using namespace mlir;
+using llvm::report_fatal_error;
+
+namespace {
+
+/// Simplify all affine expressions appearing in the operation statements of the
+/// MLFunction.
+// TODO(someone): Gradually, extend this to all affine map references found in
+// ML functions and CFG functions.
+struct SimplifyAffineExpr : public FunctionPass {
+ explicit SimplifyAffineExpr() {}
+
+ void runOnMLFunction(MLFunction *f);
+ // Does nothing on CFG functions for now. No reusable walkers/visitors exist
+ // for this yet? TODO(someone).
+ void runOnCFGFunction(CFGFunction *f) {}
+};
+
+// This class is used to flatten a pure affine expression into a sum of products
+// (w.r.t constants) when possible, and in that process accumulating
+// contributions for each dimensional and symbolic identifier together. Note
+// that an affine expression may not always be expressible that way due to the
+// preesnce of modulo, floordiv, and ceildiv expressions. A simplification that
+// this flattening naturally performs is to fold a modulo expression to a zero,
+// if possible. Two examples are below:
+//
+// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
+// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
+//
+// For modulo and floordiv expressions, an additional variable is introduced to
+// rewrite it as a sum of products (w.r.t constants). For example, for the
+// second example above, d0 % 4 is replaced by d0 - 4*q with q being introduced:
+// the expression simplifies to:
+// (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to zero.
+//
+// This is a linear time post order walk for an affine expression that attempts
+// the above simplifications through visit methods, with partial results being
+// stored in 'operandExprStack'. When a parent expr is visited, the flattened
+// expressions corresponding to its two operands would already be on the stack -
+// the parent expr looks at the two flattened expressions and combines the two.
+// It pops off the operand expressions and pushes the combined result (although
+// this is done in-place on its LHS operand expr. When the walk is completed,
+// the flattened form of the top-level expression would be left on the stack.
+//
+class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
+public:
+ std::vector<SmallVector<int64_t, 32>> operandExprStack;
+
+ // The layout of the flattened expressions is dimensions, symbols, locals,
+ // and constant term.
+ unsigned getNumCols() const { return numDims + numSymbols + numLocals + 1; }
+
+ AffineExprFlattener(unsigned numDims, unsigned numSymbols)
+ : numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
+
+ void visitMulExpr(AffineBinaryOpExpr *expr) {
+ assert(expr->isPureAffine());
+ // Get the RHS constant.
+ auto rhsConst = operandExprStack.back()[getNumCols() - 1];
+ operandExprStack.pop_back();
+ // Update the LHS in place instead of pop and push.
+ auto &lhs = operandExprStack.back();
+ for (unsigned i = 0, e = lhs.size(); i < e; i++) {
+ lhs[i] *= rhsConst;
+ }
+ }
+ void visitAddExpr(AffineBinaryOpExpr *expr) {
+ const auto &rhs = operandExprStack.back();
+ auto &lhs = operandExprStack[operandExprStack.size() - 2];
+ assert(lhs.size() == rhs.size());
+ // Update the LHS in place.
+ for (unsigned i = 0; i < rhs.size(); i++) {
+ lhs[i] += rhs[i];
+ }
+ // Pop off the RHS.
+ operandExprStack.pop_back();
+ }
+ void visitModExpr(AffineBinaryOpExpr *expr) {
+ assert(expr->isPureAffine());
+ // This is a pure affine expr; the RHS is a constant.
+ auto rhsConst = operandExprStack.back()[getNumCols() - 1];
+ operandExprStack.pop_back();
+ auto &lhs = operandExprStack.back();
+ assert(rhsConst != 0 && "RHS constant can't be zero");
+ unsigned i;
+ for (i = 0; i < lhs.size(); i++)
+ if (lhs[i] % rhsConst != 0)
+ break;
+ if (i == lhs.size()) {
+ // The modulo expression here simplifies to zero.
+ lhs.assign(lhs.size(), 0);
+ return;
+ }
+ // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
+ // q * expr2) where q is the existential quantifier introduced.
+ addExistentialQuantifier();
+ lhs = operandExprStack.back();
+ lhs[numDims + numSymbols + numLocals - 1] = -rhsConst;
+ }
+ void visitConstantExpr(AffineConstantExpr *expr) {
+ operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
+ auto &eq = operandExprStack.back();
+ eq[getNumCols() - 1] = expr->getValue();
+ }
+ void visitDimExpr(AffineDimExpr *expr) {
+ SmallVector<int64_t, 32> eq(getNumCols(), 0);
+ eq[expr->getPosition()] = 1;
+ operandExprStack.push_back(eq);
+ }
+ void visitSymbolExpr(AffineSymbolExpr *expr) {
+ SmallVector<int64_t, 32> eq(getNumCols(), 0);
+ eq[numDims + expr->getPosition()] = 1;
+ operandExprStack.push_back(eq);
+ }
+ void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
+ // TODO(bondhugula): handle ceildiv as well; won't simplify further through
+ // this analysis but will be handled (rest of the expr will simplify).
+ report_fatal_error("ceildiv expr simplification not supported here");
+ }
+ void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
+ // TODO(bondhugula): handle ceildiv as well; won't simplify further through
+ // this analysis but will be handled (rest of the expr will simplify).
+ report_fatal_error("floordiv expr simplification unimplemented");
+ }
+ // Add an existential quantifier (used to flatten a mod or a floordiv expr).
+ void addExistentialQuantifier() {
+ for (auto &subExpr : operandExprStack) {
+ subExpr.insert(subExpr.begin() + numDims + numSymbols + numLocals, 0);
+ }
+ numLocals++;
+ }
+
+ unsigned numDims;
+ unsigned numSymbols;
+ unsigned numLocals;
+};
+
+} // end anonymous namespace
+
+FunctionPass *mlir::createSimplifyAffineExprPass() {
+ return new SimplifyAffineExpr();
+}
+
+AffineMap *MutableAffineMap::getAffineMap() {
+ return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
+}
+
+void SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
+ struct MapSimplifier : public StmtWalker<MapSimplifier> {
+ MLIRContext *context;
+ MapSimplifier(MLIRContext *context) : context(context) {}
+
+ void visitOperationStmt(OperationStmt *opStmt) {
+ for (auto attr : opStmt->getAttrs()) {
+ if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
+ MutableAffineMap mMap(mapAttr->getValue(), context);
+ mMap.simplify();
+ auto *map = mMap.getAffineMap();
+ opStmt->setAttr(attr.first, AffineMapAttr::get(map, context));
+ }
+ }
+ }
+ };
+
+ MapSimplifier v(f->getContext());
+ v.walkPostOrder(f);
+}
+
+/// Get an affine expression from a flat ArrayRef. If there are local variables
+/// (existential quantifiers introduced during the flattening) that appear in
+/// the sum of products expression, we can't readily express it as an affine
+/// expression of dimension and symbol id's; return nullptr in such cases.
+static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
+ unsigned numSymbols, MLIRContext *context) {
+ // Check if any local variable has a non-zero coefficient.
+ for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
+ if (eq[j] != 0)
+ return nullptr;
+ }
+
+ AffineExpr *expr = AffineConstantExpr::get(0, context);
+ for (unsigned j = 0; j < numDims + numSymbols; j++) {
+ if (eq[j] != 0) {
+ AffineExpr *id =
+ j < numDims
+ ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
+ : AffineSymbolExpr::get(j - numDims, context);
+ expr = AffineBinaryOpExpr::get(
+ AffineExpr::Kind::Add, expr,
+ AffineBinaryOpExpr::get(AffineExpr::Kind::Mul,
+ AffineConstantExpr::get(eq[j], context), id,
+ context),
+ context);
+ }
+ }
+ unsigned constTerm = eq[eq.size() - 1];
+ if (constTerm != 0)
+ expr = AffineBinaryOpExpr::get(AffineExpr::Kind::Add, expr,
+ AffineConstantExpr::get(constTerm, context),
+ context);
+ return expr;
+}
+
+// Simplify the result affine expressions of this map. The expressions have to
+// be pure for the simplification implemented.
+void MutableAffineMap::simplify() {
+ // Simplify each of the results if possible.
+ for (unsigned i = 0, e = getNumResults(); i < e; i++) {
+ AffineExpr *result = getResult(i);
+ if (!result->isPureAffine())
+ continue;
+
+ AffineExprFlattener flattener(numDims, numSymbols);
+ flattener.walkPostOrder(result);
+ const auto &flattenedExpr = flattener.operandExprStack.back();
+ auto *expr = toAffineExpr(flattenedExpr, numDims, numSymbols, context);
+ if (expr)
+ results[i] = expr;
+ flattener.operandExprStack.pop_back();
+ assert(flattener.operandExprStack.empty());
+ }
+}