blob: 3c72887c3018439234e8b3b630bb1c5f272b0665 [file] [log] [blame]
Uday Bondhugula83a41c92018-08-30 17:35:15 -07001//===- SimplifyAffineExpr.cpp - MLIR Affine Structures Class-----*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements a pass to simplify affine expressions.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineStructures.h"
23#include "mlir/IR/AffineExprVisitor.h"
24#include "mlir/IR/AffineMap.h"
25#include "mlir/IR/Attributes.h"
26#include "mlir/IR/StmtVisitor.h"
27
28#include "mlir/Transforms/Pass.h"
29#include "mlir/Transforms/Passes.h"
30
31using namespace mlir;
32using llvm::report_fatal_error;
33
34namespace {
35
36/// Simplify all affine expressions appearing in the operation statements of the
37/// MLFunction.
38// TODO(someone): Gradually, extend this to all affine map references found in
39// ML functions and CFG functions.
40struct SimplifyAffineExpr : public FunctionPass {
41 explicit SimplifyAffineExpr() {}
42
43 void runOnMLFunction(MLFunction *f);
44 // Does nothing on CFG functions for now. No reusable walkers/visitors exist
45 // for this yet? TODO(someone).
46 void runOnCFGFunction(CFGFunction *f) {}
47};
48
49// This class is used to flatten a pure affine expression into a sum of products
50// (w.r.t constants) when possible, and in that process accumulating
51// contributions for each dimensional and symbolic identifier together. Note
52// that an affine expression may not always be expressible that way due to the
53// preesnce of modulo, floordiv, and ceildiv expressions. A simplification that
54// this flattening naturally performs is to fold a modulo expression to a zero,
55// if possible. Two examples are below:
56//
57// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
58// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
59//
60// For modulo and floordiv expressions, an additional variable is introduced to
61// rewrite it as a sum of products (w.r.t constants). For example, for the
62// second example above, d0 % 4 is replaced by d0 - 4*q with q being introduced:
63// the expression simplifies to:
64// (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to zero.
65//
66// This is a linear time post order walk for an affine expression that attempts
67// the above simplifications through visit methods, with partial results being
68// stored in 'operandExprStack'. When a parent expr is visited, the flattened
69// expressions corresponding to its two operands would already be on the stack -
70// the parent expr looks at the two flattened expressions and combines the two.
71// It pops off the operand expressions and pushes the combined result (although
72// this is done in-place on its LHS operand expr. When the walk is completed,
73// the flattened form of the top-level expression would be left on the stack.
74//
75class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
76public:
77 std::vector<SmallVector<int64_t, 32>> operandExprStack;
78
79 // The layout of the flattened expressions is dimensions, symbols, locals,
80 // and constant term.
81 unsigned getNumCols() const { return numDims + numSymbols + numLocals + 1; }
82
83 AffineExprFlattener(unsigned numDims, unsigned numSymbols)
84 : numDims(numDims), numSymbols(numSymbols), numLocals(0) {}
85
86 void visitMulExpr(AffineBinaryOpExpr *expr) {
87 assert(expr->isPureAffine());
88 // Get the RHS constant.
89 auto rhsConst = operandExprStack.back()[getNumCols() - 1];
90 operandExprStack.pop_back();
91 // Update the LHS in place instead of pop and push.
92 auto &lhs = operandExprStack.back();
93 for (unsigned i = 0, e = lhs.size(); i < e; i++) {
94 lhs[i] *= rhsConst;
95 }
96 }
97 void visitAddExpr(AffineBinaryOpExpr *expr) {
98 const auto &rhs = operandExprStack.back();
99 auto &lhs = operandExprStack[operandExprStack.size() - 2];
100 assert(lhs.size() == rhs.size());
101 // Update the LHS in place.
102 for (unsigned i = 0; i < rhs.size(); i++) {
103 lhs[i] += rhs[i];
104 }
105 // Pop off the RHS.
106 operandExprStack.pop_back();
107 }
108 void visitModExpr(AffineBinaryOpExpr *expr) {
109 assert(expr->isPureAffine());
110 // This is a pure affine expr; the RHS is a constant.
111 auto rhsConst = operandExprStack.back()[getNumCols() - 1];
112 operandExprStack.pop_back();
113 auto &lhs = operandExprStack.back();
114 assert(rhsConst != 0 && "RHS constant can't be zero");
115 unsigned i;
116 for (i = 0; i < lhs.size(); i++)
117 if (lhs[i] % rhsConst != 0)
118 break;
119 if (i == lhs.size()) {
120 // The modulo expression here simplifies to zero.
121 lhs.assign(lhs.size(), 0);
122 return;
123 }
124 // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
125 // q * expr2) where q is the existential quantifier introduced.
126 addExistentialQuantifier();
127 lhs = operandExprStack.back();
128 lhs[numDims + numSymbols + numLocals - 1] = -rhsConst;
129 }
130 void visitConstantExpr(AffineConstantExpr *expr) {
131 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
132 auto &eq = operandExprStack.back();
133 eq[getNumCols() - 1] = expr->getValue();
134 }
135 void visitDimExpr(AffineDimExpr *expr) {
136 SmallVector<int64_t, 32> eq(getNumCols(), 0);
137 eq[expr->getPosition()] = 1;
138 operandExprStack.push_back(eq);
139 }
140 void visitSymbolExpr(AffineSymbolExpr *expr) {
141 SmallVector<int64_t, 32> eq(getNumCols(), 0);
142 eq[numDims + expr->getPosition()] = 1;
143 operandExprStack.push_back(eq);
144 }
145 void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
146 // TODO(bondhugula): handle ceildiv as well; won't simplify further through
147 // this analysis but will be handled (rest of the expr will simplify).
148 report_fatal_error("ceildiv expr simplification not supported here");
149 }
150 void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
151 // TODO(bondhugula): handle ceildiv as well; won't simplify further through
152 // this analysis but will be handled (rest of the expr will simplify).
153 report_fatal_error("floordiv expr simplification unimplemented");
154 }
155 // Add an existential quantifier (used to flatten a mod or a floordiv expr).
156 void addExistentialQuantifier() {
157 for (auto &subExpr : operandExprStack) {
158 subExpr.insert(subExpr.begin() + numDims + numSymbols + numLocals, 0);
159 }
160 numLocals++;
161 }
162
163 unsigned numDims;
164 unsigned numSymbols;
165 unsigned numLocals;
166};
167
168} // end anonymous namespace
169
170FunctionPass *mlir::createSimplifyAffineExprPass() {
171 return new SimplifyAffineExpr();
172}
173
174AffineMap *MutableAffineMap::getAffineMap() {
175 return AffineMap::get(numDims, numSymbols, results, rangeSizes, context);
176}
177
178void SimplifyAffineExpr::runOnMLFunction(MLFunction *f) {
179 struct MapSimplifier : public StmtWalker<MapSimplifier> {
180 MLIRContext *context;
181 MapSimplifier(MLIRContext *context) : context(context) {}
182
183 void visitOperationStmt(OperationStmt *opStmt) {
184 for (auto attr : opStmt->getAttrs()) {
185 if (auto *mapAttr = dyn_cast<AffineMapAttr>(attr.second)) {
186 MutableAffineMap mMap(mapAttr->getValue(), context);
187 mMap.simplify();
188 auto *map = mMap.getAffineMap();
189 opStmt->setAttr(attr.first, AffineMapAttr::get(map, context));
190 }
191 }
192 }
193 };
194
195 MapSimplifier v(f->getContext());
196 v.walkPostOrder(f);
197}
198
199/// Get an affine expression from a flat ArrayRef. If there are local variables
200/// (existential quantifiers introduced during the flattening) that appear in
201/// the sum of products expression, we can't readily express it as an affine
202/// expression of dimension and symbol id's; return nullptr in such cases.
203static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
204 unsigned numSymbols, MLIRContext *context) {
205 // Check if any local variable has a non-zero coefficient.
206 for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
207 if (eq[j] != 0)
208 return nullptr;
209 }
210
211 AffineExpr *expr = AffineConstantExpr::get(0, context);
212 for (unsigned j = 0; j < numDims + numSymbols; j++) {
213 if (eq[j] != 0) {
214 AffineExpr *id =
215 j < numDims
216 ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
217 : AffineSymbolExpr::get(j - numDims, context);
218 expr = AffineBinaryOpExpr::get(
219 AffineExpr::Kind::Add, expr,
220 AffineBinaryOpExpr::get(AffineExpr::Kind::Mul,
221 AffineConstantExpr::get(eq[j], context), id,
222 context),
223 context);
224 }
225 }
226 unsigned constTerm = eq[eq.size() - 1];
227 if (constTerm != 0)
228 expr = AffineBinaryOpExpr::get(AffineExpr::Kind::Add, expr,
229 AffineConstantExpr::get(constTerm, context),
230 context);
231 return expr;
232}
233
234// Simplify the result affine expressions of this map. The expressions have to
235// be pure for the simplification implemented.
236void MutableAffineMap::simplify() {
237 // Simplify each of the results if possible.
238 for (unsigned i = 0, e = getNumResults(); i < e; i++) {
239 AffineExpr *result = getResult(i);
240 if (!result->isPureAffine())
241 continue;
242
243 AffineExprFlattener flattener(numDims, numSymbols);
244 flattener.walkPostOrder(result);
245 const auto &flattenedExpr = flattener.operandExprStack.back();
246 auto *expr = toAffineExpr(flattenedExpr, numDims, numSymbols, context);
247 if (expr)
248 results[i] = expr;
249 flattener.operandExprStack.pop_back();
250 assert(flattener.operandExprStack.empty());
251 }
252}