MLStmt cloning and IV replacement for loop unrolling, add constant pool to
MLFunctions.

- MLStmt cloning and IV replacement
- While at this, fix the innermostLoopGatherer to actually gather all the
  innermost loops (it was stopping its walk at the first innermost loop it
  found)
- Improve comments for MLFunction statement classes, fix inheritance order.

- Fixed StmtBlock destructor.

PiperOrigin-RevId: 207049173
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 67cc7b8..4f7161b 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -18,6 +18,7 @@
 #ifndef MLIR_IR_BUILDERS_H
 #define MLIR_IR_BUILDERS_H
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Statements.h"
@@ -162,6 +163,12 @@
     return op;
   }
 
+  OperationInst *cloneOperation(const OperationInst &srcOpInst) {
+    auto *op = srcOpInst.clone();
+    block->getOperations().insert(insertPoint, op);
+    return op;
+  }
+
   // Terminators.
 
   ReturnInst *createReturnInst(ArrayRef<CFGValue *> operands) {
@@ -232,6 +239,12 @@
     insertPoint = block->end();
   }
 
+  /// Set the insertion point at the beginning of the specified block.
+  void setInsertionPointAtStart(StmtBlock *block) {
+    this->block = block;
+    insertPoint = block->begin();
+  }
+
   OperationStmt *createOperation(Identifier name, ArrayRef<MLValue *> operands,
                                  ArrayRef<Type *> resultTypes,
                                  ArrayRef<NamedAttribute> attributes) {
@@ -241,6 +254,12 @@
     return op;
   }
 
+  OperationStmt *cloneOperation(const OperationStmt &srcOpStmt) {
+    auto *op = srcOpStmt.clone();
+    block->getStatements().insert(insertPoint, op);
+    return op;
+  }
+
   // Creates for statement. When step is not specified, it is set to 1.
   ForStmt *createFor(AffineConstantExpr *lowerBound,
                      AffineConstantExpr *upperBound,
@@ -252,6 +271,15 @@
     return stmt;
   }
 
+  // TODO: subsume with a generate create<ConstantInt>() method.
+  OperationStmt *createConstInt32Op(int value) {
+    std::pair<Identifier, Attribute *> namedAttr(
+        Identifier::get("value", context), getIntegerAttr(value));
+    auto *mlconst = createOperation(Identifier::get("constant", context), {},
+                                    {getIntegerType(32)}, {namedAttr});
+    return mlconst;
+  }
+
 private:
   StmtBlock *block = nullptr;
   StmtBlock::iterator insertPoint;
diff --git a/include/mlir/IR/Instructions.h b/include/mlir/IR/Instructions.h
index a18a97f..1eac36c 100644
--- a/include/mlir/IR/Instructions.h
+++ b/include/mlir/IR/Instructions.h
@@ -158,6 +158,8 @@
   /// Return the context this operation is associated with.
   MLIRContext *getContext() const { return Instruction::getContext(); }
 
+  OperationInst *clone() const;
+
   //===--------------------------------------------------------------------===//
   // Operands
   //===--------------------------------------------------------------------===//
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index 65cad98..b5d2371 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -34,8 +34,8 @@
 
 /// Statement is a basic unit of execution within an ML function.
 /// Statements can be nested within for and if statements effectively
-/// forming a tree. Statements are organized into statement blocks
-/// represented by StmtBlock class.
+/// forming a tree. Child statements are organized into statement blocks
+/// represented by a 'StmtBlock' class.
 class Statement : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
 public:
   enum class Kind {
@@ -77,6 +77,7 @@
 
 private:
   Kind kind;
+  /// The statement block that containts this statement.
   StmtBlock *block = nullptr;
 
   // allow ilist_traits access to 'block' field.
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 306702f..ebecf58 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -47,6 +47,8 @@
   /// Return the context this operation is associated with.
   MLIRContext *getContext() const;
 
+  OperationStmt *clone() const;
+
   //===--------------------------------------------------------------------===//
   // Operands
   //===--------------------------------------------------------------------===//
@@ -190,7 +192,7 @@
 };
 
 /// For statement represents an affine loop nest.
