Extend loop unrolling to unroll by a given factor; add builder for affine
apply op.

- add builder for AffineApplyOp (first one for an operation that has
  non-zero operands)
- add support for loop unrolling by a given factor; uses the affine apply op
  builder.

While on this, change 'step' of ForStmt to be 'unsigned' instead of
AffineConstantExpr *. Add setters for ForStmt lb, ub, step.

Sample Input:

// CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() {
mlfunc @loop_nest_unroll_cleanup() {
  for %i = 1 to 100 {
    for %j = 0 to 17 {
      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
      %y = "addi32"(%x, %x) : (i32, i32) -> i32
    }
  }
  return
}

Output:

$ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir
#map0 = (d0) -> (d0 + 1)
#map1 = (d0) -> (d0 + 2)
#map2 = (d0) -> (d0 + 3)
mlfunc @loop_nest_unroll_cleanup() {
  for %i0 = 1 to 100 {
    for %i1 = 0 to 17 step 4 {
      %0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32
      %1 = "addi32"(%0, %0) : (i32, i32) -> i32
      %2 = affine_apply #map0(%i1)
      %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
      %4 = affine_apply #map1(%i1)
      %5 = "addi32"(%4, %4) : (affineint, affineint) -> i32
      %6 = affine_apply #map2(%i1)
      %7 = "addi32"(%6, %6) : (affineint, affineint) -> i32
    }
    for %i2 = 16 to 17 {
      %8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
      %9 = "addi32"(%8, %8) : (i32, i32) -> i32
    }
  }
  return
}

PiperOrigin-RevId: 209676220
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 0f6d428..a4a11a7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -31,30 +31,49 @@
 #include "mlir/Transforms/Pass.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
+using namespace llvm;
+
+// Loop unrolling factor.
+static llvm::cl::opt<unsigned>
+    clUnrollFactor("unroll-factor", cl::Hidden,
+                   cl::desc("Use this unroll factor for all loops"));
+
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
+                                        cl::desc("Fully unroll loops"));
+
+static llvm::cl::opt<unsigned> clUnrollFullThreshold(
+    "unroll-full-threshold", cl::Hidden,
+    cl::desc("Unroll all loops with trip count less than or equal to this"));
 
 namespace {
-/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
-/// MLFunction.
+/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
+/// full unroll threshold was specified, in which case, fully unrolls all loops
+/// with trip count less than the specified threshold. The latter is for testing
+/// purposes, especially for testing outer loop unrolling.
 struct LoopUnroll : public MLFunctionPass {
-  void runOnMLFunction(MLFunction *f) override;
-  void runOnForStmt(ForStmt *forStmt);
-};
+  Optional<unsigned> unrollFactor;
+  Optional<bool> unrollFull;
 
-/// Unrolls all loops with trip count <= minTripCount.
-struct ShortLoopUnroll : public LoopUnroll {
-  const unsigned minTripCount;
+  explicit LoopUnroll(Optional<unsigned> unrollFactor,
+                      Optional<bool> unrollFull)
+      : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+
   void runOnMLFunction(MLFunction *f) override;
-  ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+  /// Unroll this for stmt. Returns false if nothing was done.
+  bool runOnForStmt(ForStmt *forStmt);
+  bool loopUnrollFull(ForStmt *forStmt);
+  bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
 };
 } // end anonymous namespace
 
-MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-
-MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
-  return new ShortLoopUnroll(minTripCount);
+MLFunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
+  return new LoopUnroll(unrollFactor == -1 ? None
+                                           : Optional<unsigned>(unrollFactor),
+                        unrollFull == -1 ? None : Optional<bool>(unrollFull));
 }
 
 void LoopUnroll::runOnMLFunction(MLFunction *f) {
@@ -81,7 +100,6 @@
       bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
       if (!hasInnerLoops)
         loops.push_back(forStmt);
-
       return true;
     }
 
@@ -101,14 +119,6 @@
     using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
   };
 
-  InnermostLoopGatherer ilg;
-  ilg.walkPostOrder(f);
-  auto &loops = ilg.loops;
-  for (auto *forStmt : loops)
-    runOnForStmt(forStmt);
-}
-
-void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
   // Gathers all loops with trip count <= minTripCount.
   class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
   public:
@@ -120,27 +130,55 @@
     void visitForStmt(ForStmt *forStmt) {
       auto lb = forStmt->getLowerBound()->getValue();
       auto ub = forStmt->getUpperBound()->getValue();
-      auto step = forStmt->getStep()->getValue();
+      auto step = forStmt->getStep();
 
       if ((ub - lb) / step + 1 <= minTripCount)
         loops.push_back(forStmt);
     }
   };
 
