Sketch out loop unrolling transformation.

- Implement a full loop unroll for innermost loops.
- Use it to implement a pass that unroll all the innermost loops of all
  mlfunction's in a module. ForStmt's parsed currently have constant trip
  counts (and constant loop bounds).
- Implement StmtVisitor based (Visitor pattern)

Loop IVs aren't currently parsed and represented as SSA values. Replacing uses
of loop IVs in unrolled bodies is thus a TODO. Class comments are sparse at some places - will add them after one round of comments.

A cmd-line flag triggers this for now.

Original:

mlfunc @loops() {
  for x = 1 to 100 step 2 {
    for x = 1 to 4 {
      "Const"(){value: 1} : () -> ()
    }
  }
  return
}

After unrolling:

mlfunc @loops() {
  for x = 1 to 100 step 2 {
    "Const"(){value: 1} : () -> ()
    "Const"(){value: 1} : () -> ()
    "Const"(){value: 1} : () -> ()
    "Const"(){value: 1} : () -> ()
  }
  return
}

PiperOrigin-RevId: 205933235
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index 5fae348..0df9ea1 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -27,15 +27,15 @@
 #include "llvm/ADT/ilist_node.h"
 
 namespace mlir {
-  class MLFunction;
-  class StmtBlock;
+class MLFunction;
+class StmtBlock;
+class ForStmt;
 
 /// Statement is a basic unit of execution within an ML function.
 /// Statements can be nested within for and if statements effectively
 /// forming a tree. Statements are organized into statement blocks
 /// represented by StmtBlock class.
-class Statement
-  : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
+class Statement : public llvm::ilist_node_with_parent<Statement, StmtBlock> {
 public:
   enum class Kind {
     Operation,
@@ -45,12 +45,19 @@
 
   Kind getKind() const { return kind; }
 
+  /// Remove this statement from its block and delete it.
+  void eraseFromBlock();
+
   /// Returns the statement block that contains this statement.
   StmtBlock *getBlock() const { return block; }
 
   /// Returns the function that this statement is part of.
   MLFunction *getFunction() const;
 
+  /// Returns the number of nested loops starting from (i.e., inclusive of) this
+  /// statement.
+  unsigned getNumNestedLoops() const;
+
   /// Destroys this statement and its subclass data.
   void destroy();
 
@@ -63,8 +70,6 @@
   // does not have a virtual destructor.
   ~Statement();
 
-  /// Remove this statement from its block and delete it.
-  void eraseFromBlock();
 private:
   Kind kind;
   StmtBlock *block = nullptr;
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 20099e4..4a2dd33 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -75,6 +75,9 @@
     return block->getStmtBlockKind() == StmtBlockKind::For;
   }
 
+  /// Returns true if there are no more for stmt's nested under this for stmt.
+  bool isInnermost() const { return 1 == getNumNestedLoops(); }
+
 private:
   AffineConstantExpr *lowerBound;
   AffineConstantExpr *upperBound;
diff --git a/include/mlir/IR/StmtVisitor.h b/include/mlir/IR/StmtVisitor.h
new file mode 100644
index 0000000..4cb3a5c
--- /dev/null
+++ b/include/mlir/IR/StmtVisitor.h
@@ -0,0 +1,140 @@
+//===- StmtVisitor.h - MLIR Instruction Visitor Class -----------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines the Statement visitor class.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_IR_STMTVISITOR_H
+#define MLIR_IR_STMTVISITOR_H
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Statements.h"
+
+namespace mlir {
+
+/// Base class for statement visitors.
+///
+/// Statement visitors are used when you want to perform different actions
+/// for different kinds of statements without having to use lots of casts
+/// and a big switch statement.
+///
+/// To define your own visitor, inherit from this class, specifying your
+/// new type for the 'SubClass' template parameter, and "override" visitXXX
+/// functions in your class. This class is defined in terms of statically
+/// resolved overloading, not virtual functions.
+///
+/// For example, here is a visitor that counts the number of for loops in an
+/// MLFunction.
+///
+///  /// Declare the class.  Note that we derive from StmtVisitor instantiated
+///  /// with _our new subclasses_ type.
+///  struct LoopCounter : public StmtVisitor<LoopCounter> {
+///    unsigned numLoops;
+///    LoopCounter() : numLoops(0) {}
+///    void visitForStmt(ForStmt &fs) { ++numLoops; }
+///  };
+///
+///  And this class would be used like this:
+///    LoopCounter lc;
+///    lc.visit(function);
+///    numLoops = lc.numLoops;
+///
+/// There  are 'visit' methods for Operation, ForStmt, IfStmt, and
+/// MLFunction, which recursively process all contained statements.
+///
+/// Note that if you don't implement visitXXX for some statement type,
+/// the visitXXX method for Statement superclass will be invoked.
+///
+/// The optional second template argument specifies the type that statement
+/// visitation functions should return. If you specify this, you *MUST* provide
+/// an implementation of visitStatement.
+///
+/// Note that this class is specifically designed as a template to avoid
+/// virtual function call overhead.  Defining and using a StmtVisitor is just
+/// as efficient as having your own switch statement over the statement
+/// opcode.
+template <typename SubClass> class StmtVisitor {
+  //===--------------------------------------------------------------------===//
+  // Interface code - This is the public interface of the StmtVisitor that you
+  // use to visit statements...
+
+public:
+  // Generic visit method - allow visitation to all statements in a range.
+  template <class Iterator> void visit(Iterator Start, Iterator End) {
+    while (Start != End) {
+      static_cast<SubClass *>(this)->visit(&(*Start++));
+    }
+  }
+
+  // Define visitors for MLFunction and all MLFunction statement kinds.
+  void visit(MLFunction *f) {
+    static_cast<SubClass *>(this)->visitMLFunction(f);
+    visit(f->begin(), f->end());
+  }
+
+  void visit(OperationStmt *opStmt) {
+    static_cast<SubClass *>(this)->visitOperationStmt(opStmt);
+  }
+
+  void visit(ForStmt *forStmt) {
+    static_cast<SubClass *>(this)->visitForStmt(forStmt);
+    visit(forStmt->begin(), forStmt->end());
+  }
+
+  void visit(IfStmt *ifStmt) {
+    static_cast<SubClass *>(this)->visitIfStmt(ifStmt);
+    visit(ifStmt->getThenClause()->begin(), ifStmt->getThenClause()->end());
+    visit(ifStmt->getElseClause()->begin(), ifStmt->getElseClause()->end());
+  }
+
+  // Function to visit a statement.
+  void visit(Statement *s) {
+    static_assert(std::is_base_of<StmtVisitor, SubClass>::value,
+                  "Must pass the derived type to this template!");
+
+    switch (s->getKind()) {
+    default:
+      llvm_unreachable("Unknown statement type encountered!");
+    case Statement::Kind::For:
+      return static_cast<SubClass *>(this)->visit(cast<ForStmt>(s));
+    case Statement::Kind::If:
+      return static_cast<SubClass *>(this)->visit(cast<IfStmt>(s));
+    case Statement::Kind::Operation:
+      return static_cast<SubClass *>(this)->visit(cast<OperationStmt>(s));
+    }
+  }
+
+  //===--------------------------------------------------------------------===//
+  // Visitation functions... these functions provide default fallbacks in case
+  // the user does not specify what to do for a particular statement type.
+  // The default behavior is to generalize the statement type to its subtype
+  // and try visiting the subtype.  All of this should be inlined perfectly,
+  // because there are no virtual functions to get in the way.
+  //
+
+  // When visiting a for stmt, if stmt, or an operation stmt directly, these
+  // methods get called to indicate when transitioning into a new unit.
+  void visitForStmt(ForStmt *forStmt) {}
+  void visitIfStmt(IfStmt *ifStmt) {}
+  void visitOperationStmt(OperationStmt *opStmt) {}
+  void visitMLFunction(MLFunction *f) {}
+};
+
+} // end namespace mlir
+
+#endif // MLIR_IR_STMTVISITOR_H
diff --git a/include/mlir/Pass.h b/include/mlir/Pass.h
new file mode 100644
index 0000000..f93061b
--- /dev/null
+++ b/include/mlir/Pass.h
@@ -0,0 +1,47 @@
+//===- mlir/Pass.h - Base class for passes ----------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file defines a base class that indicates that a specified class is a
+// transformation pass implementation.
+//
+//===----------------------------------------------------------------------===//
+#ifndef MLIR_PASS_H
+#define MLIR_PASS_H
+
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
+
+namespace mlir {
+
+class Pass {
+protected:
+  virtual ~Pass() = default;
+};
+
+class FunctionPass : public Pass {};
+
+class CFGFunctionPass : public FunctionPass {};
+
+class MLFunctionPass : public FunctionPass {
+public:
+  virtual bool runOnMLFunction(MLFunction *f) = 0;
+  virtual bool runOnModule(Module *m);
+};
+
+} // end namespace mlir
+
+#endif // MLIR_PASS_H
diff --git a/include/mlir/Transforms/Loop.h b/include/mlir/Transforms/Loop.h
new file mode 100644
index 0000000..cff6609
--- /dev/null
+++ b/include/mlir/Transforms/Loop.h
@@ -0,0 +1,35 @@
+//===- Loop.h - Loop Transformations ----------------------------*- C++ -*-===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This header file defines prototypes that expose passes in the loop
+// transformation library.
+//
+//===----------------------------------------------------------------------===//
+
+#ifndef MLIR_TRANSFORMS_LOOP_H
+#define MLIR_TRANSFORMS_LOOP_H
+
+namespace mlir {
+
+class MLFunctionPass;
+
+/// A loop unrolling pass.
+MLFunctionPass *createLoopUnrollPass();
+
+} // end namespace mlir
+
+#endif // MLIR_TRANSFORMS_LOOP_H
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index b469482..6b272e0 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -17,6 +17,7 @@
 
 #include "mlir/IR/MLFunction.h"
 #include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
 using namespace mlir;
 
 //===----------------------------------------------------------------------===//
@@ -48,6 +49,18 @@
   return this->getBlock()->getFunction();
 }
 
+unsigned Statement::getNumNestedLoops() const {
+  struct NestedLoopCounter : public StmtVisitor<NestedLoopCounter> {
+    unsigned numNestedLoops;
+    NestedLoopCounter() : numNestedLoops(0) {}
+    void visitForStmt(const ForStmt *fs) { numNestedLoops++; }
+  };
+
+  NestedLoopCounter nlc;
+  nlc.visit(const_cast<Statement *>(this));
+  return nlc.numNestedLoops;
+}
+
 //===----------------------------------------------------------------------===//
 // ilist_traits for Statement
 //===----------------------------------------------------------------------===//
@@ -91,7 +104,9 @@
     first->block = curParent;
 }
 