-class ForStmt : public Statement, public StmtBlock, public MLValue {
+class ForStmt : public Statement, public MLValue, public StmtBlock {
 public:
   // TODO: lower and upper bounds should be affine maps with
   // dimension and symbol use lists.
@@ -199,6 +201,10 @@
                    MLIRContext *context);
 
   // Loop bounds and step are immortal objects and don't need to be deleted.
+  // With this dtor, ForStmt needs to inherit from MLValue before it does from
+  // StmtBlock since an MLValue can't be destroyed before the StmtBlock is ---
+  // the latter has uses for the induction variables, which is actually the
+  // MLValue here. FIXME: this dtor.
   ~ForStmt() {}
 
   AffineConstantExpr *getLowerBound() const { return lowerBound; }
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index 8b03bf1..77f490a 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -26,10 +26,12 @@
 #include "mlir/IR/Statement.h"
 
 namespace mlir {
-  class MLFunction;
-  class IfStmt;
+class MLFunction;
+class IfStmt;
 
-/// Statement block represents an ordered list of statements.
+/// Statement block represents an ordered list of statements, with the order
+/// being the contiguous lexical order in which the statements appear as
+/// children of a parent statement in the ML Function.
 class StmtBlock {
 public:
   enum class StmtBlockKind {
@@ -54,7 +56,7 @@
 
   /// This is the list of statements in the block.
   typedef llvm::iplist<Statement> StmtListType;
-  StmtListType       &getStatements() { return statements; }
+  StmtListType &getStatements() { return statements; }
   const StmtListType &getStatements() const { return statements; }
 
   // Iteration over the statements in the block.
@@ -82,14 +84,14 @@
   }
   Statement       &front() { return statements.front(); }
   const Statement &front() const {
-    return const_cast<StmtBlock*>(this)->front();
+    return const_cast<StmtBlock *>(this)->front();
   }
 
   void print(raw_ostream &os) const;
   void dump() const;
 
   /// getSublistAccess() - Returns pointer to member of statement list
-  static StmtListType StmtBlock::*getSublistAccess(Statement*) {
+  static StmtListType StmtBlock::*getSublistAccess(Statement *) {
     return &StmtBlock::statements;
   }
 
@@ -101,9 +103,8 @@
   /// This is the list of statements in the block.
   StmtListType statements;
 
-  StmtBlock(const StmtBlock&) = delete;
-  void operator=(const StmtBlock&) = delete;
-
+  StmtBlock(const StmtBlock &) = delete;
+  void operator=(const StmtBlock &) = delete;
 };
 
 } //end namespace mlir
diff --git a/lib/IR/Instructions.cpp b/lib/IR/Instructions.cpp
index a10cb3a..847907b 100644
--- a/lib/IR/Instructions.cpp
+++ b/lib/IR/Instructions.cpp
@@ -145,6 +145,21 @@
   return inst;
 }
 
+OperationInst *OperationInst::clone() const {
+  SmallVector<CFGValue *, 8> operands;
+  SmallVector<Type *, 8> resultTypes;
+
+  // TODO(clattner): switch to iterator logic.
+  // Put together the operands and results.
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
+    operands.push_back(getInstOperand(i).get());
+
+  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+    resultTypes.push_back(getInstResult(i).getType());
+
+  return create(getName(), operands, resultTypes, getAttrs(), getContext());
+}
+
 OperationInst::OperationInst(Identifier name, unsigned numOperands,
                              unsigned numResults,
                              ArrayRef<NamedAttribute> attributes,
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 5bc99e0..4e2a7b0 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -120,7 +120,6 @@
 
 /// Remove this statement (and its descendants) from its StmtBlock and delete
 /// all of them.
-/// TODO: erase all descendents for ForStmt/IfStmt.
 void Statement::eraseFromBlock() {
   assert(getBlock() && "Statement has no block");
   getBlock()->getStatements().erase(this);
@@ -155,6 +154,22 @@
   return stmt;
 }
 
+/// Clone an existing OperationStmt.
+OperationStmt *OperationStmt::clone() const {
+  SmallVector<MLValue *, 8> operands;
+  SmallVector<Type *, 8> resultTypes;
+
+  // TODO(clattner): switch this to iterator logic.
+  // Put together operands and results.
+  for (unsigned i = 0, e = getNumOperands(); i != e; ++i)
+    operands.push_back(getStmtOperand(i).get());
+
+  for (unsigned i = 0, e = getNumResults(); i != e; ++i)
+    resultTypes.push_back(getStmtResult(i).getType());
+
+  return create(getName(), operands, resultTypes, getAttrs(), getContext());
+}
+
 OperationStmt::OperationStmt(Identifier name, unsigned numOperands,
                              unsigned numResults,
                              ArrayRef<NamedAttribute> attributes,
@@ -205,9 +220,10 @@
 
 ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
                  AffineConstantExpr *step, MLIRContext *context)
-    : Statement(Kind::For), StmtBlock(StmtBlockKind::For),
+    : Statement(Kind::For),
       MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
-      lowerBound(lowerBound), upperBound(upperBound), step(step) {}
+      StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
+      upperBound(upperBound), step(step) {}
 
 //===----------------------------------------------------------------------===//
 // IfStmt
@@ -215,6 +231,6 @@
 
 IfStmt::~IfStmt() {
   delete thenClause;
-  if (elseClause != nullptr)
+  if (elseClause)
     delete elseClause;
 }
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 160a463..fe110d2 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -19,6 +19,7 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/Attributes.h"
 #include "mlir/IR/Builders.h"
 #include "mlir/IR/CFGFunction.h"
 #include "mlir/IR/MLFunction.h"
@@ -54,10 +55,13 @@
     typedef llvm::iplist<Statement> StmtListType;
     bool walkPostOrder(StmtListType::iterator Start,
                        StmtListType::iterator End) {
+      bool hasInnerLoops = false;
+      // We need to walk all elements since all innermost loops need to be
+      // gathered as opposed to determining whether this list has any inner
+      // loops or not.
       while (Start != End)
-        if (walkPostOrder(&(*Start++)))
-          return true;
-      return false;
+        hasInnerLoops |= walkPostOrder(&(*Start++));
+      return hasInnerLoops;
     }
 
     // FIXME: can't use base class method for this because that in turn would
@@ -73,12 +77,11 @@
     }
 
     bool walkIfStmtPostOrder(IfStmt *ifStmt) {
-      if (walkPostOrder(ifStmt->getThenClause()->begin(),
-                        ifStmt->getThenClause()->end()) ||
-          walkPostOrder(ifStmt->getElseClause()->begin(),
-                        ifStmt->getElseClause()->end()))
-        return true;
-      return false;
+      bool hasInnerLoops = walkPostOrder(ifStmt->getThenClause()->begin(),
+                                         ifStmt->getThenClause()->end());
+      hasInnerLoops |= walkPostOrder(ifStmt->getElseClause()->begin(),
+                                     ifStmt->getElseClause()->end());
+      return hasInnerLoops;
     }
 
     bool walkOpStmt(OperationStmt *opStmt) { return false; }
@@ -93,17 +96,45 @@
     runOnForStmt(forStmt);
 }
 
