Extend loop unrolling to unroll by a given factor; add builder for affine
apply op.

- add builder for AffineApplyOp (first one for an operation that has
  non-zero operands)
- add support for loop unrolling by a given factor; uses the affine apply op
  builder.

While on this, change 'step' of ForStmt to be 'unsigned' instead of
AffineConstantExpr *. Add setters for ForStmt lb, ub, step.

Sample Input:

// CHECK-LABEL: mlfunc @loop_nest_unroll_cleanup() {
mlfunc @loop_nest_unroll_cleanup() {
  for %i = 1 to 100 {
    for %j = 0 to 17 {
      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
      %y = "addi32"(%x, %x) : (i32, i32) -> i32
    }
  }
  return
}

Output:

$ mlir-opt -loop-unroll -unroll-factor=4 /tmp/single2.mlir
#map0 = (d0) -> (d0 + 1)
#map1 = (d0) -> (d0 + 2)
#map2 = (d0) -> (d0 + 3)
mlfunc @loop_nest_unroll_cleanup() {
  for %i0 = 1 to 100 {
    for %i1 = 0 to 17 step 4 {
      %0 = "addi32"(%i1, %i1) : (affineint, affineint) -> i32
      %1 = "addi32"(%0, %0) : (i32, i32) -> i32
      %2 = affine_apply #map0(%i1)
      %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
      %4 = affine_apply #map1(%i1)
      %5 = "addi32"(%4, %4) : (affineint, affineint) -> i32
      %6 = affine_apply #map2(%i1)
      %7 = "addi32"(%6, %6) : (affineint, affineint) -> i32
    }
    for %i2 = 16 to 17 {
      %8 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
      %9 = "addi32"(%8, %8) : (i32, i32) -> i32
    }
  }
  return
}

PiperOrigin-RevId: 209676220
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index 62c0dcb..e8d3894 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -252,7 +252,8 @@
     this->insertPoint = insertPoint;
   }
 
-  /// Set the insertion point to the specified operation.
+  /// Set the insertion point to the specified operation, which will cause
+  /// subsequent insertions to go right before it.
   void setInsertionPoint(Statement *stmt) {
     setInsertionPoint(stmt->getBlock(), StmtBlock::iterator(stmt));
   }
