MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions.

- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
  innermost loops (it was stopping its walk at the first innermost loop it
  found)
- Improve comments for MLFunction statement classes, fix inheritance order.

- Fixed StmtBlock destructor.

PiperOrigin-RevId: 207049173
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 160a463..fe110d2 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -19,6 +19,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
@@ -54,10 +55,13 @@
     typedef llvm::iplist<Statement> StmtListType;
     bool walkPostOrder(StmtListType::iterator Start,
                        StmtListType::iterator End) {
+      bool hasInnerLoops = false;
+      // We need to walk all elements since all innermost loops need to be
+      // gathered as opposed to determining whether this list has any inner
+      // loops or not.
       while (Start != End)
-        if (walkPostOrder(&(*Start++)))
-          return true;
-      return false;
+        hasInnerLoops |= walkPostOrder(&(*Start++));
+      return hasInnerLoops;
     }
 
     // FIXME: can't use base class method for this because that in turn would
@@ -73,12 +77,11 @@
     }
 
     bool walkIfStmtPostOrder(IfStmt *ifStmt) {
-      if (walkPostOrder(ifStmt->getThenClause()->begin(),
-                        ifStmt->getThenClause()->end()) ||
-          walkPostOrder(ifStmt->getElseClause()->begin(),
-                        ifStmt->getElseClause()->end()))
-        return true;
-      return false;
+      bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
+                                         ifStmt->getThenClause()->end());
+      hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
+                                     ifStmt->getElseClause()->end());
+      return hasInnerLoops;
     }
 
     bool walkOpStmt(OperationStmt *opStmt) { return false; }
@@ -93,17 +96,45 @@
     runOnForStmt(forStmt);
 }
 
-/// Unrolls this loop completely. Returns true if the unrolling happens.
+/// Replace an IV with a constant value.
+static void replaceIterator(Statement *stmt, const ForStmt &iv,
+                            MLValue *constVal) {
+  struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
+    // IV to be replaced.
+    const ForStmt *iv;
+    // Constant to be replaced with.
+    MLValue *constVal;
+
+    ReplaceIterator(const ForStmt &iv, MLValue *constVal)
+        : iv(&iv), constVal(constVal){};
+
+    void visitOperationStmt(OperationStmt *os) {
+      for (auto &operand : os->getStmtOperands()) {
+        if (operand.get() == static_cast<const MLValue *>(iv)) {
+          operand.set(constVal);
+        }
+      }
+    }
+  };
+
+  ReplaceIterator ri(iv, constVal);
+  ri.walk(stmt);
+}
+
+/// Unrolls this loop completely.
 void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
   auto lb = forStmt->getLowerBound()->getValue();
   auto ub = forStmt->getUpperBound()->getValue();
   auto step = forStmt->getStep()->getValue();
   auto trip_count = (ub - lb + 1) / step;
 
-  auto *block = forStmt->getBlock();
-  MLFuncBuilder builder(block);
+  auto *mlFunc = forStmt->Statement::findFunction();
+  MLFuncBuilder funcTopBuilder(mlFunc);
+  funcTopBuilder.setInsertionPointAtStart(mlFunc);
 
+  MLFuncBuilder builder(forStmt->getBlock());
   for (int i = 0; i < trip_count; i++) {
+    auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
     for (auto &stmt : forStmt->getStatements()) {
       switch (stmt.getKind()) {
       case Statement::Kind::For:
@@ -113,16 +144,13 @@
         llvm_unreachable("unrolling loops that have only operations");
         break;
       case Statement::Kind::Operation:
-        auto *op = cast<OperationStmt>(&stmt);
-        // TODO: clone operands and result types.
-        builder.createOperation(op->getName(), /*operands*/ {},
-                                /*resultTypes*/ {}, op->getAttrs());
-        // TODO: loop iterator parsing not yet implemented; replace loop
-        // iterator uses in unrolled body appropriately.
+        auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
+        // TODO(bondhugula): only generate constants when the IV actually
+        // appears in the body.
+        replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
         break;
       }
     }
   }
-
   forStmt->eraseFromBlock();
 }