blob: fe110d21d669daaf5dc37b0c25c4a6d5e408ffbb [file] [log] [blame]
Chris Lattneree0c2ae2018-07-29 12:37:35 -07001//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
Uday Bondhugula0b4059b2018-07-24 20:01:16 -07002//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop unrolling.
19//
20//===----------------------------------------------------------------------===//
21
Uday Bondhugula15984952018-08-01 22:36:12 -070022#include "mlir/IR/Attributes.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070023#include "mlir/IR/Builders.h"
24#include "mlir/IR/CFGFunction.h"
25#include "mlir/IR/MLFunction.h"
26#include "mlir/IR/Module.h"
27#include "mlir/IR/OperationSet.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070028#include "mlir/IR/Pass.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070029#include "mlir/IR/Statements.h"
30#include "mlir/IR/StmtVisitor.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070031#include "mlir/Transforms/Passes.h"
Uday Bondhugula081d9e72018-07-27 10:58:14 -070032#include "llvm/Support/raw_ostream.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070033
34using namespace mlir;
35
36namespace {
37struct LoopUnroll : public MLFunctionPass {
Chris Lattneree0c2ae2018-07-29 12:37:35 -070038 void runOnMLFunction(MLFunction *f) override;
39 void runOnForStmt(ForStmt *forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070040};
Chris Lattneree0c2ae2018-07-29 12:37:35 -070041} // end anonymous namespace
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070042
43MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
44
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070045/// Unrolls all the innermost loops of this MLFunction.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070046void LoopUnroll::runOnMLFunction(MLFunction *f) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070047 // Gathers all innermost loops through a post order pruned walk.
48 // TODO: figure out the right reusable template here to better refactor code.
49 class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
50 public:
51 // Store innermost loops as we walk.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070052 std::vector<ForStmt *> loops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070053
54 // This method specialized to encode custom return logic.
55 typedef llvm::iplist<Statement> StmtListType;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070056 bool walkPostOrder(StmtListType::iterator Start,
57 StmtListType::iterator End) {
Uday Bondhugula15984952018-08-01 22:36:12 -070058 bool hasInnerLoops = false;
59 // We need to walk all elements since all innermost loops need to be
60 // gathered as opposed to determining whether this list has any inner
61 // loops or not.
Uday Bondhugula081d9e72018-07-27 10:58:14 -070062 while (Start != End)
Uday Bondhugula15984952018-08-01 22:36:12 -070063 hasInnerLoops |= walkPostOrder(&(*Start++));
64 return hasInnerLoops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070065 }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070066
67 // FIXME: can't use base class method for this because that in turn would
68 // need to use the derived class method above. CRTP doesn't allow it, and
69 // the compiler error resulting from it is also very misleading!
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070070 void walkPostOrder(MLFunction *f) { walkPostOrder(f->begin(), f->end()); }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070071
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070072 bool walkForStmtPostOrder(ForStmt *forStmt) {
73 bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
Uday Bondhugula081d9e72018-07-27 10:58:14 -070074 if (!hasInnerLoops)
75 loops.push_back(forStmt);
76 return true;
77 }
78
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070079 bool walkIfStmtPostOrder(IfStmt *ifStmt) {
Uday Bondhugula15984952018-08-01 22:36:12 -070080 bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
81 ifStmt->getThenClause()->end());
82 hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
83 ifStmt->getElseClause()->end());
84 return hasInnerLoops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070085 }
86
87 bool walkOpStmt(OperationStmt *opStmt) { return false; }
88
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070089 using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070090 };
91
92 InnermostLoopGatherer ilg;
Uday Bondhugula8572d1a2018-07-30 10:49:49 -070093 ilg.walkPostOrder(f);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070094 auto &loops = ilg.loops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070095 for (auto *forStmt : loops)
Chris Lattneree0c2ae2018-07-29 12:37:35 -070096 runOnForStmt(forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070097}
98
Uday Bondhugula15984952018-08-01 22:36:12 -070099/// Replace an IV with a constant value.
100static void replaceIterator(Statement *stmt, const ForStmt &iv,
101 MLValue *constVal) {
102 struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
103 // IV to be replaced.
104 const ForStmt *iv;
105 // Constant to be replaced with.
106 MLValue *constVal;
107
108 ReplaceIterator(const ForStmt &iv, MLValue *constVal)
109 : iv(&iv), constVal(constVal){};
110
111 void visitOperationStmt(OperationStmt *os) {
112 for (auto &operand : os->getStmtOperands()) {
113 if (operand.get() == static_cast<const MLValue *>(iv)) {
114 operand.set(constVal);
115 }
116 }
117 }
118 };
119
120 ReplaceIterator ri(iv, constVal);
121 ri.walk(stmt);
122}
123
124/// Unrolls this loop completely.
Chris Lattneree0c2ae2018-07-29 12:37:35 -0700125void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700126 auto lb = forStmt->getLowerBound()->getValue();
127 auto ub = forStmt->getUpperBound()->getValue();
128 auto step = forStmt->getStep()->getValue();
129 auto trip_count = (ub - lb + 1) / step;
130
Uday Bondhugula15984952018-08-01 22:36:12 -0700131 auto *mlFunc = forStmt->Statement::findFunction();
132 MLFuncBuilder funcTopBuilder(mlFunc);
133 funcTopBuilder.setInsertionPointAtStart(mlFunc);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700134
Uday Bondhugula15984952018-08-01 22:36:12 -0700135 MLFuncBuilder builder(forStmt->getBlock());
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700136 for (int i = 0; i < trip_count; i++) {
Uday Bondhugula15984952018-08-01 22:36:12 -0700137 auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700138 for (auto &stmt : forStmt->getStatements()) {
139 switch (stmt.getKind()) {
140 case Statement::Kind::For:
141 llvm_unreachable("unrolling loops that have only operations");
142 break;
143 case Statement::Kind::If:
144 llvm_unreachable("unrolling loops that have only operations");
145 break;
146 case Statement::Kind::Operation:
Uday Bondhugula15984952018-08-01 22:36:12 -0700147 auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
148 // TODO(bondhugula): only generate constants when the IV actually
149 // appears in the body.
150 replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700151 break;
152 }
153 }
154 }
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700155 forStmt->eraseFromBlock();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700156}