-  ShortLoopGatherer slg(minTripCount);
-  // Do a post order walk so that loops are gathered from innermost to
-  // outermost (or else unrolling an outer one may delete gathered inner ones).
-  slg.walkPostOrder(f);
-  auto &loops = slg.loops;
+  if (clUnrollFull.getNumOccurrences() > 0 &&
+      clUnrollFullThreshold.getNumOccurrences() > 0) {
+    ShortLoopGatherer slg(clUnrollFullThreshold);
+    // Do a post order walk so that loops are gathered from innermost to
+    // outermost (or else unrolling an outer one may delete gathered inner
+    // ones).
+    slg.walkPostOrder(f);
+    auto &loops = slg.loops;
+    for (auto *forStmt : loops)
+      loopUnrollFull(forStmt);
+    return;
+  }
+
+  InnermostLoopGatherer ilg;
+  ilg.walkPostOrder(f);
+  auto &loops = ilg.loops;
   for (auto *forStmt : loops)
     runOnForStmt(forStmt);
 }
 
-/// Unroll this For loop completely.
-void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+/// Unroll a for stmt. Default unroll factor is 4.
+bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+  // Unroll completely if full loop unroll was specified.
+  if (clUnrollFull.getNumOccurrences() > 0 ||
+      (unrollFull.hasValue() && unrollFull.getValue()))
+    return loopUnrollFull(forStmt);
+
+  // Unroll by the specified factor if one was specified.
+  if (clUnrollFactor.getNumOccurrences() > 0)
+    return loopUnrollByFactor(forStmt, clUnrollFactor);
+  else if (unrollFactor.hasValue())
+    return loopUnrollByFactor(forStmt, unrollFactor.getValue());
+
+  // Unroll by four otherwise.
+  return loopUnrollByFactor(forStmt, 4);
+}
+
+// Unrolls this loop completely.
+bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
   auto lb = forStmt->getLowerBound()->getValue();
   auto ub = forStmt->getUpperBound()->getValue();
-  auto step = forStmt->getStep()->getValue();
+  auto step = forStmt->getStep();
 
   // Builder to add constants need for the unrolled iterator.
   auto *mlFunc = forStmt->findFunction();
@@ -164,9 +202,75 @@
 
     // Clone the body of the loop.
     for (auto &childStmt : *forStmt) {
-      (void)builder.clone(childStmt, operandMapping);
+      builder.clone(childStmt, operandMapping);
     }
   }
   // Erase the original 'for' stmt from the block.
   forStmt->eraseFromBlock();
+  return true;
+}
+
+/// Unrolls this loop by the specified unroll factor.
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+  assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
+
+  if (unrollFactor == 1 || forStmt->getStatements().empty())
+    return false;
+
+  auto lb = forStmt->getLowerBound()->getValue();
+  auto ub = forStmt->getUpperBound()->getValue();
+  auto step = forStmt->getStep();
+
+  int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+
+  // If the trip count is lower than the unroll factor, no unrolled body.
+  // TODO(bondhugula): option to specify cleanup loop unrolling.
+  if (tripCount < unrollFactor)
+    return true;
+
+  // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+  if (tripCount % unrollFactor) {
+    DenseMap<const MLValue *, MLValue *> operandMap;
+    MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
+    auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
+    cleanupForStmt->setLowerBound(builder.getConstantExpr(
+        lb + (tripCount - tripCount % unrollFactor) * step));
+  }
+
+  // Builder to insert unrolled bodies right after the last statement in the
+  // body of 'forStmt'.
+  MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+  forStmt->setStep(step * unrollFactor);
+  forStmt->setUpperBound(builder.getConstantExpr(
+      lb + (tripCount - tripCount % unrollFactor - 1) * step));
+
+  // Keep a pointer to the last statement in the original block so that we know
+  // what to clone (since we are doing this in-place).
+  StmtBlock::iterator srcBlockEnd = --forStmt->end();
+
+  // Unroll the contents of 'forStmt' (unrollFactor-1 additional copies
+  // appended).
+  for (unsigned i = 1; i < unrollFactor; i++) {
+    DenseMap<const MLValue *, MLValue *> operandMapping;
+
+    // If the induction variable is used, create a remapping to the value for
+    // this unrolled instance.
+    if (!forStmt->use_empty()) {
+      // iv' = iv + 1/2/3...unrollFactor-1;
+      auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0),
+                                          builder.getConstantExpr(i * step));
+      auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
+      auto *ivUnroll =
+          builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
+      operandMapping[forStmt] = cast<MLValue>(ivUnroll);
+    }
+
+    // Clone the original body of the loop (this doesn't include the last stmt).
+    for (auto it = forStmt->begin(); it != srcBlockEnd; it++) {
+      builder.clone(*it, operandMapping);
+    }
+    // Clone the last statement in the original body.
+    builder.clone(*srcBlockEnd, operandMapping);
+  }
+  return true;
 }