-/// Remove this statement from its StmtBlock and delete it.
+/// Remove this statement (and its descendants) from its StmtBlock and delete
+/// all of them.
+/// TODO: erase all descendents for ForStmt/IfStmt.
 void Statement::eraseFromBlock() {
   assert(getBlock() && "Statement has no block");
   getBlock()->getStatements().erase(this);
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
new file mode 100644
index 0000000..af0c7ca
--- /dev/null
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -0,0 +1,109 @@
+//===- Unroll.cpp - Code to perform loop unrolling    ---------------------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements loop unrolling.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/CFGFunction.h"
+#include "mlir/IR/MLFunction.h"
+#include "mlir/IR/Module.h"
+#include "mlir/IR/OperationSet.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/Pass.h"
+#include "mlir/Transforms/Loop.h"
+
+using namespace mlir;
+
+namespace {
+struct LoopUnroll : public MLFunctionPass {
+  bool runOnMLFunction(MLFunction *f);
+  bool runOnForStmt(ForStmt *forStmt);
+  bool runLoopUnroll(MLFunction *f);
+};
+} // namespace
+
+MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
+
+/// Unrolls all the innermost loops of this Module.
+bool MLFunctionPass::runOnModule(Module *m) {
+  bool changed = false;
+  for (auto fn : m->functionList) {
+    if (auto *mlFunc = dyn_cast<MLFunction>(fn))
+      changed |= runOnMLFunction(mlFunc);
+  }
+  return changed;
+}
+
+/// Unrolls all the innermost loops of this MLFunction.
+bool LoopUnroll::runOnMLFunction(MLFunction *f) {
+  // Gathers all innermost loops. TODO: change the visitor to post order to make
+  // this linear time / single traversal.
+  struct InnermostLoopGatherer : public StmtVisitor<InnermostLoopGatherer> {
+    std::vector<ForStmt *> loops;
+    InnermostLoopGatherer() {}
+    void visitForStmt(ForStmt *fs) {
+      if (fs->isInnermost())
+        loops.push_back(fs);
+    }
+  };
+
+  InnermostLoopGatherer ilg;
+  ilg.visit(f);
+  auto &loops = ilg.loops;
+  bool changed = false;
+  for (auto *forStmt : loops)
+    changed |= runOnForStmt(forStmt);
+  return changed;
+}
+
+/// Unrolls this loop completely. Returns true if the unrolling happens.
+bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+  auto lb = forStmt->getLowerBound()->getValue();
+  auto ub = forStmt->getUpperBound()->getValue();
+  auto step = forStmt->getStep()->getValue();
+  auto trip_count = (ub - lb + 1) / step;
+
+  auto *block = forStmt->getBlock();
+
+  MLFuncBuilder builder(forStmt->Statement::getFunction());
+  builder.setInsertionPoint(block);
+
+  for (int i = 0; i < trip_count; i++) {
+    for (auto &stmt : forStmt->getStatements()) {
+      switch (stmt.getKind()) {
+      case Statement::Kind::For:
+        llvm_unreachable("unrolling loops that have only operations");
+        break;
+      case Statement::Kind::If:
+        llvm_unreachable("unrolling loops that have only operations");
+        break;
+      case Statement::Kind::Operation:
+        auto *op = cast<OperationStmt>(&stmt);
+        builder.createOperation(op->getName(), op->getAttrs());
+        // TODO: loop iterator parsing not yet implemented; replace loop
+        // iterator uses in unrolled body appropriately.
+        break;
+      }
+    }
+  }
+
+  forStmt->eraseFromBlock();
+  return true;
+}
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
new file mode 100644
index 0000000..3c42142
--- /dev/null
+++ b/test/Transforms/unroll.mlir
@@ -0,0 +1,16 @@
+// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
+
+// CHECK-LABEL: mlfunc @loops() {
+mlfunc @loops() {
+  // CHECK: for x = 1 to 100 step 2 {
+  for %i = 1 to 100 step 2 {
+    // CHECK: "custom"(){value: 1} : () -> ()
+    // CHECK-NEXT: "custom"(){value: 1} : () -> ()
+    // CHECK-NEXT: "custom"(){value: 1} : () -> ()
+    // CHECK-NEXT: "custom"(){value: 1} : () -> ()
+    for %j = 1 to 4 {
+      "custom"(){value: 1} : () -> f32
+    }
+  }       // CHECK:  }
+  return  // CHECK:  return
+}         // CHECK }
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index f1ccddd..4f62b28 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -21,10 +21,13 @@
 //
 //===----------------------------------------------------------------------===//
 
