Stmt visitors and walkers.

- Update InnermostLoopGatherer to use a post order traversal (linear
  time/single traversal).
- Drop getNumNestedLoops().
- Update isInnermost() to use the StmtWalker.

When using return values in conjunction with walkers, the StmtWalker CRTP
pattern doesn't appear to be of any use. It just requires overriding nearly all
of the methods, which is what InnermostLoopGatherer currently does. Please see
FIXME/ENLIGHTENME comments. TODO: figure this out from this CL discussion.

Note
- Comments on visitor/walker base class are out of date; will update when this
  CL is finalized.

PiperOrigin-RevId: 206340901
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 70367e9..9592ef7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -28,6 +28,7 @@
 #include "mlir/IR/StmtVisitor.h"
 #include "mlir/Pass.h"
 #include "mlir/Transforms/Loop.h"
+#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
 
@@ -53,19 +54,50 @@
 
 /// Unrolls all the innermost loops of this MLFunction.
 bool LoopUnroll::runOnMLFunction(MLFunction *f) {
-  // Gathers all innermost loops. TODO: change the visitor to post order to make
-  // this linear time / single traversal.
-  struct InnermostLoopGatherer : public StmtVisitor<InnermostLoopGatherer> {
+  // Gathers all innermost loops through a post order pruned walk.
+  // TODO: figure out the right reusable template here to better refactor code.
+  class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
+  public:
+    // Store innermost loops as we walk.
     std::vector<ForStmt *> loops;
-    InnermostLoopGatherer() {}
-    void visitForStmt(ForStmt *fs) {
-      if (fs->isInnermost())
-        loops.push_back(fs);
+
+    // This method specialized to encode custom return logic.
+    typedef llvm::iplist<Statement> StmtListType;
+    bool walk(StmtListType::iterator Start, StmtListType::iterator End) {
+      while (Start != End)
+        if (walk(&(*Start++)))
+          return true;
+      return false;
     }
+
+    // FIXME: can't use base class method for this because that in turn would
+    // need to use the derived class method above. CRTP doesn't allow it, and
+    // the compiler error resulting from it is also very misleading!
+    void walkMLFunction(MLFunction *f) { walk(f->begin(), f->end()); }
+
+    bool walkForStmt(ForStmt *forStmt) {
+      bool hasInnerLoops = walk(forStmt->begin(), forStmt->end());
+      if (!hasInnerLoops)
+        loops.push_back(forStmt);
+      return true;
+    }
+
+    bool walkIfStmt(IfStmt *ifStmt) {
+      if (walk(ifStmt->getThenClause()->begin(),
+               ifStmt->getThenClause()->end()) ||
+          walk(ifStmt->getElseClause()->begin(),
+               ifStmt->getElseClause()->end()))
+        return true;
+      return false;
+    }
+
+    bool walkOpStmt(OperationStmt *opStmt) { return false; }
+
+    using StmtWalker<InnermostLoopGatherer, bool>::walk;
   };
 
   InnermostLoopGatherer ilg;
-  ilg.visit(f);
+  ilg.walkMLFunction(f);
   auto &loops = ilg.loops;
   bool changed = false;
   for (auto *forStmt : loops)