blob: 9d7e7cddb05e8b8f5297a4c2a1c7e17b77072c8a [file] [log] [blame]
Uday Bondhugula041817a2018-09-28 12:17:26 -07001//===- LoopUtils.cpp ---- Misc utilities for loop transformation ----------===//
Uday Bondhugula3bae0412018-09-07 14:47:21 -07002//
Mehdi Amini56222a02019-12-23 09:35:36 -08003// Part of the MLIR Project, under the Apache License v2.0 with LLVM Exceptions.
4// See https://llvm.org/LICENSE.txt for license information.
5// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
Uday Bondhugula3bae0412018-09-07 14:47:21 -07006//
Mehdi Amini56222a02019-12-23 09:35:36 -08007//===----------------------------------------------------------------------===//
Uday Bondhugula3bae0412018-09-07 14:47:21 -07008//
Uday Bondhugula041817a2018-09-28 12:17:26 -07009// This file implements miscellaneous loop transformation routines.
Uday Bondhugula3bae0412018-09-07 14:47:21 -070010//
11//===----------------------------------------------------------------------===//
12
Uday Bondhugulaab479722018-09-18 10:22:03 -070013#include "mlir/Transforms/LoopUtils.h"
Uday Bondhugula64812a52018-09-12 10:21:23 -070014
MLIR Team8f5f2c72019-02-15 09:32:18 -080015#include "mlir/Analysis/AffineAnalysis.h"
Uday Bondhugula64812a52018-09-12 10:21:23 -070016#include "mlir/Analysis/LoopAnalysis.h"
Nicolas Vasilache48a1bae2019-07-22 04:30:50 -070017#include "mlir/Analysis/SliceAnalysis.h"
Uday Bondhugula4f32ae62019-09-14 13:23:18 -070018#include "mlir/Analysis/Utils.h"
River Riddleffde9752019-08-20 15:36:08 -070019#include "mlir/Dialect/AffineOps/AffineOps.h"
Nicolas Vasilachecca53e82019-07-15 02:50:09 -070020#include "mlir/Dialect/LoopOps/LoopOps.h"
Uday Bondhugulaab479722018-09-18 10:22:03 -070021#include "mlir/IR/AffineMap.h"
River Riddle451869f2019-01-24 12:25:30 -080022#include "mlir/IR/BlockAndValueMapping.h"
River Riddle9dbef0b2019-07-11 11:41:04 -070023#include "mlir/IR/Function.h"
Alex Zinenkofc044e82019-07-15 06:40:07 -070024#include "mlir/Transforms/RegionUtils.h"
Uday Bondhugula4f32ae62019-09-14 13:23:18 -070025#include "mlir/Transforms/Utils.h"
Uday Bondhugula041817a2018-09-28 12:17:26 -070026#include "llvm/ADT/DenseMap.h"
Uday Bondhugula4f32ae62019-09-14 13:23:18 -070027#include "llvm/ADT/MapVector.h"
Nicolas Vasilache48a1bae2019-07-22 04:30:50 -070028#include "llvm/ADT/SetVector.h"
Alex Zinenkofc044e82019-07-15 06:40:07 -070029#include "llvm/ADT/SmallPtrSet.h"
Uday Bondhugulaccfe5932018-10-22 13:44:31 -070030#include "llvm/Support/Debug.h"
Uday Bondhugula4f32ae62019-09-14 13:23:18 -070031#include "llvm/Support/raw_ostream.h"
Alex Zinenkofc044e82019-07-15 06:40:07 -070032
Uday Bondhugulaccfe5932018-10-22 13:44:31 -070033#define DEBUG_TYPE "LoopUtils"
Uday Bondhugula3bae0412018-09-07 14:47:21 -070034
Uday Bondhugulaab479722018-09-18 10:22:03 -070035using namespace mlir;
Nicolas Vasilache48a1bae2019-07-22 04:30:50 -070036using llvm::SetVector;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -070037using llvm::SmallMapVector;
Uday Bondhugulaab479722018-09-18 10:22:03 -070038
Uday Bondhugula075090f2019-03-12 08:00:52 -070039/// Computes the cleanup loop lower bound of the loop being unrolled with
40/// the specified unroll factor; this bound will also be upper bound of the main
41/// part of the unrolled loop. Computes the bound as an AffineMap with its
42/// operands or a null map when the trip count can't be expressed as an affine
43/// expression.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -070044void mlir::getCleanupLoopLowerBound(AffineForOp forOp, unsigned unrollFactor,
45 AffineMap *map,
River Riddlee62a6952019-12-23 14:45:01 -080046 SmallVectorImpl<Value> *operands,
Nicolas Vasilache08047502019-06-20 15:10:35 -070047 OpBuilder &b) {
River Riddleaf1abcc2019-03-25 11:13:31 -070048 auto lbMap = forOp.getLowerBoundMap();
Uday Bondhugulaab479722018-09-18 10:22:03 -070049
50 // Single result lower bound map only.
Uday Bondhugula075090f2019-03-12 08:00:52 -070051 if (lbMap.getNumResults() != 1) {
52 *map = AffineMap();
53 return;
54 }
Uday Bondhugulaab479722018-09-18 10:22:03 -070055
Uday Bondhugula075090f2019-03-12 08:00:52 -070056 AffineMap tripCountMap;
River Riddlee62a6952019-12-23 14:45:01 -080057 SmallVector<Value, 4> tripCountOperands;
Uday Bondhugula075090f2019-03-12 08:00:52 -070058 buildTripCountMapAndOperands(forOp, &tripCountMap, &tripCountOperands);
Uday Bondhugulaab479722018-09-18 10:22:03 -070059
60 // Sometimes the trip count cannot be expressed as an affine expression.
Uday Bondhugula075090f2019-03-12 08:00:52 -070061 if (!tripCountMap) {
62 *map = AffineMap();
63 return;
64 }
Uday Bondhugulaab479722018-09-18 10:22:03 -070065
River Riddleaf1abcc2019-03-25 11:13:31 -070066 unsigned step = forOp.getStep();
River Riddled6ee6a02019-12-07 10:35:01 -080067 auto lb = b.create<AffineApplyOp>(forOp.getLoc(), lbMap,
68 forOp.getLowerBoundOperands());
Uday Bondhugula075090f2019-03-12 08:00:52 -070069
70 // For each upper bound expr, get the range.
River Riddle832567b2019-03-25 10:14:34 -070071 // Eg: affine.for %i = lb to min (ub1, ub2),
Uday Bondhugula075090f2019-03-12 08:00:52 -070072 // where tripCountExprs yield (tr1, tr2), we create affine.apply's:
73 // lb + tr1 - tr1 % ufactor, lb + tr2 - tr2 % ufactor; the results of all
74 // these affine.apply's make up the cleanup loop lower bound.
75 SmallVector<AffineExpr, 4> bumpExprs(tripCountMap.getNumResults());
River Riddlee62a6952019-12-23 14:45:01 -080076 SmallVector<Value, 4> bumpValues(tripCountMap.getNumResults());
Uday Bondhugula075090f2019-03-12 08:00:52 -070077 for (unsigned i = 0, e = tripCountMap.getNumResults(); i < e; i++) {
78 auto tripCountExpr = tripCountMap.getResult(i);
79 bumpExprs[i] = (tripCountExpr - tripCountExpr % unrollFactor) * step;
River Riddle2acc2202019-10-17 20:08:01 -070080 auto bumpMap = AffineMap::get(tripCountMap.getNumDims(),
Nicolas Vasilache08047502019-06-20 15:10:35 -070081 tripCountMap.getNumSymbols(), bumpExprs[i]);
Uday Bondhugula075090f2019-03-12 08:00:52 -070082 bumpValues[i] =
Nicolas Vasilache08047502019-06-20 15:10:35 -070083 b.create<AffineApplyOp>(forOp.getLoc(), bumpMap, tripCountOperands);
Uday Bondhugula075090f2019-03-12 08:00:52 -070084 }
85
86 SmallVector<AffineExpr, 4> newUbExprs(tripCountMap.getNumResults());
87 for (unsigned i = 0, e = bumpExprs.size(); i < e; i++)
Nicolas Vasilache08047502019-06-20 15:10:35 -070088 newUbExprs[i] = b.getAffineDimExpr(0) + b.getAffineDimExpr(i + 1);
Uday Bondhugula075090f2019-03-12 08:00:52 -070089
90 operands->clear();
91 operands->push_back(lb);
92 operands->append(bumpValues.begin(), bumpValues.end());
River Riddle2acc2202019-10-17 20:08:01 -070093 *map = AffineMap::get(1 + tripCountMap.getNumResults(), 0, newUbExprs);
Uday Bondhugula075090f2019-03-12 08:00:52 -070094 // Simplify the map + operands.
95 fullyComposeAffineMapAndOperands(map, operands);
96 *map = simplifyAffineMap(*map);
97 canonicalizeMapAndOperands(map, operands);
98 // Remove any affine.apply's that became dead from the simplification above.
River Riddle35807bc2019-12-22 21:59:55 -080099 for (auto v : bumpValues) {
River Riddle2bdf33c2020-01-11 08:54:04 -0800100 if (v.use_empty())
101 v.getDefiningOp()->erase();
Uday Bondhugula075090f2019-03-12 08:00:52 -0700102 }
River Riddleaf1abcc2019-03-25 11:13:31 -0700103 if (lb.use_empty())
104 lb.erase();
Uday Bondhugulaab479722018-09-18 10:22:03 -0700105}
106
River Riddle5052bd82019-02-01 16:42:18 -0800107/// Promotes the loop body of a forOp to its containing block if the forOp
River Riddleba6fdc82019-03-06 17:37:14 -0800108/// was known to have a single iteration.
Uday Bondhugula64812a52018-09-12 10:21:23 -0700109// TODO(bondhugula): extend this for arbitrary affine bounds.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700110LogicalResult mlir::promoteIfSingleIteration(AffineForOp forOp) {
River Riddle5052bd82019-02-01 16:42:18 -0800111 Optional<uint64_t> tripCount = getConstantTripCount(forOp);
Uday Bondhugulaab479722018-09-18 10:22:03 -0700112 if (!tripCount.hasValue() || tripCount.getValue() != 1)
River Riddle0310d492019-03-10 15:32:54 -0700113 return failure();
Uday Bondhugula3bae0412018-09-07 14:47:21 -0700114
Uday Bondhugulaab479722018-09-18 10:22:03 -0700115 // TODO(mlir-team): there is no builder for a max.
River Riddleaf1abcc2019-03-25 11:13:31 -0700116 if (forOp.getLowerBoundMap().getNumResults() != 1)
River Riddle0310d492019-03-10 15:32:54 -0700117 return failure();
Uday Bondhugula3bae0412018-09-07 14:47:21 -0700118
119 // Replaces all IV uses to its single iteration value.
River Riddle35807bc2019-12-22 21:59:55 -0800120 auto iv = forOp.getInductionVar();
River Riddle99b87c92019-03-27 14:02:02 -0700121 Operation *op = forOp.getOperation();
River Riddle2bdf33c2020-01-11 08:54:04 -0800122 if (!iv.use_empty()) {
River Riddleaf1abcc2019-03-25 11:13:31 -0700123 if (forOp.hasConstantLowerBound()) {
River Riddlece502af2019-07-08 11:20:26 -0700124 OpBuilder topBuilder(op->getParentOfType<FuncOp>().getBody());
Chris Lattnerd2d89cb2018-10-06 17:21:53 -0700125 auto constOp = topBuilder.create<ConstantIndexOp>(
River Riddleaf1abcc2019-03-25 11:13:31 -0700126 forOp.getLoc(), forOp.getConstantLowerBound());
River Riddle2bdf33c2020-01-11 08:54:04 -0800127 iv.replaceAllUsesWith(constOp);
Uday Bondhugulaab479722018-09-18 10:22:03 -0700128 } else {
River Riddleaf1abcc2019-03-25 11:13:31 -0700129 AffineBound lb = forOp.getLowerBound();
River Riddlee62a6952019-12-23 14:45:01 -0800130 SmallVector<Value, 4> lbOperands(lb.operand_begin(), lb.operand_end());
River Riddlef1b848e2019-06-04 19:18:23 -0700131 OpBuilder builder(op->getBlock(), Block::iterator(op));
Uday Bondhugula94a03f82019-01-22 13:58:52 -0800132 if (lb.getMap() == builder.getDimIdentityMap()) {
River Riddle3227dee2019-02-06 11:08:18 -0800133 // No need of generating an affine.apply.
River Riddle2bdf33c2020-01-11 08:54:04 -0800134 iv.replaceAllUsesWith(lbOperands[0]);
Uday Bondhugula94a03f82019-01-22 13:58:52 -0800135 } else {
136 auto affineApplyOp = builder.create<AffineApplyOp>(
River Riddle99b87c92019-03-27 14:02:02 -0700137 op->getLoc(), lb.getMap(), lbOperands);
River Riddle2bdf33c2020-01-11 08:54:04 -0800138 iv.replaceAllUsesWith(affineApplyOp);
Uday Bondhugula94a03f82019-01-22 13:58:52 -0800139 }
Uday Bondhugulaab479722018-09-18 10:22:03 -0700140 }
141 }
River Riddle99b87c92019-03-27 14:02:02 -0700142 // Move the loop body operations, except for terminator, to the loop's
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700143 // containing block.
River Riddle99b87c92019-03-27 14:02:02 -0700144 auto *block = op->getBlock();
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700145 forOp.getBody()->getOperations().back().erase();
River Riddle99b87c92019-03-27 14:02:02 -0700146 block->getOperations().splice(Block::iterator(op),
River Riddlef9d91532019-03-26 17:05:09 -0700147 forOp.getBody()->getOperations());
River Riddleaf1abcc2019-03-25 11:13:31 -0700148 forOp.erase();
River Riddle0310d492019-03-10 15:32:54 -0700149 return success();
Uday Bondhugula3bae0412018-09-07 14:47:21 -0700150}
151
River Riddle8c443672019-07-09 16:17:55 -0700152/// Promotes all single iteration for op's in the FuncOp, i.e., moves
Chris Lattner315a4662018-12-28 13:07:39 -0800153/// their body into the containing Block.
River Riddle8c443672019-07-09 16:17:55 -0700154void mlir::promoteSingleIterationLoops(FuncOp f) {
Uday Bondhugula3bae0412018-09-07 14:47:21 -0700155 // Gathers all innermost loops through a post order pruned walk.
River Riddle4bfae662019-08-29 13:04:22 -0700156 f.walk([](AffineForOp forOp) { promoteIfSingleIteration(forOp); });
Uday Bondhugula3bae0412018-09-07 14:47:21 -0700157}
Uday Bondhugula041817a2018-09-28 12:17:26 -0700158
River Riddle99b87c92019-03-27 14:02:02 -0700159/// Generates a 'affine.for' op with the specified lower and upper bounds
160/// while generating the right IV remappings for the shifted operations. The
161/// operation blocks that go into the loop are specified in instGroupQueue
Uday Bondhugula041817a2018-09-28 12:17:26 -0700162/// starting from the specified offset, and in that order; the first element of
River Riddle99b87c92019-03-27 14:02:02 -0700163/// the pair specifies the shift applied to that group of operations; note
Chris Lattner456ad6a2018-12-28 16:05:35 -0800164/// that the shift is multiplied by the loop step before being applied. Returns
Uday Bondhugula041817a2018-09-28 12:17:26 -0700165/// nullptr if the generated loop simplifies to a single iteration one.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700166static AffineForOp
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800167generateLoop(AffineMap lbMap, AffineMap ubMap,
River Riddle99b87c92019-03-27 14:02:02 -0700168 const std::vector<std::pair<uint64_t, ArrayRef<Operation *>>>
Chris Lattner456ad6a2018-12-28 16:05:35 -0800169 &instGroupQueue,
Nicolas Vasilache08047502019-06-20 15:10:35 -0700170 unsigned offset, AffineForOp srcForInst, OpBuilder b) {
River Riddlee62a6952019-12-23 14:45:01 -0800171 SmallVector<Value, 4> lbOperands(srcForInst.getLowerBoundOperands());
172 SmallVector<Value, 4> ubOperands(srcForInst.getUpperBoundOperands());
Uday Bondhugula041817a2018-09-28 12:17:26 -0700173
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800174 assert(lbMap.getNumInputs() == lbOperands.size());
175 assert(ubMap.getNumInputs() == ubOperands.size());
176
River Riddle5052bd82019-02-01 16:42:18 -0800177 auto loopChunk =
Nicolas Vasilache08047502019-06-20 15:10:35 -0700178 b.create<AffineForOp>(srcForInst.getLoc(), lbOperands, lbMap, ubOperands,
179 ubMap, srcForInst.getStep());
River Riddle35807bc2019-12-22 21:59:55 -0800180 auto loopChunkIV = loopChunk.getInductionVar();
181 auto srcIV = srcForInst.getInductionVar();
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800182
River Riddle451869f2019-01-24 12:25:30 -0800183 BlockAndValueMapping operandMap;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700184
River Riddlef1b848e2019-06-04 19:18:23 -0700185 OpBuilder bodyBuilder = loopChunk.getBodyBuilder();
Chris Lattner456ad6a2018-12-28 16:05:35 -0800186 for (auto it = instGroupQueue.begin() + offset, e = instGroupQueue.end();
Uday Bondhugula041817a2018-09-28 12:17:26 -0700187 it != e; ++it) {
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800188 uint64_t shift = it->first;
Chris Lattner456ad6a2018-12-28 16:05:35 -0800189 auto insts = it->second;
River Riddle99b87c92019-03-27 14:02:02 -0700190 // All 'same shift' operations get added with their operands being
191 // remapped to results of cloned operations, and their IV used remapped.
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800192 // Generate the remapping if the shift is not zero: remappedIV = newIV -
193 // shift.
River Riddle2bdf33c2020-01-11 08:54:04 -0800194 if (!srcIV.use_empty() && shift != 0) {
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700195 auto ivRemap = bodyBuilder.create<AffineApplyOp>(
River Riddleaf1abcc2019-03-25 11:13:31 -0700196 srcForInst.getLoc(),
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700197 bodyBuilder.getSingleDimShiftAffineMap(
River Riddleaf1abcc2019-03-25 11:13:31 -0700198 -static_cast<int64_t>(srcForInst.getStep() * shift)),
Chris Lattnerb42bea22019-01-27 09:33:19 -0800199 loopChunkIV);
River Riddle36babbd2019-01-26 12:40:12 -0800200 operandMap.map(srcIV, ivRemap);
Uday Bondhugula041817a2018-09-28 12:17:26 -0700201 } else {
River Riddle36babbd2019-01-26 12:40:12 -0800202 operandMap.map(srcIV, loopChunkIV);
Uday Bondhugula041817a2018-09-28 12:17:26 -0700203 }
River Riddle99b87c92019-03-27 14:02:02 -0700204 for (auto *op : insts) {
River Riddled5b60ee2019-05-11 18:59:54 -0700205 if (!isa<AffineTerminatorOp>(op))
River Riddle99b87c92019-03-27 14:02:02 -0700206 bodyBuilder.clone(*op, operandMap);
Uday Bondhugula041817a2018-09-28 12:17:26 -0700207 }
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700208 };
River Riddleba6fdc82019-03-06 17:37:14 -0800209 if (succeeded(promoteIfSingleIteration(loopChunk)))
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700210 return AffineForOp();
Uday Bondhugula041817a2018-09-28 12:17:26 -0700211 return loopChunk;
212}
213
River Riddle99b87c92019-03-27 14:02:02 -0700214/// Skew the operations in the body of a 'affine.for' operation with the
215/// specified operation-wise shifts. The shifts are with respect to the
River Riddle832567b2019-03-25 10:14:34 -0700216/// original execution order, and are multiplied by the loop 'step' before being
River Riddle99b87c92019-03-27 14:02:02 -0700217/// applied. A shift of zero for each operation will lead to no change.
218// The skewing of operations with respect to one another can be used for
Chris Lattner456ad6a2018-12-28 16:05:35 -0800219// example to allow overlap of asynchronous operations (such as DMA
River Riddle99b87c92019-03-27 14:02:02 -0700220// communication) with computation, or just relative shifting of operations
Chris Lattner456ad6a2018-12-28 16:05:35 -0800221// for better register reuse, locality or parallelism. As such, the shifts are
River Riddle99b87c92019-03-27 14:02:02 -0700222// typically expected to be at most of the order of the number of operations.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800223// This method should not be used as a substitute for loop distribution/fission.
River Riddle99b87c92019-03-27 14:02:02 -0700224// This method uses an algorithm// in time linear in the number of operations
Chris Lattner456ad6a2018-12-28 16:05:35 -0800225// in the body of the for loop - (using the 'sweep line' paradigm). This method
Uday Bondhugula041817a2018-09-28 12:17:26 -0700226// asserts preservation of SSA dominance. A check for that as well as that for
Kazuaki Ishizaki8bfedb32019-10-20 00:11:03 -0700227// memory-based dependence preservation check rests with the users of this
Uday Bondhugula041817a2018-09-28 12:17:26 -0700228// method.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700229LogicalResult mlir::instBodySkew(AffineForOp forOp, ArrayRef<uint64_t> shifts,
River Riddle80d35682019-03-08 16:04:42 -0800230 bool unrollPrologueEpilogue) {
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700231 if (forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
River Riddle0310d492019-03-10 15:32:54 -0700232 return success();
Uday Bondhugula041817a2018-09-28 12:17:26 -0700233
234 // If the trip counts aren't constant, we would need versioning and
235 // conditional guards (or context information to prevent such versioning). The
236 // better way to pipeline for such loops is to first tile them and extract
237 // constant trip count "full tiles" before applying this.
River Riddle5052bd82019-02-01 16:42:18 -0800238 auto mayBeConstTripCount = getConstantTripCount(forOp);
Uday Bondhugulaccfe5932018-10-22 13:44:31 -0700239 if (!mayBeConstTripCount.hasValue()) {
River Riddleb14c4b42019-05-01 12:13:44 -0700240 LLVM_DEBUG(forOp.emitRemark("non-constant trip count loop not handled"));
River Riddle0310d492019-03-10 15:32:54 -0700241 return success();
Uday Bondhugulaccfe5932018-10-22 13:44:31 -0700242 }
Uday Bondhugula041817a2018-09-28 12:17:26 -0700243 uint64_t tripCount = mayBeConstTripCount.getValue();
244
River Riddle5052bd82019-02-01 16:42:18 -0800245 assert(isInstwiseShiftValid(forOp, shifts) &&
Uday Bondhugulaccfe5932018-10-22 13:44:31 -0700246 "shifts will lead to an invalid transformation\n");
Uday Bondhugula041817a2018-09-28 12:17:26 -0700247
River Riddleaf1abcc2019-03-25 11:13:31 -0700248 int64_t step = forOp.getStep();
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800249
River Riddlef9d91532019-03-26 17:05:09 -0700250 unsigned numChildInsts = forOp.getBody()->getOperations().size();
Uday Bondhugula041817a2018-09-28 12:17:26 -0700251
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800252 // Do a linear time (counting) sort for the shifts.
253 uint64_t maxShift = 0;
Chris Lattner456ad6a2018-12-28 16:05:35 -0800254 for (unsigned i = 0; i < numChildInsts; i++) {
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800255 maxShift = std::max(maxShift, shifts[i]);
Uday Bondhugula041817a2018-09-28 12:17:26 -0700256 }
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800257 // Such large shifts are not the typical use case.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800258 if (maxShift >= numChildInsts) {
River Riddleaf1abcc2019-03-25 11:13:31 -0700259 forOp.emitWarning("not shifting because shifts are unrealistically large");
River Riddle0310d492019-03-10 15:32:54 -0700260 return success();
Uday Bondhugulaccfe5932018-10-22 13:44:31 -0700261 }
Uday Bondhugula041817a2018-09-28 12:17:26 -0700262
River Riddle99b87c92019-03-27 14:02:02 -0700263 // An array of operation groups sorted by shift amount; each group has all
264 // operations with the same shift in the order in which they appear in the
265 // body of the 'affine.for' op.
266 std::vector<std::vector<Operation *>> sortedInstGroups(maxShift + 1);
Uday Bondhugula041817a2018-09-28 12:17:26 -0700267 unsigned pos = 0;
River Riddle99b87c92019-03-27 14:02:02 -0700268 for (auto &op : *forOp.getBody()) {
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800269 auto shift = shifts[pos++];
River Riddle99b87c92019-03-27 14:02:02 -0700270 sortedInstGroups[shift].push_back(&op);
Uday Bondhugula041817a2018-09-28 12:17:26 -0700271 }
272
273 // Unless the shifts have a specific pattern (which actually would be the
274 // common use case), prologue and epilogue are not meaningfully defined.
275 // Nevertheless, if 'unrollPrologueEpilogue' is set, we will treat the first
276 // loop generated as the prologue and the last as epilogue and unroll these
277 // fully.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700278 AffineForOp prologue;
279 AffineForOp epilogue;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700280
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800281 // Do a sweep over the sorted shifts while storing open groups in a
Uday Bondhugula041817a2018-09-28 12:17:26 -0700282 // vector, and generating loop portions as necessary during the sweep. A block
River Riddle99b87c92019-03-27 14:02:02 -0700283 // of operations is paired with its shift.
284 std::vector<std::pair<uint64_t, ArrayRef<Operation *>>> instGroupQueue;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700285
River Riddleaf1abcc2019-03-25 11:13:31 -0700286 auto origLbMap = forOp.getLowerBoundMap();
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800287 uint64_t lbShift = 0;
River Riddlef1b848e2019-06-04 19:18:23 -0700288 OpBuilder b(forOp.getOperation());
Chris Lattner456ad6a2018-12-28 16:05:35 -0800289 for (uint64_t d = 0, e = sortedInstGroups.size(); d < e; ++d) {
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800290 // If nothing is shifted by d, continue.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800291 if (sortedInstGroups[d].empty())
Uday Bondhugula041817a2018-09-28 12:17:26 -0700292 continue;
Chris Lattner456ad6a2018-12-28 16:05:35 -0800293 if (!instGroupQueue.empty()) {
Uday Bondhugula041817a2018-09-28 12:17:26 -0700294 assert(d >= 1 &&
295 "Queue expected to be empty when the first block is found");
296 // The interval for which the loop needs to be generated here is:
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800297 // [lbShift, min(lbShift + tripCount, d)) and the body of the
River Riddle99b87c92019-03-27 14:02:02 -0700298 // loop needs to have all operations in instQueue in that order.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700299 AffineForOp res;
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800300 if (lbShift + tripCount * step < d * step) {
301 res = generateLoop(
302 b.getShiftedAffineMap(origLbMap, lbShift),
303 b.getShiftedAffineMap(origLbMap, lbShift + tripCount * step),
Nicolas Vasilache08047502019-06-20 15:10:35 -0700304 instGroupQueue, 0, forOp, b);
River Riddle99b87c92019-03-27 14:02:02 -0700305 // Entire loop for the queued op groups generated, empty it.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800306 instGroupQueue.clear();
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800307 lbShift += tripCount * step;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700308 } else {
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800309 res = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
Chris Lattner456ad6a2018-12-28 16:05:35 -0800310 b.getShiftedAffineMap(origLbMap, d), instGroupQueue,
Nicolas Vasilache08047502019-06-20 15:10:35 -0700311 0, forOp, b);
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800312 lbShift = d * step;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700313 }
314 if (!prologue && res)
315 prologue = res;
316 epilogue = res;
317 } else {
318 // Start of first interval.
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800319 lbShift = d * step;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700320 }
River Riddle99b87c92019-03-27 14:02:02 -0700321 // Augment the list of operations that get into the current open interval.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800322 instGroupQueue.push_back({d, sortedInstGroups[d]});
Uday Bondhugula041817a2018-09-28 12:17:26 -0700323 }
324
River Riddle99b87c92019-03-27 14:02:02 -0700325 // Those operations groups left in the queue now need to be processed (FIFO)
Uday Bondhugula041817a2018-09-28 12:17:26 -0700326 // and their loops completed.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800327 for (unsigned i = 0, e = instGroupQueue.size(); i < e; ++i) {
328 uint64_t ubShift = (instGroupQueue[i].first + tripCount) * step;
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800329 epilogue = generateLoop(b.getShiftedAffineMap(origLbMap, lbShift),
330 b.getShiftedAffineMap(origLbMap, ubShift),
Nicolas Vasilache08047502019-06-20 15:10:35 -0700331 instGroupQueue, i, forOp, b);
Uday Bondhugulab9f53dc2018-12-10 15:17:25 -0800332 lbShift = ubShift;
Uday Bondhugula041817a2018-09-28 12:17:26 -0700333 if (!prologue)
334 prologue = epilogue;
335 }
336
River Riddle99b87c92019-03-27 14:02:02 -0700337 // Erase the original for op.
River Riddleaf1abcc2019-03-25 11:13:31 -0700338 forOp.erase();
Uday Bondhugula041817a2018-09-28 12:17:26 -0700339
340 if (unrollPrologueEpilogue && prologue)
341 loopUnrollFull(prologue);
River Riddle5052bd82019-02-01 16:42:18 -0800342 if (unrollPrologueEpilogue && !epilogue &&
River Riddlef9d91532019-03-26 17:05:09 -0700343 epilogue.getOperation() != prologue.getOperation())
Uday Bondhugula041817a2018-09-28 12:17:26 -0700344 loopUnrollFull(epilogue);
345
River Riddle0310d492019-03-10 15:32:54 -0700346 return success();
Uday Bondhugula041817a2018-09-28 12:17:26 -0700347}
Alex Zinenkocb406332018-11-14 14:15:48 -0800348
Alex Zinenko9d03f562019-07-09 06:37:17 -0700349// Collect perfectly nested loops starting from `rootForOps`. Loops are
350// perfectly nested if each loop is the first and only non-terminator operation
351// in the parent loop. Collect at most `maxLoops` loops and append them to
352// `forOps`.
353template <typename T>
354void getPerfectlyNestedLoopsImpl(
355 SmallVectorImpl<T> &forOps, T rootForOp,
356 unsigned maxLoops = std::numeric_limits<unsigned>::max()) {
357 for (unsigned i = 0; i < maxLoops; ++i) {
358 forOps.push_back(rootForOp);
Uday Bondhugula74eabdd2019-09-27 11:57:52 -0700359 Block &body = rootForOp.region().front();
Alex Zinenko9d03f562019-07-09 06:37:17 -0700360 if (body.begin() != std::prev(body.end(), 2))
361 return;
362
363 rootForOp = dyn_cast<T>(&body.front());
364 if (!rootForOp)
365 return;
366 }
367}
368
MLIR Team0cd589c2019-04-04 15:19:17 -0700369/// Get perfectly nested sequence of loops starting at root of loop nest
370/// (the first op being another AffineFor, and the second op - a terminator).
371/// A loop is perfectly nested iff: the first op in the loop's body is another
372/// AffineForOp, and the second op is a terminator).
373void mlir::getPerfectlyNestedLoops(SmallVectorImpl<AffineForOp> &nestedLoops,
374 AffineForOp root) {
Alex Zinenko9d03f562019-07-09 06:37:17 -0700375 getPerfectlyNestedLoopsImpl(nestedLoops, root);
MLIR Team0cd589c2019-04-04 15:19:17 -0700376}
377
Alex Zinenkofc044e82019-07-15 06:40:07 -0700378void mlir::getPerfectlyNestedLoops(SmallVectorImpl<loop::ForOp> &nestedLoops,
379 loop::ForOp root) {
380 getPerfectlyNestedLoopsImpl(nestedLoops, root);
381}
382
Alex Zinenkocb406332018-11-14 14:15:48 -0800383/// Unrolls this loop completely.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700384LogicalResult mlir::loopUnrollFull(AffineForOp forOp) {
River Riddle5052bd82019-02-01 16:42:18 -0800385 Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
Alex Zinenkocb406332018-11-14 14:15:48 -0800386 if (mayBeConstantTripCount.hasValue()) {
387 uint64_t tripCount = mayBeConstantTripCount.getValue();
388 if (tripCount == 1) {
River Riddle5052bd82019-02-01 16:42:18 -0800389 return promoteIfSingleIteration(forOp);
Alex Zinenkocb406332018-11-14 14:15:48 -0800390 }
River Riddle5052bd82019-02-01 16:42:18 -0800391 return loopUnrollByFactor(forOp, tripCount);
Alex Zinenkocb406332018-11-14 14:15:48 -0800392 }
River Riddle0310d492019-03-10 15:32:54 -0700393 return failure();
Alex Zinenkocb406332018-11-14 14:15:48 -0800394}
395
396/// Unrolls and jams this loop by the specified factor or by the trip count (if
397/// constant) whichever is lower.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700398LogicalResult mlir::loopUnrollUpToFactor(AffineForOp forOp,
River Riddle80d35682019-03-08 16:04:42 -0800399 uint64_t unrollFactor) {
River Riddle5052bd82019-02-01 16:42:18 -0800400 Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
Alex Zinenkocb406332018-11-14 14:15:48 -0800401
402 if (mayBeConstantTripCount.hasValue() &&
403 mayBeConstantTripCount.getValue() < unrollFactor)
River Riddle5052bd82019-02-01 16:42:18 -0800404 return loopUnrollByFactor(forOp, mayBeConstantTripCount.getValue());
405 return loopUnrollByFactor(forOp, unrollFactor);
Alex Zinenkocb406332018-11-14 14:15:48 -0800406}
407
River Riddleba6fdc82019-03-06 17:37:14 -0800408/// Unrolls this loop by the specified factor. Returns success if the loop
Alex Zinenkocb406332018-11-14 14:15:48 -0800409/// is successfully unrolled.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700410LogicalResult mlir::loopUnrollByFactor(AffineForOp forOp,
River Riddle80d35682019-03-08 16:04:42 -0800411 uint64_t unrollFactor) {
Alex Zinenkocb406332018-11-14 14:15:48 -0800412 assert(unrollFactor >= 1 && "unroll factor should be >= 1");
413
Uday Bondhugula92e9d942019-01-22 15:48:07 -0800414 if (unrollFactor == 1)
River Riddle5052bd82019-02-01 16:42:18 -0800415 return promoteIfSingleIteration(forOp);
Uday Bondhugula92e9d942019-01-22 15:48:07 -0800416
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700417 if (forOp.getBody()->empty() ||
418 forOp.getBody()->begin() == std::prev(forOp.getBody()->end()))
River Riddle0310d492019-03-10 15:32:54 -0700419 return failure();
Alex Zinenkocb406332018-11-14 14:15:48 -0800420
Uday Bondhugula075090f2019-03-12 08:00:52 -0700421 // Loops where the lower bound is a max expression isn't supported for
422 // unrolling since the trip count can be expressed as an affine function when
423 // both the lower bound and the upper bound are multi-result maps. However,
424 // one meaningful way to do such unrolling would be to specialize the loop for
425 // the 'hotspot' case and unroll that hotspot.
River Riddleaf1abcc2019-03-25 11:13:31 -0700426 if (forOp.getLowerBoundMap().getNumResults() != 1)
River Riddle0310d492019-03-10 15:32:54 -0700427 return failure();
Alex Zinenkocb406332018-11-14 14:15:48 -0800428
Alex Zinenkocb406332018-11-14 14:15:48 -0800429 // If the trip count is lower than the unroll factor, no unrolled body.
430 // TODO(bondhugula): option to specify cleanup loop unrolling.
Uday Bondhugula075090f2019-03-12 08:00:52 -0700431 Optional<uint64_t> mayBeConstantTripCount = getConstantTripCount(forOp);
Alex Zinenkocb406332018-11-14 14:15:48 -0800432 if (mayBeConstantTripCount.hasValue() &&
433 mayBeConstantTripCount.getValue() < unrollFactor)
River Riddle0310d492019-03-10 15:32:54 -0700434 return failure();
Alex Zinenkocb406332018-11-14 14:15:48 -0800435
436 // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
River Riddle99b87c92019-03-27 14:02:02 -0700437 Operation *op = forOp.getOperation();
River Riddle5052bd82019-02-01 16:42:18 -0800438 if (getLargestDivisorOfTripCount(forOp) % unrollFactor != 0) {
River Riddlef1b848e2019-06-04 19:18:23 -0700439 OpBuilder builder(op->getBlock(), ++Block::iterator(op));
River Riddleadca3c22019-05-11 17:57:32 -0700440 auto cleanupForInst = cast<AffineForOp>(builder.clone(*op));
Uday Bondhugula075090f2019-03-12 08:00:52 -0700441 AffineMap cleanupMap;
River Riddlee62a6952019-12-23 14:45:01 -0800442 SmallVector<Value, 4> cleanupOperands;
Uday Bondhugula075090f2019-03-12 08:00:52 -0700443 getCleanupLoopLowerBound(forOp, unrollFactor, &cleanupMap, &cleanupOperands,
Nicolas Vasilache08047502019-06-20 15:10:35 -0700444 builder);
Uday Bondhugula075090f2019-03-12 08:00:52 -0700445 assert(cleanupMap &&
446 "cleanup loop lower bound map for single result lower bound maps "
447 "can always be determined");
River Riddleaf1abcc2019-03-25 11:13:31 -0700448 cleanupForInst.setLowerBound(cleanupOperands, cleanupMap);
Alex Zinenkocb406332018-11-14 14:15:48 -0800449 // Promote the loop body up if this has turned into a single iteration loop.
Chris Lattner456ad6a2018-12-28 16:05:35 -0800450 promoteIfSingleIteration(cleanupForInst);
Alex Zinenkocb406332018-11-14 14:15:48 -0800451
Uday Bondhugula075090f2019-03-12 08:00:52 -0700452 // Adjust upper bound of the original loop; this is the same as the lower
453 // bound of the cleanup loop.
River Riddleaf1abcc2019-03-25 11:13:31 -0700454 forOp.setUpperBound(cleanupOperands, cleanupMap);
Alex Zinenkocb406332018-11-14 14:15:48 -0800455 }
456
457 // Scale the step of loop being unrolled by unroll factor.
River Riddleaf1abcc2019-03-25 11:13:31 -0700458 int64_t step = forOp.getStep();
459 forOp.setStep(step * unrollFactor);
Alex Zinenkocb406332018-11-14 14:15:48 -0800460
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700461 // Builder to insert unrolled bodies just before the terminator of the body of
462 // 'forOp'.
River Riddlef1b848e2019-06-04 19:18:23 -0700463 OpBuilder builder = forOp.getBodyBuilder();
Alex Zinenkocb406332018-11-14 14:15:48 -0800464
River Riddle99b87c92019-03-27 14:02:02 -0700465 // Keep a pointer to the last non-terminator operation in the original block
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700466 // so that we know what to clone (since we are doing this in-place).
467 Block::iterator srcBlockEnd = std::prev(forOp.getBody()->end(), 2);
Alex Zinenkocb406332018-11-14 14:15:48 -0800468
River Riddle5052bd82019-02-01 16:42:18 -0800469 // Unroll the contents of 'forOp' (append unrollFactor-1 additional copies).
River Riddle35807bc2019-12-22 21:59:55 -0800470 auto forOpIV = forOp.getInductionVar();
Alex Zinenkocb406332018-11-14 14:15:48 -0800471 for (unsigned i = 1; i < unrollFactor; i++) {
River Riddle451869f2019-01-24 12:25:30 -0800472 BlockAndValueMapping operandMap;
Alex Zinenkocb406332018-11-14 14:15:48 -0800473
474 // If the induction variable is used, create a remapping to the value for
475 // this unrolled instance.
River Riddle2bdf33c2020-01-11 08:54:04 -0800476 if (!forOpIV.use_empty()) {
Alex Zinenkocb406332018-11-14 14:15:48 -0800477 // iv' = iv + 1/2/3...unrollFactor-1;
478 auto d0 = builder.getAffineDimExpr(0);
River Riddle2acc2202019-10-17 20:08:01 -0700479 auto bumpMap = AffineMap::get(1, 0, {d0 + i * step});
Chris Lattnerb42bea22019-01-27 09:33:19 -0800480 auto ivUnroll =
River Riddleaf1abcc2019-03-25 11:13:31 -0700481 builder.create<AffineApplyOp>(forOp.getLoc(), bumpMap, forOpIV);
River Riddle5052bd82019-02-01 16:42:18 -0800482 operandMap.map(forOpIV, ivUnroll);
Alex Zinenkocb406332018-11-14 14:15:48 -0800483 }
484
River Riddle5052bd82019-02-01 16:42:18 -0800485 // Clone the original body of 'forOp'.
River Riddleaf1abcc2019-03-25 11:13:31 -0700486 for (auto it = forOp.getBody()->begin(); it != std::next(srcBlockEnd);
Chris Lattner1301f902018-12-23 08:17:48 -0800487 it++) {
Alex Zinenkocb406332018-11-14 14:15:48 -0800488 builder.clone(*it, operandMap);
489 }
490 }
491
492 // Promote the loop body up if this has turned into a single iteration loop.
River Riddle5052bd82019-02-01 16:42:18 -0800493 promoteIfSingleIteration(forOp);
River Riddle0310d492019-03-10 15:32:54 -0700494 return success();
Alex Zinenkocb406332018-11-14 14:15:48 -0800495}
MLIR Team8f5f2c72019-02-15 09:32:18 -0800496
497/// Performs loop interchange on 'forOpA' and 'forOpB', where 'forOpB' is
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700498/// nested within 'forOpA' as the only non-terminator operation in its block.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700499void mlir::interchangeLoops(AffineForOp forOpA, AffineForOp forOpB) {
River Riddlef9d91532019-03-26 17:05:09 -0700500 auto *forOpAInst = forOpA.getOperation();
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700501
River Riddlef9d91532019-03-26 17:05:09 -0700502 assert(&*forOpA.getBody()->begin() == forOpB.getOperation());
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700503 auto &forOpABody = forOpA.getBody()->getOperations();
504 auto &forOpBBody = forOpB.getBody()->getOperations();
505
506 // 1) Splice forOpA's non-terminator operations (which is just forOpB) just
507 // before forOpA (in ForOpA's parent's block) this should leave 'forOpA's
508 // body containing only the terminator.
509 forOpAInst->getBlock()->getOperations().splice(Block::iterator(forOpAInst),
510 forOpABody, forOpABody.begin(),
511 std::prev(forOpABody.end()));
512 // 2) Splice forOpB's non-terminator operations into the beginning of forOpA's
513 // body (this leaves forOpB's body containing only the terminator).
514 forOpABody.splice(forOpABody.begin(), forOpBBody, forOpBBody.begin(),
515 std::prev(forOpBBody.end()));
516 // 3) Splice forOpA into the beginning of forOpB's body.
517 forOpBBody.splice(forOpBBody.begin(), forOpAInst->getBlock()->getOperations(),
518 Block::iterator(forOpAInst));
MLIR Team8f5f2c72019-02-15 09:32:18 -0800519}
520
Andy Davis90d40232019-05-13 06:57:56 -0700521// Checks each dependence component against the permutation to see if the
522// desired loop interchange would violate dependences by making the
Kazuaki Ishizaki8bfedb32019-10-20 00:11:03 -0700523// dependence component lexicographically negative.
Andy Davis90d40232019-05-13 06:57:56 -0700524static bool checkLoopInterchangeDependences(
River Riddle4562e382019-12-18 09:28:48 -0800525 const std::vector<SmallVector<DependenceComponent, 2>> &depCompsVec,
Andy Davis90d40232019-05-13 06:57:56 -0700526 ArrayRef<AffineForOp> loops, ArrayRef<unsigned> loopPermMap) {
527 // Invert permutation map.
528 unsigned maxLoopDepth = loops.size();
River Riddle4562e382019-12-18 09:28:48 -0800529 SmallVector<unsigned, 4> loopPermMapInv;
Andy Davis90d40232019-05-13 06:57:56 -0700530 loopPermMapInv.resize(maxLoopDepth);
531 for (unsigned i = 0; i < maxLoopDepth; ++i)
532 loopPermMapInv[loopPermMap[i]] = i;
533
534 // Check each dependence component against the permutation to see if the
535 // desired loop interchange permutation would make the dependence vectors
536 // lexicographically negative.
537 // Example 1: [-1, 1][0, 0]
538 // Example 2: [0, 0][-1, 1]
539 for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
River Riddle4562e382019-12-18 09:28:48 -0800540 const SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
Andy Davis90d40232019-05-13 06:57:56 -0700541 assert(depComps.size() >= maxLoopDepth);
542 // Check if the first non-zero dependence component is positive.
543 // This iterates through loops in the desired order.
544 for (unsigned j = 0; j < maxLoopDepth; ++j) {
545 unsigned permIndex = loopPermMapInv[j];
546 assert(depComps[permIndex].lb.hasValue());
547 int64_t depCompLb = depComps[permIndex].lb.getValue();
548 if (depCompLb > 0)
549 break;
550 if (depCompLb < 0)
551 return false;
552 }
553 }
554 return true;
555}
556
557/// Checks if the loop interchange permutation 'loopPermMap' of the perfectly
558/// nested sequence of loops in 'loops' would violate dependences.
559bool mlir::isValidLoopInterchangePermutation(ArrayRef<AffineForOp> loops,
560 ArrayRef<unsigned> loopPermMap) {
561 // Gather dependence components for dependences between all ops in loop nest
562 // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
563 assert(loopPermMap.size() == loops.size());
564 unsigned maxLoopDepth = loops.size();
River Riddle4562e382019-12-18 09:28:48 -0800565 std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
Andy Davis90d40232019-05-13 06:57:56 -0700566 getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
567 return checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap);
568}
569
570/// Performs a sequence of loop interchanges of loops in perfectly nested
571/// sequence of loops in 'loops', as specified by permutation in 'loopPermMap'.
572unsigned mlir::interchangeLoops(ArrayRef<AffineForOp> loops,
573 ArrayRef<unsigned> loopPermMap) {
574 Optional<unsigned> loopNestRootIndex;
575 for (int i = loops.size() - 1; i >= 0; --i) {
576 int permIndex = static_cast<int>(loopPermMap[i]);
577 // Store the index of the for loop which will be the new loop nest root.
578 if (permIndex == 0)
579 loopNestRootIndex = i;
580 if (permIndex > i) {
581 // Sink loop 'i' by 'permIndex - i' levels deeper into the loop nest.
582 sinkLoop(loops[i], permIndex - i);
583 }
584 }
585 assert(loopNestRootIndex.hasValue());
586 return loopNestRootIndex.getValue();
587}
588
589// Sinks all sequential loops to the innermost levels (while preserving
590// relative order among them) and moves all parallel loops to the
591// outermost (while again preserving relative order among them).
592AffineForOp mlir::sinkSequentialLoops(AffineForOp forOp) {
593 SmallVector<AffineForOp, 4> loops;
594 getPerfectlyNestedLoops(loops, forOp);
595 if (loops.size() < 2)
596 return forOp;
597
598 // Gather dependence components for dependences between all ops in loop nest
599 // rooted at 'loops[0]', at loop depths in range [1, maxLoopDepth].
600 unsigned maxLoopDepth = loops.size();
River Riddle4562e382019-12-18 09:28:48 -0800601 std::vector<SmallVector<DependenceComponent, 2>> depCompsVec;
Andy Davis90d40232019-05-13 06:57:56 -0700602 getDependenceComponents(loops[0], maxLoopDepth, &depCompsVec);
603
604 // Mark loops as either parallel or sequential.
River Riddle4562e382019-12-18 09:28:48 -0800605 SmallVector<bool, 8> isParallelLoop(maxLoopDepth, true);
Andy Davis90d40232019-05-13 06:57:56 -0700606 for (unsigned i = 0, e = depCompsVec.size(); i < e; ++i) {
River Riddle4562e382019-12-18 09:28:48 -0800607 SmallVector<DependenceComponent, 2> &depComps = depCompsVec[i];
Andy Davis90d40232019-05-13 06:57:56 -0700608 assert(depComps.size() >= maxLoopDepth);
609 for (unsigned j = 0; j < maxLoopDepth; ++j) {
610 DependenceComponent &depComp = depComps[j];
611 assert(depComp.lb.hasValue() && depComp.ub.hasValue());
612 if (depComp.lb.getValue() != 0 || depComp.ub.getValue() != 0)
613 isParallelLoop[j] = false;
614 }
615 }
616
617 // Count the number of parallel loops.
618 unsigned numParallelLoops = 0;
619 for (unsigned i = 0, e = isParallelLoop.size(); i < e; ++i)
620 if (isParallelLoop[i])
621 ++numParallelLoops;
622
623 // Compute permutation of loops that sinks sequential loops (and thus raises
624 // parallel loops) while preserving relative order.
River Riddle4562e382019-12-18 09:28:48 -0800625 SmallVector<unsigned, 4> loopPermMap(maxLoopDepth);
Andy Davis90d40232019-05-13 06:57:56 -0700626 unsigned nextSequentialLoop = numParallelLoops;
627 unsigned nextParallelLoop = 0;
628 for (unsigned i = 0; i < maxLoopDepth; ++i) {
629 if (isParallelLoop[i]) {
630 loopPermMap[i] = nextParallelLoop++;
631 } else {
632 loopPermMap[i] = nextSequentialLoop++;
633 }
634 }
635
636 // Check if permutation 'loopPermMap' would violate dependences.
637 if (!checkLoopInterchangeDependences(depCompsVec, loops, loopPermMap))
638 return forOp;
639 // Perform loop interchange according to permutation 'loopPermMap'.
640 unsigned loopNestRootIndex = interchangeLoops(loops, loopPermMap);
641 return loops[loopNestRootIndex];
642}
643
MLIR Team8f5f2c72019-02-15 09:32:18 -0800644/// Performs a series of loop interchanges to sink 'forOp' 'loopDepth' levels
645/// deeper in the loop nest.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700646void mlir::sinkLoop(AffineForOp forOp, unsigned loopDepth) {
MLIR Team8f5f2c72019-02-15 09:32:18 -0800647 for (unsigned i = 0; i < loopDepth; ++i) {
River Riddleadca3c22019-05-11 17:57:32 -0700648 AffineForOp nextForOp = cast<AffineForOp>(forOp.getBody()->front());
MLIR Team8f5f2c72019-02-15 09:32:18 -0800649 interchangeLoops(forOp, nextForOp);
650 }
651}
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800652
Nicolas Vasilached6c650c2019-03-20 10:23:04 -0700653// Factors out common behavior to add a new `iv` (resp. `iv` + `offset`) to the
Nicolas Vasilachefc5bbdd2019-03-21 07:32:51 -0700654// lower (resp. upper) loop bound. When called for both the lower and upper
655// bounds, the resulting IR resembles:
Nicolas Vasilached6c650c2019-03-20 10:23:04 -0700656//
657// ```mlir
River Riddle832567b2019-03-25 10:14:34 -0700658// affine.for %i = max (`iv, ...) to min (`iv` + `offset`) {
Nicolas Vasilachefc5bbdd2019-03-21 07:32:51 -0700659// ...
660// }
661// ```
River Riddlee62a6952019-12-23 14:45:01 -0800662static void augmentMapAndBounds(OpBuilder &b, Value iv, AffineMap *map,
663 SmallVector<Value, 4> *operands,
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800664 int64_t offset = 0) {
665 auto bounds = llvm::to_vector<4>(map->getResults());
Nicolas Vasilache08047502019-06-20 15:10:35 -0700666 bounds.push_back(b.getAffineDimExpr(map->getNumDims()) + offset);
Nicolas Vasilached6c650c2019-03-20 10:23:04 -0700667 operands->insert(operands->begin() + map->getNumDims(), iv);
River Riddle2acc2202019-10-17 20:08:01 -0700668 *map = AffineMap::get(map->getNumDims() + 1, map->getNumSymbols(), bounds);
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800669 canonicalizeMapAndOperands(map, operands);
670}
671
672// Stripmines `forOp` by `factor` and sinks it under each of the `targets`.
673// Stripmine-sink is a primitive building block for generalized tiling of
674// imperfectly nested loops.
675// This transformation is purely mechanical and does not check legality,
676// profitability or even structural correctness. It is the user's
677// responsibility to specify `targets` that are dominated by `forOp`.
678// Returns the new AffineForOps, one per `targets`, nested immediately under
679// each of the `targets`.
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700680static SmallVector<AffineForOp, 8>
681stripmineSink(AffineForOp forOp, uint64_t factor,
682 ArrayRef<AffineForOp> targets) {
River Riddleaf1abcc2019-03-25 11:13:31 -0700683 auto originalStep = forOp.getStep();
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800684 auto scaledStep = originalStep * factor;
River Riddleaf1abcc2019-03-25 11:13:31 -0700685 forOp.setStep(scaledStep);
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800686
River Riddle99b87c92019-03-27 14:02:02 -0700687 auto *op = forOp.getOperation();
River Riddlef1b848e2019-06-04 19:18:23 -0700688 OpBuilder b(op->getBlock(), ++Block::iterator(op));
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800689
690 // Lower-bound map creation.
River Riddleaf1abcc2019-03-25 11:13:31 -0700691 auto lbMap = forOp.getLowerBoundMap();
River Riddlee62a6952019-12-23 14:45:01 -0800692 SmallVector<Value, 4> lbOperands(forOp.getLowerBoundOperands());
Nicolas Vasilache08047502019-06-20 15:10:35 -0700693 augmentMapAndBounds(b, forOp.getInductionVar(), &lbMap, &lbOperands);
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800694
695 // Upper-bound map creation.
River Riddleaf1abcc2019-03-25 11:13:31 -0700696 auto ubMap = forOp.getUpperBoundMap();
River Riddlee62a6952019-12-23 14:45:01 -0800697 SmallVector<Value, 4> ubOperands(forOp.getUpperBoundOperands());
Nicolas Vasilache08047502019-06-20 15:10:35 -0700698 augmentMapAndBounds(b, forOp.getInductionVar(), &ubMap, &ubOperands,
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800699 /*offset=*/scaledStep);
700
River Riddle35807bc2019-12-22 21:59:55 -0800701 auto iv = forOp.getInductionVar();
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700702 SmallVector<AffineForOp, 8> innerLoops;
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800703 for (auto t : targets) {
Alex Zinenko5a5bba02019-03-27 05:11:58 -0700704 // Insert newForOp before the terminator of `t`.
River Riddlef1b848e2019-06-04 19:18:23 -0700705 OpBuilder b = t.getBodyBuilder();
River Riddleaf1abcc2019-03-25 11:13:31 -0700706 auto newForOp = b.create<AffineForOp>(t.getLoc(), lbOperands, lbMap,
Nicolas Vasilached6c650c2019-03-20 10:23:04 -0700707 ubOperands, ubMap, originalStep);
Nicolas Vasilached2a87292019-07-19 05:02:39 -0700708 auto begin = t.getBody()->begin();
709 // Skip terminator and `newForOp` which is just before the terminator.
710 auto nOps = t.getBody()->getOperations().size() - 2;
711 newForOp.getBody()->getOperations().splice(
712 newForOp.getBody()->getOperations().begin(),
713 t.getBody()->getOperations(), begin, std::next(begin, nOps));
714 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
715 newForOp.region());
Nicolas Vasilached6c650c2019-03-20 10:23:04 -0700716 innerLoops.push_back(newForOp);
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800717 }
718
719 return innerLoops;
720}
721
River Riddlee62a6952019-12-23 14:45:01 -0800722static Loops stripmineSink(loop::ForOp forOp, Value factor,
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700723 ArrayRef<loop::ForOp> targets) {
River Riddle35807bc2019-12-22 21:59:55 -0800724 auto originalStep = forOp.step();
725 auto iv = forOp.getInductionVar();
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700726
727 OpBuilder b(forOp);
728 forOp.setStep(b.create<MulIOp>(forOp.getLoc(), originalStep, factor));
729
730 Loops innerLoops;
731 for (auto t : targets) {
732 // Save information for splicing ops out of t when done
733 auto begin = t.getBody()->begin();
734 auto nOps = t.getBody()->getOperations().size();
735
736 // Insert newForOp before the terminator of `t`.
737 OpBuilder b(t.getBodyBuilder());
River Riddlee62a6952019-12-23 14:45:01 -0800738 Value stepped = b.create<AddIOp>(t.getLoc(), iv, forOp.step());
739 Value less = b.create<CmpIOp>(t.getLoc(), CmpIPredicate::slt,
740 forOp.upperBound(), stepped);
741 Value ub =
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700742 b.create<SelectOp>(t.getLoc(), less, forOp.upperBound(), stepped);
743
744 // Splice [begin, begin + nOps - 1) into `newForOp` and replace uses.
745 auto newForOp = b.create<loop::ForOp>(t.getLoc(), iv, ub, originalStep);
746 newForOp.getBody()->getOperations().splice(
747 newForOp.getBody()->getOperations().begin(),
748 t.getBody()->getOperations(), begin, std::next(begin, nOps - 1));
749 replaceAllUsesInRegionWith(iv, newForOp.getInductionVar(),
750 newForOp.region());
751
752 innerLoops.push_back(newForOp);
753 }
754
755 return innerLoops;
756}
757
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800758// Stripmines a `forOp` by `factor` and sinks it under a single `target`.
759// Returns the new AffineForOps, nested immediately under `target`.
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700760template <typename ForType, typename SizeType>
761static ForType stripmineSink(ForType forOp, SizeType factor, ForType target) {
762 // TODO(ntv): Use cheap structural assertions that targets are nested under
763 // forOp and that targets are not nested under each other when DominanceInfo
764 // exposes the capability. It seems overkill to construct a whole function
765 // dominance tree at this point.
766 auto res = stripmineSink(forOp, factor, ArrayRef<ForType>{target});
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800767 assert(res.size() == 1 && "Expected 1 inner forOp");
768 return res[0];
769}
770
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700771template <typename ForType, typename SizeType>
772static SmallVector<SmallVector<ForType, 8>, 8>
773tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes,
774 ArrayRef<ForType> targets) {
775 SmallVector<SmallVector<ForType, 8>, 8> res;
776 SmallVector<ForType, 8> currentTargets(targets.begin(), targets.end());
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800777 for (auto it : llvm::zip(forOps, sizes)) {
778 auto step = stripmineSink(std::get<0>(it), std::get<1>(it), currentTargets);
779 res.push_back(step);
780 currentTargets = step;
781 }
782 return res;
783}
784
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700785SmallVector<SmallVector<AffineForOp, 8>, 8>
786mlir::tile(ArrayRef<AffineForOp> forOps, ArrayRef<uint64_t> sizes,
787 ArrayRef<AffineForOp> targets) {
788 return tileImpl(forOps, sizes, targets);
789}
790
791SmallVector<Loops, 8> mlir::tile(ArrayRef<loop::ForOp> forOps,
River Riddlee62a6952019-12-23 14:45:01 -0800792 ArrayRef<Value> sizes,
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700793 ArrayRef<loop::ForOp> targets) {
794 return tileImpl(forOps, sizes, targets);
795}
796
797template <typename ForType, typename SizeType>
798static SmallVector<ForType, 8>
799tileImpl(ArrayRef<ForType> forOps, ArrayRef<SizeType> sizes, ForType target) {
800 SmallVector<ForType, 8> res;
801 for (auto loops : tile(forOps, sizes, ArrayRef<ForType>{target})) {
802 assert(loops.size() == 1);
803 res.push_back(loops[0]);
804 }
805 return res;
806}
807
Chris Lattnerd9b5bc82019-03-24 19:53:05 -0700808SmallVector<AffineForOp, 8> mlir::tile(ArrayRef<AffineForOp> forOps,
809 ArrayRef<uint64_t> sizes,
810 AffineForOp target) {
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700811 return tileImpl(forOps, sizes, target);
Nicolas Vasilache62c54a22019-02-25 09:53:05 -0800812}
Alex Zinenko9d03f562019-07-09 06:37:17 -0700813
River Riddlee62a6952019-12-23 14:45:01 -0800814Loops mlir::tile(ArrayRef<loop::ForOp> forOps, ArrayRef<Value> sizes,
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700815 loop::ForOp target) {
816 return tileImpl(forOps, sizes, target);
Alex Zinenko9d03f562019-07-09 06:37:17 -0700817}
818
River Riddlee62a6952019-12-23 14:45:01 -0800819Loops mlir::tilePerfectlyNested(loop::ForOp rootForOp, ArrayRef<Value> sizes) {
Kazuaki Ishizaki8bfedb32019-10-20 00:11:03 -0700820 // Collect perfectly nested loops. If more size values provided than nested
Alex Zinenko9d03f562019-07-09 06:37:17 -0700821 // loops available, truncate `sizes`.
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700822 SmallVector<loop::ForOp, 4> forOps;
Alex Zinenko9d03f562019-07-09 06:37:17 -0700823 forOps.reserve(sizes.size());
824 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
825 if (forOps.size() < sizes.size())
826 sizes = sizes.take_front(forOps.size());
827
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700828 return ::tile(forOps, sizes, forOps.back());
Alex Zinenko9d03f562019-07-09 06:37:17 -0700829}
830
831// Build the IR that performs ceil division of a positive value by a constant:
832// ceildiv(a, B) = divis(a + (B-1), B)
Kazuaki Ishizaki8bfedb32019-10-20 00:11:03 -0700833// where divis is rounding-to-zero division.
River Riddlee62a6952019-12-23 14:45:01 -0800834static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
835 int64_t divisor) {
Alex Zinenko9d03f562019-07-09 06:37:17 -0700836 assert(divisor > 0 && "expected positive divisor");
River Riddle2bdf33c2020-01-11 08:54:04 -0800837 assert(dividend.getType().isIndex() && "expected index-typed value");
Alex Zinenko9d03f562019-07-09 06:37:17 -0700838
River Riddlee62a6952019-12-23 14:45:01 -0800839 Value divisorMinusOneCst = builder.create<ConstantIndexOp>(loc, divisor - 1);
840 Value divisorCst = builder.create<ConstantIndexOp>(loc, divisor);
841 Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOneCst);
Manuel Freiberger22954a02019-12-22 10:01:35 -0800842 return builder.create<SignedDivIOp>(loc, sum, divisorCst);
Alex Zinenko9d03f562019-07-09 06:37:17 -0700843}
844
Alex Zinenkofc044e82019-07-15 06:40:07 -0700845// Build the IR that performs ceil division of a positive value by another
846// positive value:
847// ceildiv(a, b) = divis(a + (b - 1), b)
848// where divis is rounding-to-zero division.
River Riddlee62a6952019-12-23 14:45:01 -0800849static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
850 Value divisor) {
River Riddle2bdf33c2020-01-11 08:54:04 -0800851 assert(dividend.getType().isIndex() && "expected index-typed value");
Alex Zinenko9d03f562019-07-09 06:37:17 -0700852
River Riddlee62a6952019-12-23 14:45:01 -0800853 Value cstOne = builder.create<ConstantIndexOp>(loc, 1);
854 Value divisorMinusOne = builder.create<SubIOp>(loc, divisor, cstOne);
855 Value sum = builder.create<AddIOp>(loc, dividend, divisorMinusOne);
Manuel Freiberger22954a02019-12-22 10:01:35 -0800856 return builder.create<SignedDivIOp>(loc, sum, divisor);
Alex Zinenko9d03f562019-07-09 06:37:17 -0700857}
858
Nicolas Vasilache48a1bae2019-07-22 04:30:50 -0700859// Hoist the ops within `outer` that appear before `inner`.
860// Such ops include the ops that have been introduced by parametric tiling.
861// Ops that come from triangular loops (i.e. that belong to the program slice
862// rooted at `outer`) and ops that have side effects cannot be hoisted.
863// Return failure when any op fails to hoist.
864static LogicalResult hoistOpsBetween(loop::ForOp outer, loop::ForOp inner) {
865 SetVector<Operation *> forwardSlice;
866 getForwardSlice(outer.getOperation(), &forwardSlice, [&inner](Operation *op) {
867 return op != inner.getOperation();
868 });
869 LogicalResult status = success();
870 SmallVector<Operation *, 8> toHoist;
871 for (auto &op : outer.getBody()->getOperations()) {
872 // Stop when encountering the inner loop.
873 if (&op == inner.getOperation())
874 break;
875 // Skip over non-hoistable ops.
876 if (forwardSlice.count(&op) > 0) {
877 status = failure();
878 continue;
879 }
880 // Skip loop::ForOp, these are not considered a failure.
881 if (op.getNumRegions() > 0)
882 continue;
883 // Skip other ops with regions.
884 if (op.getNumRegions() > 0) {
885 status = failure();
886 continue;
887 }
888 // Skip if op has side effects.
889 // TODO(ntv): loads to immutable memory regions are ok.
890 if (!op.hasNoSideEffect()) {
891 status = failure();
892 continue;
893 }
894 toHoist.push_back(&op);
895 }
896 auto *outerForOp = outer.getOperation();
897 for (auto *op : toHoist)
898 op->moveBefore(outerForOp);
899 return status;
900}
901
902// Traverse the interTile and intraTile loops and try to hoist ops such that
903// bands of perfectly nested loops are isolated.
904// Return failure if either perfect interTile or perfect intraTile bands cannot
905// be formed.
906static LogicalResult tryIsolateBands(const TileLoops &tileLoops) {
907 LogicalResult status = success();
908 auto &interTile = tileLoops.first;
909 auto &intraTile = tileLoops.second;
910 auto size = interTile.size();
911 assert(size == intraTile.size());
912 if (size <= 1)
913 return success();
914 for (unsigned s = 1; s < size; ++s)
915 status = succeeded(status) ? hoistOpsBetween(intraTile[0], intraTile[s])
916 : failure();
917 for (unsigned s = 1; s < size; ++s)
918 status = succeeded(status) ? hoistOpsBetween(interTile[0], interTile[s])
919 : failure();
920 return status;
921}
922
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700923TileLoops mlir::extractFixedOuterLoops(loop::ForOp rootForOp,
924 ArrayRef<int64_t> sizes) {
Kazuaki Ishizakiae05cf22019-12-09 09:23:15 -0800925 // Collect perfectly nested loops. If more size values provided than nested
Alex Zinenko9d03f562019-07-09 06:37:17 -0700926 // loops available, truncate `sizes`.
Nicolas Vasilachecca53e82019-07-15 02:50:09 -0700927 SmallVector<loop::ForOp, 4> forOps;
Alex Zinenko9d03f562019-07-09 06:37:17 -0700928 forOps.reserve(sizes.size());
929 getPerfectlyNestedLoopsImpl(forOps, rootForOp, sizes.size());
930 if (forOps.size() < sizes.size())
931 sizes = sizes.take_front(forOps.size());
932
Alex Zinenko9d03f562019-07-09 06:37:17 -0700933 // Compute the tile sizes such that i-th outer loop executes size[i]
934 // iterations. Given that the loop current executes
935 // numIterations = ceildiv((upperBound - lowerBound), step)
936 // iterations, we need to tile with size ceildiv(numIterations, size[i]).
River Riddlee62a6952019-12-23 14:45:01 -0800937 SmallVector<Value, 4> tileSizes;
Alex Zinenko9d03f562019-07-09 06:37:17 -0700938 tileSizes.reserve(sizes.size());
939 for (unsigned i = 0, e = sizes.size(); i < e; ++i) {
940 assert(sizes[i] > 0 && "expected strictly positive size for strip-mining");
941
942 auto forOp = forOps[i];
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700943 OpBuilder builder(forOp);
944 auto loc = forOp.getLoc();
River Riddlee62a6952019-12-23 14:45:01 -0800945 Value diff =
Alex Zinenko9d03f562019-07-09 06:37:17 -0700946 builder.create<SubIOp>(loc, forOp.upperBound(), forOp.lowerBound());
River Riddlee62a6952019-12-23 14:45:01 -0800947 Value numIterations = ceilDivPositive(builder, loc, diff, forOp.step());
948 Value iterationsPerBlock =
Alex Zinenko9d03f562019-07-09 06:37:17 -0700949 ceilDivPositive(builder, loc, numIterations, sizes[i]);
950 tileSizes.push_back(iterationsPerBlock);
951 }
952
953 // Call parametric tiling with the given sizes.
Nicolas Vasilache5bc34472019-07-19 01:52:36 -0700954 auto intraTile = tile(forOps, tileSizes, forOps.back());
Nicolas Vasilache48a1bae2019-07-22 04:30:50 -0700955 TileLoops tileLoops = std::make_pair(forOps, intraTile);
956
957 // TODO(ntv, zinenko) for now we just ignore the result of band isolation.
958 // In the future, mapping decisions may be impacted by the ability to
959 // isolate perfectly nested bands.
960 tryIsolateBands(tileLoops);
961
962 return tileLoops;
Alex Zinenko9d03f562019-07-09 06:37:17 -0700963}
Alex Zinenkofc044e82019-07-15 06:40:07 -0700964
965// Replaces all uses of `orig` with `replacement` except if the user is listed
966// in `exceptions`.
967static void
River Riddlee62a6952019-12-23 14:45:01 -0800968replaceAllUsesExcept(Value orig, Value replacement,
Alex Zinenkofc044e82019-07-15 06:40:07 -0700969 const SmallPtrSetImpl<Operation *> &exceptions) {
River Riddle2bdf33c2020-01-11 08:54:04 -0800970 for (auto &use : llvm::make_early_inc_range(orig.getUses())) {
Alex Zinenkofc044e82019-07-15 06:40:07 -0700971 if (exceptions.count(use.getOwner()) == 0)
972 use.set(replacement);
973 }
974}
975
976// Transform a loop with a strictly positive step
977// for %i = %lb to %ub step %s
978// into a 0-based loop with step 1
979// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
980// %i = %ii * %s + %lb
981// Insert the induction variable remapping in the body of `inner`, which is
982// expected to be either `loop` or another loop perfectly nested under `loop`.
983// Insert the definition of new bounds immediate before `outer`, which is
984// expected to be either `loop` or its parent in the loop nest.
985static void normalizeLoop(loop::ForOp loop, loop::ForOp outer,
986 loop::ForOp inner) {
987 OpBuilder builder(outer);
988 Location loc = loop.getLoc();
989
990 // Check if the loop is already known to have a constant zero lower bound or
991 // a constant one step.
992 bool isZeroBased = false;
993 if (auto ubCst =
River Riddle2bdf33c2020-01-11 08:54:04 -0800994 dyn_cast_or_null<ConstantIndexOp>(loop.lowerBound().getDefiningOp()))
Alex Zinenkofc044e82019-07-15 06:40:07 -0700995 isZeroBased = ubCst.getValue() == 0;
996
997 bool isStepOne = false;
998 if (auto stepCst =
River Riddle2bdf33c2020-01-11 08:54:04 -0800999 dyn_cast_or_null<ConstantIndexOp>(loop.step().getDefiningOp()))
Alex Zinenkofc044e82019-07-15 06:40:07 -07001000 isStepOne = stepCst.getValue() == 1;
1001
1002 if (isZeroBased && isStepOne)
1003 return;
1004
1005 // Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
1006 // assuming the step is strictly positive. Update the bounds and the step
1007 // of the loop to go from 0 to the number of iterations, if necessary.
1008 // TODO(zinenko): introduce support for negative steps or emit dynamic asserts
1009 // on step positivity, whatever gets implemented first.
River Riddlee62a6952019-12-23 14:45:01 -08001010 Value diff =
Alex Zinenkofc044e82019-07-15 06:40:07 -07001011 builder.create<SubIOp>(loc, loop.upperBound(), loop.lowerBound());
River Riddlee62a6952019-12-23 14:45:01 -08001012 Value numIterations = ceilDivPositive(builder, loc, diff, loop.step());
Alex Zinenkofc044e82019-07-15 06:40:07 -07001013 loop.setUpperBound(numIterations);
1014
River Riddlee62a6952019-12-23 14:45:01 -08001015 Value lb = loop.lowerBound();
Alex Zinenkofc044e82019-07-15 06:40:07 -07001016 if (!isZeroBased) {
River Riddlee62a6952019-12-23 14:45:01 -08001017 Value cst0 = builder.create<ConstantIndexOp>(loc, 0);
Alex Zinenkofc044e82019-07-15 06:40:07 -07001018 loop.setLowerBound(cst0);
1019 }
1020
River Riddlee62a6952019-12-23 14:45:01 -08001021 Value step = loop.step();
Alex Zinenkofc044e82019-07-15 06:40:07 -07001022 if (!isStepOne) {
River Riddlee62a6952019-12-23 14:45:01 -08001023 Value cst1 = builder.create<ConstantIndexOp>(loc, 1);
Alex Zinenkofc044e82019-07-15 06:40:07 -07001024 loop.setStep(cst1);
1025 }
1026
1027 // Insert code computing the value of the original loop induction variable
1028 // from the "normalized" one.
Nicolas Vasilache0002e292019-07-16 12:20:15 -07001029 builder.setInsertionPointToStart(inner.getBody());
River Riddlee62a6952019-12-23 14:45:01 -08001030 Value scaled =
Alex Zinenkofc044e82019-07-15 06:40:07 -07001031 isStepOne ? loop.getInductionVar()
1032 : builder.create<MulIOp>(loc, loop.getInductionVar(), step);
River Riddlee62a6952019-12-23 14:45:01 -08001033 Value shifted =
Alex Zinenkofc044e82019-07-15 06:40:07 -07001034 isZeroBased ? scaled : builder.create<AddIOp>(loc, scaled, lb);
1035
River Riddle2bdf33c2020-01-11 08:54:04 -08001036 SmallPtrSet<Operation *, 2> preserve{scaled.getDefiningOp(),
1037 shifted.getDefiningOp()};
Alex Zinenkofc044e82019-07-15 06:40:07 -07001038 replaceAllUsesExcept(loop.getInductionVar(), shifted, preserve);
1039}
1040
1041void mlir::coalesceLoops(MutableArrayRef<loop::ForOp> loops) {
1042 if (loops.size() < 2)
1043 return;
1044
1045 loop::ForOp innermost = loops.back();
1046 loop::ForOp outermost = loops.front();
1047
1048 // 1. Make sure all loops iterate from 0 to upperBound with step 1. This
1049 // allows the following code to assume upperBound is the number of iterations.
1050 for (auto loop : loops)
1051 normalizeLoop(loop, outermost, innermost);
1052
1053 // 2. Emit code computing the upper bound of the coalesced loop as product
1054 // of the number of iterations of all loops.
1055 OpBuilder builder(outermost);
1056 Location loc = outermost.getLoc();
River Riddlee62a6952019-12-23 14:45:01 -08001057 Value upperBound = outermost.upperBound();
Alex Zinenkofc044e82019-07-15 06:40:07 -07001058 for (auto loop : loops.drop_front())
1059 upperBound = builder.create<MulIOp>(loc, upperBound, loop.upperBound());
1060 outermost.setUpperBound(upperBound);
1061
Nicolas Vasilache0002e292019-07-16 12:20:15 -07001062 builder.setInsertionPointToStart(outermost.getBody());
Alex Zinenkofc044e82019-07-15 06:40:07 -07001063
1064 // 3. Remap induction variables. For each original loop, the value of the
1065 // induction variable can be obtained by dividing the induction variable of
1066 // the linearized loop by the total number of iterations of the loops nested
1067 // in it modulo the number of iterations in this loop (remove the values
1068 // related to the outer loops):
1069 // iv_i = floordiv(iv_linear, product-of-loop-ranges-until-i) mod range_i.
1070 // Compute these iteratively from the innermost loop by creating a "running
1071 // quotient" of division by the range.
River Riddlee62a6952019-12-23 14:45:01 -08001072 Value previous = outermost.getInductionVar();
Alex Zinenkofc044e82019-07-15 06:40:07 -07001073 for (unsigned i = 0, e = loops.size(); i < e; ++i) {
1074 unsigned idx = loops.size() - i - 1;
1075 if (i != 0)
Manuel Freiberger22954a02019-12-22 10:01:35 -08001076 previous = builder.create<SignedDivIOp>(loc, previous,
1077 loops[idx + 1].upperBound());
Alex Zinenkofc044e82019-07-15 06:40:07 -07001078
River Riddlee62a6952019-12-23 14:45:01 -08001079 Value iv = (i == e - 1) ? previous
1080 : builder.create<SignedRemIOp>(
1081 loc, previous, loops[idx].upperBound());
Alex Zinenkofc044e82019-07-15 06:40:07 -07001082 replaceAllUsesInRegionWith(loops[idx].getInductionVar(), iv,
1083 loops.back().region());
1084 }
1085
1086 // 4. Move the operations from the innermost just above the second-outermost
1087 // loop, delete the extra terminator and the second-outermost loop.
1088 loop::ForOp second = loops[1];
Nicolas Vasilache0002e292019-07-16 12:20:15 -07001089 innermost.getBody()->back().erase();
1090 outermost.getBody()->getOperations().splice(
Alex Zinenkofc044e82019-07-15 06:40:07 -07001091 Block::iterator(second.getOperation()),
Nicolas Vasilache0002e292019-07-16 12:20:15 -07001092 innermost.getBody()->getOperations());
Alex Zinenkofc044e82019-07-15 06:40:07 -07001093 second.erase();
1094}
Nicolas Vasilachedb4cd1c2019-07-19 04:47:27 -07001095
River Riddlee62a6952019-12-23 14:45:01 -08001096void mlir::mapLoopToProcessorIds(loop::ForOp forOp, ArrayRef<Value> processorId,
1097 ArrayRef<Value> numProcessors) {
Nicolas Vasilachedb4cd1c2019-07-19 04:47:27 -07001098 assert(processorId.size() == numProcessors.size());
1099 if (processorId.empty())
1100 return;
1101
1102 OpBuilder b(forOp);
1103 Location loc(forOp.getLoc());
River Riddlee62a6952019-12-23 14:45:01 -08001104 Value mul = processorId.front();
Nicolas Vasilachedb4cd1c2019-07-19 04:47:27 -07001105 for (unsigned i = 1, e = processorId.size(); i < e; ++i)
1106 mul = b.create<AddIOp>(loc, b.create<MulIOp>(loc, mul, numProcessors[i]),
1107 processorId[i]);
River Riddlee62a6952019-12-23 14:45:01 -08001108 Value lb = b.create<AddIOp>(loc, forOp.lowerBound(),
1109 b.create<MulIOp>(loc, forOp.step(), mul));
Nicolas Vasilachedb4cd1c2019-07-19 04:47:27 -07001110 forOp.setLowerBound(lb);
1111
River Riddlee62a6952019-12-23 14:45:01 -08001112 Value step = forOp.step();
River Riddle35807bc2019-12-22 21:59:55 -08001113 for (auto numProcs : numProcessors)
Nicolas Vasilachedb4cd1c2019-07-19 04:47:27 -07001114 step = b.create<MulIOp>(loc, step, numProcs);
1115 forOp.setStep(step);
1116}
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001117
1118/// Given a memref region, determine the lowest depth at which transfers can be
1119/// placed for it, and return the corresponding block, start and end positions
1120/// in the block for placing incoming (read) and outgoing (write) copies
1121/// respectively. The lowest depth depends on whether the region being accessed
1122/// is hoistable with respect to one or more immediately surrounding loops.
1123static void
1124findHighestBlockForPlacement(const MemRefRegion &region, Block &block,
1125 Block::iterator &begin, Block::iterator &end,
1126 Block **copyPlacementBlock,
1127 Block::iterator *copyInPlacementStart,
1128 Block::iterator *copyOutPlacementStart) {
1129 const auto *cst = region.getConstraints();
River Riddlee62a6952019-12-23 14:45:01 -08001130 SmallVector<Value, 4> symbols;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001131 cst->getIdValues(cst->getNumDimIds(), cst->getNumDimAndSymbolIds(), &symbols);
1132
1133 SmallVector<AffineForOp, 4> enclosingFors;
1134 getLoopIVs(*block.begin(), &enclosingFors);
1135 // Walk up loop parents till we find an IV on which this region is
1136 // symbolic/variant.
1137 auto it = enclosingFors.rbegin();
1138 for (auto e = enclosingFors.rend(); it != e; ++it) {
1139 // TODO(bondhugula): also need to be checking this for regions symbols that
1140 // aren't loop IVs, whether we are within their resp. defs' dominance scope.
1141 if (llvm::is_contained(symbols, it->getInductionVar()))
1142 break;
1143 }
1144
1145 if (it != enclosingFors.rbegin()) {
1146 auto lastInvariantIV = *std::prev(it);
1147 *copyInPlacementStart = Block::iterator(lastInvariantIV.getOperation());
1148 *copyOutPlacementStart = std::next(*copyInPlacementStart);
1149 *copyPlacementBlock = lastInvariantIV.getOperation()->getBlock();
1150 } else {
1151 *copyInPlacementStart = begin;
1152 *copyOutPlacementStart = end;
1153 *copyPlacementBlock = &block;
1154 }
1155}
1156
1157// Info comprising stride and number of elements transferred every stride.
1158struct StrideInfo {
1159 int64_t stride;
1160 int64_t numEltPerStride;
1161};
1162
1163/// Returns striding information for a copy/transfer of this region with
1164/// potentially multiple striding levels from outermost to innermost. For an
1165/// n-dimensional region, there can be at most n-1 levels of striding
1166/// successively nested.
1167// TODO(bondhugula): make this work with non-identity layout maps.
1168static void getMultiLevelStrides(const MemRefRegion &region,
1169 ArrayRef<int64_t> bufferShape,
1170 SmallVectorImpl<StrideInfo> *strideInfos) {
1171 if (bufferShape.size() <= 1)
1172 return;
1173
1174 int64_t numEltPerStride = 1;
1175 int64_t stride = 1;
1176 for (int d = bufferShape.size() - 1; d >= 1; d--) {
River Riddle2bdf33c2020-01-11 08:54:04 -08001177 int64_t dimSize = region.memref.getType().cast<MemRefType>().getDimSize(d);
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001178 stride *= dimSize;
1179 numEltPerStride *= bufferShape[d];
1180 // A stride is needed only if the region has a shorter extent than the
1181 // memref along the dimension *and* has an extent greater than one along the
1182 // next major dimension.
1183 if (bufferShape[d] < dimSize && bufferShape[d - 1] > 1) {
1184 strideInfos->push_back({stride, numEltPerStride});
1185 }
1186 }
1187}
1188
1189/// Generates a point-wise copy from/to `memref' to/from `fastMemRef' and
1190/// returns the outermost AffineForOp of the copy loop nest. `memIndicesStart'
1191/// holds the lower coordinates of the region in the original memref to copy
1192/// in/out. If `copyOut' is true, generates a copy-out; otherwise a copy-in.
River Riddlee62a6952019-12-23 14:45:01 -08001193static AffineForOp generatePointWiseCopy(Location loc, Value memref,
1194 Value fastMemRef,
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001195 AffineMap memAffineMap,
River Riddlee62a6952019-12-23 14:45:01 -08001196 ArrayRef<Value> memIndicesStart,
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001197 ArrayRef<int64_t> fastBufferShape,
1198 bool isCopyOut, OpBuilder b) {
1199 assert(!memIndicesStart.empty() && "only 1-d or more memrefs");
1200
1201 // The copy-in nest is generated as follows as an example for a 2-d region:
1202 // for x = ...
1203 // for y = ...
1204 // fast_buf[x][y] = buf[mem_x + x][mem_y + y]
1205
River Riddlee62a6952019-12-23 14:45:01 -08001206 SmallVector<Value, 4> fastBufIndices, memIndices;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001207 AffineForOp copyNestRoot;
1208 for (unsigned d = 0, e = fastBufferShape.size(); d < e; ++d) {
1209 auto forOp = b.create<AffineForOp>(loc, 0, fastBufferShape[d]);
1210 if (d == 0)
1211 copyNestRoot = forOp;
1212 b = forOp.getBodyBuilder();
1213 fastBufIndices.push_back(forOp.getInductionVar());
1214
River Riddlee62a6952019-12-23 14:45:01 -08001215 Value memBase =
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001216 (memAffineMap == b.getMultiDimIdentityMap(memAffineMap.getNumDims()))
1217 ? memIndicesStart[d]
1218 : b.create<AffineApplyOp>(
1219 loc,
River Riddle2acc2202019-10-17 20:08:01 -07001220 AffineMap::get(memAffineMap.getNumDims(),
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001221 memAffineMap.getNumSymbols(),
1222 memAffineMap.getResult(d)),
1223 memIndicesStart);
1224
1225 // Construct the subscript for the slow memref being copied.
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001226 auto memIndex = b.create<AffineApplyOp>(
1227 loc,
River Riddle2acc2202019-10-17 20:08:01 -07001228 AffineMap::get(2, 0, b.getAffineDimExpr(0) + b.getAffineDimExpr(1)),
Uday Bondhugulaa63f6e02019-12-09 06:26:05 -08001229 ValueRange({memBase, forOp.getInductionVar()}));
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001230 memIndices.push_back(memIndex);
1231 }
1232
1233 if (!isCopyOut) {
1234 // Copy in.
1235 auto load = b.create<AffineLoadOp>(loc, memref, memIndices);
1236 b.create<AffineStoreOp>(loc, load, fastMemRef, fastBufIndices);
1237 return copyNestRoot;
1238 }
1239
1240 // Copy out.
1241 auto load = b.create<AffineLoadOp>(loc, fastMemRef, fastBufIndices);
1242 b.create<AffineStoreOp>(loc, load, memref, memIndices);
1243 return copyNestRoot;
1244}
1245
1246static InFlightDiagnostic LLVM_ATTRIBUTE_UNUSED
1247emitRemarkForBlock(Block &block) {
1248 return block.getParentOp()->emitRemark();
1249}
1250
1251/// Creates a buffer in the faster memory space for the specified memref region;
1252/// generates a copy from the lower memory space to this one, and replaces all
1253/// loads/stores in the block range [`begin', `end') of `block' to load/store
1254/// from that buffer. Returns failure if copies could not be generated due to
1255/// yet unimplemented cases. `copyInPlacementStart` and `copyOutPlacementStart`
1256/// in copyPlacementBlock specify the insertion points where the incoming copies
1257/// and outgoing copies, respectively, should be inserted (the insertion happens
1258/// right before the insertion point). Since `begin` can itself be invalidated
1259/// due to the memref rewriting done from this method, the output argument
1260/// `nBegin` is set to its replacement (set to `begin` if no invalidation
1261/// happens). Since outgoing copies could have been inserted at `end`, the
1262/// output argument `nEnd` is set to the new end. `sizeInBytes` is set to the
1263/// size of the fast buffer allocated.
1264static LogicalResult generateCopy(
1265 const MemRefRegion &region, Block *block, Block::iterator begin,
1266 Block::iterator end, Block *copyPlacementBlock,
1267 Block::iterator copyInPlacementStart, Block::iterator copyOutPlacementStart,
River Riddlee62a6952019-12-23 14:45:01 -08001268 AffineCopyOptions copyOptions, DenseMap<Value, Value> &fastBufferMap,
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001269 DenseSet<Operation *> &copyNests, uint64_t *sizeInBytes,
1270 Block::iterator *nBegin, Block::iterator *nEnd) {
1271 *nBegin = begin;
1272 *nEnd = end;
1273
1274 FuncOp f = begin->getParentOfType<FuncOp>();
1275 OpBuilder topBuilder(f.getBody());
River Riddlee62a6952019-12-23 14:45:01 -08001276 Value zeroIndex = topBuilder.create<ConstantIndexOp>(f.getLoc(), 0);
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001277
1278 if (begin == end)
1279 return success();
1280
1281 // Is the copy out point at the end of the block where we are doing
1282 // explicit copying.
1283 bool isCopyOutAtEndOfBlock = (end == copyOutPlacementStart);
1284
1285 // Copies for read regions are going to be inserted at 'begin'.
1286 OpBuilder prologue(copyPlacementBlock, copyInPlacementStart);
1287 // Copies for write regions are going to be inserted at 'end'.
1288 OpBuilder epilogue(copyPlacementBlock, copyOutPlacementStart);
1289 OpBuilder &b = region.isWrite() ? epilogue : prologue;
1290
1291 // Builder to create constants at the top level.
1292 auto func = copyPlacementBlock->getParent()->getParentOfType<FuncOp>();
1293 OpBuilder top(func.getBody());
1294
1295 auto loc = region.loc;
River Riddle35807bc2019-12-22 21:59:55 -08001296 auto memref = region.memref;
River Riddle2bdf33c2020-01-11 08:54:04 -08001297 auto memRefType = memref.getType().cast<MemRefType>();
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001298
1299 auto layoutMaps = memRefType.getAffineMaps();
1300 if (layoutMaps.size() > 1 ||
1301 (layoutMaps.size() == 1 && !layoutMaps[0].isIdentity())) {
1302 LLVM_DEBUG(llvm::dbgs() << "Non-identity layout map not yet supported\n");
1303 return failure();
1304 }
1305
1306 // Indices to use for the copying.
1307 // Indices for the original memref being copied from/to.
River Riddlee62a6952019-12-23 14:45:01 -08001308 SmallVector<Value, 4> memIndices;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001309 // Indices for the faster buffer being copied into/from.
River Riddlee62a6952019-12-23 14:45:01 -08001310 SmallVector<Value, 4> bufIndices;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001311
1312 unsigned rank = memRefType.getRank();
1313 SmallVector<int64_t, 4> fastBufferShape;
1314
1315 // Compute the extents of the buffer.
1316 std::vector<SmallVector<int64_t, 4>> lbs;
1317 SmallVector<int64_t, 8> lbDivisors;
1318 lbs.reserve(rank);
1319 Optional<int64_t> numElements = region.getConstantBoundingSizeAndShape(
1320 &fastBufferShape, &lbs, &lbDivisors);
1321 if (!numElements.hasValue()) {
1322 LLVM_DEBUG(llvm::dbgs() << "Non-constant region size not supported\n");
1323 return failure();
1324 }
1325
1326 if (numElements.getValue() == 0) {
1327 LLVM_DEBUG(llvm::dbgs() << "Nothing to copy\n");
1328 *sizeInBytes = 0;
1329 return success();
1330 }
1331
1332 const FlatAffineConstraints *cst = region.getConstraints();
Kazuaki Ishizaki8bfedb32019-10-20 00:11:03 -07001333 // 'regionSymbols' hold values that this memory region is symbolic/parametric
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001334 // on; these typically include loop IVs surrounding the level at which the
1335 // copy generation is being done or other valid symbols in MLIR.
River Riddlee62a6952019-12-23 14:45:01 -08001336 SmallVector<Value, 8> regionSymbols;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001337 cst->getIdValues(rank, cst->getNumIds(), &regionSymbols);
1338
1339 // Construct the index expressions for the fast memory buffer. The index
1340 // expression for a particular dimension of the fast buffer is obtained by
1341 // subtracting out the lower bound on the original memref's data region
1342 // along the corresponding dimension.
1343
1344 // Index start offsets for faster memory buffer relative to the original.
1345 SmallVector<AffineExpr, 4> offsets;
1346 offsets.reserve(rank);
1347 for (unsigned d = 0; d < rank; d++) {
1348 assert(lbs[d].size() == cst->getNumCols() - rank && "incorrect bound size");
1349
1350 AffineExpr offset = top.getAffineConstantExpr(0);
1351 for (unsigned j = 0, e = cst->getNumCols() - rank - 1; j < e; j++) {
1352 offset = offset + lbs[d][j] * top.getAffineDimExpr(j);
1353 }
1354 assert(lbDivisors[d] > 0);
1355 offset =
1356 (offset + lbs[d][cst->getNumCols() - 1 - rank]).floorDiv(lbDivisors[d]);
1357
1358 // Set copy start location for this dimension in the lower memory space
1359 // memref.
1360 if (auto caf = offset.dyn_cast<AffineConstantExpr>()) {
1361 auto indexVal = caf.getValue();
1362 if (indexVal == 0) {
1363 memIndices.push_back(zeroIndex);
1364 } else {
1365 memIndices.push_back(
1366 top.create<ConstantIndexOp>(loc, indexVal).getResult());
1367 }
1368 } else {
1369 // The coordinate for the start location is just the lower bound along the
1370 // corresponding dimension on the memory region (stored in 'offset').
River Riddle2acc2202019-10-17 20:08:01 -07001371 auto map = AffineMap::get(
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001372 cst->getNumDimIds() + cst->getNumSymbolIds() - rank, 0, offset);
1373 memIndices.push_back(b.create<AffineApplyOp>(loc, map, regionSymbols));
1374 }
1375 // The fast buffer is copied into at location zero; addressing is relative.
1376 bufIndices.push_back(zeroIndex);
1377
1378 // Record the offsets since they are needed to remap the memory accesses of
1379 // the original memref further below.
1380 offsets.push_back(offset);
1381 }
1382
1383 // The faster memory space buffer.
River Riddlee62a6952019-12-23 14:45:01 -08001384 Value fastMemRef;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001385
1386 // Check if a buffer was already created.
1387 bool existingBuf = fastBufferMap.count(memref) > 0;
1388 if (!existingBuf) {
1389 AffineMap fastBufferLayout = b.getMultiDimIdentityMap(rank);
1390 auto fastMemRefType =
River Riddle2acc2202019-10-17 20:08:01 -07001391 MemRefType::get(fastBufferShape, memRefType.getElementType(),
1392 fastBufferLayout, copyOptions.fastMemorySpace);
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001393
1394 // Create the fast memory space buffer just before the 'affine.for'
1395 // operation.
1396 fastMemRef = prologue.create<AllocOp>(loc, fastMemRefType).getResult();
1397 // Record it.
1398 fastBufferMap[memref] = fastMemRef;
1399 // fastMemRefType is a constant shaped memref.
1400 *sizeInBytes = getMemRefSizeInBytes(fastMemRefType).getValue();
1401 LLVM_DEBUG(emitRemarkForBlock(*block)
1402 << "Creating fast buffer of type " << fastMemRefType
1403 << " and size " << llvm::divideCeil(*sizeInBytes, 1024)
1404 << " KiB\n");
1405 } else {
1406 // Reuse the one already created.
1407 fastMemRef = fastBufferMap[memref];
1408 *sizeInBytes = 0;
1409 }
1410
1411 auto numElementsSSA =
1412 top.create<ConstantIndexOp>(loc, numElements.getValue());
1413
1414 SmallVector<StrideInfo, 4> strideInfos;
1415 getMultiLevelStrides(region, fastBufferShape, &strideInfos);
1416
1417 // TODO(bondhugula): use all stride levels once DmaStartOp is extended for
1418 // multi-level strides.
1419 if (strideInfos.size() > 1) {
1420 LLVM_DEBUG(llvm::dbgs() << "Only up to one level of stride supported\n");
1421 return failure();
1422 }
1423
River Riddlee62a6952019-12-23 14:45:01 -08001424 Value stride = nullptr;
1425 Value numEltPerStride = nullptr;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001426 if (!strideInfos.empty()) {
1427 stride = top.create<ConstantIndexOp>(loc, strideInfos[0].stride);
1428 numEltPerStride =
1429 top.create<ConstantIndexOp>(loc, strideInfos[0].numEltPerStride);
1430 }
1431
1432 // Record the last operation where we want the memref replacement to end. We
1433 // later do the memref replacement only in [begin, postDomFilter] so
1434 // that the original memref's used in the data movement code themselves don't
1435 // get replaced.
1436 auto postDomFilter = std::prev(end);
1437
1438 // Create fully composed affine maps for each memref.
1439 auto memAffineMap = b.getMultiDimIdentityMap(memIndices.size());
1440 fullyComposeAffineMapAndOperands(&memAffineMap, &memIndices);
1441 auto bufAffineMap = b.getMultiDimIdentityMap(bufIndices.size());
1442 fullyComposeAffineMapAndOperands(&bufAffineMap, &bufIndices);
1443
1444 if (!copyOptions.generateDma) {
1445 // Point-wise copy generation.
1446 auto copyNest = generatePointWiseCopy(loc, memref, fastMemRef, memAffineMap,
1447 memIndices, fastBufferShape,
1448 /*isCopyOut=*/region.isWrite(), b);
1449
1450 // Record this so that we can skip it from yet another copy.
1451 copyNests.insert(copyNest);
1452
1453 // Since new ops are being appended (for copy out's), adjust the end to
1454 // mark end of block range being processed if necessary.
1455 if (region.isWrite() && isCopyOutAtEndOfBlock)
1456 *nEnd = Block::iterator(copyNest.getOperation());
1457 } else {
1458 // DMA generation.
1459 // Create a tag (single element 1-d memref) for the DMA.
River Riddle2acc2202019-10-17 20:08:01 -07001460 auto tagMemRefType = MemRefType::get({1}, top.getIntegerType(32), {},
1461 copyOptions.tagMemorySpace);
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001462 auto tagMemRef = prologue.create<AllocOp>(loc, tagMemRefType);
1463
River Riddlee62a6952019-12-23 14:45:01 -08001464 SmallVector<Value, 4> tagIndices({zeroIndex});
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001465 auto tagAffineMap = b.getMultiDimIdentityMap(tagIndices.size());
1466 fullyComposeAffineMapAndOperands(&tagAffineMap, &tagIndices);
1467 if (!region.isWrite()) {
1468 // DMA non-blocking read from original buffer to fast buffer.
1469 b.create<AffineDmaStartOp>(loc, memref, memAffineMap, memIndices,
1470 fastMemRef, bufAffineMap, bufIndices,
1471 tagMemRef, tagAffineMap, tagIndices,
1472 numElementsSSA, stride, numEltPerStride);
1473 } else {
1474 // DMA non-blocking write from fast buffer to the original memref.
1475 auto op = b.create<AffineDmaStartOp>(
1476 loc, fastMemRef, bufAffineMap, bufIndices, memref, memAffineMap,
1477 memIndices, tagMemRef, tagAffineMap, tagIndices, numElementsSSA,
1478 stride, numEltPerStride);
1479 // Since new ops may be appended at 'end' (for outgoing DMAs), adjust the
1480 // end to mark end of block range being processed.
1481 if (isCopyOutAtEndOfBlock)
1482 *nEnd = Block::iterator(op.getOperation());
1483 }
1484
1485 // Matching DMA wait to block on completion; tag always has a 0 index.
1486 b.create<AffineDmaWaitOp>(loc, tagMemRef, tagAffineMap, zeroIndex,
1487 numElementsSSA);
1488
1489 // Generate dealloc for the tag.
1490 auto tagDeallocOp = epilogue.create<DeallocOp>(loc, tagMemRef);
1491 if (*nEnd == end && isCopyOutAtEndOfBlock)
1492 // Since new ops are being appended (for outgoing DMAs), adjust the end to
1493 // mark end of range of the original.
1494 *nEnd = Block::iterator(tagDeallocOp.getOperation());
1495 }
1496
1497 // Generate dealloc for the buffer.
1498 if (!existingBuf) {
1499 auto bufDeallocOp = epilogue.create<DeallocOp>(loc, fastMemRef);
1500 // When generating pointwise copies, `nEnd' has to be set to deallocOp on
1501 // the fast buffer (since it marks the new end insertion point).
1502 if (!copyOptions.generateDma && *nEnd == end && isCopyOutAtEndOfBlock)
1503 *nEnd = Block::iterator(bufDeallocOp.getOperation());
1504 }
1505
1506 // Replace all uses of the old memref with the faster one while remapping
1507 // access indices (subtracting out lower bound offsets for each dimension).
1508 // Ex: to replace load %A[%i, %j] with load %Abuf[%i - %iT, %j - %jT],
1509 // index remap will be (%i, %j) -> (%i - %iT, %j - %jT),
1510 // i.e., affine.apply (d0, d1, d2, d3) -> (d2-d0, d3-d1) (%iT, %jT, %i, %j),
1511 // and (%iT, %jT) will be the 'extraOperands' for 'rep all memref uses with'.
1512 // d2, d3 correspond to the original indices (%i, %j).
1513 SmallVector<AffineExpr, 4> remapExprs;
1514 remapExprs.reserve(rank);
1515 for (unsigned i = 0; i < rank; i++) {
1516 // The starting operands of indexRemap will be regionSymbols (the symbols on
1517 // which the memref region is parametric); then those corresponding to
1518 // the memref's original indices follow.
1519 auto dimExpr = b.getAffineDimExpr(regionSymbols.size() + i);
1520 remapExprs.push_back(dimExpr - offsets[i]);
1521 }
River Riddle2acc2202019-10-17 20:08:01 -07001522 auto indexRemap = AffineMap::get(regionSymbols.size() + rank, 0, remapExprs);
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001523
1524 // Record the begin since it may be invalidated by memref replacement.
1525 Block::iterator prevOfBegin;
1526 bool isBeginAtStartOfBlock = (begin == block->begin());
1527 if (!isBeginAtStartOfBlock)
1528 prevOfBegin = std::prev(begin);
1529
1530 // *Only* those uses within the range [begin, end) of 'block' are replaced.
1531 replaceAllMemRefUsesWith(memref, fastMemRef,
1532 /*extraIndices=*/{}, indexRemap,
1533 /*extraOperands=*/regionSymbols,
Uday Bondhugula727a50a2019-09-18 11:25:33 -07001534 /*symbolOperands=*/{},
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001535 /*domInstFilter=*/&*begin,
1536 /*postDomInstFilter=*/&*postDomFilter);
1537
1538 *nBegin = isBeginAtStartOfBlock ? block->begin() : std::next(prevOfBegin);
1539
1540 return success();
1541}
1542
1543/// Construct the memref region to just include the entire memref. Returns false
1544/// dynamic shaped memref's for now. `numParamLoopIVs` is the number of
1545/// enclosing loop IVs of opInst (starting from the outermost) that the region
1546/// is parametric on.
1547static bool getFullMemRefAsRegion(Operation *opInst, unsigned numParamLoopIVs,
1548 MemRefRegion *region) {
1549 unsigned rank;
1550 if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
1551 rank = loadOp.getMemRefType().getRank();
1552 region->memref = loadOp.getMemRef();
1553 region->setWrite(false);
1554 } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
1555 rank = storeOp.getMemRefType().getRank();
1556 region->memref = storeOp.getMemRef();
1557 region->setWrite(true);
1558 } else {
1559 assert(false && "expected load or store op");
1560 return false;
1561 }
River Riddle2bdf33c2020-01-11 08:54:04 -08001562 auto memRefType = region->memref.getType().cast<MemRefType>();
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001563 if (!memRefType.hasStaticShape())
1564 return false;
1565
1566 auto *regionCst = region->getConstraints();
1567
1568 // Just get the first numSymbols IVs, which the memref region is parametric
1569 // on.
1570 SmallVector<AffineForOp, 4> ivs;
1571 getLoopIVs(*opInst, &ivs);
1572 ivs.resize(numParamLoopIVs);
River Riddlee62a6952019-12-23 14:45:01 -08001573 SmallVector<Value, 4> symbols;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001574 extractForInductionVars(ivs, &symbols);
1575 regionCst->reset(rank, numParamLoopIVs, 0);
1576 regionCst->setIdValues(rank, rank + numParamLoopIVs, symbols);
1577
1578 // Memref dim sizes provide the bounds.
1579 for (unsigned d = 0; d < rank; d++) {
1580 auto dimSize = memRefType.getDimSize(d);
1581 assert(dimSize > 0 && "filtered dynamic shapes above");
1582 regionCst->addConstantLowerBound(d, 0);
1583 regionCst->addConstantUpperBound(d, dimSize - 1);
1584 }
1585 return true;
1586}
1587
1588/// Generates copies for a contiguous sequence of operations in `block` in the
1589/// iterator range [`begin', `end'), where `end' can't be past the terminator of
1590/// the block (since additional operations are potentially inserted right before
1591/// `end'. Returns the total size of the fast buffers used.
1592// Since we generate alloc's and dealloc's for all fast buffers (before and
1593// after the range of operations resp.), all of the fast memory capacity is
1594// assumed to be available for processing this block range.
1595uint64_t mlir::affineDataCopyGenerate(Block::iterator begin,
1596 Block::iterator end,
1597 const AffineCopyOptions &copyOptions,
1598 DenseSet<Operation *> &copyNests) {
1599 if (begin == end)
1600 return 0;
1601
1602 assert(begin->getBlock() == std::prev(end)->getBlock() &&
1603 "Inconsistent block begin/end args");
1604 assert(end != end->getBlock()->end() && "end can't be the block terminator");
1605
1606 Block *block = begin->getBlock();
1607
1608 // Copies will be generated for this depth, i.e., symbolic in all loops
1609 // surrounding the this block range.
1610 unsigned copyDepth = getNestingDepth(*begin);
1611
1612 LLVM_DEBUG(llvm::dbgs() << "Generating copies at depth " << copyDepth
1613 << "\n");
1614 LLVM_DEBUG(llvm::dbgs() << "from begin: " << *begin << "\n");
1615 LLVM_DEBUG(llvm::dbgs() << "to inclusive end: " << *std::prev(end) << "\n");
1616
1617 // List of memory regions to copy for. We need a map vector to have a
1618 // guaranteed iteration order to write test cases. CHECK-DAG doesn't help here
1619 // since the alloc's for example are identical except for the SSA id.
River Riddlee62a6952019-12-23 14:45:01 -08001620 SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> readRegions;
1621 SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4> writeRegions;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001622
1623 // Map from original memref's to the fast buffers that their accesses are
1624 // replaced with.
River Riddlee62a6952019-12-23 14:45:01 -08001625 DenseMap<Value, Value> fastBufferMap;
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001626
1627 // To check for errors when walking the block.
1628 bool error = false;
1629
1630 // Walk this range of operations to gather all memory regions.
1631 block->walk(begin, end, [&](Operation *opInst) {
1632 // Gather regions to allocate to buffers in faster memory space.
1633 if (auto loadOp = dyn_cast<AffineLoadOp>(opInst)) {
1634 if ((loadOp.getMemRefType().getMemorySpace() !=
1635 copyOptions.slowMemorySpace))
1636 return;
1637 } else if (auto storeOp = dyn_cast<AffineStoreOp>(opInst)) {
1638 if (storeOp.getMemRefType().getMemorySpace() !=
1639 copyOptions.slowMemorySpace)
1640 return;
1641 } else {
1642 // Neither load nor a store op.
1643 return;
1644 }
1645
1646 // Compute the MemRefRegion accessed.
1647 auto region = std::make_unique<MemRefRegion>(opInst->getLoc());
1648 if (failed(region->compute(opInst, copyDepth))) {
1649 LLVM_DEBUG(llvm::dbgs()
1650 << "Error obtaining memory region: semi-affine maps?\n");
1651 LLVM_DEBUG(llvm::dbgs() << "over-approximating to the entire memref\n");
1652 if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
1653 LLVM_DEBUG(
MLIR Team1c73be72019-09-18 07:44:39 -07001654 opInst->emitError("non-constant memref sizes not yet supported"));
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001655 error = true;
1656 return;
1657 }
1658 }
1659
1660 // Each memref has a single buffer associated with it irrespective of how
1661 // many load's and store's happen on it.
1662 // TODO(bondhugula): in the future, when regions don't intersect and satisfy
1663 // other properties (based on load/store regions), we could consider
1664 // multiple buffers per memref.
1665
1666 // Add to the appropriate region if it's not already in it, or take a
1667 // bounding box union with the existing one if it's already in there.
1668 // Note that a memref may have both read and write regions - so update the
1669 // region in the other list if one exists (write in case of read and vice
1670 // versa) since there is a single bounding box for a memref across all reads
1671 // and writes that happen on it.
1672
1673 // Attempts to update; returns true if 'region' exists in targetRegions.
1674 auto updateRegion =
River Riddlee62a6952019-12-23 14:45:01 -08001675 [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001676 &targetRegions) {
1677 auto it = targetRegions.find(region->memref);
1678 if (it == targetRegions.end())
1679 return false;
1680
1681 // Perform a union with the existing region.
1682 if (failed(it->second->unionBoundingBox(*region))) {
1683 LLVM_DEBUG(llvm::dbgs()
1684 << "Memory region bounding box failed; "
1685 "over-approximating to the entire memref\n");
1686 // If the union fails, we will overapproximate.
1687 if (!getFullMemRefAsRegion(opInst, copyDepth, region.get())) {
1688 LLVM_DEBUG(opInst->emitError(
MLIR Team1c73be72019-09-18 07:44:39 -07001689 "non-constant memref sizes not yet supported"));
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001690 error = true;
1691 return true;
1692 }
1693 it->second->getConstraints()->clearAndCopyFrom(
1694 *region->getConstraints());
1695 } else {
1696 // Union was computed and stored in 'it->second': copy to 'region'.
1697 region->getConstraints()->clearAndCopyFrom(
1698 *it->second->getConstraints());
1699 }
1700 return true;
1701 };
1702
1703 bool existsInRead = updateRegion(readRegions);
1704 if (error)
1705 return;
1706 bool existsInWrite = updateRegion(writeRegions);
1707 if (error)
1708 return;
1709
1710 // Finally add it to the region list.
1711 if (region->isWrite() && !existsInWrite) {
1712 writeRegions[region->memref] = std::move(region);
1713 } else if (!region->isWrite() && !existsInRead) {
1714 readRegions[region->memref] = std::move(region);
1715 }
1716 });
1717
1718 if (error) {
1719 begin->emitError(
1720 "copy generation failed for one or more memref's in this block\n");
1721 return 0;
1722 }
1723
1724 uint64_t totalCopyBuffersSizeInBytes = 0;
1725 bool ret = true;
1726 auto processRegions =
River Riddlee62a6952019-12-23 14:45:01 -08001727 [&](const SmallMapVector<Value, std::unique_ptr<MemRefRegion>, 4>
Uday Bondhugula4f32ae62019-09-14 13:23:18 -07001728 &regions) {
1729 for (const auto &regionEntry : regions) {
1730 // For each region, hoist copy in/out past all hoistable
1731 // 'affine.for's.
1732 Block::iterator copyInPlacementStart, copyOutPlacementStart;
1733 Block *copyPlacementBlock;
1734 findHighestBlockForPlacement(
1735 *regionEntry.second, *block, begin, end, &copyPlacementBlock,
1736 &copyInPlacementStart, &copyOutPlacementStart);
1737
1738 uint64_t sizeInBytes;
1739 Block::iterator nBegin, nEnd;
1740 LogicalResult iRet = generateCopy(
1741 *regionEntry.second, block, begin, end, copyPlacementBlock,
1742 copyInPlacementStart, copyOutPlacementStart, copyOptions,
1743 fastBufferMap, copyNests, &sizeInBytes, &nBegin, &nEnd);
1744 if (succeeded(iRet)) {
1745 // begin/end could have been invalidated, and need update.
1746 begin = nBegin;
1747 end = nEnd;
1748 totalCopyBuffersSizeInBytes += sizeInBytes;
1749 }
1750 ret = ret & succeeded(iRet);
1751 }
1752 };
1753 processRegions(readRegions);
1754 processRegions(writeRegions);
1755
1756 if (!ret) {
1757 begin->emitError(
1758 "copy generation failed for one or more memref's in this block\n");
1759 return totalCopyBuffersSizeInBytes;
1760 }
1761
1762 // For a range of operations, a note will be emitted at the caller.
1763 AffineForOp forOp;
1764 uint64_t sizeInKib = llvm::divideCeil(totalCopyBuffersSizeInBytes, 1024);
1765 if (llvm::DebugFlag && (forOp = dyn_cast<AffineForOp>(&*begin))) {
1766 forOp.emitRemark()
1767 << sizeInKib
1768 << " KiB of copy buffers in fast memory space for this block\n";
1769 }
1770
1771 if (totalCopyBuffersSizeInBytes > copyOptions.fastMemCapacityBytes) {
1772 StringRef str = "Total size of all copy buffers' for this block "
1773 "exceeds fast memory capacity\n";
1774 block->getParentOp()->emitError(str);
1775 }
1776
1777 return totalCopyBuffersSizeInBytes;
1778}