blob: 9592ef70718537efbff4cd31bb3a308c5fcd0dea [file] [log] [blame]
Uday Bondhugula0b4059b2018-07-24 20:01:16 -07001//===- Unroll.cpp - Code to perform loop unrolling ---------------------===//
2//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop unrolling.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/CFGFunction.h"
24#include "mlir/IR/MLFunction.h"
25#include "mlir/IR/Module.h"
26#include "mlir/IR/OperationSet.h"
27#include "mlir/IR/Statements.h"
28#include "mlir/IR/StmtVisitor.h"
29#include "mlir/Pass.h"
30#include "mlir/Transforms/Loop.h"
Uday Bondhugula081d9e72018-07-27 10:58:14 -070031#include "llvm/Support/raw_ostream.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070032
33using namespace mlir;
34
35namespace {
36struct LoopUnroll : public MLFunctionPass {
37 bool runOnMLFunction(MLFunction *f);
38 bool runOnForStmt(ForStmt *forStmt);
39 bool runLoopUnroll(MLFunction *f);
40};
41} // namespace
42
43MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
44
45/// Unrolls all the innermost loops of this Module.
46bool MLFunctionPass::runOnModule(Module *m) {
47 bool changed = false;
Chris Lattnera8e47672018-07-25 14:08:16 -070048 for (auto &fn : *m) {
49 if (auto *mlFunc = dyn_cast<MLFunction>(&fn))
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070050 changed |= runOnMLFunction(mlFunc);
51 }
52 return changed;
53}
54
55/// Unrolls all the innermost loops of this MLFunction.
56bool LoopUnroll::runOnMLFunction(MLFunction *f) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070057 // Gathers all innermost loops through a post order pruned walk.
58 // TODO: figure out the right reusable template here to better refactor code.
59 class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
60 public:
61 // Store innermost loops as we walk.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070062 std::vector<ForStmt *> loops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070063
64 // This method specialized to encode custom return logic.
65 typedef llvm::iplist<Statement> StmtListType;
66 bool walk(StmtListType::iterator Start, StmtListType::iterator End) {
67 while (Start != End)
68 if (walk(&(*Start++)))
69 return true;
70 return false;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070071 }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070072
73 // FIXME: can't use base class method for this because that in turn would
74 // need to use the derived class method above. CRTP doesn't allow it, and
75 // the compiler error resulting from it is also very misleading!
76 void walkMLFunction(MLFunction *f) { walk(f->begin(), f->end()); }
77
78 bool walkForStmt(ForStmt *forStmt) {
79 bool hasInnerLoops = walk(forStmt->begin(), forStmt->end());
80 if (!hasInnerLoops)
81 loops.push_back(forStmt);
82 return true;
83 }
84
85 bool walkIfStmt(IfStmt *ifStmt) {
86 if (walk(ifStmt->getThenClause()->begin(),
87 ifStmt->getThenClause()->end()) ||
88 walk(ifStmt->getElseClause()->begin(),
89 ifStmt->getElseClause()->end()))
90 return true;
91 return false;
92 }
93
94 bool walkOpStmt(OperationStmt *opStmt) { return false; }
95
96 using StmtWalker<InnermostLoopGatherer, bool>::walk;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070097 };
98
99 InnermostLoopGatherer ilg;
Uday Bondhugula081d9e72018-07-27 10:58:14 -0700100 ilg.walkMLFunction(f);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700101 auto &loops = ilg.loops;
102 bool changed = false;
103 for (auto *forStmt : loops)
104 changed |= runOnForStmt(forStmt);
105 return changed;
106}
107
108/// Unrolls this loop completely. Returns true if the unrolling happens.
109bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
110 auto lb = forStmt->getLowerBound()->getValue();
111 auto ub = forStmt->getUpperBound()->getValue();
112 auto step = forStmt->getStep()->getValue();
113 auto trip_count = (ub - lb + 1) / step;
114
115 auto *block = forStmt->getBlock();
116
117 MLFuncBuilder builder(forStmt->Statement::getFunction());
118 builder.setInsertionPoint(block);
119
120 for (int i = 0; i < trip_count; i++) {
121 for (auto &stmt : forStmt->getStatements()) {
122 switch (stmt.getKind()) {
123 case Statement::Kind::For:
124 llvm_unreachable("unrolling loops that have only operations");
125 break;
126 case Statement::Kind::If:
127 llvm_unreachable("unrolling loops that have only operations");
128 break;
129 case Statement::Kind::Operation:
130 auto *op = cast<OperationStmt>(&stmt);
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700131 // TODO: clone operands and result types.
132 builder.createOperation(op->getName(), /*operands*/ {},
133 /*resultTypes*/ {}, op->getAttrs());
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700134 // TODO: loop iterator parsing not yet implemented; replace loop
135 // iterator uses in unrolled body appropriately.
136 break;
137 }
138 }
139 }
140
141 forStmt->eraseFromBlock();
142 return true;
143}