blob: bff76a837fdb46b37b3b8bebccc4bb8663359aeb [file] [log] [blame]
Uday Bondhugula257339b2018-08-21 10:32:24 -07001//===- AffineStructures.cpp - MLIR Affine Structures Class-------*- C++ -*-===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// Structures for affine/polyhedral analysis of MLIR functions.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/AffineStructures.h"
Uday Bondhugula83a41c92018-08-30 17:35:15 -070023
Uday Bondhugula83a41c92018-08-30 17:35:15 -070024#include "mlir/IR/AffineExprVisitor.h"
Uday Bondhugula257339b2018-08-21 10:32:24 -070025#include "mlir/IR/AffineMap.h"
26#include "mlir/IR/IntegerSet.h"
Uday Bondhugula257339b2018-08-21 10:32:24 -070027#include "mlir/IR/StandardOps.h"
Uday Bondhugula83a41c92018-08-30 17:35:15 -070028#include "llvm/Support/raw_ostream.h"
Uday Bondhugula257339b2018-08-21 10:32:24 -070029
Uday Bondhugula128c7aa2018-09-04 15:55:38 -070030using namespace mlir;
31
32/// Constructs an affine expression from a flat ArrayRef. If there are local
33/// identifiers (neither dimensional nor symbolic) that appear in the sum of
34/// products expression, 'localExprs' is expected to have the AffineExpr for it,
35/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
36/// [dims, symbols, locals, constant term].
37static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
38 unsigned numSymbols,
39 ArrayRef<AffineExpr *> localExprs,
40 MLIRContext *context) {
MLIR Teamd651bc32018-09-05 07:38:20 -070041 // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
42 assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
Uday Bondhugula128c7aa2018-09-04 15:55:38 -070043 "unexpected number of local expressions");
44
45 AffineExpr *expr = AffineConstantExpr::get(0, context);
46 // Dimensions and symbols.
47 for (unsigned j = 0; j < numDims + numSymbols; j++) {
48 if (eq[j] != 0) {
49 AffineExpr *id =
50 j < numDims
51 ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
52 : AffineSymbolExpr::get(j - numDims, context);
53 auto *term = AffineBinaryOpExpr::getMul(
54 AffineConstantExpr::get(eq[j], context), id, context);
55 expr = AffineBinaryOpExpr::getAdd(expr, term, context);
56 }
57 }
58
59 // Local identifiers.
60 for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
61 if (eq[j] != 0) {
62 auto *term = AffineBinaryOpExpr::getMul(
63 AffineConstantExpr::get(eq[j], context),
64 localExprs[j - numDims - numSymbols], context);
65 expr = AffineBinaryOpExpr::getAdd(expr, term, context);
66 }
67 }
68
69 // Constant term.
70 unsigned constTerm = eq[eq.size() - 1];
71 if (constTerm != 0)
72 expr = AffineBinaryOpExpr::getAdd(
73 expr, AffineConstantExpr::get(constTerm, context), context);
74 return expr;
75}
76
77namespace {
78
79// This class is used to flatten a pure affine expression (AffineExpr *, which
80// is in a tree form) into a sum of products (w.r.t constants) when possible,
81// and in that process simplifying the expression. The simplification performed
82// includes the accumulation of contributions for each dimensional and symbolic
83// identifier together, the simplification of floordiv/ceildiv/mod exprssions
84// and other simplifications that in turn happen as a result. A simplification
85// that this flattening naturally performs is of simplifying the numerator and
86// denominator of floordiv/ceildiv, and folding a modulo expression to a zero,
87// if possible. Three examples are below:
88//
89// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
90// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
91// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
92//
93// For a modulo, floordiv, or a ceildiv expression, an additional identifier
94// (called a local identifier) is introduced to rewrite it as a sum of products
95// (w.r.t constants). For example, for the second example above, d0 % 4 is
96// replaced by d0 - 4*q with q being introduced: the expression then simplifies
97// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
98// zero. Note that an affine expression may not always be expressible in a sum
99// of products form due to the presence of modulo/floordiv/ceildiv expressions
100// that may not be eliminated after simplification; in such cases, the final
101// expression can be reconstructed by replacing the local identifier with its
102// explicit form stored in localExprs (note that the explicit form itself would
103// have been simplified and not necessarily the original form).
104//
105// This is a linear time post order walk for an affine expression that attempts
106// the above simplifications through visit methods, with partial results being
107// stored in 'operandExprStack'. When a parent expr is visited, the flattened
108// expressions corresponding to its two operands would already be on the stack -
109// the parent expr looks at the two flattened expressions and combines the two.
110// It pops off the operand expressions and pushes the combined result (although
111// this is done in-place on its LHS operand expr. When the walk is completed,
112// the flattened form of the top-level expression would be left on the stack.
113//
114class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
115public:
116 // Flattend expression layout: [dims, symbols, locals, constant]
117 // Stack that holds the LHS and RHS operands while visiting a binary op expr.
118 // In future, consider adding a prepass to determine how big the SmallVector's
119 // will be, and linearize this to std::vector<int64_t> to prevent
120 // SmallVector moves on re-allocation.
121 std::vector<SmallVector<int64_t, 32>> operandExprStack;
122
123 inline unsigned getNumCols() const {
124 return numDims + numSymbols + numLocals + 1;
125 }
126
127 unsigned numDims;
128 unsigned numSymbols;
129 // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv
130 // expressions that could not be simplified.
131 unsigned numLocals;
132 // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
133 // which new identifiers were introduced; if the latter do not get canceled
134 // out, these expressions are needed to reconstruct the AffineExpr * / tree
135 // form. Note that these expressions themselves would have been simplified
136 // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
137 // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
138 // would be the local expression stored for q.
139 SmallVector<AffineExpr *, 4> localExprs;
140 MLIRContext *context;
141
142 AffineExprFlattener(unsigned numDims, unsigned numSymbols,
143 MLIRContext *context)
144 : numDims(numDims), numSymbols(numSymbols), numLocals(0),
145 context(context) {
146 operandExprStack.reserve(8);
147 }
148
149 void visitMulExpr(AffineBinaryOpExpr *expr) {
150 assert(operandExprStack.size() >= 2);
151 // This is a pure affine expr; the RHS will be a constant.
152 assert(isa<AffineConstantExpr>(expr->getRHS()));
153 // Get the RHS constant.
154 auto rhsConst = operandExprStack.back()[getConstantIndex()];
155 operandExprStack.pop_back();
156 // Update the LHS in place instead of pop and push.
157 auto &lhs = operandExprStack.back();
158 for (unsigned i = 0, e = lhs.size(); i < e; i++) {
159 lhs[i] *= rhsConst;
160 }
161 }
162
163 void visitAddExpr(AffineBinaryOpExpr *expr) {
164 assert(operandExprStack.size() >= 2);
165 const auto &rhs = operandExprStack.back();
166 auto &lhs = operandExprStack[operandExprStack.size() - 2];
167 assert(lhs.size() == rhs.size());
168 // Update the LHS in place.
169 for (unsigned i = 0; i < rhs.size(); i++) {
170 lhs[i] += rhs[i];
171 }
172 // Pop off the RHS.
173 operandExprStack.pop_back();
174 }
175
176 void visitModExpr(AffineBinaryOpExpr *expr) {
177 assert(operandExprStack.size() >= 2);
178 // This is a pure affine expr; the RHS will be a constant.
179 assert(isa<AffineConstantExpr>(expr->getRHS()));
180 auto rhsConst = operandExprStack.back()[getConstantIndex()];
181 operandExprStack.pop_back();
182 auto &lhs = operandExprStack.back();
183 // TODO(bondhugula): handle modulo by zero case when this issue is fixed
184 // at the other places in the IR.
185 assert(rhsConst != 0 && "RHS constant can't be zero");
186
187 // Check if the LHS expression is a multiple of modulo factor.
188 unsigned i;
189 for (i = 0; i < lhs.size(); i++)
190 if (lhs[i] % rhsConst != 0)
191 break;
192 // If yes, modulo expression here simplifies to zero.
193 if (i == lhs.size()) {
194 lhs.assign(lhs.size(), 0);
195 return;
196 }
197
198 // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
199 // q * expr2) where q is the existential quantifier introduced.
200 addLocalId(AffineBinaryOpExpr::get(
201 AffineExpr::Kind::FloorDiv,
202 toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
203 AffineConstantExpr::get(rhsConst, context), context));
204 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
205 }
206 void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
207 visitDivExpr(expr, /*isCeil=*/true);
208 }
209 void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
210 visitDivExpr(expr, /*isCeil=*/false);
211 }
212 void visitDimExpr(AffineDimExpr *expr) {
213 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
214 auto &eq = operandExprStack.back();
215 eq[getDimStartIndex() + expr->getPosition()] = 1;
216 }
217 void visitSymbolExpr(AffineSymbolExpr *expr) {
218 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
219 auto &eq = operandExprStack.back();
220 eq[getSymbolStartIndex() + expr->getPosition()] = 1;
221 }
222 void visitConstantExpr(AffineConstantExpr *expr) {
223 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
224 auto &eq = operandExprStack.back();
225 eq[getConstantIndex()] = expr->getValue();
226 }
227
228private:
229 void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
230 assert(operandExprStack.size() >= 2);
231 assert(isa<AffineConstantExpr>(expr->getRHS()));
232 // This is a pure affine expr; the RHS is a positive constant.
233 auto rhsConst = operandExprStack.back()[getConstantIndex()];
234 // TODO(bondhugula): handle division by zero at the same time the issue is
235 // fixed at other places.
236 assert(rhsConst != 0 && "RHS constant can't be zero");
237 operandExprStack.pop_back();
238 auto &lhs = operandExprStack.back();
239
240 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
241 // common divisors of the numerator and denominator.
242 uint64_t gcd = std::abs(rhsConst);
243 for (unsigned i = 0; i < lhs.size(); i++)
244 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
245 // Simplify the numerator and the denominator.
246 if (gcd != 1) {
247 for (unsigned i = 0; i < lhs.size(); i++)
248 lhs[i] = lhs[i] / gcd;
249 }
250 int64_t denominator = rhsConst / gcd;
251 // If the denominator becomes 1, the updated LHS is the result. (The
252 // denominator can't be negative since rhsConst is positive).
253 if (denominator == 1)
254 return;
255
256 // If the denominator cannot be simplified to one, we will have to retain
257 // the ceil/floor expr (simplified up until here). Add an existential
258 // quantifier to express its result, i.e., expr1 div expr2 is replaced
259 // by a new identifier, q.
260 auto divKind =
261 isCeil ? AffineExpr::Kind::CeilDiv : AffineExpr::Kind::FloorDiv;
262 addLocalId(AffineBinaryOpExpr::get(
263 divKind, toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
264 AffineConstantExpr::get(denominator, context), context));
265 lhs.assign(lhs.size(), 0);
266 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
267 }
268
269 // Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
270 // expr). localExpr is the simplified tree expression (AffineExpr *)
271 // corresponding to the quantifier.
272 void addLocalId(AffineExpr *localExpr) {
273 for (auto &subExpr : operandExprStack) {
274 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
275 }
276 localExprs.push_back(localExpr);
277 numLocals++;
278 }
279
280 inline unsigned getConstantIndex() const { return getNumCols() - 1; }
281 inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
282 inline unsigned getSymbolStartIndex() const { return numDims; }
283 inline unsigned getDimStartIndex() const { return 0; }
284};
285
286} // end anonymous namespace
287
288AffineExpr *mlir::simplifyAffineExpr(AffineExpr *expr, unsigned numDims,
289 unsigned numSymbols,
290 MLIRContext *context) {
291 // TODO(bondhugula): only pure affine for now. The simplification here can be
292 // extended to semi-affine maps as well.
293 if (!expr->isPureAffine())
294 return nullptr;
295
296 AffineExprFlattener flattener(numDims, numSymbols, context);
297 flattener.walkPostOrder(expr);
298 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
299 auto *simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
300 flattener.localExprs, context);
301 flattener.operandExprStack.pop_back();
302 assert(flattener.operandExprStack.empty());
303 if (simplifiedExpr == expr)
304 return nullptr;
305 return simplifiedExpr;
306}
Uday Bondhugula257339b2018-08-21 10:32:24 -0700307
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700308MutableAffineMap::MutableAffineMap(AffineMap *map, MLIRContext *context)
309 : numDims(map->getNumDims()), numSymbols(map->getNumSymbols()),
310 context(context) {
Uday Bondhugula257339b2018-08-21 10:32:24 -0700311 for (auto *result : map->getResults())
312 results.push_back(result);
313 for (auto *rangeSize : map->getRangeSizes())
314 results.push_back(rangeSize);
315}
316
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700317bool MutableAffineMap::isMultipleOf(unsigned idx, int64_t factor) const {
318 if (results[idx]->isMultipleOf(factor))
319 return true;
Uday Bondhugula257339b2018-08-21 10:32:24 -0700320
Uday Bondhugula128c7aa2018-09-04 15:55:38 -0700321 // TODO(bondhugula): use simplifyAffineExpr and FlatAffineConstraints to
322 // complete this (for a more powerful analysis).
Uday Bondhugulab553adb2018-08-25 17:17:56 -0700323 return false;
Uday Bondhugula257339b2018-08-21 10:32:24 -0700324}
325
Uday Bondhugula128c7aa2018-09-04 15:55:38 -0700326// Simplifies the result affine expressions of this map. The expressions have to
327// be pure for the simplification implemented.
328void MutableAffineMap::simplify() {
329 // Simplify each of the results if possible.
330 for (unsigned i = 0, e = getNumResults(); i < e; i++) {
331 AffineExpr *sExpr =
332 simplifyAffineExpr(getResult(i), numDims, numSymbols, context);
333 if (sExpr)
334 results[i] = sExpr;
335 }
336}
337
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700338MutableIntegerSet::MutableIntegerSet(IntegerSet *set, MLIRContext *context)
339 : numDims(set->getNumDims()), numSymbols(set->getNumSymbols()),
340 context(context) {
341 // TODO(bondhugula)
342}
343
344// Universal set.
345MutableIntegerSet::MutableIntegerSet(unsigned numDims, unsigned numSymbols,
346 MLIRContext *context)
347 : numDims(numDims), numSymbols(numSymbols), context(context) {}
348
349AffineValueMap::AffineValueMap(const AffineApplyOp &op, MLIRContext *context)
350 : map(op.getAffineMap(), context) {
351 // TODO: pull operands and results in.
352}
353
354inline bool AffineValueMap::isMultipleOf(unsigned idx, int64_t factor) const {
355 return map.isMultipleOf(idx, factor);
356}
357
Uday Bondhugula257339b2018-08-21 10:32:24 -0700358AffineValueMap::~AffineValueMap() {}
359
Uday Bondhugula83a41c92018-08-30 17:35:15 -0700360void FlatAffineConstraints::addEquality(ArrayRef<int64_t> eq) {
361 assert(eq.size() == getNumCols());
362 unsigned offset = equalities.size();
363 equalities.resize(equalities.size() + eq.size());
364 for (unsigned i = 0, e = eq.size(); i < e; i++) {
365 equalities[offset + i] = eq[i];
366 }
367}