@@ -298,8 +299,7 @@
 
   // Creates for statement. When step is not specified, it is set to 1.
   ForStmt *createFor(AffineConstantExpr *lowerBound,
-                     AffineConstantExpr *upperBound,
-                     AffineConstantExpr *step = nullptr);
+                     AffineConstantExpr *upperBound, int64_t step = 1);
 
   IfStmt *createIf(IntegerSet *condition) {
     auto *stmt = new IfStmt(condition);
diff --git a/include/mlir/IR/Operation.h b/include/mlir/IR/Operation.h
index fed0d4e..af01a1d 100644
--- a/include/mlir/IR/Operation.h
+++ b/include/mlir/IR/Operation.h
@@ -47,6 +47,7 @@
 struct OperationState {
   Identifier name;
   SmallVector<SSAValue *, 4> operands;
+  /// Types of the results of this operation.
   SmallVector<Type *, 4> types;
   SmallVector<NamedAttribute, 4> attributes;
 
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 3b615cc..63921ca 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -77,6 +77,10 @@
 class AffineApplyOp : public OpBase<AffineApplyOp, OpTrait::VariadicOperands,
                                     OpTrait::VariadicResults> {
 public:
+  /// Builds an affine apply op with the specified map and operands.
+  static OperationState build(Builder *builder, AffineMap *map,
+                              ArrayRef<SSAValue *> operands);
+
   // Returns the affine map to be applied by this operation.
   AffineMap *getAffineMap() const {
     return getAttrOfType<AffineMapAttr>("map")->getValue();
@@ -163,6 +167,7 @@
 ///
 class ConstantFloatOp : public ConstantOp {
 public:
+  /// Builds a constant float op producing a float of the specified type.
   static OperationState build(Builder *builder, double value, FloatType *type);
 
   double getValue() const {
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index f851d9f..1a68aab 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -199,7 +199,7 @@
   // TODO: lower and upper bounds should be affine maps with
   // dimension and symbol use lists.
   explicit ForStmt(AffineConstantExpr *lowerBound,
-                   AffineConstantExpr *upperBound, AffineConstantExpr *step,
+                   AffineConstantExpr *upperBound, int64_t step,
                    MLIRContext *context);
 
   ~ForStmt() {
@@ -216,7 +216,11 @@
 
   AffineConstantExpr *getLowerBound() const { return lowerBound; }
   AffineConstantExpr *getUpperBound() const { return upperBound; }
-  AffineConstantExpr *getStep() const { return step; }
+  int64_t getStep() const { return step; }
+
+  void setLowerBound(AffineConstantExpr *lb) { lowerBound = lb; }
+  void setUpperBound(AffineConstantExpr *ub) { upperBound = ub; }
+  void setStep(unsigned s) { step = s; }
 
   using Statement::dump;
   using Statement::print;
@@ -242,7 +246,7 @@
   // an affinemap and its operands as AffineBound.
   AffineConstantExpr *lowerBound;
   AffineConstantExpr *upperBound;
-  AffineConstantExpr *step;
+  int64_t step;
 };
 
 /// An if clause represents statements contained within a then or an else clause
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index 3944c13..3f983d8 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -28,9 +28,9 @@
 class MLFunctionPass;
 class ModulePass;
 
-/// Loop unrolling passes.
-MLFunctionPass *createLoopUnrollPass();
-MLFunctionPass *createLoopUnrollPass(unsigned);
+// Loop unrolling passes.
+/// Creates a loop unrolling pass.
+MLFunctionPass *createLoopUnrollPass(int unrollFactor, int unrollFull);
 
 /// Replaces all ML functions in the module with equivalent CFG functions.
 /// Function references are appropriately patched to refer to the newly
diff --git a/lib/IR/AsmPrinter.cpp b/lib/IR/AsmPrinter.cpp
index 8303f57..6bcaad4 100644
--- a/lib/IR/AsmPrinter.cpp
+++ b/lib/IR/AsmPrinter.cpp
@@ -1316,8 +1316,8 @@
   printOperand(stmt);
   os << " = " << *stmt->getLowerBound();
   os << " to " << *stmt->getUpperBound();
-  if (stmt->getStep()->getValue() != 1)
-    os << " step " << *stmt->getStep();
+  if (stmt->getStep() != 1)
+    os << " step " << stmt->getStep();
 
   os << " {\n";
   print(static_cast<const StmtBlock *>(stmt));
diff --git a/lib/IR/Builders.cpp b/lib/IR/Builders.cpp
index 5fce94c..b1dce25 100644
--- a/lib/IR/Builders.cpp
+++ b/lib/IR/Builders.cpp
@@ -103,8 +103,8 @@
   return ArrayAttr::get(value, context);
 }
 
-AffineMapAttr *Builder::getAffineMapAttr(AffineMap *value) {
-  return AffineMapAttr::get(value, context);
+AffineMapAttr *Builder::getAffineMapAttr(AffineMap *map) {
+  return AffineMapAttr::get(map, context);
 }
 
 TypeAttr *Builder::getTypeAttr(Type *type) {
@@ -207,9 +207,7 @@
 
 ForStmt *MLFuncBuilder::createFor(AffineConstantExpr *lowerBound,
                                   AffineConstantExpr *upperBound,
-                                  AffineConstantExpr *step) {
-  if (!step)
-    step = getConstantExpr(1);
+                                  int64_t step) {
   auto *stmt = new ForStmt(lowerBound, upperBound, step, context);
   block->getStatements().insert(insertPoint, stmt);
   return stmt;
diff --git a/lib/IR/StandardOps.cpp b/lib/IR/StandardOps.cpp
index 6d8c366..cec23e0 100644
--- a/lib/IR/StandardOps.cpp
+++ b/lib/IR/StandardOps.cpp
@@ -305,6 +305,22 @@
 }
 
 //===----------------------------------------------------------------------===//
+// AffineApplyOp
+//===----------------------------------------------------------------------===//
+
+OperationState AffineApplyOp::build(Builder *builder, AffineMap *map,
+                                    ArrayRef<SSAValue *> operands) {
+  SmallVector<Type *, 4> resultTypes(map->getNumResults(),
+                                     builder->getAffineIntType());
+
+  OperationState result(
+      builder->getIdentifier("affine_apply"), operands, resultTypes,
+      {{builder->getIdentifier("map"), builder->getAffineMapAttr(map)}});
+
+  return result;
+}
+
+//===----------------------------------------------------------------------===//
 // DeallocOp
 //===----------------------------------------------------------------------===//
 
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index d20b45f..7da08c2 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -198,7 +198,7 @@
 //===----------------------------------------------------------------------===//
 
 ForStmt::ForStmt(AffineConstantExpr *lowerBound, AffineConstantExpr *upperBound,
-                 AffineConstantExpr *step, MLIRContext *context)
+                 int64_t step, MLIRContext *context)
     : Statement(Kind::For),
       MLValue(MLValueKind::ForStmt, Type::getAffineInt(context)),
       StmtBlock(StmtBlockKind::For), lowerBound(lowerBound),
diff --git a/lib/Parser/Parser.cpp b/lib/Parser/Parser.cpp
index 174d6ca..0debe2d 100644
--- a/lib/Parser/Parser.cpp
+++ b/lib/Parser/Parser.cpp
@@ -2164,7 +2164,7 @@
 ParseResult MLFunctionParser::parseForStmt() {
   consumeToken(Token::kw_for);
 
-  // Parse induction variable
+  // Parse induction variable.
   if (getToken().isNot(Token::percent_identifier))
     return emitError("expected SSA identifier for the loop variable");
 
@@ -2175,7 +2175,7 @@
   if (parseToken(Token::equal, "expected '='"))
     return ParseFailure;
 
-  // Parse loop bounds
+  // Parse loop bounds.
   AffineConstantExpr *lowerBound = parseIntConstant();
   if (!lowerBound)
     return ParseFailure;
@@ -2187,12 +2187,13 @@
   if (!upperBound)
     return ParseFailure;
 
-  // Parse step
-  AffineConstantExpr *step = nullptr;
+  // Parse step.
+  int64_t step = 1;
   if (consumeIf(Token::kw_step)) {
-    step = parseIntConstant();
-    if (!step)
+    AffineConstantExpr *stepExpr = parseIntConstant();
+    if (!stepExpr)
       return ParseFailure;
+    step = stepExpr->getValue();
   }
 
   // Create for statement.
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 0f6d428..a4a11a7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -31,30 +31,49 @@
 #include "mlir/Transforms/Pass.h"
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
+#include "llvm/Support/CommandLine.h"
 #include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
+using namespace llvm;
+
+// Loop unrolling factor.
+static llvm::cl::opt<unsigned>
+    clUnrollFactor("unroll-factor", cl::Hidden,
+                   cl::desc("Use this unroll factor for all loops"));
+
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
+                                        cl::desc("Fully unroll loops"));
+
+static llvm::cl::opt<unsigned> clUnrollFullThreshold(
+    "unroll-full-threshold", cl::Hidden,
+    cl::desc("Unroll all loops with trip count less than or equal to this"));
 
 namespace {
-/// Loop unrolling pass. For now, this unrolls all the innermost loops of this
-/// MLFunction.
+/// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
+/// full unroll threshold was specified, in which case, fully unrolls all loops
+/// with trip count less than the specified threshold. The latter is for testing
+/// purposes, especially for testing outer loop unrolling.
 struct LoopUnroll : public MLFunctionPass {
-  void runOnMLFunction(MLFunction *f) override;
-  void runOnForStmt(ForStmt *forStmt);
-};
+  Optional<unsigned> unrollFactor;
+  Optional<bool> unrollFull;
 
-/// Unrolls all loops with trip count <= minTripCount.
-struct ShortLoopUnroll : public LoopUnroll {
-  const unsigned minTripCount;
+  explicit LoopUnroll(Optional<unsigned> unrollFactor,
+                      Optional<bool> unrollFull)
+      : unrollFactor(unrollFactor), unrollFull(unrollFull) {}
+
   void runOnMLFunction(MLFunction *f) override;
-  ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+  /// Unroll this for stmt. Returns false if nothing was done.
+  bool runOnForStmt(ForStmt *forStmt);
+  bool loopUnrollFull(ForStmt *forStmt);
+  bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
 };
 } // end anonymous namespace
 
-MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
-
-MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
-  return new ShortLoopUnroll(minTripCount);
+MLFunctionPass *mlir::createLoopUnrollPass(int unrollFactor, int unrollFull) {
+  return new LoopUnroll(unrollFactor == -1 ? None
+                                           : Optional<unsigned>(unrollFactor),
+                        unrollFull == -1 ? None : Optional<bool>(unrollFull));
 }
 
 void LoopUnroll::runOnMLFunction(MLFunction *f) {
@@ -81,7 +100,6 @@
       bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
       if (!hasInnerLoops)
         loops.push_back(forStmt);
-
       return true;
     }
 
@@ -101,14 +119,6 @@
     using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
   };
 
-  InnermostLoopGatherer ilg;
-  ilg.walkPostOrder(f);
-  auto &loops = ilg.loops;
-  for (auto *forStmt : loops)
-    runOnForStmt(forStmt);
-}
-
-void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
   // Gathers all loops with trip count <= minTripCount.
   class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
   public:
@@ -120,27 +130,55 @@
     void visitForStmt(ForStmt *forStmt) {
       auto lb = forStmt->getLowerBound()->getValue();
       auto ub = forStmt->getUpperBound()->getValue();
-      auto step = forStmt->getStep()->getValue();
+      auto step = forStmt->getStep();
 
       if ((ub - lb) / step + 1 <= minTripCount)
         loops.push_back(forStmt);
     }
   };
 
-  ShortLoopGatherer slg(minTripCount);
-  // Do a post order walk so that loops are gathered from innermost to
-  // outermost (or else unrolling an outer one may delete gathered inner ones).
-  slg.walkPostOrder(f);
-  auto &loops = slg.loops;
+  if (clUnrollFull.getNumOccurrences() > 0 &&
+      clUnrollFullThreshold.getNumOccurrences() > 0) {
+    ShortLoopGatherer slg(clUnrollFullThreshold);
+    // Do a post order walk so that loops are gathered from innermost to
+    // outermost (or else unrolling an outer one may delete gathered inner
+    // ones).
+    slg.walkPostOrder(f);
+    auto &loops = slg.loops;
+    for (auto *forStmt : loops)
+      loopUnrollFull(forStmt);
+    return;
+  }
+
+  InnermostLoopGatherer ilg;
+  ilg.walkPostOrder(f);
+  auto &loops = ilg.loops;
   for (auto *forStmt : loops)
     runOnForStmt(forStmt);
 }
 
-/// Unroll this For loop completely.
-void LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+/// Unroll a for stmt. Default unroll factor is 4.
+bool LoopUnroll::runOnForStmt(ForStmt *forStmt) {
+  // Unroll completely if full loop unroll was specified.
+  if (clUnrollFull.getNumOccurrences() > 0 ||
+      (unrollFull.hasValue() && unrollFull.getValue()))
+    return loopUnrollFull(forStmt);
+
+  // Unroll by the specified factor if one was specified.
+  if (clUnrollFactor.getNumOccurrences() > 0)
+    return loopUnrollByFactor(forStmt, clUnrollFactor);
+  else if (unrollFactor.hasValue())
+    return loopUnrollByFactor(forStmt, unrollFactor.getValue());
+
+  // Unroll by four otherwise.
+  return loopUnrollByFactor(forStmt, 4);
+}
+
+// Unrolls this loop completely.
+bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
   auto lb = forStmt->getLowerBound()->getValue();
   auto ub = forStmt->getUpperBound()->getValue();
-  auto step = forStmt->getStep()->getValue();
+  auto step = forStmt->getStep();
 
   // Builder to add constants need for the unrolled iterator.
   auto *mlFunc = forStmt->findFunction();
@@ -164,9 +202,75 @@
 
     // Clone the body of the loop.
     for (auto &childStmt : *forStmt) {
-      (void)builder.clone(childStmt, operandMapping);
+      builder.clone(childStmt, operandMapping);
     }
   }
   // Erase the original 'for' stmt from the block.
   forStmt->eraseFromBlock();
+  return true;
+}
+
+/// Unrolls this loop by the specified unroll factor.
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+  assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
+
+  if (unrollFactor == 1 || forStmt->getStatements().empty())
+    return false;
+
+  auto lb = forStmt->getLowerBound()->getValue();
+  auto ub = forStmt->getUpperBound()->getValue();
+  auto step = forStmt->getStep();
+
+  int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+
+  // If the trip count is lower than the unroll factor, no unrolled body.
+  // TODO(bondhugula): option to specify cleanup loop unrolling.
+  if (tripCount < unrollFactor)
+    return true;
+
+  // Generate the cleanup loop if trip count isn't a multiple of unrollFactor.
+  if (tripCount % unrollFactor) {
+    DenseMap<const MLValue *, MLValue *> operandMap;
+    MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
+    auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
+    cleanupForStmt->setLowerBound(builder.getConstantExpr(
+        lb + (tripCount - tripCount % unrollFactor) * step));
+  }
+
+  // Builder to insert unrolled bodies right after the last statement in the
+  // body of 'forStmt'.
+  MLFuncBuilder builder(forStmt, StmtBlock::iterator(forStmt->end()));
+  forStmt->setStep(step * unrollFactor);
+  forStmt->setUpperBound(builder.getConstantExpr(
+      lb + (tripCount - tripCount % unrollFactor - 1) * step));
+
+  // Keep a pointer to the last statement in the original block so that we know
+  // what to clone (since we are doing this in-place).
+  StmtBlock::iterator srcBlockEnd = --forStmt->end();
+
+  // Unroll the contents of 'forStmt' (unrollFactor-1 additional copies
+  // appended).
+  for (unsigned i = 1; i < unrollFactor; i++) {
+    DenseMap<const MLValue *, MLValue *> operandMapping;
+
+    // If the induction variable is used, create a remapping to the value for
+    // this unrolled instance.
+    if (!forStmt->use_empty()) {
+      // iv' = iv + 1/2/3...unrollFactor-1;
+      auto *bumpExpr = builder.getAddExpr(builder.getDimExpr(0),
+                                          builder.getConstantExpr(i * step));
+      auto *bumpMap = builder.getAffineMap(1, 0, {bumpExpr}, {});
+      auto *ivUnroll =
+          builder.create<AffineApplyOp>(bumpMap, forStmt)->getResult(0);
+      operandMapping[forStmt] = cast<MLValue>(ivUnroll);
+    }
+
+    // Clone the original body of the loop (this doesn't include the last stmt).
+    for (auto it = forStmt->begin(); it != srcBlockEnd; it++) {
+      builder.clone(*it, operandMapping);
+    }
+    // Clone the last statement in the original body.
+    builder.clone(*srcBlockEnd, operandMapping);
+  }
+  return true;
 }
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 0f9f208..71536c1 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,5 +1,7 @@
-// RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
-// RUN: %S/../../mlir-opt %s -o - -unroll-short-loops | FileCheck %s --check-prefix SHORT
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-full | FileCheck %s
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4
+// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=3 | FileCheck %s --check-prefix UNROLL-BY-3
 
 // CHECK: #map0 = (d0) -> (d0 + 1)
 