+#include "mlir/IR/MLFunction.h"
 #include "mlir/IR/MLIRContext.h"
 #include "mlir/IR/Module.h"
 #include "mlir/Parser.h"
+#include "mlir/Pass.h"
 #include "mlir/Transforms/ConvertToCFG.h"
+#include "mlir/Transforms/Loop.h"
 #include "llvm/Support/CommandLine.h"
 #include "llvm/Support/FileUtilities.h"
 #include "llvm/Support/InitLLVM.h"
@@ -49,6 +52,10 @@
     "convert-to-cfg",
     cl::desc("Convert all ML functions in the module to CFG ones"));
 
+static cl::opt<bool> unrollInnermostLoops("unroll-innermost-loops",
+                                          cl::desc("Unroll innermost loops"),
+                                          cl::init(false));
+
 enum OptResult { OptSuccess, OptFailure };
 
 /// Open the specified output file and return it, exiting if there is any I/O or
@@ -83,6 +90,11 @@
   if (convertToCFGOpt)
     convertToCFG(module.get());
 
+  if (unrollInnermostLoops) {
+    MLFunctionPass *loopUnroll = createLoopUnrollPass();
+    loopUnroll->runOnModule(module.get());
+  }
+
   // Print the output.
   auto output = getOutputStream();
   module->print(output->os());