blob: 5e6bd7fa59bd512e534fd26bbd48a3e133dc0a37 [file] [log] [blame]
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -07001//===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements miscellaneous loop analysis routines.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/LoopAnalysis.h"
23
24#include "mlir/Analysis/AffineAnalysis.h"
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -070025#include "mlir/Analysis/AffineStructures.h"
26#include "mlir/Analysis/MLFunctionMatcher.h"
Nicolas Vasilache13b3bce2018-11-20 08:36:07 -080027#include "mlir/Analysis/VectorAnalysis.h"
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -070028#include "mlir/IR/Builders.h"
29#include "mlir/IR/BuiltinOps.h"
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070030#include "mlir/IR/Statements.h"
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -070031#include "mlir/StandardOps/StandardOps.h"
Alex Zinenkof9e30b92018-12-14 09:31:17 -080032#include "mlir/SuperVectorOps/SuperVectorOps.h"
Nicolas Vasilache078d9b92018-10-30 07:54:23 -070033#include "mlir/Support/Functional.h"
Uday Bondhugula48e4c4b2018-10-03 10:07:54 -070034#include "mlir/Support/MathExtras.h"
Nicolas Vasilache787a93c2018-12-06 11:37:25 -080035
36#include "llvm/ADT/DenseSet.h"
Nicolas Vasilache6b197462018-11-14 04:04:10 -080037#include "llvm/ADT/SmallString.h"
Nicolas Vasilache787a93c2018-12-06 11:37:25 -080038#include <type_traits>
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070039
Nicolas Vasilache5373b092018-10-03 15:39:12 -070040using namespace mlir;
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070041
42/// Returns the trip count of the loop as an affine expression if the latter is
43/// expressible as an affine expression, and nullptr otherwise. The trip count
44/// expression is simplified before returning.
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -070045AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
Nicolas Vasilacheff303282018-11-07 05:44:50 -080046 // upper_bound - lower_bound
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070047 int64_t loopSpan;
48
49 int64_t step = forStmt.getStep();
50 auto *context = forStmt.getContext();
51
52 if (forStmt.hasConstantBounds()) {
53 int64_t lb = forStmt.getConstantLowerBound();
54 int64_t ub = forStmt.getConstantUpperBound();
Nicolas Vasilacheff303282018-11-07 05:44:50 -080055 loopSpan = ub - lb;
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070056 } else {
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070057 auto lbMap = forStmt.getLowerBoundMap();
58 auto ubMap = forStmt.getUpperBoundMap();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070059 // TODO(bondhugula): handle max/min of multiple expressions.
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070060 if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070061 return nullptr;
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070062
63 // TODO(bondhugula): handle bounds with different operands.
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070064 // Bounds have different operands, unhandled for now.
Uday Bondhugula5912e872018-09-18 10:22:03 -070065 if (!forStmt.matchingBoundOperandList())
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070066 return nullptr;
67
Nicolas Vasilacheff303282018-11-07 05:44:50 -080068 // ub_expr - lb_expr
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070069 AffineExpr lbExpr(lbMap.getResult(0));
70 AffineExpr ubExpr(ubMap.getResult(0));
Nicolas Vasilache5373b092018-10-03 15:39:12 -070071 auto loopSpanExpr = simplifyAffineExpr(
Nicolas Vasilacheff303282018-11-07 05:44:50 -080072 ubExpr - lbExpr, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070073 std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -070074 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070075 if (!cExpr)
Nicolas Vasilachebc746092018-10-08 10:20:25 -070076 return loopSpanExpr.ceilDiv(step);
Nicolas Vasilacheb7717092018-10-09 10:59:27 -070077 loopSpan = cExpr.getValue();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070078 }
79
80 // 0 iteration loops.
Uday Bondhugulaff5d6bd2018-09-27 18:03:27 -070081 if (loopSpan < 0)
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070082 return 0;
83
Nicolas Vasilachebc746092018-10-08 10:20:25 -070084 return getAffineConstantExpr(static_cast<uint64_t>(ceilDiv(loopSpan, step)),
85 context);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070086}
87
88/// Returns the trip count of the loop if it's a constant, None otherwise. This
89/// method uses affine expression analysis (in turn using getTripCount) and is
90/// able to determine constant trip count in non-trivial cases.
91llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
Nicolas Vasilache5373b092018-10-03 15:39:12 -070092 auto tripCountExpr = getTripCountExpr(forStmt);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070093
Nicolas Vasilache32402e52018-10-08 08:09:50 -070094 if (!tripCountExpr)
95 return None;
96
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -070097 if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>())
Nicolas Vasilacheb7717092018-10-09 10:59:27 -070098 return constExpr.getValue();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070099
100 return None;
101}
102
103/// Returns the greatest known integral divisor of the trip count. Affine
104/// expression analysis is used (indirectly through getTripCount), and
105/// this method is thus able to determine non-trivial divisors.
106uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
Nicolas Vasilache5373b092018-10-03 15:39:12 -0700107 auto tripCountExpr = getTripCountExpr(forStmt);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700108
109 if (!tripCountExpr)
110 return 1;
111
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -0700112 if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>()) {
Nicolas Vasilacheb7717092018-10-09 10:59:27 -0700113 uint64_t tripCount = constExpr.getValue();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700114
115 // 0 iteration loops (greatest divisor is 2^64 - 1).
116 if (tripCount == 0)
117 return ULONG_MAX;
118
119 // The greatest divisor is the trip count.
120 return tripCount;
121 }
122
123 // Trip count is not a known constant; return its largest known divisor.
Nicolas Vasilacheb7717092018-10-09 10:59:27 -0700124 return tripCountExpr.getLargestKnownDivisor();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700125}
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700126
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800127bool mlir::isAccessInvariant(const MLValue &iv, const MLValue &index) {
128 assert(isa<ForStmt>(iv) && "iv must be a ForStmt");
129 assert(index.getType().isa<IndexType>() && "index must be of IndexType");
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700130 SmallVector<OperationStmt *, 4> affineApplyOps;
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800131 getReachableAffineApplyOps({const_cast<MLValue *>(&index)}, affineApplyOps);
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700132
133 if (affineApplyOps.empty()) {
134 // Pointer equality test because of MLValue pointer semantics.
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800135 return &index != &iv;
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700136 }
137
Nicolas Vasilachec040bd52018-12-06 11:38:09 -0800138 if (affineApplyOps.size() > 1) {
139 affineApplyOps[0]->emitError(
140 "CompositionAffineMapsPass must have been run: there should be at most "
141 "one AffineApplyOp");
142 return false;
143 }
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800144
Feng Liuec065d72018-10-19 09:07:58 -0700145 auto composeOp = affineApplyOps[0]->cast<AffineApplyOp>();
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700146 // We need yet another level of indirection because the `dim` index of the
147 // access may not correspond to the `dim` index of composeOp.
148 unsigned idx = std::numeric_limits<unsigned>::max();
149 unsigned numResults = composeOp->getNumResults();
150 for (unsigned i = 0; i < numResults; ++i) {
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800151 if (&index == composeOp->getResult(i)) {
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700152 idx = i;
153 break;
154 }
155 }
156 assert(idx < std::numeric_limits<unsigned>::max());
157 return !AffineValueMap(*composeOp)
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800158 .isFunctionOf(idx, &const_cast<MLValue &>(iv));
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700159}
160
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800161llvm::DenseSet<const MLValue *>
162mlir::getInvariantAccesses(const MLValue &iv,
163 llvm::ArrayRef<const MLValue *> indices) {
164 llvm::DenseSet<const MLValue *> res;
165 for (unsigned idx = 0, n = indices.size(); idx < n; ++idx) {
166 auto *val = indices[idx];
167 if (isAccessInvariant(iv, *val)) {
168 res.insert(val);
169 }
170 }
171 return res;
172}
173
174/// Given:
175/// 1. an induction variable `iv` of type ForStmt;
176/// 2. a `memoryOp` of type const LoadOp& or const StoreOp&;
177/// 3. the index of the `fastestVaryingDim` along which to check;
178/// determines whether `memoryOp`[`fastestVaryingDim`] is a contiguous access
179/// along `iv`.
180/// Contiguous is defined as either invariant or varying only along
181/// `fastestVaryingDim`.
182///
183/// Prerequisites:
184/// 1. `iv` of the proper type;
185/// 2. the MemRef accessed by `memoryOp` has no layout map or at most an
186/// identity layout map.
187///
Nicolas Vasilachec040bd52018-12-06 11:38:09 -0800188/// Currently only supports no layoutMap or identity layoutMap in the MemRef.
189/// Returns false if the MemRef has a non-identity layoutMap or more than
190/// 1 layoutMap. This is conservative.
191///
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800192// TODO(ntv): check strides.
193template <typename LoadOrStoreOp>
194static bool isContiguousAccess(const MLValue &iv, const LoadOrStoreOp &memoryOp,
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700195 unsigned fastestVaryingDim) {
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800196 static_assert(std::is_same<LoadOrStoreOp, LoadOp>::value ||
197 std::is_same<LoadOrStoreOp, StoreOp>::value,
198 "Must be called on either const LoadOp & or const StoreOp &");
199 auto memRefType = memoryOp.getMemRefType();
200 auto layoutMap = memRefType.getAffineMaps();
Nicolas Vasilachec040bd52018-12-06 11:38:09 -0800201 // TODO(ntv): remove dependence on Builder once we support non-identity
202 // layout map.
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800203 Builder b(memoryOp.getOperation()->getContext());
Nicolas Vasilachec040bd52018-12-06 11:38:09 -0800204 if (layoutMap.size() >= 2 ||
205 (layoutMap.size() == 1 &&
206 !(layoutMap[0] ==
207 b.getMultiDimIdentityMap(layoutMap[0].getNumDims())))) {
208 return memoryOp.emitError("NYI: non-trivial layoutMap"), false;
209 }
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800210 assert(fastestVaryingDim < memRefType.getRank());
211
212 auto indices = memoryOp.getIndices();
213 // TODO(clattner): should iterator_range have a size method?
214 auto numIndices = indices.end() - indices.begin();
215 unsigned d = 0;
216 for (auto index : indices) {
217 if (fastestVaryingDim == (numIndices - 1) - d++) {
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700218 continue;
219 }
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800220 if (!isAccessInvariant(iv, cast<MLValue>(*index))) {
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700221 return false;
222 }
223 }
224 return true;
225}
226
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700227template <typename LoadOrStoreOpPointer>
228static bool isVectorElement(LoadOrStoreOpPointer memoryOp) {
River Riddle666dfbe2018-10-30 14:59:22 -0700229 auto memRefType = memoryOp->getMemRefType();
230 return memRefType.getElementType().template isa<VectorType>();
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700231}
232
Nicolas Vasilache6b197462018-11-14 04:04:10 -0800233static bool isVectorTransferReadOrWrite(const Statement &stmt) {
234 const auto *opStmt = cast<OperationStmt>(&stmt);
Nicolas Vasilache9a19ada2018-12-03 15:21:27 -0800235 return opStmt->isa<VectorTransferReadOp>() ||
236 opStmt->isa<VectorTransferWriteOp>();
Nicolas Vasilache6b197462018-11-14 04:04:10 -0800237}
238
Nicolas Vasilached64816a2018-11-01 07:14:14 -0700239using VectorizableStmtFun =
240 std::function<bool(const ForStmt &, const OperationStmt &)>;
241
242static bool isVectorizableLoopWithCond(const ForStmt &loop,
243 VectorizableStmtFun isVectorizableStmt) {
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700244 if (!matcher::isParallelLoop(loop) && !matcher::isReductionLoop(loop)) {
245 return false;
246 }
247
248 // No vectorization across conditionals for now.
249 auto conditionals = matcher::If();
250 auto *forStmt = const_cast<ForStmt *>(&loop);
251 auto conditionalsMatched = conditionals.match(forStmt);
252 if (!conditionalsMatched.empty()) {
253 return false;
254 }
255
Nicolas Vasilache6b197462018-11-14 04:04:10 -0800256 auto vectorTransfers = matcher::Op(isVectorTransferReadOrWrite);
257 auto vectorTransfersMatched = vectorTransfers.match(forStmt);
258 if (!vectorTransfersMatched.empty()) {
259 return false;
260 }
261
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700262 auto loadAndStores = matcher::Op(matcher::isLoadOrStore);
263 auto loadAndStoresMatched = loadAndStores.match(forStmt);
264 for (auto ls : loadAndStoresMatched) {
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700265 auto *op = cast<OperationStmt>(ls.first);
Feng Liuec065d72018-10-19 09:07:58 -0700266 auto load = op->dyn_cast<LoadOp>();
267 auto store = op->dyn_cast<StoreOp>();
Nicolas Vasilache078d9b92018-10-30 07:54:23 -0700268 // Only scalar types are considered vectorizable, all load/store must be
269 // vectorizable for a loop to qualify as vectorizable.
270 // TODO(ntv): ponder whether we want to be more general here.
271 bool vector = load ? isVectorElement(load) : isVectorElement(store);
272 if (vector) {
273 return false;
274 }
Nicolas Vasilached64816a2018-11-01 07:14:14 -0700275 if (!isVectorizableStmt(loop, *op)) {
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700276 return false;
277 }
278 }
279 return true;
280}
Uday Bondhugula861fe642018-10-18 11:14:26 -0700281
Nicolas Vasilached64816a2018-11-01 07:14:14 -0700282bool mlir::isVectorizableLoopAlongFastestVaryingMemRefDim(
283 const ForStmt &loop, unsigned fastestVaryingDim) {
284 VectorizableStmtFun fun(
285 [fastestVaryingDim](const ForStmt &loop, const OperationStmt &op) {
286 auto load = op.dyn_cast<LoadOp>();
287 auto store = op.dyn_cast<StoreOp>();
Nicolas Vasilache787a93c2018-12-06 11:37:25 -0800288 return load ? isContiguousAccess(loop, *load, fastestVaryingDim)
289 : isContiguousAccess(loop, *store, fastestVaryingDim);
Nicolas Vasilached64816a2018-11-01 07:14:14 -0700290 });
291 return isVectorizableLoopWithCond(loop, fun);
292}
293
294bool mlir::isVectorizableLoop(const ForStmt &loop) {
295 VectorizableStmtFun fun(
296 // TODO: implement me
297 [](const ForStmt &loop, const OperationStmt &op) { return true; });
298 return isVectorizableLoopWithCond(loop, fun);
299}
300
Uday Bondhugula861fe642018-10-18 11:14:26 -0700301/// Checks whether SSA dominance would be violated if a for stmt's body
302/// statements are shifted by the specified shifts. This method checks if a
303/// 'def' and all its uses have the same shift factor.
304// TODO(mlir-team): extend this to check for memory-based dependence
305// violation when we have the support.
306bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
307 ArrayRef<uint64_t> shifts) {
308 assert(shifts.size() == forStmt.getStatements().size());
309 unsigned s = 0;
310 for (const auto &stmt : forStmt) {
311 // A for or if stmt does not produce any def/results (that are used
312 // outside).
313 if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
314 for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
315 const MLValue *result = opStmt->getResult(i);
316 for (const StmtOperand &use : result->getUses()) {
317 // If an ancestor statement doesn't lie in the block of forStmt, there
318 // is no shift to check.
319 // This is a naive way. If performance becomes an issue, a map can
320 // be used to store 'shifts' - to look up the shift for a statement in
321 // constant time.
322 if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
323 if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
324 return false;
325 }
326 }
327 }
328 s++;
329 }
330 return true;
331}