Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 1 | //===- LoopUnrollAndJam.cpp - Code to perform loop unroll and jam ---------===// |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 2 | // |
| 3 | // Copyright 2019 The MLIR Authors. |
| 4 | // |
| 5 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 6 | // you may not use this file except in compliance with the License. |
| 7 | // You may obtain a copy of the License at |
| 8 | // |
| 9 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 10 | // |
| 11 | // Unless required by applicable law or agreed to in writing, software |
| 12 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 13 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 14 | // See the License for the specific language governing permissions and |
| 15 | // limitations under the License. |
| 16 | // ============================================================================= |
| 17 | // |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 18 | // This file implements loop unroll and jam for MLFunctions. Unroll and jam is a |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 19 | // transformation that improves locality, in particular, register reuse, while |
| 20 | // also improving instruction level parallelism. The example below shows what it |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 21 | // does in nearly the general case. Loop unroll and jam currently works if the |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 22 | // bounds of the loops inner to the loop being unroll-jammed do not depend on |
| 23 | // the latter. |
| 24 | // |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 25 | // Before After unroll and jam of i by factor 2: |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 26 | // |
| 27 | // for i, step = 2 |
| 28 | // for i S1(i); |
| 29 | // S1; S2(i); |
| 30 | // S2; S1(i+1); |
| 31 | // for j S2(i+1); |
| 32 | // S3; for j |
| 33 | // S4; S3(i, j); |
| 34 | // S5; S4(i, j); |
| 35 | // S6; S3(i+1, j) |
| 36 | // S4(i+1, j) |
| 37 | // S5(i); |
| 38 | // S6(i); |
| 39 | // S5(i+1); |
| 40 | // S6(i+1); |
| 41 | // |
| 42 | // Note: 'if/else' blocks are not jammed. So, if there are loops inside if |
| 43 | // stmt's, bodies of those loops will not be jammed. |
| 44 | // |
| 45 | //===----------------------------------------------------------------------===// |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 46 | #include "mlir/Transforms/Passes.h" |
| 47 | |
| 48 | #include "mlir/Analysis/LoopAnalysis.h" |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 49 | #include "mlir/IR/AffineExpr.h" |
| 50 | #include "mlir/IR/Builders.h" |
| 51 | #include "mlir/IR/StandardOps.h" |
| 52 | #include "mlir/IR/StmtVisitor.h" |
| 53 | #include "mlir/Transforms/Pass.h" |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 54 | #include "llvm/ADT/DenseMap.h" |
| 55 | #include "llvm/Support/CommandLine.h" |
| 56 | |
| 57 | using namespace mlir; |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 58 | |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 59 | // Loop unroll and jam factor. |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 60 | static llvm::cl::opt<unsigned> |
| 61 | clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden, |
| 62 | llvm::cl::desc("Use this unroll jam factor for all loops" |
| 63 | " (default 4)")); |
| 64 | |
| 65 | namespace { |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 66 | /// Loop unroll jam pass. Currently, this just unroll jams the first |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 67 | /// outer loop in an MLFunction. |
| 68 | struct LoopUnrollAndJam : public MLFunctionPass { |
| 69 | Optional<unsigned> unrollJamFactor; |
| 70 | static const unsigned kDefaultUnrollJamFactor = 4; |
| 71 | |
| 72 | explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor) |
| 73 | : unrollJamFactor(unrollJamFactor) {} |
| 74 | |
| 75 | void runOnMLFunction(MLFunction *f) override; |
| 76 | bool runOnForStmt(ForStmt *forStmt); |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 77 | bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 78 | }; |
| 79 | } // end anonymous namespace |
| 80 | |
| 81 | MLFunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { |
| 82 | return new LoopUnrollAndJam( |
| 83 | unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor)); |
| 84 | } |
| 85 | |
| 86 | void LoopUnrollAndJam::runOnMLFunction(MLFunction *f) { |
| 87 | // Currently, just the outermost loop from the first loop nest is |
| 88 | // unroll-and-jammed by this pass. However, runOnForStmt can be called on any |
| 89 | // for Stmt. |
| 90 | if (!isa<ForStmt>(f->begin())) |
| 91 | return; |
| 92 | |
| 93 | auto *forStmt = cast<ForStmt>(f->begin()); |
| 94 | runOnForStmt(forStmt); |
| 95 | } |
| 96 | |
| 97 | /// Unroll and jam a 'for' stmt. Default unroll jam factor is |
| 98 | /// kDefaultUnrollJamFactor. Return false if nothing was done. |
| 99 | bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) { |
| 100 | // Unroll and jam by the factor that was passed if any. |
| 101 | if (unrollJamFactor.hasValue()) |
| 102 | return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue()); |
| 103 | // Otherwise, unroll jam by the command-line factor if one was specified. |
| 104 | if (clUnrollJamFactor.getNumOccurrences() > 0) |
| 105 | return loopUnrollJamByFactor(forStmt, clUnrollJamFactor); |
| 106 | |
| 107 | // Unroll and jam by four otherwise. |
| 108 | return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor); |
| 109 | } |
| 110 | |
| 111 | /// Unrolls and jams this loop by the specified factor. |
| 112 | bool LoopUnrollAndJam::loopUnrollJamByFactor(ForStmt *forStmt, |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 113 | uint64_t unrollJamFactor) { |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 114 | // Gathers all maximal sub-blocks of statements that do not themselves include |
| 115 | // a for stmt (a statement could have a descendant for stmt though in its |
| 116 | // tree). |
| 117 | class JamBlockGatherer : public StmtWalker<JamBlockGatherer> { |
| 118 | public: |
| 119 | typedef llvm::iplist<Statement> StmtListType; |
| 120 | |
| 121 | // Store iterators to the first and last stmt of each sub-block found. |
| 122 | std::vector<std::pair<StmtBlock::iterator, StmtBlock::iterator>> subBlocks; |
| 123 | |
| 124 | // This is a linear time walk. |
| 125 | void walk(StmtListType::iterator Start, StmtListType::iterator End) { |
| 126 | for (auto it = Start; it != End;) { |
| 127 | auto subBlockStart = it; |
| 128 | while (it != End && !isa<ForStmt>(it)) |
| 129 | ++it; |
| 130 | if (it != subBlockStart) |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 131 | subBlocks.push_back({subBlockStart, std::prev(it)}); |
| 132 | // Process all for stmts that appear next. |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 133 | while (it != End && isa<ForStmt>(it)) |
| 134 | walkForStmt(cast<ForStmt>(it++)); |
| 135 | } |
| 136 | } |
| 137 | }; |
| 138 | |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 139 | assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 140 | |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 141 | if (unrollJamFactor == 1 || forStmt->getStatements().empty()) |
| 142 | return false; |
| 143 | |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 144 | Optional<uint64_t> mayTripCount = getConstantTripCount(*forStmt).getValue(); |
| 145 | |
| 146 | if (!mayTripCount.hasValue()) |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 147 | return false; |
| 148 | |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 149 | uint64_t tripCount = mayTripCount.getValue(); |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 150 | int64_t lb = forStmt->getConstantLowerBound(); |
| 151 | int64_t step = forStmt->getStep(); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 152 | |
| 153 | // If the trip count is lower than the unroll jam factor, no unrolled body. |
| 154 | // TODO(bondhugula): option to specify cleanup loop unrolling. |
| 155 | if (tripCount < unrollJamFactor) |
| 156 | return true; |
| 157 | |
| 158 | // Gather all sub-blocks to jam upon the loop being unrolled. |
| 159 | JamBlockGatherer jbg; |
| 160 | jbg.walkForStmt(forStmt); |
| 161 | auto &subBlocks = jbg.subBlocks; |
| 162 | |
| 163 | // Generate the cleanup loop if trip count isn't a multiple of |
| 164 | // unrollJamFactor. |
| 165 | if (tripCount % unrollJamFactor) { |
| 166 | DenseMap<const MLValue *, MLValue *> operandMap; |
| 167 | // Insert the cleanup loop right after 'forStmt'. |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 168 | MLFuncBuilder builder(forStmt->getBlock(), |
| 169 | std::next(StmtBlock::iterator(forStmt))); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 170 | auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); |
| 171 | cleanupForStmt->setConstantLowerBound( |
| 172 | lb + (tripCount - tripCount % unrollJamFactor) * step); |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 173 | |
| 174 | // Promote the loop body up if this has turned into a single iteration loop. |
| 175 | promoteIfSingleIteration(cleanupForStmt); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 176 | } |
| 177 | |
| 178 | MLFuncBuilder b(forStmt); |
| 179 | forStmt->setStep(step * unrollJamFactor); |
| 180 | forStmt->setConstantUpperBound( |
| 181 | lb + (tripCount - tripCount % unrollJamFactor - 1) * step); |
| 182 | |
| 183 | for (auto &subBlock : subBlocks) { |
| 184 | // Builder to insert unroll-jammed bodies. Insert right at the end of |
| 185 | // sub-block. |
| 186 | MLFuncBuilder builder(subBlock.first->getBlock(), |
| 187 | std::next(subBlock.second)); |
| 188 | |
| 189 | // Unroll and jam (appends unrollJamFactor-1 additional copies). |
| 190 | for (unsigned i = 1; i < unrollJamFactor; i++) { |
| 191 | DenseMap<const MLValue *, MLValue *> operandMapping; |
| 192 | |
| 193 | // If the induction variable is used, create a remapping to the value for |
| 194 | // this unrolled instance. |
| 195 | if (!forStmt->use_empty()) { |
| 196 | // iv' = iv + i, i = 1 to unrollJamFactor-1. |
| 197 | auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0), |
| 198 | builder.getConstantExpr(i * step)); |
| 199 | auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {}); |
| 200 | auto *ivUnroll = |
| 201 | builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) |
| 202 | ->getResult(0); |
| 203 | operandMapping[forStmt] = cast<MLValue>(ivUnroll); |
| 204 | } |
Uday Bondhugula | cf4f4c4 | 2018-09-12 10:21:23 -0700 | [diff] [blame^] | 205 | // Clone the sub-block being unroll-jammed. |
| 206 | for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) { |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 207 | builder.clone(*it, operandMapping); |
| 208 | } |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 209 | } |
| 210 | } |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 211 | |
| 212 | // Promote the loop body up if this has turned into a single iteration loop. |
| 213 | promoteIfSingleIteration(forStmt); |
| 214 | |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 215 | return true; |
| 216 | } |