blob: eea3bf7c88e0eb3eaf2c4857f6fd4f13b54c5fc7 [file] [log] [blame]
Chris Lattneree0c2ae2018-07-29 12:37:35 -07001//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
Uday Bondhugula0b4059b2018-07-24 20:01:16 -07002//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop unrolling.
19//
20//===----------------------------------------------------------------------===//
21
Uday Bondhugula15984952018-08-01 22:36:12 -070022#include "mlir/IR/Attributes.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070023#include "mlir/IR/Builders.h"
24#include "mlir/IR/CFGFunction.h"
25#include "mlir/IR/MLFunction.h"
26#include "mlir/IR/Module.h"
27#include "mlir/IR/OperationSet.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070028#include "mlir/IR/Pass.h"
Uday Bondhugula84b80952018-08-03 13:22:26 -070029#include "mlir/IR/StandardOps.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070030#include "mlir/IR/Statements.h"
31#include "mlir/IR/StmtVisitor.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070032#include "mlir/Transforms/Passes.h"
Uday Bondhugula081d9e72018-07-27 10:58:14 -070033#include "llvm/Support/raw_ostream.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070034
35using namespace mlir;
36
37namespace {
38struct LoopUnroll : public MLFunctionPass {
Chris Lattneree0c2ae2018-07-29 12:37:35 -070039 void runOnMLFunction(MLFunction *f) override;
40 void runOnForStmt(ForStmt *forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070041};
Uday Bondhugula134154e2018-08-06 18:40:34 -070042struct ShortLoopUnroll : public LoopUnroll {
43 const unsigned minTripCount;
44 void runOnMLFunction(MLFunction *f) override;
45 ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
46};
Chris Lattneree0c2ae2018-07-29 12:37:35 -070047} // end anonymous namespace
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070048
49MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
50
Uday Bondhugula134154e2018-08-06 18:40:34 -070051MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
52 return new ShortLoopUnroll(minTripCount);
53}
54
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070055/// Unrolls all the innermost loops of this MLFunction.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070056void LoopUnroll::runOnMLFunction(MLFunction *f) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070057 // Gathers all innermost loops through a post order pruned walk.
Uday Bondhugula081d9e72018-07-27 10:58:14 -070058 class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
59 public:
60 // Store innermost loops as we walk.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070061 std::vector<ForStmt *> loops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070062
63 // This method specialized to encode custom return logic.
64 typedef llvm::iplist<Statement> StmtListType;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070065 bool walkPostOrder(StmtListType::iterator Start,
66 StmtListType::iterator End) {
Uday Bondhugula15984952018-08-01 22:36:12 -070067 bool hasInnerLoops = false;
68 // We need to walk all elements since all innermost loops need to be
69 // gathered as opposed to determining whether this list has any inner
70 // loops or not.
Uday Bondhugula081d9e72018-07-27 10:58:14 -070071 while (Start != End)
Uday Bondhugula15984952018-08-01 22:36:12 -070072 hasInnerLoops |= walkPostOrder(&(*Start++));
73 return hasInnerLoops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070074 }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070075
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070076 bool walkForStmtPostOrder(ForStmt *forStmt) {
77 bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
Uday Bondhugula081d9e72018-07-27 10:58:14 -070078 if (!hasInnerLoops)
79 loops.push_back(forStmt);
80 return true;
81 }
82
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070083 bool walkIfStmtPostOrder(IfStmt *ifStmt) {
Uday Bondhugula15984952018-08-01 22:36:12 -070084 bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
85 ifStmt->getThenClause()->end());
86 hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
87 ifStmt->getElseClause()->end());
88 return hasInnerLoops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070089 }
90
Uday Bondhugula134154e2018-08-06 18:40:34 -070091 bool visitOperationStmt(OperationStmt *opStmt) { return false; }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070092
Uday Bondhugula134154e2018-08-06 18:40:34 -070093 // FIXME: can't use base class method for this because that in turn would
94 // need to use the derived class method above. CRTP doesn't allow it, and
95 // the compiler error resulting from it is also misleading.
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070096 using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070097 };
98
99 InnermostLoopGatherer ilg;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -0700100 ilg.walkPostOrder(f);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700101 auto &loops = ilg.loops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700102 for (auto *forStmt : loops)
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700103 runOnForStmt(forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700104}
105
Uday Bondhugula134154e2018-08-06 18:40:34 -0700106/// Unrolls all loops with trip count <= minTripCount.
107void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
108 // Gathers all loops with trip count <= minTripCount.
109 class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
110 public:
111 // Store short loops as we walk.
112 std::vector<ForStmt *> loops;
113 const unsigned minTripCount;
114 ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
Uday Bondhugula15984952018-08-01 22:36:12 -0700115
Uday Bondhugula134154e2018-08-06 18:40:34 -0700116 void visitForStmt(ForStmt *forStmt) {
117 auto lb = forStmt->getLowerBound()->getValue();
118 auto ub = forStmt->getUpperBound()->getValue();
119 auto step = forStmt->getStep()->getValue();
Uday Bondhugula15984952018-08-01 22:36:12 -0700120
Uday Bondhugula134154e2018-08-06 18:40:34 -0700121 if ((ub - lb) / step + 1 <= minTripCount)
122 loops.push_back(forStmt);
Uday Bondhugula15984952018-08-01 22:36:12 -0700123 }
124 };
125
Uday Bondhugula134154e2018-08-06 18:40:34 -0700126 ShortLoopGatherer slg(minTripCount);
127 slg.walk(f);
128 auto &loops = slg.loops;
129 for (auto *forStmt : loops)
130 runOnForStmt(forStmt);
131}
132
133/// Replace all uses of oldVal with newVal from begin to end.
134static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
135 MLValue *oldVal, MLValue *newVal) {
136 // TODO(bondhugula,clattner): do this more efficiently by walking those uses
137 // of oldVal that fall within this list of statements (instead of iterating
138 // through all statements / through all operands of operations found).
139 for (auto it = begin; it != end; it++) {
140 it->replaceUses(oldVal, newVal);
141 }
142}
143
144/// Replace all uses of oldVal with newVal.
145void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
146 // TODO(bondhugula,clattner): do this more efficiently by walking those uses
147 // of oldVal that fall within this StmtBlock (instead of iterating through
148 // all statements / through all operands of operations found).
149 for (auto it = block->begin(); it != block->end(); it++) {
150 it->replaceUses(oldVal, newVal);
151 }
152}
153
154/// Clone the list of stmt's from 'block' and insert into the current
155/// position of the builder.
156// TODO(bondhugula,clattner): replace this with a parameterizable clone.
157void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
158 // Pairs of <old op stmt result whose uses need to be replaced,
159 // new result generated by the corresponding cloned op stmt>.
160 SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
161
162 // Iterator pointing to just before 'this' (i^th) unrolled iteration.
163 StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
164
165 for (auto &stmt : block.getStatements()) {
166 auto *cloneStmt = builder->clone(stmt);
167 // Whenever we have an op stmt, we'll have a new ML Value defined: replace
168 // uses of the old result with this one.
169 if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
170 if (opStmt->getNumResults()) {
171 auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
172 for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
173 // Store old/new result pairs.
174 // TODO(bondhugula) *only* if needed later: storing of old/new
175 // results can be avoided by cloning the statement list in the
176 // reverse direction (and running the IR builder in the reverse
177 // (iplist.insertAfter()). That way, a newly created result can be
178 // immediately propagated to all its uses.
179 oldNewResultPairs.push_back(std::make_pair(
180 const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
181 &cloneOpStmt->getStmtResult(i)));
182 }
183 }
184 }
185 }
186
187 // Replace uses of old op results' with the new results.
188 StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
189 StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
190
191 // Replace uses of old op results' with the newly created ones.
192 for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
193 replaceUses(startOfUnrolledBody, endOfUnrolledBody,
194 oldNewResultPairs[i].first, oldNewResultPairs[i].second);
195 }
Uday Bondhugula15984952018-08-01 22:36:12 -0700196}
197
Uday Bondhugula84b80952018-08-03 13:22:26 -0700198/// Unroll this 'for stmt' / loop completely.
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700199void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700200 auto lb = forStmt->getLowerBound()->getValue();
201 auto ub = forStmt->getUpperBound()->getValue();
202 auto step = forStmt->getStep()->getValue();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700203
Uday Bondhugula84b80952018-08-03 13:22:26 -0700204 // Builder to add constants need for the unrolled iterator.
Uday Bondhugula15984952018-08-01 22:36:12 -0700205 auto *mlFunc = forStmt->Statement::findFunction();
206 MLFuncBuilder funcTopBuilder(mlFunc);
207 funcTopBuilder.setInsertionPointAtStart(mlFunc);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700208
Uday Bondhugula84b80952018-08-03 13:22:26 -0700209 // Builder to insert the unrolled bodies.
Uday Bondhugula15984952018-08-01 22:36:12 -0700210 MLFuncBuilder builder(forStmt->getBlock());
Uday Bondhugula84b80952018-08-03 13:22:26 -0700211 // Set insertion point to right after where the for stmt ends.
212 builder.setInsertionPoint(forStmt->getBlock(),
213 ++StmtBlock::iterator(forStmt));
214
215 // Unroll the contents of 'forStmt'.
Uday Bondhugula134154e2018-08-06 18:40:34 -0700216 for (int64_t i = lb; i <= ub; i += step) {
217 MLValue *ivConst = nullptr;
218 if (!forStmt->use_empty()) {
219 auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
220 ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
221 }
Uday Bondhugula84b80952018-08-03 13:22:26 -0700222 StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
223
Uday Bondhugula134154e2018-08-06 18:40:34 -0700224 // Clone the loop body and insert it right after the loop - the latter will
225 // be erased after all unrolling has been done.
226 cloneStmtListFromBlock(&builder, *forStmt);
Uday Bondhugula84b80952018-08-03 13:22:26 -0700227
Uday Bondhugula134154e2018-08-06 18:40:34 -0700228 // Replace unrolled loop IV with the unrolled constant.
229 if (ivConst) {
230 StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
231 StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
232 replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700233 }
234 }
Uday Bondhugula134154e2018-08-06 18:40:34 -0700235 // Erase the original 'for' stmt from the block.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700236 forStmt->eraseFromBlock();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700237}