Loop unrolling update.

- deal with non-operation stmt's (if/for stmt's) in loops being unrolled
  (unrolling of non-innermost loops works).
- update uses in unrolled bodies to use results of new operations that may be
  introduced in the unrolled bodies.

Unrolling now works for all kinds of loop nests - perfect nests, imperfect
nests, loops at any depth, and with any kind of operation in the body. (IfStmt
support not done, hence untested there).

Added missing dump/print method for StmtBlock.

TODO: add test case for outer loop unrolling.
PiperOrigin-RevId: 207314286
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index fe110d2..27bb43f 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -26,6 +26,7 @@
 #include "mlir/IR/Module.h"
 #include "mlir/IR/OperationSet.h"
 #include "mlir/IR/Pass.h"
+#include "mlir/IR/StandardOps.h"
 #include "mlir/IR/Statements.h"
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/Transforms/Passes.h"
@@ -96,61 +97,94 @@
     runOnForStmt(forStmt);
 }
 
-/// Replace an IV with a constant value.
-static void replaceIterator(Statement *stmt, const ForStmt &iv,
-                            MLValue *constVal) {
-  struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
-    // IV to be replaced.
-    const ForStmt *iv;
-    // Constant to be replaced with.
-    MLValue *constVal;
+/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'
+static void replaceAllStmtUses(Statement *stmt, MLValue *oldVal,
+                               MLValue *newVal) {
+  struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
+    // Value to be replaced.
+    MLValue *oldVal;
+    // Value to be replaced with.
+    MLValue *newVal;
 
-    ReplaceIterator(const ForStmt &iv, MLValue *constVal)
-        : iv(&iv), constVal(constVal){};
+    ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
+        : oldVal(oldVal), newVal(newVal){};
 
     void visitOperationStmt(OperationStmt *os) {
       for (auto &operand : os->getStmtOperands()) {
-        if (operand.get() == static_cast<const MLValue *>(iv)) {
-          operand.set(constVal);
-        }
+        if (operand.get() == oldVal)
+          operand.set(newVal);
       }
     }
   };
 
-  ReplaceIterator ri(iv, constVal);
+  ReplaceUseWalker ri(oldVal, newVal);
   ri.walk(stmt);
 }
 
-/// Unrolls this loop completely.
+/// Unroll this 'for stmt' / loop completely.
 void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
   auto lb = forStmt->getLowerBound()->getValue();
   auto ub = forStmt->getUpperBound()->getValue();
   auto step = forStmt->getStep()->getValue();
-  auto trip_count = (ub - lb + 1) / step;
 
+  // Builder to add constants need for the unrolled iterator.
   auto *mlFunc = forStmt->Statement::findFunction();
   MLFuncBuilder funcTopBuilder(mlFunc);
   funcTopBuilder.setInsertionPointAtStart(mlFunc);
 
+  // Builder to insert the unrolled bodies.
   MLFuncBuilder builder(forStmt->getBlock());
-  for (int i = 0; i < trip_count; i++) {
-    auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
-    for (auto &stmt : forStmt->getStatements()) {
-      switch (stmt.getKind()) {
-      case Statement::Kind::For:
-        llvm_unreachable("unrolling loops that have only operations");
-        break;
-      case Statement::Kind::If:
-        llvm_unreachable("unrolling loops that have only operations");
-        break;
-      case Statement::Kind::Operation:
-        auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
-        // TODO(bondhugula): only generate constants when the IV actually
-        // appears in the body.
-        replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
-        break;
+  // Set insertion point to right after where the for stmt ends.
+  builder.setInsertionPoint(forStmt->getBlock(),
+                            ++StmtBlock::iterator(forStmt));
+
+  // Unroll the contents of 'forStmt'.
+  for (int i = lb; i <= ub; i += step) {
+    // TODO(bondhugula): generate constants only when IV actually appears.
+    auto constOp = funcTopBuilder.create<ConstantIntOp>(i, 32);
+    auto *ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+
+    // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+    StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
+
+    // Pairs of <old op stmt result whose uses need to be replaced,
+    // new result generated by the corresponding cloned op stmt>.
+    SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+
+    for (auto &loopBodyStmt : forStmt->getStatements()) {
+      auto *cloneStmt = builder.clone(loopBodyStmt);
+      // Replace all uses of the IV in the clone with constant iteration value.
+      replaceAllStmtUses(cloneStmt, forStmt, ivConst);
+
+      // Whenever we have an op stmt, we'll have a new ML Value defined: replace
+      // uses of the old result with this one.
+      if (auto *opStmt = dyn_cast<OperationStmt>(&loopBodyStmt)) {
+        if (opStmt->getNumResults()) {
+          auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+          for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+            // Store old/new result pairs.
+            // TODO *only* if needed later: storing of old/new results can be
+            // avoided, by cloning the statement list in the reverse direction
+            // (and running the IR builder in the reverse
+            // (iplist.insertAfter()). That way, a newly created result can be
+            // immediately propagated to all its uses, which would already  been
+            // cloned/inserted.
+            oldNewResultPairs.push_back(std::make_pair(
+                &opStmt->getStmtResult(i), &cloneOpStmt->getStmtResult(i)));
+          }
+        }
+      }
+    }
+    // Replace uses of old op results' with the results in the just
+    // unrolled body.
+    StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
+    for (auto it = ++beforeUnrolledBody; it != endOfUnrolledBody; it++) {
+      for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
+        replaceAllStmtUses(&(*it), oldNewResultPairs[i].first,
+                           oldNewResultPairs[i].second);
       }
     }
   }
+  // Erase the original for stmt from the block.
   forStmt->eraseFromBlock();
 }