-/// Unrolls this loop completely. Returns true if the unrolling happens.
+/// Replace an IV with a constant value.
+static void replaceIterator(Statement *stmt, const ForStmt &iv,
+                            MLValue *constVal) {
+  struct ReplaceIterator : public StmtWalker<ReplaceIterator> {
+    // IV to be replaced.
+    const ForStmt *iv;
+    // Constant to be replaced with.
+    MLValue *constVal;
+
+    ReplaceIterator(const ForStmt &iv, MLValue *constVal)
+        : iv(&iv), constVal(constVal){};
+
+    void visitOperationStmt(OperationStmt *os) {
+      for (auto &operand : os->getStmtOperands()) {
+        if (operand.get() == static_cast<const MLValue *>(iv)) {
+          operand.set(constVal);
+        }
+      }
+    }
+  };
+
+  ReplaceIterator ri(iv, constVal);
+  ri.walk(stmt);
+}
+
+/// Unrolls this loop completely.
 void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
   auto lb = forStmt->getLowerBound()->getValue();
   auto ub = forStmt->getUpperBound()->getValue();
   auto step = forStmt->getStep()->getValue();
   auto trip_count = (ub - lb + 1) / step;
 
-  auto *block = forStmt->getBlock();
-  MLFuncBuilder builder(block);
+  auto *mlFunc = forStmt->Statement::findFunction();
+  MLFuncBuilder funcTopBuilder(mlFunc);
+  funcTopBuilder.setInsertionPointAtStart(mlFunc);
 
