blob: 9e65e7b4d9e76ac77990356ad41e8f7b299bb74c [file] [log] [blame]
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -07001//===- LoopAnalysis.cpp - Misc loop analysis routines //-------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements miscellaneous loop analysis routines.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/Analysis/LoopAnalysis.h"
23
24#include "mlir/Analysis/AffineAnalysis.h"
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -070025#include "mlir/Analysis/AffineStructures.h"
26#include "mlir/Analysis/MLFunctionMatcher.h"
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -070027#include "mlir/IR/Builders.h"
28#include "mlir/IR/BuiltinOps.h"
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070029#include "mlir/IR/Statements.h"
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -070030#include "mlir/StandardOps/StandardOps.h"
Uday Bondhugula48e4c4b2018-10-03 10:07:54 -070031#include "mlir/Support/MathExtras.h"
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070032
Nicolas Vasilache5373b092018-10-03 15:39:12 -070033using namespace mlir;
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070034
35/// Returns the trip count of the loop as an affine expression if the latter is
36/// expressible as an affine expression, and nullptr otherwise. The trip count
37/// expression is simplified before returning.
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -070038AffineExpr mlir::getTripCountExpr(const ForStmt &forStmt) {
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070039 // upper_bound - lower_bound + 1
40 int64_t loopSpan;
41
42 int64_t step = forStmt.getStep();
43 auto *context = forStmt.getContext();
44
45 if (forStmt.hasConstantBounds()) {
46 int64_t lb = forStmt.getConstantLowerBound();
47 int64_t ub = forStmt.getConstantUpperBound();
48 loopSpan = ub - lb + 1;
49 } else {
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070050 auto lbMap = forStmt.getLowerBoundMap();
51 auto ubMap = forStmt.getUpperBoundMap();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070052 // TODO(bondhugula): handle max/min of multiple expressions.
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070053 if (lbMap.getNumResults() != 1 || ubMap.getNumResults() != 1)
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070054 return nullptr;
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070055
56 // TODO(bondhugula): handle bounds with different operands.
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070057 // Bounds have different operands, unhandled for now.
Uday Bondhugula5912e872018-09-18 10:22:03 -070058 if (!forStmt.matchingBoundOperandList())
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070059 return nullptr;
60
61 // ub_expr - lb_expr + 1
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070062 AffineExpr lbExpr(lbMap.getResult(0));
63 AffineExpr ubExpr(ubMap.getResult(0));
Nicolas Vasilache5373b092018-10-03 15:39:12 -070064 auto loopSpanExpr = simplifyAffineExpr(
Nicolas Vasilache75ed3372018-10-09 16:39:24 -070065 ubExpr - lbExpr + 1, std::max(lbMap.getNumDims(), ubMap.getNumDims()),
66 std::max(lbMap.getNumSymbols(), ubMap.getNumSymbols()));
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -070067 auto cExpr = loopSpanExpr.dyn_cast<AffineConstantExpr>();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070068 if (!cExpr)
Nicolas Vasilachebc746092018-10-08 10:20:25 -070069 return loopSpanExpr.ceilDiv(step);
Nicolas Vasilacheb7717092018-10-09 10:59:27 -070070 loopSpan = cExpr.getValue();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070071 }
72
73 // 0 iteration loops.
Uday Bondhugulaff5d6bd2018-09-27 18:03:27 -070074 if (loopSpan < 0)
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070075 return 0;
76
Nicolas Vasilachebc746092018-10-08 10:20:25 -070077 return getAffineConstantExpr(static_cast<uint64_t>(ceilDiv(loopSpan, step)),
78 context);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070079}
80
81/// Returns the trip count of the loop if it's a constant, None otherwise. This
82/// method uses affine expression analysis (in turn using getTripCount) and is
83/// able to determine constant trip count in non-trivial cases.
84llvm::Optional<uint64_t> mlir::getConstantTripCount(const ForStmt &forStmt) {
Nicolas Vasilache5373b092018-10-03 15:39:12 -070085 auto tripCountExpr = getTripCountExpr(forStmt);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070086
Nicolas Vasilache32402e52018-10-08 08:09:50 -070087 if (!tripCountExpr)
88 return None;
89
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -070090 if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>())
Nicolas Vasilacheb7717092018-10-09 10:59:27 -070091 return constExpr.getValue();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -070092
93 return None;
94}
95
96/// Returns the greatest known integral divisor of the trip count. Affine
97/// expression analysis is used (indirectly through getTripCount), and
98/// this method is thus able to determine non-trivial divisors.
99uint64_t mlir::getLargestDivisorOfTripCount(const ForStmt &forStmt) {
Nicolas Vasilache5373b092018-10-03 15:39:12 -0700100 auto tripCountExpr = getTripCountExpr(forStmt);
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700101
102 if (!tripCountExpr)
103 return 1;
104
Nicolas Vasilachefb11e0e2018-10-08 13:47:18 -0700105 if (auto constExpr = tripCountExpr.dyn_cast<AffineConstantExpr>()) {
Nicolas Vasilacheb7717092018-10-09 10:59:27 -0700106 uint64_t tripCount = constExpr.getValue();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700107
108 // 0 iteration loops (greatest divisor is 2^64 - 1).
109 if (tripCount == 0)
110 return ULONG_MAX;
111
112 // The greatest divisor is the trip count.
113 return tripCount;
114 }
115
116 // Trip count is not a known constant; return its largest known divisor.
Nicolas Vasilacheb7717092018-10-09 10:59:27 -0700117 return tripCountExpr.getLargestKnownDivisor();
Uday Bondhugulacf4f4c42018-09-12 10:21:23 -0700118}
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700119
120/// Given a MemRef accessed by `indices` and a dimension `dim`, determines
121/// whether indices[dim] is independent of the value `input`.
122// For now we assume no layout map or identity layout map in the MemRef.
123// TODO(ntv): support more than identity layout map.
124static bool isAccessInvariant(MLValue *input, MemRefType *memRefType,
125 ArrayRef<MLValue *> indices, unsigned dim) {
126 assert(indices.size() == memRefType->getRank());
127 assert(dim < indices.size());
128 auto layoutMap = memRefType->getAffineMaps();
Uday Bondhugula861fe642018-10-18 11:14:26 -0700129 assert(memRefType->getAffineMaps().size() <= 1);
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700130 // TODO(ntv): remove dependency on Builder once we support non-identity
131 // layout map.
132 Builder b(memRefType->getContext());
133 assert(layoutMap.empty() ||
134 layoutMap[0] == b.getMultiDimIdentityMap(indices.size()));
Uday Bondhugula861fe642018-10-18 11:14:26 -0700135 (void)layoutMap;
Nicolas Vasilachefd8d2562018-10-17 18:01:44 -0700136
137 SmallVector<OperationStmt *, 4> affineApplyOps;
138 getReachableAffineApplyOps({indices[dim]}, affineApplyOps);
139
140 if (affineApplyOps.empty()) {
141 // Pointer equality test because of MLValue pointer semantics.
142 return indices[dim] != input;
143 }
144
145 assert(affineApplyOps.size() == 1 &&
146 "CompositionAffineMapsPass must have "
147 "been run: there should be at most one AffineApplyOp");
148 auto composeOp = affineApplyOps[0]->getAs<AffineApplyOp>();
149 return !AffineValueMap(*composeOp).isFunctionOf(dim, input);
150}
151
152/// Determines whether a load or a store has a contiguous access along the
153/// value `input`. Contiguous is defined as either invariant or varying only
154/// along the fastest varying memory dimension.
155// TODO(ntv): allow more advanced notions of contiguity (non-fastest varying,
156// check strides, ...).
157template <typename LoadOrStoreOpPointer>
158static bool isContiguousAccess(MLValue *input, LoadOrStoreOpPointer memoryOp) {
159 auto indicesAsOperandIterators = memoryOp->getIndices();
160 auto *memRefType = cast<MemRefType>(memoryOp->getMemRef()->getType());
161 SmallVector<MLValue *, 4> indices;
162 for (auto *it : indicesAsOperandIterators) {
163 indices.push_back(cast<MLValue>(it));
164 }
165 unsigned numIndices = indices.size();
166 for (unsigned d = 0; d < numIndices - 1; ++d) {
167 if (!isAccessInvariant(input, memRefType, indices, d)) {
168 return false;
169 }
170 }
171 return true;
172}
173
174/// Checks whether all the LoadOp and StoreOp matched have access indexing
175/// functions that are are either:
176/// 1. invariant along the `loop` induction variable;
177/// 2. varying along the fastest varying memory dimension only.
178// TODO(ntv): Also need to check the contiguous dimension to discriminate
179// between broadcast (i.e. stride 0), stride 1 and stride > 1 and return the
180// information so we can build a cost model.
181bool mlir::isVectorizableLoop(const ForStmt &loop) {
182 // TODO(ntv): check parallel or reduction loop semantics
183 using matcher::LoadStores;
184 auto *forStmt = &const_cast<ForStmt &>(loop);
185 auto loadAndStores = LoadStores();
186 auto &matches = loadAndStores.match(forStmt);
187 for (auto ls : matches) {
188 auto *op = cast<OperationStmt>(ls.first);
189 auto load = op->getAs<LoadOp>();
190 auto store = op->getAs<StoreOp>();
191 bool contiguous = load ? isContiguousAccess(forStmt, load)
192 : isContiguousAccess(forStmt, store);
193 if (!contiguous) {
194 return false;
195 }
196 }
197 return true;
198}
Uday Bondhugula861fe642018-10-18 11:14:26 -0700199
200/// Checks whether SSA dominance would be violated if a for stmt's body
201/// statements are shifted by the specified shifts. This method checks if a
202/// 'def' and all its uses have the same shift factor.
203// TODO(mlir-team): extend this to check for memory-based dependence
204// violation when we have the support.
205bool mlir::isStmtwiseShiftValid(const ForStmt &forStmt,
206 ArrayRef<uint64_t> shifts) {
207 assert(shifts.size() == forStmt.getStatements().size());
208 unsigned s = 0;
209 for (const auto &stmt : forStmt) {
210 // A for or if stmt does not produce any def/results (that are used
211 // outside).
212 if (const auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
213 for (unsigned i = 0, e = opStmt->getNumResults(); i < e; ++i) {
214 const MLValue *result = opStmt->getResult(i);
215 for (const StmtOperand &use : result->getUses()) {
216 // If an ancestor statement doesn't lie in the block of forStmt, there
217 // is no shift to check.
218 // This is a naive way. If performance becomes an issue, a map can
219 // be used to store 'shifts' - to look up the shift for a statement in
220 // constant time.
221 if (auto *ancStmt = forStmt.findAncestorStmtInBlock(*use.getOwner()))
222 if (shifts[s] != shifts[forStmt.findStmtPosInBlock(*ancStmt)])
223 return false;
224 }
225 }
226 }
227 s++;
228 }
229 return true;
230}