@@ -279,3 +281,87 @@
   %ret = load %C[%zero_idx, %zero_idx] : memref<512 x 512 x i32, (d0, d1) -> (d0, d1), 2>
   return %ret : i32
 }
+
+// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_no_cleanup() {
+mlfunc @unroll_unit_stride_no_cleanup() {
+  // UNROLL-BY-4: for %i0 = 1 to 100 {
+  for %i = 1 to 100 {
+    // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
+    // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: }
+    for %j = 1 to 8 {
+      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+      %y = "addi32"(%x, %x) : (i32, i32) -> i32
+    }
+    // empty loop
+    // UNROLL-BY-4: for %i2 = 1 to 8 {
+    for %k = 1 to 8 {
+    }
+  }
+  return
+}
+
+// UNROLL-BY-4-LABEL: mlfunc @unroll_unit_stride_cleanup() {
+mlfunc @unroll_unit_stride_cleanup() {
+  // UNROLL-BY-4: for %i0 = 1 to 100 {
+  for %i = 1 to 100 {
+    // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 1 to 8 step 4 {
+    // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: }
+    // UNROLL-BY-4-NEXT: for [[L2:%i[0-9]+]] = 9 to 10 {
+    // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: }
+    for %j = 1 to 10 {
+      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+      %y = "addi32"(%x, %x) : (i32, i32) -> i32
+    }
+  }
+  return
+}
+
+// UNROLL-BY-3-LABEL: mlfunc @unroll_non_unit_stride_cleanup() {
+mlfunc @unroll_non_unit_stride_cleanup() {
+  // UNROLL-BY-3: for %i0 = 1 to 100 {
+  for %i = 1 to 100 {
+    // UNROLL-BY-3: for [[L1:%i[0-9]+]] = 2 to 12 step 15 {
+    // UNROLL-BY-3-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-3-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+    // UNROLL-BY-3-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-3-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+    // UNROLL-BY-3-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+    // UNROLL-BY-3-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-3-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+    // UNROLL-BY-3-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+    // UNROLL-BY-3-NEXT: }
+    // UNROLL-BY-3-NEXT: for [[L2:%i[0-9]+]] = 17 to 20 step 5 {
+    // UNROLL-BY-3-NEXT: %8 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-3-NEXT: %9 = "addi32"(%8, %8) : (i32, i32) -> i32
+    // UNROLL-BY-3-NEXT: }
+    for %j = 2 to 20 step 5 {
+      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+      %y = "addi32"(%x, %x) : (i32, i32) -> i32
+    }
+  }
+  return
+}
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index 2c6b76b..252e1f2 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -53,8 +53,7 @@
 
 enum Passes {
   ConvertToCFG,
-  UnrollInnermostLoops,
-  UnrollShortLoops,
+  LoopUnroll,
   TFRaiseControlFlow,
 };
 
@@ -62,10 +61,7 @@
     "", cl::desc("Compiler passes to run"),
     cl::values(clEnumValN(ConvertToCFG, "convert-to-cfg",
                           "Convert all ML functions in the module to CFG ones"),
-               clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
-                          "Unroll innermost loops"),
-               clEnumValN(UnrollShortLoops, "unroll-short-loops",
-                          "Unroll loops of trip count <= 2"),
+               clEnumValN(LoopUnroll, "loop-unroll", "Unroll loops"),
                clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
                           "Dynamic TensorFlow Switch/Match nodes to a CFG")));
 
@@ -112,11 +108,8 @@
     case ConvertToCFG:
       pass = createConvertToCFGPass();
       break;
-    case UnrollInnermostLoops:
-      pass = createLoopUnrollPass();
-      break;
-    case UnrollShortLoops:
-      pass = createLoopUnrollPass(2);
+    case LoopUnroll:
+      pass = createLoopUnrollPass(-1, -1);
       break;
     case TFRaiseControlFlow:
       pass = createRaiseTFControlFlowPass();