blob: 489a98f8468b260c5f6014adddec9ef9c41b4f14 [file] [log] [blame]
Chris Lattneree0c2ae2018-07-29 12:37:35 -07001//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
Uday Bondhugula0b4059b2018-07-24 20:01:16 -07002//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop unrolling.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/CFGFunction.h"
24#include "mlir/IR/MLFunction.h"
25#include "mlir/IR/Module.h"
26#include "mlir/IR/OperationSet.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070027#include "mlir/IR/Pass.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070028#include "mlir/IR/Statements.h"
29#include "mlir/IR/StmtVisitor.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070030#include "mlir/Transforms/Passes.h"
Uday Bondhugula081d9e72018-07-27 10:58:14 -070031#include "llvm/Support/raw_ostream.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070032
33using namespace mlir;
34
35namespace {
36struct LoopUnroll : public MLFunctionPass {
Chris Lattneree0c2ae2018-07-29 12:37:35 -070037 void runOnMLFunction(MLFunction *f) override;
38 void runOnForStmt(ForStmt *forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070039};
Chris Lattneree0c2ae2018-07-29 12:37:35 -070040} // end anonymous namespace
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070041
42MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
43
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070044/// Unrolls all the innermost loops of this MLFunction.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070045void LoopUnroll::runOnMLFunction(MLFunction *f) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070046 // Gathers all innermost loops through a post order pruned walk.
47 // TODO: figure out the right reusable template here to better refactor code.
48 class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
49 public:
50 // Store innermost loops as we walk.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070051 std::vector<ForStmt *> loops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070052
53 // This method specialized to encode custom return logic.
54 typedef llvm::iplist<Statement> StmtListType;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070055 bool walkPostOrder(StmtListType::iterator Start,
56 StmtListType::iterator End) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070057 while (Start != End)
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070058 if (walkPostOrder(&(*Start++)))
Uday Bondhugula081d9e72018-07-27 10:58:14 -070059 return true;
60 return false;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070061 }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070062
63 // FIXME: can't use base class method for this because that in turn would
64 // need to use the derived class method above. CRTP doesn't allow it, and
65 // the compiler error resulting from it is also very misleading!
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070066 void walkPostOrder(MLFunction *f) { walkPostOrder(f->begin(), f->end()); }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070067
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070068 bool walkForStmtPostOrder(ForStmt *forStmt) {
69 bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
Uday Bondhugula081d9e72018-07-27 10:58:14 -070070 if (!hasInnerLoops)
71 loops.push_back(forStmt);
72 return true;
73 }
74
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070075 bool walkIfStmtPostOrder(IfStmt *ifStmt) {
76 if (walkPostOrder(ifStmt->getThenClause()->begin(),
77 ifStmt->getThenClause()->end()) ||
78 walkPostOrder(ifStmt->getElseClause()->begin(),
79 ifStmt->getElseClause()->end()))
Uday Bondhugula081d9e72018-07-27 10:58:14 -070080 return true;
81 return false;
82 }
83
84 bool walkOpStmt(OperationStmt *opStmt) { return false; }
85
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070086 using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070087 };
88
89 InnermostLoopGatherer ilg;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070090 ilg.walkPostOrder(f);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070091 auto &loops = ilg.loops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070092 for (auto *forStmt : loops)
Chris Lattneree0c2ae2018-07-29 12:37:35 -070093 runOnForStmt(forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070094}
95
96/// Unrolls this loop completely. Returns true if the unrolling happens.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070097void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070098 auto lb = forStmt->getLowerBound()->getValue();
99 auto ub = forStmt->getUpperBound()->getValue();
100 auto step = forStmt->getStep()->getValue();
101 auto trip_count = (ub - lb + 1) / step;
102
103 auto *block = forStmt->getBlock();
104
105 MLFuncBuilder builder(forStmt->Statement::getFunction());
106 builder.setInsertionPoint(block);
107
108 for (int i = 0; i < trip_count; i++) {
109 for (auto &stmt : forStmt->getStatements()) {
110 switch (stmt.getKind()) {
111 case Statement::Kind::For:
112 llvm_unreachable("unrolling loops that have only operations");
113 break;
114 case Statement::Kind::If:
115 llvm_unreachable("unrolling loops that have only operations");
116 break;
117 case Statement::Kind::Operation:
118 auto *op = cast<OperationStmt>(&stmt);
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700119 // TODO: clone operands and result types.
120 builder.createOperation(op->getName(), /*operands*/ {},
121 /*resultTypes*/ {}, op->getAttrs());
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700122 // TODO: loop iterator parsing not yet implemented; replace loop
123 // iterator uses in unrolled body appropriately.
124 break;
125 }
126 }
127 }
128
129 forStmt->eraseFromBlock();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700130}