Rework the cloning infrastructure for statements to be able to take and update
an operand mapping, which simplifies it a bit. Implement cloning for IfStmt,
rename getThenClause() to getThen() which is unambiguous and less repetitive in
use cases.
PiperOrigin-RevId: 207915990
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index eea3bf7..5c2dd86 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -30,6 +30,7 @@
#include "mlir/IR/Statements.h"
#include "mlir/IR/StmtVisitor.h"
#include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/DenseMap.h"
#include "llvm/Support/raw_ostream.h"
using namespace mlir;
@@ -77,14 +78,15 @@
bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
if (!hasInnerLoops)
loops.push_back(forStmt);
+
return true;
}
bool walkIfStmtPostOrder(IfStmt *ifStmt) {
- bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
- ifStmt->getThenClause()->end());
- hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
- ifStmt->getElseClause()->end());
+ bool hasInnerLoops =
+ walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
+ hasInnerLoops |=
+ walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
return hasInnerLoops;
}
@@ -130,106 +132,35 @@
runOnForStmt(forStmt);
}
-/// Replace all uses of oldVal with newVal from begin to end.
-static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
- MLValue *oldVal, MLValue *newVal) {
- // TODO(bondhugula,clattner): do this more efficiently by walking those uses
- // of oldVal that fall within this list of statements (instead of iterating
- // through all statements / through all operands of operations found).
- for (auto it = begin; it != end; it++) {
- it->replaceUses(oldVal, newVal);
- }
-}
-
-/// Replace all uses of oldVal with newVal.
-void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
- // TODO(bondhugula,clattner): do this more efficiently by walking those uses
- // of oldVal that fall within this StmtBlock (instead of iterating through
- // all statements / through all operands of operations found).
- for (auto it = block->begin(); it != block->end(); it++) {
- it->replaceUses(oldVal, newVal);
- }
-}
-
-/// Clone the list of stmt's from 'block' and insert into the current
-/// position of the builder.
-// TODO(bondhugula,clattner): replace this with a parameterizable clone.
-void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
- // Pairs of <old op stmt result whose uses need to be replaced,
- // new result generated by the corresponding cloned op stmt>.
- SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
-
- // Iterator pointing to just before 'this' (i^th) unrolled iteration.
- StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
-
- for (auto &stmt : block.getStatements()) {
- auto *cloneStmt = builder->clone(stmt);
- // Whenever we have an op stmt, we'll have a new ML Value defined: replace
- // uses of the old result with this one.
- if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
- if (opStmt->getNumResults()) {
- auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
- for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
- // Store old/new result pairs.
- // TODO(bondhugula) *only* if needed later: storing of old/new
- // results can be avoided by cloning the statement list in the
- // reverse direction (and running the IR builder in the reverse
- // (iplist.insertAfter()). That way, a newly created result can be
- // immediately propagated to all its uses.
- oldNewResultPairs.push_back(std::make_pair(
- const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
- &cloneOpStmt->getStmtResult(i)));
- }
- }
- }
- }
-
- // Replace uses of old op results' with the new results.
- StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
- StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
-
- // Replace uses of old op results' with the newly created ones.
- for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
- replaceUses(startOfUnrolledBody, endOfUnrolledBody,
- oldNewResultPairs[i].first, oldNewResultPairs[i].second);
- }
-}
-
-/// Unroll this 'for stmt' / loop completely.
+/// Unroll this For loop completely.
void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
auto lb = forStmt->getLowerBound()->getValue();
auto ub = forStmt->getUpperBound()->getValue();
auto step = forStmt->getStep()->getValue();
// Builder to add constants need for the unrolled iterator.
- auto *mlFunc = forStmt->Statement::findFunction();
- MLFuncBuilder funcTopBuilder(mlFunc);
- funcTopBuilder.setInsertionPointAtStart(mlFunc);
+ auto *mlFunc = forStmt->findFunction();
+ MLFuncBuilder funcTopBuilder(&mlFunc->front());
- // Builder to insert the unrolled bodies.
- MLFuncBuilder builder(forStmt->getBlock());
- // Set insertion point to right after where the for stmt ends.
- builder.setInsertionPoint(forStmt->getBlock(),
- ++StmtBlock::iterator(forStmt));
+ // Builder to insert the unrolled bodies. We insert right after the
+ /// ForStmt we're unrolling.
+ MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
// Unroll the contents of 'forStmt'.
for (int64_t i = lb; i <= ub; i += step) {
- MLValue *ivConst = nullptr;
+ DenseMap<const MLValue *, MLValue *> operandMapping;
+
+ // If the induction variable is used, create a constant for this unrolled
+ // value and add an operand mapping for it.
if (!forStmt->use_empty()) {
- auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
- ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+ auto *ivConst =
+ funcTopBuilder.create<ConstantAffineIntOp>(i)->getResult();
+ operandMapping[forStmt] = cast<MLValue>(ivConst);
}
- StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
- // Clone the loop body and insert it right after the loop - the latter will
- // be erased after all unrolling has been done.
- cloneStmtListFromBlock(&builder, *forStmt);
-
- // Replace unrolled loop IV with the unrolled constant.
- if (ivConst) {
- StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
- StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
- replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
+ // Clone the body of the loop.
+ for (auto &childStmt : *forStmt) {
+ (void)builder.clone(childStmt, operandMapping);
}
}
// Erase the original 'for' stmt from the block.