Loop unrolling pass update

- fix/complete forStmt cloning for unrolling to work for outer loops
- create IV const's only when needed
- test outer loop unrolling by creating a short trip count unroll pass for
  loops with trip counts <= <parameter>
- add unrolling test cases for multiple op results, outer loop unrolling
- fix/clean up StmtWalker class while on this
- switch unroll loop iterator values from i32 to affineint

PiperOrigin-RevId: 207645967
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 27bb43f..eea3bf7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -39,14 +39,22 @@
   void runOnMLFunction(MLFunction *f) override;
   void runOnForStmt(ForStmt *forStmt);
 };
+struct ShortLoopUnroll : public LoopUnroll {
+  const unsigned minTripCount;
+  void runOnMLFunction(MLFunction *f) override;
+  ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+};
 } // end anonymous namespace
 
 MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
 
+MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
+  return new ShortLoopUnroll(minTripCount);
+}
+
 /// Unrolls all the innermost loops of this MLFunction.
 void LoopUnroll::runOnMLFunction(MLFunction *f) {
   // Gathers all innermost loops through a post order pruned walk.
-  // TODO: figure out the right reusable template here to better refactor code.
   class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
   public:
     // Store innermost loops as we walk.
@@ -65,11 +73,6 @@
       return hasInnerLoops;
     }
 
-    // FIXME: can't use base class method for this because that in turn would
-    // need to use the derived class method above. CRTP doesn't allow it, and
-    // the compiler error resulting from it is also very misleading!
-    void walkPostOrder(MLFunction *f) { walkPostOrder(f->begin(), f->end()); }
-
     bool walkForStmtPostOrder(ForStmt *forStmt) {
       bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
       if (!hasInnerLoops)
@@ -85,8 +88,11 @@
       return hasInnerLoops;
     }
 
-    bool walkOpStmt(OperationStmt *opStmt) { return false; }
+    bool visitOperationStmt(OperationStmt *opStmt) { return false; }
 
+    // FIXME: can't use base class method for this because that in turn would
+    // need to use the derived class method above. CRTP doesn't allow it, and
+    // the compiler error resulting from it is also misleading.
     using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
   };
 
@@ -97,28 +103,96 @@
     runOnForStmt(forStmt);
 }
 
-/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'
-static void replaceAllStmtUses(Statement *stmt, MLValue *oldVal,
-                               MLValue *newVal) {
-  struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
-    // Value to be replaced.
-    MLValue *oldVal;
-    // Value to be replaced with.
-    MLValue *newVal;
+/// Unrolls all loops with trip count <= minTripCount.
+void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
+  // Gathers all loops with trip count <= minTripCount.
+  class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
+  public:
+    // Store short loops as we walk.
+    std::vector<ForStmt *> loops;
+    const unsigned minTripCount;
+    ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
 
-    ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
-        : oldVal(oldVal), newVal(newVal){};
+    void visitForStmt(ForStmt *forStmt) {
+      auto lb = forStmt->getLowerBound()->getValue();
+      auto ub = forStmt->getUpperBound()->getValue();
+      auto step = forStmt->getStep()->getValue();
 
-    void visitOperationStmt(OperationStmt *os) {
-      for (auto &operand : os->getStmtOperands()) {
-        if (operand.get() == oldVal)
-          operand.set(newVal);
-      }
+      if ((ub - lb) / step + 1 <= minTripCount)
+        loops.push_back(forStmt);
     }
   };
 
-  ReplaceUseWalker ri(oldVal, newVal);
-  ri.walk(stmt);
+  ShortLoopGatherer slg(minTripCount);
+  slg.walk(f);
+  auto &loops = slg.loops;
+  for (auto *forStmt : loops)
+    runOnForStmt(forStmt);
+}
+
+/// Replace all uses of oldVal with newVal from begin to end.
+static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
+                        MLValue *oldVal, MLValue *newVal) {
+  // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+  // of oldVal that fall within this list of statements (instead of iterating
+  // through all statements / through all operands of operations found).
+  for (auto it = begin; it != end; it++) {
+    it->replaceUses(oldVal, newVal);
+  }
+}
+
+/// Replace all uses of oldVal with newVal.
+void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
+  // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+  // of oldVal that fall within this StmtBlock (instead of iterating through
+  // all statements / through all operands of operations found).
+  for (auto it = block->begin(); it != block->end(); it++) {
+    it->replaceUses(oldVal, newVal);
+  }
+}
+
+/// Clone the list of stmt's from 'block' and insert into the current
+/// position of the builder.
+// TODO(bondhugula,clattner): replace this with a parameterizable clone.
+void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
+  // Pairs of <old op stmt result whose uses need to be replaced,
+  // new result generated by the corresponding cloned op stmt>.
+  SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+
+  // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+  StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
+
+  for (auto &stmt : block.getStatements()) {
+    auto *cloneStmt = builder->clone(stmt);
+    // Whenever we have an op stmt, we'll have a new ML Value defined: replace
+    // uses of the old result with this one.
+    if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+      if (opStmt->getNumResults()) {
+        auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+        for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+          // Store old/new result pairs.
+          // TODO(bondhugula) *only* if needed later: storing of old/new
+          // results can be avoided by cloning the statement list in the
+          // reverse direction (and running the IR builder in the reverse
+          // (iplist.insertAfter()). That way, a newly created result can be
+          // immediately propagated to all its uses.
+          oldNewResultPairs.push_back(std::make_pair(
+              const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
+              &cloneOpStmt->getStmtResult(i)));
+        }
+      }
+    }
+  }
+
+  // Replace uses of old op results' with the new results.
+  StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+  StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
+
+  // Replace uses of old op results' with the newly created ones.
+  for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
+    replaceUses(startOfUnrolledBody, endOfUnrolledBody,
+                oldNewResultPairs[i].first, oldNewResultPairs[i].second);
+  }
 }
 
 /// Unroll this 'for stmt' / loop completely.
