Rework the cloning infrastructure for statements to be able to take and update
an operand mapping, which simplifies it a bit.  Implement cloning for IfStmt,
rename getThenClause() to getThen() which is unambiguous and less repetitive in
use cases.

PiperOrigin-RevId: 207915990
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index eea3bf7..5c2dd86 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -30,6 +30,7 @@
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/Transforms/Passes.h"
+#include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
@@ -77,14 +78,15 @@
       bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
       if (!hasInnerLoops)
         loops.push_back(forStmt);
+
       return true;
     }
 
     bool walkIfStmtPostOrder(IfStmt *ifStmt) {
-      bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
-                                         ifStmt->getThenClause()->end());
-      hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
-                                     ifStmt->getElseClause()->end());
+      bool hasInnerLoops =
+          walkPostOrder(ifStmt->getThen()->begin(), ifStmt->getThen()->end());
+      hasInnerLoops |=
+          walkPostOrder(ifStmt->getElse()->begin(), ifStmt->getElse()->end());
       return hasInnerLoops;
     }
 
@@ -130,106 +132,35 @@
     runOnForStmt(forStmt);
 }
 
-/// Replace all uses of oldVal with newVal from begin to end.
-static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
-                        MLValue *oldVal, MLValue *newVal) {
-  // TODO(bondhugula,clattner): do this more efficiently by walking those uses
-  // of oldVal that fall within this list of statements (instead of iterating
-  // through all statements / through all operands of operations found).
-  for (auto it = begin; it != end; it++) {
-    it->replaceUses(oldVal, newVal);
-  }
-}
-
-/// Replace all uses of oldVal with newVal.
-void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
-  // TODO(bondhugula,clattner): do this more efficiently by walking those uses
-  // of oldVal that fall within this StmtBlock (instead of iterating through
-  // all statements / through all operands of operations found).
-  for (auto it = block->begin(); it != block->end(); it++) {
-    it->replaceUses(oldVal, newVal);
-  }
-}
-
-/// Clone the list of stmt's from 'block' and insert into the current
-/// position of the builder.
-// TODO(bondhugula,clattner): replace this with a parameterizable clone.
-void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
-  // Pairs of <old op stmt result whose uses need to be replaced,
-  // new result generated by the corresponding cloned op stmt>.
-  SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
-
-  // Iterator pointing to just before 'this' (i^th) unrolled iteration.
-  StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
-
-  for (auto &stmt : block.getStatements()) {
-    auto *cloneStmt = builder->clone(stmt);
-    // Whenever we have an op stmt, we'll have a new ML Value defined: replace
-    // uses of the old result with this one.
-    if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
-      if (opStmt->getNumResults()) {
-        auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
-        for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
-          // Store old/new result pairs.
-          // TODO(bondhugula) *only* if needed later: storing of old/new
-          // results can be avoided by cloning the statement list in the
-          // reverse direction (and running the IR builder in the reverse
-          // (iplist.insertAfter()). That way, a newly created result can be
-          // immediately propagated to all its uses.
-          oldNewResultPairs.push_back(std::make_pair(
-              const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
-              &cloneOpStmt->getStmtResult(i)));
-        }
-      }
-    }
-  }
-
-  // Replace uses of old op results' with the new results.
-  StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
-  StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
-
-  // Replace uses of old op results' with the newly created ones.
-  for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
-    replaceUses(startOfUnrolledBody, endOfUnrolledBody,
-                oldNewResultPairs[i].first, oldNewResultPairs[i].second);
-  }
-}
-
-/// Unroll this 'for stmt' / loop completely.
+/// Unroll this For loop completely.
 void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
   auto lb = forStmt->getLowerBound()->getValue();
   auto ub = forStmt->getUpperBound()->getValue();
   auto step = forStmt->getStep()->getValue();
 
   // Builder to add constants need for the unrolled iterator.
-  auto *mlFunc = forStmt->Statement::findFunction();
-  MLFuncBuilder funcTopBuilder(mlFunc);
-  funcTopBuilder.setInsertionPointAtStart(mlFunc);
+  auto *mlFunc = forStmt->findFunction();
+  MLFuncBuilder funcTopBuilder(&mlFunc->front());
 
-  // Builder to insert the unrolled bodies.
-  MLFuncBuilder builder(forStmt->getBlock());
-  // Set insertion point to right after where the for stmt ends.
-  builder.setInsertionPoint(forStmt->getBlock(),
-                            ++StmtBlock::iterator(forStmt));
+  // Builder to insert the unrolled bodies.  We insert right after the
+  /// ForStmt we're unrolling.
+  MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
 
   // Unroll the contents of 'forStmt'.
   for (int64_t i = lb; i <= ub; i += step) {
-    MLValue *ivConst = nullptr;
+    DenseMap<const MLValue *, MLValue *> operandMapping;
+
+    // If the induction variable is used, create a constant for this unrolled
+    // value and add an operand mapping for it.
     if (!forStmt->use_empty()) {
-      auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
-      ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+      auto *ivConst =
+          funcTopBuilder.create<ConstantAffineIntOp>(i)->getResult();
+      operandMapping[forStmt] = cast<MLValue>(ivConst);
     }
-    StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
 
-    // Clone the loop body and insert it right after the loop - the latter will
-    // be erased after all unrolling has been done.
-    cloneStmtListFromBlock(&builder, *forStmt);
-
-    // Replace unrolled loop IV with the unrolled constant.
-    if (ivConst) {
-      StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
-      StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
-      replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
+    // Clone the body of the loop.
+    for (auto &childStmt : *forStmt) {
+      (void)builder.clone(childStmt, operandMapping);
     }
   }
   // Erase the original 'for' stmt from the block.