blob: 629a5eb4d92a715249de4cbb57a0858ae05ea518 [file] [log] [blame]
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -07001//===- AffineAnalysis.cpp - Affine structures analysis routines -----------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements miscellaneous analysis routines for affine structures
19// (expressions, maps, sets), and other utilities relying on such analysis.
20//
21//===----------------------------------------------------------------------===//
22
23#include "mlir/Analysis/AffineAnalysis.h"
24#include "mlir/IR/AffineExprVisitor.h"
25#include "llvm/ADT/ArrayRef.h"
26
27using namespace mlir;
28
29/// Constructs an affine expression from a flat ArrayRef. If there are local
30/// identifiers (neither dimensional nor symbolic) that appear in the sum of
31/// products expression, 'localExprs' is expected to have the AffineExpr for it,
32/// and is substituted into. The ArrayRef 'eq' is expected to be in the format
33/// [dims, symbols, locals, constant term].
34static AffineExpr *toAffineExpr(ArrayRef<int64_t> eq, unsigned numDims,
35 unsigned numSymbols,
36 ArrayRef<AffineExpr *> localExprs,
37 MLIRContext *context) {
38 // Assert expected numLocals = eq.size() - numDims - numSymbols - 1
39 assert(eq.size() - numDims - numSymbols - 1 == localExprs.size() &&
40 "unexpected number of local expressions");
41
42 AffineExpr *expr = AffineConstantExpr::get(0, context);
43 // Dimensions and symbols.
44 for (unsigned j = 0; j < numDims + numSymbols; j++) {
45 if (eq[j] != 0) {
46 AffineExpr *id =
47 j < numDims
48 ? static_cast<AffineExpr *>(AffineDimExpr::get(j, context))
49 : AffineSymbolExpr::get(j - numDims, context);
50 auto *term = AffineBinaryOpExpr::getMul(
51 AffineConstantExpr::get(eq[j], context), id, context);
52 expr = AffineBinaryOpExpr::getAdd(expr, term, context);
53 }
54 }
55
56 // Local identifiers.
57 for (unsigned j = numDims + numSymbols; j < eq.size() - 1; j++) {
58 if (eq[j] != 0) {
59 auto *term = AffineBinaryOpExpr::getMul(
60 AffineConstantExpr::get(eq[j], context),
61 localExprs[j - numDims - numSymbols], context);
62 expr = AffineBinaryOpExpr::getAdd(expr, term, context);
63 }
64 }
65
66 // Constant term.
67 unsigned constTerm = eq[eq.size() - 1];
68 if (constTerm != 0)
69 expr = AffineBinaryOpExpr::getAdd(
70 expr, AffineConstantExpr::get(constTerm, context), context);
71 return expr;
72}
73
74namespace {
75
76// This class is used to flatten a pure affine expression (AffineExpr *, which
77// is in a tree form) into a sum of products (w.r.t constants) when possible,
78// and in that process simplifying the expression. The simplification performed
79// includes the accumulation of contributions for each dimensional and symbolic
80// identifier together, the simplification of floordiv/ceildiv/mod exprssions
81// and other simplifications that in turn happen as a result. A simplification
82// that this flattening naturally performs is of simplifying the numerator and
83// denominator of floordiv/ceildiv, and folding a modulo expression to a zero,
84// if possible. Three examples are below:
85//
86// (d0 + 3 * d1) + d0) - 2 * d1) - d0 simplified to d0 + d1
87// (d0 - d0 mod 4 + 4) mod 4 simplified to 0.
88// (3*d0 + 2*d1 + d0) floordiv 2 + d1 simplified to 2*d0 + 2*d1
89//
90// For a modulo, floordiv, or a ceildiv expression, an additional identifier
91// (called a local identifier) is introduced to rewrite it as a sum of products
92// (w.r.t constants). For example, for the second example above, d0 % 4 is
93// replaced by d0 - 4*q with q being introduced: the expression then simplifies
94// to: (d0 - (d0 - 4q) + 4) = 4q + 4, modulo of which w.r.t 4 simplifies to
95// zero. Note that an affine expression may not always be expressible in a sum
96// of products form due to the presence of modulo/floordiv/ceildiv expressions
97// that may not be eliminated after simplification; in such cases, the final
98// expression can be reconstructed by replacing the local identifier with its
99// explicit form stored in localExprs (note that the explicit form itself would
100// have been simplified and not necessarily the original form).
101//
102// This is a linear time post order walk for an affine expression that attempts
103// the above simplifications through visit methods, with partial results being
104// stored in 'operandExprStack'. When a parent expr is visited, the flattened
105// expressions corresponding to its two operands would already be on the stack -
106// the parent expr looks at the two flattened expressions and combines the two.
107// It pops off the operand expressions and pushes the combined result (although
108// this is done in-place on its LHS operand expr. When the walk is completed,
109// the flattened form of the top-level expression would be left on the stack.
110//
111class AffineExprFlattener : public AffineExprVisitor<AffineExprFlattener> {
112public:
113 // Flattend expression layout: [dims, symbols, locals, constant]
114 // Stack that holds the LHS and RHS operands while visiting a binary op expr.
115 // In future, consider adding a prepass to determine how big the SmallVector's
116 // will be, and linearize this to std::vector<int64_t> to prevent
117 // SmallVector moves on re-allocation.
118 std::vector<SmallVector<int64_t, 32>> operandExprStack;
119
120 inline unsigned getNumCols() const {
121 return numDims + numSymbols + numLocals + 1;
122 }
123
124 unsigned numDims;
125 unsigned numSymbols;
126 // Number of newly introduced identifiers to flatten mod/floordiv/ceildiv
127 // expressions that could not be simplified.
128 unsigned numLocals;
129 // AffineExpr's corresponding to the floordiv/ceildiv/mod expressions for
130 // which new identifiers were introduced; if the latter do not get canceled
131 // out, these expressions are needed to reconstruct the AffineExpr * / tree
132 // form. Note that these expressions themselves would have been simplified
133 // (recursively) by this pass. Eg. d0 + (d0 + 2*d1 + d0) ceildiv 4 will be
134 // simplified to d0 + q, where q = (d0 + d1) ceildiv 2. (d0 + d1) ceildiv 2
135 // would be the local expression stored for q.
136 SmallVector<AffineExpr *, 4> localExprs;
137 MLIRContext *context;
138
139 AffineExprFlattener(unsigned numDims, unsigned numSymbols,
140 MLIRContext *context)
141 : numDims(numDims), numSymbols(numSymbols), numLocals(0),
142 context(context) {
143 operandExprStack.reserve(8);
144 }
145
146 void visitMulExpr(AffineBinaryOpExpr *expr) {
147 assert(operandExprStack.size() >= 2);
148 // This is a pure affine expr; the RHS will be a constant.
149 assert(isa<AffineConstantExpr>(expr->getRHS()));
150 // Get the RHS constant.
151 auto rhsConst = operandExprStack.back()[getConstantIndex()];
152 operandExprStack.pop_back();
153 // Update the LHS in place instead of pop and push.
154 auto &lhs = operandExprStack.back();
155 for (unsigned i = 0, e = lhs.size(); i < e; i++) {
156 lhs[i] *= rhsConst;
157 }
158 }
159
160 void visitAddExpr(AffineBinaryOpExpr *expr) {
161 assert(operandExprStack.size() >= 2);
162 const auto &rhs = operandExprStack.back();
163 auto &lhs = operandExprStack[operandExprStack.size() - 2];
164 assert(lhs.size() == rhs.size());
165 // Update the LHS in place.
166 for (unsigned i = 0; i < rhs.size(); i++) {
167 lhs[i] += rhs[i];
168 }
169 // Pop off the RHS.
170 operandExprStack.pop_back();
171 }
172
173 void visitModExpr(AffineBinaryOpExpr *expr) {
174 assert(operandExprStack.size() >= 2);
175 // This is a pure affine expr; the RHS will be a constant.
176 assert(isa<AffineConstantExpr>(expr->getRHS()));
177 auto rhsConst = operandExprStack.back()[getConstantIndex()];
178 operandExprStack.pop_back();
179 auto &lhs = operandExprStack.back();
180 // TODO(bondhugula): handle modulo by zero case when this issue is fixed
181 // at the other places in the IR.
182 assert(rhsConst != 0 && "RHS constant can't be zero");
183
184 // Check if the LHS expression is a multiple of modulo factor.
185 unsigned i;
186 for (i = 0; i < lhs.size(); i++)
187 if (lhs[i] % rhsConst != 0)
188 break;
189 // If yes, modulo expression here simplifies to zero.
190 if (i == lhs.size()) {
191 lhs.assign(lhs.size(), 0);
192 return;
193 }
194
195 // Add an existential quantifier. expr1 % expr2 is replaced by (expr1 -
196 // q * expr2) where q is the existential quantifier introduced.
197 addLocalId(AffineBinaryOpExpr::getFloorDiv(
198 toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
199 AffineConstantExpr::get(rhsConst, context), context));
200 lhs[getLocalVarStartIndex() + numLocals - 1] = -rhsConst;
201 }
202 void visitCeilDivExpr(AffineBinaryOpExpr *expr) {
203 visitDivExpr(expr, /*isCeil=*/true);
204 }
205 void visitFloorDivExpr(AffineBinaryOpExpr *expr) {
206 visitDivExpr(expr, /*isCeil=*/false);
207 }
208 void visitDimExpr(AffineDimExpr *expr) {
209 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
210 auto &eq = operandExprStack.back();
211 eq[getDimStartIndex() + expr->getPosition()] = 1;
212 }
213 void visitSymbolExpr(AffineSymbolExpr *expr) {
214 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
215 auto &eq = operandExprStack.back();
216 eq[getSymbolStartIndex() + expr->getPosition()] = 1;
217 }
218 void visitConstantExpr(AffineConstantExpr *expr) {
219 operandExprStack.emplace_back(SmallVector<int64_t, 32>(getNumCols(), 0));
220 auto &eq = operandExprStack.back();
221 eq[getConstantIndex()] = expr->getValue();
222 }
223
224private:
225 void visitDivExpr(AffineBinaryOpExpr *expr, bool isCeil) {
226 assert(operandExprStack.size() >= 2);
227 assert(isa<AffineConstantExpr>(expr->getRHS()));
228 // This is a pure affine expr; the RHS is a positive constant.
229 auto rhsConst = operandExprStack.back()[getConstantIndex()];
230 // TODO(bondhugula): handle division by zero at the same time the issue is
231 // fixed at other places.
232 assert(rhsConst != 0 && "RHS constant can't be zero");
233 operandExprStack.pop_back();
234 auto &lhs = operandExprStack.back();
235
236 // Simplify the floordiv, ceildiv if possible by canceling out the greatest
237 // common divisors of the numerator and denominator.
238 uint64_t gcd = std::abs(rhsConst);
239 for (unsigned i = 0; i < lhs.size(); i++)
240 gcd = llvm::GreatestCommonDivisor64(gcd, std::abs(lhs[i]));
241 // Simplify the numerator and the denominator.
242 if (gcd != 1) {
243 for (unsigned i = 0; i < lhs.size(); i++)
244 lhs[i] = lhs[i] / gcd;
245 }
246 int64_t denominator = rhsConst / gcd;
247 // If the denominator becomes 1, the updated LHS is the result. (The
248 // denominator can't be negative since rhsConst is positive).
249 if (denominator == 1)
250 return;
251
252 // If the denominator cannot be simplified to one, we will have to retain
253 // the ceil/floor expr (simplified up until here). Add an existential
254 // quantifier to express its result, i.e., expr1 div expr2 is replaced
255 // by a new identifier, q.
256 auto divKind =
257 isCeil ? AffineExpr::Kind::CeilDiv : AffineExpr::Kind::FloorDiv;
258 addLocalId(AffineBinaryOpExpr::get(
259 divKind, toAffineExpr(lhs, numDims, numSymbols, localExprs, context),
260 AffineConstantExpr::get(denominator, context), context));
261 lhs.assign(lhs.size(), 0);
262 lhs[getLocalVarStartIndex() + numLocals - 1] = 1;
263 }
264
265 // Add an existential quantifier (used to flatten a mod, floordiv, ceildiv
266 // expr). localExpr is the simplified tree expression (AffineExpr *)
267 // corresponding to the quantifier.
268 void addLocalId(AffineExpr *localExpr) {
269 for (auto &subExpr : operandExprStack) {
270 subExpr.insert(subExpr.begin() + getLocalVarStartIndex() + numLocals, 0);
271 }
272 localExprs.push_back(localExpr);
273 numLocals++;
274 }
275
276 inline unsigned getConstantIndex() const { return getNumCols() - 1; }
277 inline unsigned getLocalVarStartIndex() const { return numDims + numSymbols; }
278 inline unsigned getSymbolStartIndex() const { return numDims; }
279 inline unsigned getDimStartIndex() const { return 0; }
280};
281
282} // end anonymous namespace
283
284AffineExpr *mlir::simplifyAffineExpr(AffineExpr *expr, unsigned numDims,
285 unsigned numSymbols,
286 MLIRContext *context) {
287 // TODO(bondhugula): only pure affine for now. The simplification here can be
288 // extended to semi-affine maps in the future.
289 if (!expr->isPureAffine())
290 return nullptr;
291
292 AffineExprFlattener flattener(numDims, numSymbols, context);
293 flattener.walkPostOrder(expr);
294 ArrayRef<int64_t> flattenedExpr = flattener.operandExprStack.back();
295 auto *simplifiedExpr = toAffineExpr(flattenedExpr, numDims, numSymbols,
296 flattener.localExprs, context);
297 flattener.operandExprStack.pop_back();
298 assert(flattener.operandExprStack.empty());
299 if (simplifiedExpr == expr)
300 return nullptr;
301 return simplifiedExpr;
302}