Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 1 | //===- LoopUnrollAndJam.cpp - Code to perform loop unroll jam |
| 2 | //----------------===// |
| 3 | // |
| 4 | // Copyright 2019 The MLIR Authors. |
| 5 | // |
| 6 | // Licensed under the Apache License, Version 2.0 (the "License"); |
| 7 | // you may not use this file except in compliance with the License. |
| 8 | // You may obtain a copy of the License at |
| 9 | // |
| 10 | // http://www.apache.org/licenses/LICENSE-2.0 |
| 11 | // |
| 12 | // Unless required by applicable law or agreed to in writing, software |
| 13 | // distributed under the License is distributed on an "AS IS" BASIS, |
| 14 | // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 15 | // See the License for the specific language governing permissions and |
| 16 | // limitations under the License. |
| 17 | // ============================================================================= |
| 18 | // |
| 19 | // This file implements loop unroll jam for MLFunctions. Unroll and jam is a |
| 20 | // transformation that improves locality, in particular, register reuse, while |
| 21 | // also improving instruction level parallelism. The example below shows what it |
| 22 | // does in nearly the general case. Loop unroll jam currently works if the |
| 23 | // bounds of the loops inner to the loop being unroll-jammed do not depend on |
| 24 | // the latter. |
| 25 | // |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 26 | // Before After unroll and jam of i by factor 2: |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 27 | // |
| 28 | // for i, step = 2 |
| 29 | // for i S1(i); |
| 30 | // S1; S2(i); |
| 31 | // S2; S1(i+1); |
| 32 | // for j S2(i+1); |
| 33 | // S3; for j |
| 34 | // S4; S3(i, j); |
| 35 | // S5; S4(i, j); |
| 36 | // S6; S3(i+1, j) |
| 37 | // S4(i+1, j) |
| 38 | // S5(i); |
| 39 | // S6(i); |
| 40 | // S5(i+1); |
| 41 | // S6(i+1); |
| 42 | // |
| 43 | // Note: 'if/else' blocks are not jammed. So, if there are loops inside if |
| 44 | // stmt's, bodies of those loops will not be jammed. |
| 45 | // |
| 46 | //===----------------------------------------------------------------------===// |
| 47 | #include "mlir/IR/AffineExpr.h" |
| 48 | #include "mlir/IR/Builders.h" |
| 49 | #include "mlir/IR/StandardOps.h" |
| 50 | #include "mlir/IR/StmtVisitor.h" |
| 51 | #include "mlir/Transforms/Pass.h" |
| 52 | #include "mlir/Transforms/Passes.h" |
| 53 | #include "llvm/ADT/DenseMap.h" |
| 54 | #include "llvm/Support/CommandLine.h" |
| 55 | |
| 56 | using namespace mlir; |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 57 | |
| 58 | // Loop unroll jam factor. |
| 59 | static llvm::cl::opt<unsigned> |
| 60 | clUnrollJamFactor("unroll-jam-factor", llvm::cl::Hidden, |
| 61 | llvm::cl::desc("Use this unroll jam factor for all loops" |
| 62 | " (default 4)")); |
| 63 | |
| 64 | namespace { |
| 65 | /// Loop unroll jam pass. For test purposes, this just unroll jams the first |
| 66 | /// outer loop in an MLFunction. |
| 67 | struct LoopUnrollAndJam : public MLFunctionPass { |
| 68 | Optional<unsigned> unrollJamFactor; |
| 69 | static const unsigned kDefaultUnrollJamFactor = 4; |
| 70 | |
| 71 | explicit LoopUnrollAndJam(Optional<unsigned> unrollJamFactor) |
| 72 | : unrollJamFactor(unrollJamFactor) {} |
| 73 | |
| 74 | void runOnMLFunction(MLFunction *f) override; |
| 75 | bool runOnForStmt(ForStmt *forStmt); |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 76 | bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 77 | }; |
| 78 | } // end anonymous namespace |
| 79 | |
| 80 | MLFunctionPass *mlir::createLoopUnrollAndJamPass(int unrollJamFactor) { |
| 81 | return new LoopUnrollAndJam( |
| 82 | unrollJamFactor == -1 ? None : Optional<unsigned>(unrollJamFactor)); |
| 83 | } |
| 84 | |
| 85 | void LoopUnrollAndJam::runOnMLFunction(MLFunction *f) { |
| 86 | // Currently, just the outermost loop from the first loop nest is |
| 87 | // unroll-and-jammed by this pass. However, runOnForStmt can be called on any |
| 88 | // for Stmt. |
| 89 | if (!isa<ForStmt>(f->begin())) |
| 90 | return; |
| 91 | |
| 92 | auto *forStmt = cast<ForStmt>(f->begin()); |
| 93 | runOnForStmt(forStmt); |
| 94 | } |
| 95 | |
| 96 | /// Unroll and jam a 'for' stmt. Default unroll jam factor is |
| 97 | /// kDefaultUnrollJamFactor. Return false if nothing was done. |
| 98 | bool LoopUnrollAndJam::runOnForStmt(ForStmt *forStmt) { |
| 99 | // Unroll and jam by the factor that was passed if any. |
| 100 | if (unrollJamFactor.hasValue()) |
| 101 | return loopUnrollJamByFactor(forStmt, unrollJamFactor.getValue()); |
| 102 | // Otherwise, unroll jam by the command-line factor if one was specified. |
| 103 | if (clUnrollJamFactor.getNumOccurrences() > 0) |
| 104 | return loopUnrollJamByFactor(forStmt, clUnrollJamFactor); |
| 105 | |
| 106 | // Unroll and jam by four otherwise. |
| 107 | return loopUnrollJamByFactor(forStmt, kDefaultUnrollJamFactor); |
| 108 | } |
| 109 | |
| 110 | /// Unrolls and jams this loop by the specified factor. |
| 111 | bool LoopUnrollAndJam::loopUnrollJamByFactor(ForStmt *forStmt, |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 112 | uint64_t unrollJamFactor) { |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 113 | // Gathers all maximal sub-blocks of statements that do not themselves include |
| 114 | // a for stmt (a statement could have a descendant for stmt though in its |
| 115 | // tree). |
| 116 | class JamBlockGatherer : public StmtWalker<JamBlockGatherer> { |
| 117 | public: |
| 118 | typedef llvm::iplist<Statement> StmtListType; |
| 119 | |
| 120 | // Store iterators to the first and last stmt of each sub-block found. |
| 121 | std::vector<std::pair<StmtBlock::iterator, StmtBlock::iterator>> subBlocks; |
| 122 | |
| 123 | // This is a linear time walk. |
| 124 | void walk(StmtListType::iterator Start, StmtListType::iterator End) { |
| 125 | for (auto it = Start; it != End;) { |
| 126 | auto subBlockStart = it; |
| 127 | while (it != End && !isa<ForStmt>(it)) |
| 128 | ++it; |
| 129 | if (it != subBlockStart) |
| 130 | // Record the last statement (one behind the iterator) while not |
| 131 | // changing the iterator position. |
| 132 | subBlocks.push_back({subBlockStart, (--it)++}); |
| 133 | // Process all for Stmts that appear next. |
| 134 | while (it != End && isa<ForStmt>(it)) |
| 135 | walkForStmt(cast<ForStmt>(it++)); |
| 136 | } |
| 137 | } |
| 138 | }; |
| 139 | |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 140 | assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1"); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 141 | |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 142 | if (unrollJamFactor == 1 || forStmt->getStatements().empty()) |
| 143 | return false; |
| 144 | |
| 145 | if (!forStmt->hasConstantBounds()) |
| 146 | return false; |
| 147 | |
| 148 | int64_t lb = forStmt->getConstantLowerBound(); |
| 149 | int64_t step = forStmt->getStep(); |
| 150 | uint64_t tripCount = forStmt->getConstantTripCount().getValue(); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 151 | |
| 152 | // If the trip count is lower than the unroll jam factor, no unrolled body. |
| 153 | // TODO(bondhugula): option to specify cleanup loop unrolling. |
| 154 | if (tripCount < unrollJamFactor) |
| 155 | return true; |
| 156 | |
| 157 | // Gather all sub-blocks to jam upon the loop being unrolled. |
| 158 | JamBlockGatherer jbg; |
| 159 | jbg.walkForStmt(forStmt); |
| 160 | auto &subBlocks = jbg.subBlocks; |
| 161 | |
| 162 | // Generate the cleanup loop if trip count isn't a multiple of |
| 163 | // unrollJamFactor. |
| 164 | if (tripCount % unrollJamFactor) { |
| 165 | DenseMap<const MLValue *, MLValue *> operandMap; |
| 166 | // Insert the cleanup loop right after 'forStmt'. |
| 167 | MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt)); |
| 168 | auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap)); |
| 169 | cleanupForStmt->setConstantLowerBound( |
| 170 | lb + (tripCount - tripCount % unrollJamFactor) * step); |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 171 | |
| 172 | // Promote the loop body up if this has turned into a single iteration loop. |
| 173 | promoteIfSingleIteration(cleanupForStmt); |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 174 | } |
| 175 | |
| 176 | MLFuncBuilder b(forStmt); |
| 177 | forStmt->setStep(step * unrollJamFactor); |
| 178 | forStmt->setConstantUpperBound( |
| 179 | lb + (tripCount - tripCount % unrollJamFactor - 1) * step); |
| 180 | |
| 181 | for (auto &subBlock : subBlocks) { |
| 182 | // Builder to insert unroll-jammed bodies. Insert right at the end of |
| 183 | // sub-block. |
| 184 | MLFuncBuilder builder(subBlock.first->getBlock(), |
| 185 | std::next(subBlock.second)); |
| 186 | |
| 187 | // Unroll and jam (appends unrollJamFactor-1 additional copies). |
| 188 | for (unsigned i = 1; i < unrollJamFactor; i++) { |
| 189 | DenseMap<const MLValue *, MLValue *> operandMapping; |
| 190 | |
| 191 | // If the induction variable is used, create a remapping to the value for |
| 192 | // this unrolled instance. |
| 193 | if (!forStmt->use_empty()) { |
| 194 | // iv' = iv + i, i = 1 to unrollJamFactor-1. |
| 195 | auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0), |
| 196 | builder.getConstantExpr(i * step)); |
| 197 | auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {}); |
| 198 | auto *ivUnroll = |
| 199 | builder.create<AffineApplyOp>(forStmt->getLoc(), bumpMap, forStmt) |
| 200 | ->getResult(0); |
| 201 | operandMapping[forStmt] = cast<MLValue>(ivUnroll); |
| 202 | } |
| 203 | // Clone the sub-block being unroll-jammed (this doesn't include the last |
| 204 | // stmt because subBlock.second is inclusive). |
| 205 | for (auto it = subBlock.first; it != subBlock.second; ++it) { |
| 206 | builder.clone(*it, operandMapping); |
| 207 | } |
| 208 | // Clone the last statement of the sub-block. |
| 209 | builder.clone(*subBlock.second, operandMapping); |
| 210 | } |
| 211 | } |
Uday Bondhugula | 832b17a | 2018-09-07 14:47:21 -0700 | [diff] [blame] | 212 | |
| 213 | // Promote the loop body up if this has turned into a single iteration loop. |
| 214 | promoteIfSingleIteration(forStmt); |
| 215 | |
Uday Bondhugula | 6cd3502 | 2018-08-28 18:24:27 -0700 | [diff] [blame] | 216 | return true; |
| 217 | } |