Add utility to promote single iteration loops. Add methods for getting constant
loop counts. Improve / refactor loop unroll / loop unroll and jam.
- add utility to remove single iteration loops.
- use this utility to promote single iteration loops after unroll/unroll-and-jam
- use loopUnrollByFactor for loopUnrollFull and remove most of the latter.
- add methods for getting constant loop trip count
PiperOrigin-RevId: 212039569
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 382b830..e663ce0 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -33,22 +33,21 @@
#include "mlir/Transforms/Passes.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/CommandLine.h"
-#include "llvm/Support/raw_ostream.h"
using namespace mlir;
-using namespace llvm;
// Loop unrolling factor.
static llvm::cl::opt<unsigned>
- clUnrollFactor("unroll-factor", cl::Hidden,
- cl::desc("Use this unroll factor for all loops"));
+ clUnrollFactor("unroll-factor", llvm::cl::Hidden,
+ llvm::cl::desc("Use this unroll factor for all loops"));
-static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
- cl::desc("Fully unroll loops"));
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", llvm::cl::Hidden,
+ llvm::cl::desc("Fully unroll loops"));
static llvm::cl::opt<unsigned> clUnrollFullThreshold(
- "unroll-full-threshold", cl::Hidden,
- cl::desc("Unroll all loops with trip count less than or equal to this"));
+ "unroll-full-threshold", llvm::cl::Hidden,
+ llvm::cl::desc(
+ "Unroll all loops with trip count less than or equal to this"));
namespace {
/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
@@ -67,7 +66,7 @@
/// Unroll this for stmt. Returns false if nothing was done.
bool runOnForStmt(ForStmt *forStmt);
bool loopUnrollFull(ForStmt *forStmt);
- bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
+ bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
};
} // end anonymous namespace
@@ -129,13 +128,8 @@
ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
void visitForStmt(ForStmt *forStmt) {
- if (!forStmt->hasConstantBounds())
- return;
- auto lb = forStmt->getConstantLowerBound();
- auto ub = forStmt->getConstantUpperBound();
- auto step = forStmt->getStep();
-
- if ((ub - lb) / step + 1 <= minTripCount)
+ Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+ if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
loops.push_back(forStmt);
}
};
@@ -180,43 +174,14 @@
// Unrolls this loop completely. Fails assertion if loop bounds are
// non-constant.
bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
- auto lb = forStmt->getConstantLowerBound();
- auto ub = forStmt->getConstantUpperBound();
- auto step = forStmt->getStep();
-
- // Builder to add constants needed for the unrolled iterator.
- auto *mlFunc = forStmt->findFunction();
- MLFuncBuilder funcTopBuilder(&mlFunc->front());
-
- // Builder to insert the unrolled bodies. We insert right after the
- // ForStmt we're unrolling.
- MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
-
- // Unroll the contents of 'forStmt'.
- for (int64_t i = lb; i <= ub; i += step) {
- DenseMap<const MLValue *, MLValue *> operandMapping;
-
- // If the induction variable is used, create a constant for this unrolled
- // value and add an operand mapping for it.
- if (!forStmt->use_empty()) {
- auto *ivConst =
- funcTopBuilder.create<ConstantAffineIntOp>(forStmt->getLoc(), i)
- ->getResult();
- operandMapping[forStmt] = cast<MLValue>(ivConst);
- }
-
- // Clone the body of the loop.
- for (auto &childStmt : *forStmt) {
- builder.clone(childStmt, operandMapping);
- }
- }
- // Erase the original 'for' stmt from the block.
- forStmt->eraseFromBlock();
- return true;
+ Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+ if (tripCount.hasValue())
+ return loopUnrollByFactor(forStmt, tripCount.getValue());
+ return false;
}
/// Unrolls this loop by the specified unroll factor.
-bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
if (unrollFactor == 1 || forStmt->getStatements().empty())
@@ -225,11 +190,9 @@
if (!forStmt->hasConstantBounds())
return false;
- auto lb = forStmt->getConstantLowerBound();
- auto ub = forStmt->getConstantUpperBound();
- auto step = forStmt->getStep();
-
- int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+ int64_t lb = forStmt->getConstantLowerBound();
+ int64_t step = forStmt->getStep();
+ uint64_t tripCount = forStmt->getConstantTripCount().getValue();
// If the trip count is lower than the unroll factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -243,6 +206,8 @@
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
cleanupForStmt->setConstantLowerBound(
lb + (tripCount - tripCount % unrollFactor) * step);
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(cleanupForStmt);
}
// Builder to insert unrolled bodies right after the last statement in the
@@ -281,5 +246,9 @@
// Clone the last statement in the original body.
builder.clone(*srcBlockEnd, operandMapping);
}
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(forStmt);
+
return true;
}
diff --git a/lib/Transforms/LoopUnrollJam.cpp b/lib/Transforms/LoopUnrollJam.cpp
index eeab87c..6fb6134 100644
--- a/lib/Transforms/LoopUnrollJam.cpp
+++ b/lib/Transforms/LoopUnrollJam.cpp
@@ -23,7 +23,7 @@
// bounds of the loops inner to the loop being unroll-jammed do not depend on
// the latter.
//
-// Before After unroll-jam of i by factor 2:
+// Before After unroll and jam of i by factor 2:
//
// for i, step = 2
// for i S1(i);
@@ -54,7 +54,6 @@
#include "llvm/Support/CommandLine.h"
using namespace mlir;
-using namespace llvm::cl;
// Loop unroll jam factor.
static llvm::cl::opt<unsigned>
@@ -74,7 +73,7 @@
void runOnMLFunction(MLFunction *f) override;
bool runOnForStmt(ForStmt *forStmt);
- bool loopUnrollJamByFactor(ForStmt *forStmt, unsigned unrollJamFactor);
+ bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
};
} // end anonymous namespace
@@ -110,15 +109,7 @@
/// Unrolls and jams this loop by the specified factor.
bool LoopUnrollAndJam::loopUnrollJamByFactor(ForStmt *forStmt,
- unsigned unrollJamFactor) {
- assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
-
- if (unrollJamFactor == 1 || forStmt->getStatements().empty())
- return false;
-
- if (!forStmt->hasConstantBounds())
- return false;
-
+ uint64_t unrollJamFactor) {
// Gathers all maximal sub-blocks of statements that do not themselves include
// a for stmt (a statement could have a descendant for stmt though in its
// tree).
@@ -146,12 +137,17 @@
}
};
- auto lb = forStmt->getConstantLowerBound();
- auto ub = forStmt->getConstantUpperBound();
- auto step = forStmt->getStep();
+ assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
- int64_t tripCount = (ub - lb + 1) % step == 0 ? (ub - lb + 1) / step
- : (ub - lb + 1) / step + 1;
+ if (unrollJamFactor == 1 || forStmt->getStatements().empty())
+ return false;
+
+ if (!forStmt->hasConstantBounds())
+ return false;
+
+ int64_t lb = forStmt->getConstantLowerBound();
+ int64_t step = forStmt->getStep();
+ uint64_t tripCount = forStmt->getConstantTripCount().getValue();
// If the trip count is lower than the unroll jam factor, no unrolled body.
// TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -172,6 +168,9 @@
auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
cleanupForStmt->setConstantLowerBound(
lb + (tripCount - tripCount % unrollJamFactor) * step);
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(cleanupForStmt);
}
MLFuncBuilder b(forStmt);
@@ -210,5 +209,9 @@
builder.clone(*subBlock.second, operandMapping);
}
}
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ promoteIfSingleIteration(forStmt);
+
return true;
}
diff --git a/lib/Transforms/LoopUtils.cpp b/lib/Transforms/LoopUtils.cpp
new file mode 100644
index 0000000..cc2dc40
--- /dev/null
+++ b/lib/Transforms/LoopUtils.cpp
@@ -0,0 +1,63 @@
+//===- LoopUtils.cpp - Misc loop utilities for simplification //-----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop simplification routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardOps.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/Transforms/Passes.h"
+
+/// Promotes the loop body of a forStmt to its containing block if the forStmt
+/// was known to have a single iteration. Returns false otherwise.
+bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
+ Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+ if (!tripCount.hasValue() || !forStmt->hasConstantLowerBound())
+ return false;
+
+ if (tripCount.getValue() != 1)
+ return false;
+
+ // Replaces all IV uses to its single iteration value.
+ auto *mlFunc = forStmt->findFunction();
+ MLFuncBuilder topBuilder(&mlFunc->front());
+ auto constOp = topBuilder.create<ConstantAffineIntOp>(
+ forStmt->getLoc(), forStmt->getConstantLowerBound());
+ forStmt->replaceAllUsesWith(constOp->getResult());
+ // Move the statements to the containing block.
+ auto *block = forStmt->getBlock();
+ block->getStatements().splice(StmtBlock::iterator(forStmt),
+ forStmt->getStatements());
+ forStmt->eraseFromBlock();
+ return true;
+}
+
+/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void mlir::promoteSingleIterationLoops(MLFunction *f) {
+ // Gathers all innermost loops through a post order pruned walk.
+ class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
+ public:
+ void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
+ };
+
+ LoopBodyPromoter fsw;
+ fsw.walkPostOrder(f);
+}