blob: 0f6d42856ddad4bd2f54e68b18a7bd05a1cf74e1 [file] [log] [blame]
Chris Lattneree0c2ae2018-07-29 12:37:35 -07001//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
Uday Bondhugula0b4059b2018-07-24 20:01:16 -07002//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop unrolling.
19//
20//===----------------------------------------------------------------------===//
21
Uday Bondhugula15984952018-08-01 22:36:12 -070022#include "mlir/IR/Attributes.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070023#include "mlir/IR/Builders.h"
24#include "mlir/IR/CFGFunction.h"
25#include "mlir/IR/MLFunction.h"
26#include "mlir/IR/Module.h"
27#include "mlir/IR/OperationSet.h"
Uday Bondhugula84b80952018-08-03 13:22:26 -070028#include "mlir/IR/StandardOps.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070029#include "mlir/IR/Statements.h"
30#include "mlir/IR/StmtVisitor.h"
Uday Bondhugula6c1f6602018-08-13 17:25:13 -070031#include "mlir/Transforms/Pass.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070032#include "mlir/Transforms/Passes.h"
Chris Lattnere787b322018-08-08 11:14:57 -070033#include "llvm/ADT/DenseMap.h"
Uday Bondhugula081d9e72018-07-27 10:58:14 -070034#include "llvm/Support/raw_ostream.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070035
36using namespace mlir;
37
38namespace {
Uday Bondhugula0077e622018-08-16 13:51:44 -070039/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
40/// MLFunction.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070041struct LoopUnroll : public MLFunctionPass {
Chris Lattneree0c2ae2018-07-29 12:37:35 -070042 void runOnMLFunction(MLFunction *f) override;
43 void runOnForStmt(ForStmt *forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070044};
Uday Bondhugula0077e622018-08-16 13:51:44 -070045
46/// Unrolls all loops with trip count <= minTripCount.
Uday Bondhugula134154e2018-08-06 18:40:34 -070047struct ShortLoopUnroll : public LoopUnroll {
48 const unsigned minTripCount;
49 void runOnMLFunction(MLFunction *f) override;
50 ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
51};
Chris Lattneree0c2ae2018-07-29 12:37:35 -070052} // end anonymous namespace
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070053
54MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
55
Uday Bondhugula134154e2018-08-06 18:40:34 -070056MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
57 return new ShortLoopUnroll(minTripCount);
58}
59
Chris Lattneree0c2ae2018-07-29 12:37:35 -070060void LoopUnroll::runOnMLFunction(MLFunction *f) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070061 // Gathers all innermost loops through a post order pruned walk.
Uday Bondhugula081d9e72018-07-27 10:58:14 -070062 class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
63 public:
64 // Store innermost loops as we walk.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070065 std::vector<ForStmt *> loops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070066
67 // This method specialized to encode custom return logic.
68 typedef llvm::iplist<Statement> StmtListType;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070069 bool walkPostOrder(StmtListType::iterator Start,
70 StmtListType::iterator End) {
Uday Bondhugula15984952018-08-01 22:36:12 -070071 bool hasInnerLoops = false;
72 // We need to walk all elements since all innermost loops need to be
73 // gathered as opposed to determining whether this list has any inner
74 // loops or not.
Uday Bondhugula081d9e72018-07-27 10:58:14 -070075 while (Start != End)
Uday Bondhugula15984952018-08-01 22:36:12 -070076 hasInnerLoops |= walkPostOrder(&(*Start++));
77 return hasInnerLoops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070078 }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070079
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070080 bool walkForStmtPostOrder(ForStmt *forStmt) {
81 bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
Uday Bondhugula081d9e72018-07-27 10:58:14 -070082 if (!hasInnerLoops)
83 loops.push_back(forStmt);
Chris Lattnere787b322018-08-08 11:14:57 -070084
Uday Bondhugula081d9e72018-07-27 10:58:14 -070085 return true;
86 }
87
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070088 bool walkIfStmtPostOrder(IfStmt *ifStmt) {
Chris Lattnere787b322018-08-08 11:14:57 -070089 bool hasInnerLoops =
90 walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
91 hasInnerLoops |=
92 walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
Uday Bondhugula15984952018-08-01 22:36:12 -070093 return hasInnerLoops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070094 }
95
Uday Bondhugula134154e2018-08-06 18:40:34 -070096 bool visitOperationStmt(OperationStmt *opStmt) { return false; }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070097
Uday Bondhugula134154e2018-08-06 18:40:34 -070098 // FIXME: can't use base class method for this because that in turn would
99 // need to use the derived class method above. CRTP doesn't allow it, and
100 // the compiler error resulting from it is also misleading.
Uday Bondhugula8572d1a2018-07-30 10:49:49 -0700101 using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700102 };
103
104 InnermostLoopGatherer ilg;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -0700105 ilg.walkPostOrder(f);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700106 auto &loops = ilg.loops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700107 for (auto *forStmt : loops)
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700108 runOnForStmt(forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700109}
110
Uday Bondhugula134154e2018-08-06 18:40:34 -0700111void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
112 // Gathers all loops with trip count <= minTripCount.
113 class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
114 public:
115 // Store short loops as we walk.
116 std::vector<ForStmt *> loops;
117 const unsigned minTripCount;
118 ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
Uday Bondhugula15984952018-08-01 22:36:12 -0700119
Uday Bondhugula134154e2018-08-06 18:40:34 -0700120 void visitForStmt(ForStmt *forStmt) {
121 auto lb = forStmt->getLowerBound()->getValue();
122 auto ub = forStmt->getUpperBound()->getValue();
123 auto step = forStmt->getStep()->getValue();
Uday Bondhugula15984952018-08-01 22:36:12 -0700124
Uday Bondhugula134154e2018-08-06 18:40:34 -0700125 if ((ub - lb) / step + 1 <= minTripCount)
126 loops.push_back(forStmt);
Uday Bondhugula15984952018-08-01 22:36:12 -0700127 }
128 };
129
Uday Bondhugula134154e2018-08-06 18:40:34 -0700130 ShortLoopGatherer slg(minTripCount);
Uday Bondhugula0077e622018-08-16 13:51:44 -0700131 // Do a post order walk so that loops are gathered from innermost to
132 // outermost (or else unrolling an outer one may delete gathered inner ones).
133 slg.walkPostOrder(f);
Uday Bondhugula134154e2018-08-06 18:40:34 -0700134 auto &loops = slg.loops;
135 for (auto *forStmt : loops)
136 runOnForStmt(forStmt);
137}
138
Chris Lattnere787b322018-08-08 11:14:57 -0700139/// Unroll this For loop completely.
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700140void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700141 auto lb = forStmt->getLowerBound()->getValue();
142 auto ub = forStmt->getUpperBound()->getValue();
143 auto step = forStmt->getStep()->getValue();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700144
Uday Bondhugula84b80952018-08-03 13:22:26 -0700145 // Builder to add constants need for the unrolled iterator.
Chris Lattnere787b322018-08-08 11:14:57 -0700146 auto *mlFunc = forStmt->findFunction();
147 MLFuncBuilder funcTopBuilder(&mlFunc->front());
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700148
Chris Lattnere787b322018-08-08 11:14:57 -0700149 // Builder to insert the unrolled bodies. We insert right after the
150 /// ForStmt we're unrolling.
151 MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
Uday Bondhugula84b80952018-08-03 13:22:26 -0700152
153 // Unroll the contents of 'forStmt'.
Uday Bondhugula134154e2018-08-06 18:40:34 -0700154 for (int64_t i = lb; i <= ub; i += step) {
Chris Lattnere787b322018-08-08 11:14:57 -0700155 DenseMap<const MLValue *, MLValue *> operandMapping;
156
157 // If the induction variable is used, create a constant for this unrolled
158 // value and add an operand mapping for it.
Uday Bondhugula134154e2018-08-06 18:40:34 -0700159 if (!forStmt->use_empty()) {
Chris Lattnere787b322018-08-08 11:14:57 -0700160 auto *ivConst =
161 funcTopBuilder.create<ConstantAffineIntOp>(i)->getResult();
162 operandMapping[forStmt] = cast<MLValue>(ivConst);
Uday Bondhugula134154e2018-08-06 18:40:34 -0700163 }
Uday Bondhugula84b80952018-08-03 13:22:26 -0700164
Chris Lattnere787b322018-08-08 11:14:57 -0700165 // Clone the body of the loop.
166 for (auto &childStmt : *forStmt) {
167 (void)builder.clone(childStmt, operandMapping);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700168 }
169 }
Uday Bondhugula134154e2018-08-06 18:40:34 -0700170 // Erase the original 'for' stmt from the block.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700171 forStmt->eraseFromBlock();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700172}