@@ -139,52 +213,25 @@
                             ++StmtBlock::iterator(forStmt));
 
   // Unroll the contents of 'forStmt'.
-  for (int i = lb; i <= ub; i += step) {
-    // TODO(bondhugula): generate constants only when IV actually appears.
-    auto constOp = funcTopBuilder.create<ConstantIntOp>(i, 32);
-    auto *ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
-
-    // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+  for (int64_t i = lb; i <= ub; i += step) {
+    MLValue *ivConst = nullptr;
+    if (!forStmt->use_empty()) {
+      auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
+      ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+    }
     StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
 
-    // Pairs of <old op stmt result whose uses need to be replaced,
-    // new result generated by the corresponding cloned op stmt>.
-    SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+    // Clone the loop body and insert it right after the loop - the latter will
+    // be erased after all unrolling has been done.
+    cloneStmtListFromBlock(&builder, *forStmt);
 
-    for (auto &loopBodyStmt : forStmt->getStatements()) {
-      auto *cloneStmt = builder.clone(loopBodyStmt);
-      // Replace all uses of the IV in the clone with constant iteration value.
-      replaceAllStmtUses(cloneStmt, forStmt, ivConst);
-
-      // Whenever we have an op stmt, we'll have a new ML Value defined: replace
-      // uses of the old result with this one.
-      if (auto *opStmt = dyn_cast<OperationStmt>(&loopBodyStmt)) {
-        if (opStmt->getNumResults()) {
-          auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
-          for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
-            // Store old/new result pairs.
-            // TODO *only* if needed later: storing of old/new results can be
-            // avoided, by cloning the statement list in the reverse direction
-            // (and running the IR builder in the reverse
-            // (iplist.insertAfter()). That way, a newly created result can be
-            // immediately propagated to all its uses, which would already  been
-            // cloned/inserted.
-            oldNewResultPairs.push_back(std::make_pair(
-                &opStmt->getStmtResult(i), &cloneOpStmt->getStmtResult(i)));
-          }
-        }
-      }
-    }
-    // Replace uses of old op results' with the results in the just
-    // unrolled body.
-    StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
-    for (auto it = ++beforeUnrolledBody; it != endOfUnrolledBody; it++) {
-      for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
-        replaceAllStmtUses(&(*it), oldNewResultPairs[i].first,
-                           oldNewResultPairs[i].second);
-      }
+    // Replace unrolled loop IV with the unrolled constant.
+    if (ivConst) {
+      StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+      StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
+      replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
     }
   }
-  // Erase the original for stmt from the block.
+  // Erase the original 'for' stmt from the block.
   forStmt->eraseFromBlock();
 }