+  MLFuncBuilder builder(forStmt->getBlock());
   for (int i = 0; i < trip_count; i++) {
+    auto *ivUnrolledVal = funcTopBuilder.createConstInt32Op(i)->getResult(0);
     for (auto &stmt : forStmt->getStatements()) {
       switch (stmt.getKind()) {
       case Statement::Kind::For:
@@ -113,16 +144,13 @@
         llvm_unreachable("unrolling loops that have only operations");
         break;
       case Statement::Kind::Operation:
-        auto *op = cast<OperationStmt>(&stmt);
-        // TODO: clone operands and result types.
-        builder.createOperation(op->getName(), /*operands*/ {},
-                                /*resultTypes*/ {}, op->getAttrs());
-        // TODO: loop iterator parsing not yet implemented; replace loop
-        // iterator uses in unrolled body appropriately.
+        auto *cloneOp = builder.cloneOperation(*cast<OperationStmt>(&stmt));
+        // TODO(bondhugula): only generate constants when the IV actually
+        // appears in the body.
+        replaceIterator(cloneOp, *forStmt, ivUnrolledVal);
         break;
       }
     }
   }
-
   forStmt->eraseFromBlock();
 }
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 2d42e4a..6a7d9cf 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,16 +1,64 @@
 // RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
 
-// CHECK-LABEL: mlfunc @loops() {
-mlfunc @loops() {
-  // CHECK: for %i0 = 1 to 100 step 2 {
+// CHECK-LABEL: mlfunc @loops1() {
+mlfunc @loops1() {
+  // CHECK: %c0_i32 = constant 0 : i32
+  // CHECK-NEXT: %c1_i32 = constant 1 : i32
+  // CHECK-NEXT: %c2_i32 = constant 2 : i32
+  // CHECK-NEXT: %c3_i32 = constant 3 : i32
+  // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
-    // CHECK: "custom"(){value: 1} : () -> ()
-    // CHECK-NEXT: "custom"(){value: 1} : () -> ()
-    // CHECK-NEXT: "custom"(){value: 1} : () -> ()
-    // CHECK-NEXT: "custom"(){value: 1} : () -> ()
+    // CHECK: %c1_i32_0 = constant 1 : i32
+    // CHECK-NEXT: %c1_i32_1 = constant 1 : i32
+    // CHECK-NEXT: %c1_i32_2 = constant 1 : i32
+    // CHECK-NEXT: %c1_i32_3 = constant 1 : i32
     for %j = 1 to 4 {
-      "custom"(){value: 1} : () -> f32
+      %x = constant 1 : i32
     }
   }       // CHECK:  }
   return  // CHECK:  return
 }         // CHECK }
+
+// CHECK-LABEL: mlfunc @loops2() {
+mlfunc @loops2() {
+  // CHECK: %c0_i32 = constant 0 : i32
+  // CHECK-NEXT: %c1_i32 = constant 1 : i32
+  // CHECK-NEXT: %c2_i32 = constant 2 : i32
+  // CHECK-NEXT: %c3_i32 = constant 3 : i32
+  // CHECK-NEXT: %c0_i32_0 = constant 0 : i32
+  // CHECK-NEXT: %c1_i32_1 = constant 1 : i32
+  // CHECK-NEXT: %c2_i32_2 = constant 2 : i32
+  // CHECK-NEXT: %c3_i32_3 = constant 3 : i32
+  // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
+  for %i = 1 to 100 step 2 {
+     // CHECK: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32_0)
+    // CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32_1)
+    // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_i32_2)
+    // CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_i32_3)
+    for %j = 1 to 4 {
+      %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
+        (affineint) -> (affineint)
+    }
+  }    // CHECK:  }
+
+  // CHECK: %c99 = constant 99 : affineint
+  %k = "constant"(){value: 99} : () -> affineint
+  // CHECK: for %i1 = 1 to 100 step 2 {
+  for %m = 1 to 100 step 2 {
+    // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
+    // CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c0_i32)[%c99]
+    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
+    // CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1_i32)[%c99]
+    // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
+    // CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2_i32)[%c99]
+    // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
+    // CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3_i32)[%c99]
+    for %n = 1 to 4 {
+      %y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
+        (affineint) -> (affineint)
+      %z = "affine_apply" (%n, %k) { map: (d0) [s0] -> (d0 + s0 + 1) } :
+        (affineint, affineint) -> (affineint)
+    }     // CHECK }
+  }       // CHECK }
+  return  // CHECK:  return
+}         // CHECK }