blob: c631bda886023678ff427e0c3fc5e9776262a9e4 [file] [log] [blame]
Chris Lattneree0c2ae2018-07-29 12:37:35 -07001//===- Unroll.cpp - Code to perform loop unrolling ------------------------===//
Uday Bondhugula0b4059b2018-07-24 20:01:16 -07002//
3// Copyright 2019 The MLIR Authors.
4//
5// Licensed under the Apache License, Version 2.0 (the "License");
6// you may not use this file except in compliance with the License.
7// You may obtain a copy of the License at
8//
9// http://www.apache.org/licenses/LICENSE-2.0
10//
11// Unless required by applicable law or agreed to in writing, software
12// distributed under the License is distributed on an "AS IS" BASIS,
13// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14// See the License for the specific language governing permissions and
15// limitations under the License.
16// =============================================================================
17//
18// This file implements loop unrolling.
19//
20//===----------------------------------------------------------------------===//
21
22#include "mlir/IR/Builders.h"
23#include "mlir/IR/CFGFunction.h"
24#include "mlir/IR/MLFunction.h"
25#include "mlir/IR/Module.h"
26#include "mlir/IR/OperationSet.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070027#include "mlir/IR/Pass.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070028#include "mlir/IR/Statements.h"
29#include "mlir/IR/StmtVisitor.h"
Chris Lattneree0c2ae2018-07-29 12:37:35 -070030#include "mlir/Transforms/Passes.h"
Uday Bondhugula081d9e72018-07-27 10:58:14 -070031#include "llvm/Support/raw_ostream.h"
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070032
33using namespace mlir;
34
35namespace {
36struct LoopUnroll : public MLFunctionPass {
Chris Lattneree0c2ae2018-07-29 12:37:35 -070037 void runOnMLFunction(MLFunction *f) override;
38 void runOnForStmt(ForStmt *forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070039};
Chris Lattneree0c2ae2018-07-29 12:37:35 -070040} // end anonymous namespace
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070041
42MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
43
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070044/// Unrolls all the innermost loops of this MLFunction.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070045void LoopUnroll::runOnMLFunction(MLFunction *f) {
Uday Bondhugula081d9e72018-07-27 10:58:14 -070046 // Gathers all innermost loops through a post order pruned walk.
47 // TODO: figure out the right reusable template here to better refactor code.
48 class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
49 public:
50 // Store innermost loops as we walk.
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070051 std::vector<ForStmt *> loops;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070052
53 // This method specialized to encode custom return logic.
54 typedef llvm::iplist<Statement> StmtListType;
55 bool walk(StmtListType::iterator Start, StmtListType::iterator End) {
56 while (Start != End)
57 if (walk(&(*Start++)))
58 return true;
59 return false;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070060 }
Uday Bondhugula081d9e72018-07-27 10:58:14 -070061
62 // FIXME: can't use base class method for this because that in turn would
63 // need to use the derived class method above. CRTP doesn't allow it, and
64 // the compiler error resulting from it is also very misleading!
65 void walkMLFunction(MLFunction *f) { walk(f->begin(), f->end()); }
66
67 bool walkForStmt(ForStmt *forStmt) {
68 bool hasInnerLoops = walk(forStmt->begin(), forStmt->end());
69 if (!hasInnerLoops)
70 loops.push_back(forStmt);
71 return true;
72 }
73
74 bool walkIfStmt(IfStmt *ifStmt) {
75 if (walk(ifStmt->getThenClause()->begin(),
76 ifStmt->getThenClause()->end()) ||
77 walk(ifStmt->getElseClause()->begin(),
78 ifStmt->getElseClause()->end()))
79 return true;
80 return false;
81 }
82
83 bool walkOpStmt(OperationStmt *opStmt) { return false; }
84
85 using StmtWalker<InnermostLoopGatherer, bool>::walk;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070086 };
87
88 InnermostLoopGatherer ilg;
Uday Bondhugula081d9e72018-07-27 10:58:14 -070089 ilg.walkMLFunction(f);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070090 auto &loops = ilg.loops;
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070091 for (auto *forStmt : loops)
Chris Lattneree0c2ae2018-07-29 12:37:35 -070092 runOnForStmt(forStmt);
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070093}
94
95/// Unrolls this loop completely. Returns true if the unrolling happens.
Chris Lattneree0c2ae2018-07-29 12:37:35 -070096void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
Uday Bondhugula0b4059b2018-07-24 20:01:16 -070097 auto lb = forStmt->getLowerBound()->getValue();
98 auto ub = forStmt->getUpperBound()->getValue();
99 auto step = forStmt->getStep()->getValue();
100 auto trip_count = (ub - lb + 1) / step;
101
102 auto *block = forStmt->getBlock();
103
104 MLFuncBuilder builder(forStmt->Statement::getFunction());
105 builder.setInsertionPoint(block);
106
107 for (int i = 0; i < trip_count; i++) {
108 for (auto &stmt : forStmt->getStatements()) {
109 switch (stmt.getKind()) {
110 case Statement::Kind::For:
111 llvm_unreachable("unrolling loops that have only operations");
112 break;
113 case Statement::Kind::If:
114 llvm_unreachable("unrolling loops that have only operations");
115 break;
116 case Statement::Kind::Operation:
117 auto *op = cast<OperationStmt>(&stmt);
Tatiana Shpeisman60bf7be2018-07-26 18:09:20 -0700118 // TODO: clone operands and result types.
119 builder.createOperation(op->getName(), /*operands*/ {},
120 /*resultTypes*/ {}, op->getAttrs());
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700121 // TODO: loop iterator parsing not yet implemented; replace loop
122 // iterator uses in unrolled body appropriately.
123 break;
124 }
125 }
126 }
127
128 forStmt->eraseFromBlock();
Uday Bondhugula0b4059b2018-07-24 20:01:16 -0700129}