Add utility to promote single iteration loops. Add methods for getting constant
loop counts. Improve / refactor loop unroll / loop unroll and jam.

- add utility to remove single iteration loops.
- use this utility to promote single iteration loops after unroll/unroll-and-jam
- use loopUnrollByFactor for loopUnrollFull and remove most of the latter.
- add methods for getting constant loop trip count

PiperOrigin-RevId: 212039569
diff --git a/include/mlir/IR/Statements.h b/include/mlir/IR/Statements.h
index 9570568..513bc08 100644
--- a/include/mlir/IR/Statements.h
+++ b/include/mlir/IR/Statements.h
@@ -266,6 +266,9 @@
   /// Sets the upper bound to the given constant value.
   void setConstantUpperBound(int64_t value);
 
+  /// Returns the trip count if it's a constant.
+  Optional<uint64_t> getConstantTripCount() const;
+
   //===--------------------------------------------------------------------===//
   // Operands
   //===--------------------------------------------------------------------===//
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index 30ece71..6454da7 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -25,7 +25,9 @@
 
 namespace mlir {
 
+class ForStmt;
 class FunctionPass;
+class MLFunction;
 class MLFunctionPass;
 class ModulePass;
 
@@ -47,6 +49,14 @@
 /// generated CFG functions.
 ModulePass *createConvertToCFGPass();
 
+/// Promotes the loop body of a ForStmt to its containing block if the ForStmt
+/// was known to have a single iteration. Returns false otherwise.
+bool promoteIfSingleIteration(ForStmt *forStmt);
+
+/// Promotes all single iteration ForStmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void promoteSingleIterationLoops(MLFunction *f);
+
 } // end namespace mlir
 
 #endif // MLIR_TRANSFORMS_LOOP_H
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 65ca340..1f7604f 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -343,6 +343,24 @@
   return ubMap->getSingleConstantValue();
 }
 
+Optional<uint64_t> ForStmt::getConstantTripCount() const {
+  // TODO(bondhugula): handle arbitrary lower/upper bounds.
+  if (!hasConstantBounds())
+    return None;
+  int64_t lb = getConstantLowerBound();
+  int64_t ub = getConstantUpperBound();
+  int64_t step = getStep();
+
+  // 0 iteration loops.
+  if ((step >= 1 && lb > ub) || (step <= -1 && lb < ub))
+    return 0;
+
+  uint64_t tripCount = static_cast<uint64_t>((ub - lb + 1) % step == 0
+                                                 ? (ub - lb + 1) / step
+                                                 : (ub - lb + 1) / step + 1);
+  return tripCount;
+}
+
 void ForStmt::setConstantLowerBound(int64_t value) {
   MLIRContext *context = getContext();
   auto *expr = AffineConstantExpr::get(value, context);
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 382b830..e663ce0 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -33,22 +33,21 @@
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/CommandLine.h"
-#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
-using namespace llvm;
 
 // Loop unrolling factor.
 static llvm::cl::opt<unsigned>
-    clUnrollFactor("unroll-factor", cl::Hidden,
-                   cl::desc("Use this unroll factor for all loops"));
+    clUnrollFactor("unroll-factor", llvm::cl::Hidden,
+                   llvm::cl::desc("Use this unroll factor for all loops"));
 
-static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
-                                        cl::desc("Fully unroll loops"));
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", llvm::cl::Hidden,
+                                        llvm::cl::desc("Fully unroll loops"));
 
 static llvm::cl::opt<unsigned> clUnrollFullThreshold(
-    "unroll-full-threshold", cl::Hidden,
-    cl::desc("Unroll all loops with trip count less than or equal to this"));
+    "unroll-full-threshold", llvm::cl::Hidden,
+    llvm::cl::desc(
+        "Unroll all loops with trip count less than or equal to this"));
 
 namespace {
 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
@@ -67,7 +66,7 @@
   /// Unroll this for stmt. Returns false if nothing was done.
   bool runOnForStmt(ForStmt *forStmt);
   bool loopUnrollFull(ForStmt *forStmt);
-  bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
+  bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
 };
 } // end anonymous namespace
 
@@ -129,13 +128,8 @@
     ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
 
     void visitForStmt(ForStmt *forStmt) {
-      if (!forStmt->hasConstantBounds())
-        return;
-      auto lb = forStmt->getConstantLowerBound();
-      auto ub = forStmt->getConstantUpperBound();
-      auto step = forStmt->getStep();
-
-      if ((ub - lb) / step + 1 <= minTripCount)
+      Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+      if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
         loops.push_back(forStmt);
     }
   };
@@ -180,43 +174,14 @@
 // Unrolls this loop completely. Fails assertion if loop bounds are
 // non-constant.
 bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
-  auto lb = forStmt->getConstantLowerBound();
-  auto ub = forStmt->getConstantUpperBound();
-  auto step = forStmt->getStep();
-
-  // Builder to add constants needed for the unrolled iterator.
-  auto *mlFunc = forStmt->findFunction();
-  MLFuncBuilder funcTopBuilder(&mlFunc->front());
-
-  // Builder to insert the unrolled bodies.  We insert right after the
-  // ForStmt we're unrolling.
-  MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
-
-  // Unroll the contents of 'forStmt'.
-  for (int64_t i = lb; i <= ub; i += step) {
-    DenseMap<const MLValue *, MLValue *> operandMapping;
-
-    // If the induction variable is used, create a constant for this unrolled
-    // value and add an operand mapping for it.
-    if (!forStmt->use_empty()) {
-      auto *ivConst =
-          funcTopBuilder.create<ConstantAffineIntOp>(forStmt->getLoc(), i)
-              ->getResult();
-      operandMapping[forStmt] = cast<MLValue>(ivConst);
-    }
-
-    // Clone the body of the loop.
-    for (auto &childStmt : *forStmt) {
-      builder.clone(childStmt, operandMapping);
-    }
-  }
-  // Erase the original 'for' stmt from the block.
-  forStmt->eraseFromBlock();
-  return true;
+  Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+  if (tripCount.hasValue())
+    return loopUnrollByFactor(forStmt, tripCount.getValue());
+  return false;
 }
 
 /// Unrolls this loop by the specified unroll factor.
-bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
   assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
 
   if (unrollFactor == 1 || forStmt->getStatements().empty())
@@ -225,11 +190,9 @@
   if (!forStmt->hasConstantBounds())
     return false;
 
-  auto lb = forStmt->getConstantLowerBound();
-  auto ub = forStmt->getConstantUpperBound();
-  auto step = forStmt->getStep();
-
-  int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+  int64_t lb = forStmt->getConstantLowerBound();
+  int64_t step = forStmt->getStep();
+  uint64_t tripCount = forStmt->getConstantTripCount().getValue();
 
   // If the trip count is lower than the unroll factor, no unrolled body.
   // TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -243,6 +206,8 @@
     auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
     cleanupForStmt->setConstantLowerBound(
         lb + (tripCount - tripCount % unrollFactor) * step);
+    // Promote the loop body up if this has turned into a single iteration loop.
+    promoteIfSingleIteration(cleanupForStmt);
   }
 
   // Builder to insert unrolled bodies right after the last statement in the
@@ -281,5 +246,9 @@
     // Clone the last statement in the original body.
     builder.clone(*srcBlockEnd, operandMapping);
   }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  promoteIfSingleIteration(forStmt);
+
   return true;
 }
diff --git a/lib/Transforms/LoopUnrollJam.cpp b/lib/Transforms/LoopUnrollJam.cpp
index eeab87c..6fb6134 100644
--- a/lib/Transforms/LoopUnrollJam.cpp
+++ b/lib/Transforms/LoopUnrollJam.cpp
@@ -23,7 +23,7 @@
 // bounds of the loops inner to the loop being unroll-jammed do not depend on
 // the latter.
 //
-// Before      After unroll-jam of i by factor 2:
+// Before      After unroll and jam of i by factor 2:
 //
 //             for i, step = 2
 // for i         S1(i);
@@ -54,7 +54,6 @@
 #include "llvm/Support/CommandLine.h"
 
 using namespace mlir;
-using namespace llvm::cl;
 
 // Loop unroll jam factor.
 static llvm::cl::opt<unsigned>
@@ -74,7 +73,7 @@
 
   void runOnMLFunction(MLFunction *f) override;
   bool runOnForStmt(ForStmt *forStmt);
-  bool loopUnrollJamByFactor(ForStmt *forStmt, unsigned unrollJamFactor);
+  bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
 };
 } // end anonymous namespace
 
@@ -110,15 +109,7 @@
 
 /// Unrolls and jams this loop by the specified factor.
 bool LoopUnrollAndJam::loopUnrollJamByFactor(ForStmt *forStmt,
-                                             unsigned unrollJamFactor) {
-  assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
-
-  if (unrollJamFactor == 1 || forStmt->getStatements().empty())
-    return false;
-
-  if (!forStmt->hasConstantBounds())
-    return false;
-
+                                             uint64_t unrollJamFactor) {
   // Gathers all maximal sub-blocks of statements that do not themselves include
   // a for stmt (a statement could have a descendant for stmt though in its
   // tree).
@@ -146,12 +137,17 @@
     }
   };
 
-  auto lb = forStmt->getConstantLowerBound();
-  auto ub = forStmt->getConstantUpperBound();
-  auto step = forStmt->getStep();
+  assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
 
-  int64_t tripCount = (ub - lb + 1) % step == 0 ? (ub - lb + 1) / step
-                                                : (ub - lb + 1) / step + 1;
+  if (unrollJamFactor == 1 || forStmt->getStatements().empty())
+    return false;
+
+  if (!forStmt->hasConstantBounds())
+    return false;
+
+  int64_t lb = forStmt->getConstantLowerBound();
+  int64_t step = forStmt->getStep();
+  uint64_t tripCount = forStmt->getConstantTripCount().getValue();
 
   // If the trip count is lower than the unroll jam factor, no unrolled body.
   // TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -172,6 +168,9 @@
     auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
     cleanupForStmt->setConstantLowerBound(
         lb + (tripCount - tripCount % unrollJamFactor) * step);
+
+    // Promote the loop body up if this has turned into a single iteration loop.
+    promoteIfSingleIteration(cleanupForStmt);
   }
 
   MLFuncBuilder b(forStmt);
@@ -210,5 +209,9 @@
       builder.clone(*subBlock.second, operandMapping);
     }
   }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  promoteIfSingleIteration(forStmt);
+
   return true;
 }
diff --git a/lib/Transforms/LoopUtils.cpp b/lib/Transforms/LoopUtils.cpp
new file mode 100644
index 0000000..cc2dc40
--- /dev/null
+++ b/lib/Transforms/LoopUtils.cpp
@@ -0,0 +1,63 @@
+//===- LoopUtils.cpp - Misc loop utilities for simplification //-----------===//
+//
+// Copyright 2019 The MLIR Authors.
+//
+// Licensed under the Apache License, Version 2.0 (the "License");
+// you may not use this file except in compliance with the License.
+// You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing, software
+// distributed under the License is distributed on an "AS IS" BASIS,
+// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+// See the License for the specific language governing permissions and
+// limitations under the License.
+// =============================================================================
+//
+// This file implements miscellaneous loop simplification routines.
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/StandardOps.h"
+#include "mlir/IR/Statements.h"
+#include "mlir/IR/StmtVisitor.h"
+#include "mlir/Transforms/Passes.h"
+
+/// Promotes the loop body of a forStmt to its containing block if the forStmt
+/// was known to have a single iteration. Returns false otherwise.
+bool mlir::promoteIfSingleIteration(ForStmt *forStmt) {
+  Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+  if (!tripCount.hasValue() || !forStmt->hasConstantLowerBound())
+    return false;
+
+  if (tripCount.getValue() != 1)
+    return false;
+
+  // Replaces all IV uses to its single iteration value.
+  auto *mlFunc = forStmt->findFunction();
+  MLFuncBuilder topBuilder(&mlFunc->front());
+  auto constOp = topBuilder.create<ConstantAffineIntOp>(
+      forStmt->getLoc(), forStmt->getConstantLowerBound());
+  forStmt->replaceAllUsesWith(constOp->getResult());
+  // Move the statements to the containing block.
+  auto *block = forStmt->getBlock();
+  block->getStatements().splice(StmtBlock::iterator(forStmt),
+                                forStmt->getStatements());
+  forStmt->eraseFromBlock();
+  return true;
+}
+
+/// Promotes all single iteration for stmt's in the MLFunction, i.e., moves
+/// their body into the containing StmtBlock.
+void mlir::promoteSingleIterationLoops(MLFunction *f) {
+  // Gathers all innermost loops through a post order pruned walk.
+  class LoopBodyPromoter : public StmtWalker<LoopBodyPromoter> {
+  public:
+    void visitForStmt(ForStmt *forStmt) { promoteIfSingleIteration(forStmt); }
+  };
+
+  LoopBodyPromoter fsw;
+  fsw.walkPostOrder(f);
+}
diff --git a/test/Transforms/unroll-jam.mlir b/test/Transforms/unroll-jam.mlir
index 3a0d0e1..19cd1df 100644
--- a/test/Transforms/unroll-jam.mlir
+++ b/test/Transforms/unroll-jam.mlir
@@ -4,7 +4,8 @@
 
 // CHECK-LABEL: mlfunc @unroll_jam_imperfect_nest() {
 mlfunc @unroll_jam_imperfect_nest() {
-  // CHECK: for %i0 = 0 to 99 step 2 {
+  // CHECK: %c100 = constant 100 : affineint
+  // CHECK-NEXT: for %i0 = 0 to 99 step 2 {
   for %i = 0 to 100 {
     // CHECK: %0 = "addi32"(%i0, %i0) : (affineint, affineint) -> i32
     // CHECK-NEXT: %1 = affine_apply #map0(%i0)
@@ -24,14 +25,12 @@
     // CHECK-NEXT: %10 = "addi32"(%9, %9) : (affineint, affineint) -> i32
     %w = "addi32"(%i, %i) : (affineint, affineint) -> i32
   } // CHECK }
-  // cleanup loop.
-  // CHECK: for %i2 = 100 to 100 {
-    // CHECK-NEXT: %11 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
-    // CHECK-NEXT: for %i3 = 0 to 17 {
-      // CHECK-NEXT: %12 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
-      // CHECK-NEXT: %13 = "addi32"(%12, %12) : (i32, i32) -> i32
-    // CHECK-NEXT: }
-    // CHECK-NEXT: %14 = "addi32"(%i2, %i2) : (affineint, affineint) -> i32
+  // cleanup loop (single iteration)
+  // CHECK: %11 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
+  // CHECK-NEXT: for %i2 = 0 to 17 {
+    // CHECK-NEXT: %12 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
+    // CHECK-NEXT: %13 = "addi32"(%12, %12) : (i32, i32) -> i32
   // CHECK-NEXT: }
+  // CHECK-NEXT: %14 = "addi32"(%c100, %c100) : (affineint, affineint) -> i32
   return
 }
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index 71536c1..10ee717 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,9 +1,44 @@
 // RUN: mlir-opt %s -o - -loop-unroll -unroll-full | FileCheck %s
 // RUN: mlir-opt %s -o - -loop-unroll -unroll-full -unroll-full-threshold=2 | FileCheck %s --check-prefix SHORT
 // RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=4 | FileCheck %s --check-prefix UNROLL-BY-4
-// RUN: mlir-opt %s -o - -loop-unroll -unroll-factor=3 | FileCheck %s --check-prefix UNROLL-BY-3
 
 // CHECK: #map0 = (d0) -> (d0 + 1)
+// CHECK: #map1 = (d0) -> (d0 + 2)
+// CHECK: #map2 = (d0) -> (d0 + 3)
+// CHECK: #map3 = (d0) -> (d0 + 4)
+// CHECK: #map4 = (d0, d1) -> (d0 + 1, d1 + 2)
+// CHECK: #map5 = (d0, d1) -> (d0 + 3, d1 + 4)
+// CHECK: #map6 = (d0)[s0] -> (d0 + s0 + 1)
+// CHECK: #map7 = (d0) -> (d0 + 5)
+// CHECK: #map8 = (d0) -> (d0 + 6)
+// CHECK: #map9 = (d0) -> (d0 + 7)
+// CHECK: #map10 = (d0, d1) -> (d0 * 16 + d1)
+// CHECK: #map11 = (d0) -> (d0 + 8)
+// CHECK: #map12 = (d0) -> (d0 + 9)
+// CHECK: #map13 = (d0) -> (d0 + 10)
+// CHECK: #map14 = (d0) -> (d0 + 15)
+// CHECK: #map15 = (d0) -> (d0 + 20)
+// CHECK: #map16 = (d0) -> (d0 + 25)
+// CHECK: #map17 = (d0) -> (d0 + 30)
+// CHECK: #map18 = (d0) -> (d0 + 35)
+
+// SHORT: #map0 = (d0) -> (d0 + 1)
+// SHORT: #map1 = (d0) -> (d0 + 2)
+// SHORT: #map2 = (d0, d1) -> (d0 + 1, d1 + 2)
+// SHORT: #map3 = (d0, d1) -> (d0 + 3, d1 + 4)
+// SHORT: #map4 = (d0)[s0] -> (d0 + s0 + 1)
+// SHORT: #map5 = (d0, d1) -> (d0 * 16 + d1)
+
+// UNROLL-BY-4: #map0 = (d0) -> (d0 + 1)
+// UNROLL-BY-4: #map1 = (d0) -> (d0 + 2)
+// UNROLL-BY-4: #map2 = (d0) -> (d0 + 3)
+// UNROLL-BY-4: #map3 = (d0, d1) -> (d0 + 1, d1 + 2)
+// UNROLL-BY-4: #map4 = (d0, d1) -> (d0 + 3, d1 + 4)
+// UNROLL-BY-4: #map5 = (d0)[s0] -> (d0 + s0 + 1)
+// UNROLL-BY-4: #map6 = (d0, d1) -> (d0 * 16 + d1)
+// UNROLL-BY-4: #map7 = (d0) -> (d0 + 5)
+// UNROLL-BY-4: #map8 = (d0) -> (d0 + 10)
+// UNROLL-BY-4: #map9 = (d0) -> (d0 + 15)
 
 // CHECK-LABEL: mlfunc @loop_nest_simplest() {
 mlfunc @loop_nest_simplest() {
@@ -23,15 +58,15 @@
 // CHECK-LABEL: mlfunc @loop_nest_simple_iv_use() {
 mlfunc @loop_nest_simple_iv_use() {
   // CHECK: %c1 = constant 1 : affineint
-  // CHECK-NEXT: %c2 = constant 2 : affineint
-  // CHECK-NEXT: %c3 = constant 3 : affineint
-  // CHECK-NEXT: %c4 = constant 4 : affineint
   // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
-    // CHECK:       %0 = "addi32"(%c1, %c1) : (affineint, affineint) -> i32
-    // CHECK-NEXT:  %1 = "addi32"(%c2, %c2) : (affineint, affineint) -> i32
-    // CHECK-NEXT:  %2 = "addi32"(%c3, %c3) : (affineint, affineint) -> i32
-    // CHECK-NEXT:  %3 = "addi32"(%c4, %c4) : (affineint, affineint) -> i32
+    // CHECK: %0 = "addi32"(%c1, %c1) : (affineint, affineint) -> i32
+    // CHECK: %1 = affine_apply #map0(%c1)
+    // CHECK-NEXT:  %2 = "addi32"(%1, %1) : (affineint, affineint) -> i32
+    // CHECK: %3 = affine_apply #map1(%c1)
+    // CHECK-NEXT:  %4 = "addi32"(%3, %3) : (affineint, affineint) -> i32
+    // CHECK: %5 = affine_apply #map2(%c1)
+    // CHECK-NEXT:  %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
     for %j = 1 to 4 {
       %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
     }
@@ -43,21 +78,21 @@
 // CHECK-LABEL: mlfunc @loop_nest_body_def_use() {
 mlfunc @loop_nest_body_def_use() {
   // CHECK: %c0 = constant 0 : affineint
-  // CHECK-NEXT: %c1 = constant 1 : affineint
-  // CHECK-NEXT: %c2 = constant 2 : affineint
-  // CHECK-NEXT: %c3 = constant 3 : affineint
   // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
     // CHECK: %c0_0 = constant 0 : affineint
     %c0 = constant 0 : affineint
     // CHECK:      %0 = affine_apply #map0(%c0)
     // CHECK-NEXT: %1 = "addi32"(%0, %c0_0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %2 = affine_apply #map0(%c1)
-    // CHECK-NEXT: %3 = "addi32"(%2, %c0_0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %4 = affine_apply #map0(%c2)
-    // CHECK-NEXT: %5 = "addi32"(%4, %c0_0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %6 = affine_apply #map0(%c3)
+    // CHECK-NEXT: %2 = affine_apply #map0(%c0)
+    // CHECK-NEXT: %3 = affine_apply #map0(%2)
+    // CHECK-NEXT: %4 = "addi32"(%3, %c0_0) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %5 = affine_apply #map1(%c0)
+    // CHECK-NEXT: %6 = affine_apply #map0(%5)
     // CHECK-NEXT: %7 = "addi32"(%6, %c0_0) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %8 = affine_apply #map2(%c0)
+    // CHECK-NEXT: %9 = affine_apply #map0(%8)
+    // CHECK-NEXT: %10 = "addi32"(%9, %c0_0) : (affineint, affineint) -> affineint
     for %j = 0 to 3 {
       %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
@@ -70,27 +105,27 @@
 // CHECK-LABEL: mlfunc @loop_nest_strided() {
 mlfunc @loop_nest_strided() {
   // CHECK: %c3 = constant 3 : affineint
-  // CHECK-NEXT: %c5 = constant 5 : affineint
-  // CHECK-NEXT: %c7 = constant 7 : affineint
   // CHECK-NEXT: %c3_0 = constant 3 : affineint
-  // CHECK-NEXT: %c5_1 = constant 5 : affineint
   // CHECK-NEXT: for %i0 = 1 to 100 {
   for %i = 1 to 100 {
     // CHECK:      %0 = affine_apply #map0(%c3_0)
     // CHECK-NEXT: %1 = "addi32"(%0, %0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %2 = affine_apply #map0(%c5_1)
-    // CHECK-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %2 = affine_apply #map1(%c3_0)
+    // CHECK-NEXT: %3 = affine_apply #map0(%2)
+    // CHECK-NEXT: %4 = "addi32"(%3, %3) : (affineint, affineint) -> affineint
     for %j = 3 to 6 step 2 {
       %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
       %y = "addi32"(%x, %x) : (affineint, affineint) -> affineint
     }
-    // CHECK:      %4 = affine_apply #map0(%c3)
-    // CHECK-NEXT: %5 = "addi32"(%4, %4) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %6 = affine_apply #map0(%c5)
-    // CHECK-NEXT: %7 = "addi32"(%6, %6) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %8 = affine_apply #map0(%c7)
+    // CHECK:      %5 = affine_apply #map0(%c3)
+    // CHECK-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %7 = affine_apply #map1(%c3)
+    // CHECK-NEXT: %8 = affine_apply #map0(%7)
     // CHECK-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %10 = affine_apply #map3(%c3)
+    // CHECK-NEXT: %11 = affine_apply #map0(%10)
+    // CHECK-NEXT: %12 = "addi32"(%11, %11) : (affineint, affineint) -> affineint
     for %k = 3 to 7 step 2 {
       %z = "affine_apply" (%k) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
@@ -103,16 +138,17 @@
 // CHECK-LABEL: mlfunc @loop_nest_multiple_results() {
 mlfunc @loop_nest_multiple_results() {
   // CHECK: %c0 = constant 0 : affineint
-  // CHECK-NEXT: %c1 = constant 1 : affineint
+  // CHECK-NEXT: for %i0 = 1 to 100 {
   for %i = 1 to 100 {
-    // CHECK: %0 = affine_apply #map1(%i0, %c0)
+    // CHECK: %0 = affine_apply #map4(%i0, %c0)
     // CHECK-NEXT: %1 = "addi32"(%0#0, %0#1) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %2 = affine_apply #map2(%i0, %c0)
+    // CHECK-NEXT: %2 = affine_apply #map5(%i0, %c0)
     // CHECK-NEXT: %3 = "fma"(%2#0, %2#1, %0#0) : (affineint, affineint, affineint) -> (affineint, affineint)
-    // CHECK-NEXT: %4 = affine_apply #map1(%i0, %c1)
-    // CHECK-NEXT: %5 = "addi32"(%4#0, %4#1) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %6 = affine_apply #map2(%i0, %c1)
-    // CHECK-NEXT: %7 = "fma"(%6#0, %6#1, %4#0) : (affineint, affineint, affineint) -> (affineint, affineint)
+    // CHECK-NEXT: %4 = affine_apply #map0(%c0)
+    // CHECK-NEXT: %5 = affine_apply #map4(%i0, %4)
+    // CHECK-NEXT: %6 = "addi32"(%5#0, %5#1) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %7 = affine_apply #map5(%i0, %4)
+    // CHECK-NEXT: %8 = "fma"(%7#0, %7#1, %5#0) : (affineint, affineint, affineint) -> (affineint, affineint)
     for %j = 0 to 1 step 1 {
       %x = "affine_apply" (%i, %j) { map: (d0, d1) -> (d0 + 1, d1 + 2) } :
         (affineint, affineint) -> (affineint, affineint)
@@ -130,9 +166,6 @@
 // CHECK-LABEL: mlfunc @loop_nest_seq_imperfect(%arg0 : memref<128x128xf32>) {
 mlfunc @loop_nest_seq_imperfect(%a : memref<128x128xf32>) {
   // CHECK: %c1 = constant 1 : affineint
-  // CHECK-NEXT: %c2 = constant 2 : affineint
-  // CHECK-NEXT: %c3 = constant 3 : affineint
-  // CHECK-NEXT: %c4 = constant 4 : affineint
   // CHECK-NEXT: %c128 = constant 128 : affineint
   %c128 = constant 128 : affineint
   // CHECK: for %i0 = 1 to 100 {
@@ -142,24 +175,27 @@
     // CHECK: %1 = affine_apply #map0(%c1)
     // CHECK-NEXT: %2 = "vmulf"(%c1, %1) : (affineint, affineint) -> affineint
     // CHECK-NEXT: %3 = "vaddf"(%2, %2) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %4 = affine_apply #map0(%c2)
-    // CHECK-NEXT: %5 = "vmulf"(%c2, %4) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %6 = "vaddf"(%5, %5) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %7 = affine_apply #map0(%c3)
-    // CHECK-NEXT: %8 = "vmulf"(%c3, %7) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %9 = "vaddf"(%8, %8) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %10 = affine_apply #map0(%c4)
-    // CHECK-NEXT: %11 = "vmulf"(%c4, %10) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %12 = "vaddf"(%11, %11) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %4 = affine_apply #map0(%c1)
+    // CHECK-NEXT: %5 = affine_apply #map0(%4)
+    // CHECK-NEXT: %6 = "vmulf"(%4, %5) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %7 = "vaddf"(%6, %6) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %8 = affine_apply #map1(%c1)
+    // CHECK-NEXT: %9 = affine_apply #map0(%8)
+    // CHECK-NEXT: %10 = "vmulf"(%8, %9) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %11 = "vaddf"(%10, %10) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %12 = affine_apply #map2(%c1)
+    // CHECK-NEXT: %13 = affine_apply #map0(%12)
+    // CHECK-NEXT: %14 = "vmulf"(%12, %13) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %15 = "vaddf"(%14, %14) : (affineint, affineint) -> affineint
     for %j = 1 to 4 {
       %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
        %y = "vmulf"(%j, %x) : (affineint, affineint) -> affineint
        %z = "vaddf"(%y, %y) : (affineint, affineint) -> affineint
     }
-    // CHECK: %13 = "scale"(%c128, %i0) : (affineint, affineint) -> affineint
+    // CHECK: %16 = "scale"(%c128, %i0) : (affineint, affineint) -> affineint
     %addr = "scale"(%c128, %i) : (affineint, affineint) -> affineint
-    // CHECK: "vst"(%13, %i0) : (affineint, affineint) -> ()
+    // CHECK: "vst"(%16, %i0) : (affineint, affineint) -> ()
     "vst"(%addr, %i) : (affineint, affineint) -> ()
   }       // CHECK }
   return  // CHECK:  return
@@ -168,21 +204,18 @@
 // CHECK-LABEL: mlfunc @loop_nest_seq_multiple() {
 mlfunc @loop_nest_seq_multiple() {
   // CHECK: %c1 = constant 1 : affineint
-  // CHECK-NEXT: %c2 = constant 2 : affineint
-  // CHECK-NEXT: %c3 = constant 3 : affineint
-  // CHECK-NEXT: %c4 = constant 4 : affineint
   // CHECK-NEXT: %c0 = constant 0 : affineint
-  // CHECK-NEXT: %c1_0 = constant 1 : affineint
-  // CHECK-NEXT: %c2_1 = constant 2 : affineint
-  // CHECK-NEXT: %c3_2 = constant 3 : affineint
   // CHECK-NEXT: %0 = affine_apply #map0(%c0)
   // CHECK-NEXT: "mul"(%0, %0) : (affineint, affineint) -> ()
-  // CHECK-NEXT: %1 = affine_apply #map0(%c1_0)
-  // CHECK-NEXT: "mul"(%1, %1) : (affineint, affineint) -> ()
-  // CHECK-NEXT: %2 = affine_apply #map0(%c2_1)
+  // CHECK-NEXT: %1 = affine_apply #map0(%c0)
+  // CHECK-NEXT: %2 = affine_apply #map0(%1)
   // CHECK-NEXT: "mul"(%2, %2) : (affineint, affineint) -> ()
-  // CHECK-NEXT: %3 = affine_apply #map0(%c3_2)
-  // CHECK-NEXT: "mul"(%3, %3) : (affineint, affineint) -> ()
+  // CHECK-NEXT: %3 = affine_apply #map1(%c0)
+  // CHECK-NEXT: %4 = affine_apply #map0(%3)
+  // CHECK-NEXT: "mul"(%4, %4) : (affineint, affineint) -> ()
+  // CHECK-NEXT: %5 = affine_apply #map2(%c0)
+  // CHECK-NEXT: %6 = affine_apply #map0(%5)
+  // CHECK-NEXT: "mul"(%6, %6) : (affineint, affineint) -> ()
   for %j = 0 to 3 {
     %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
       (affineint) -> (affineint)
@@ -193,14 +226,17 @@
   %k = "constant"(){value: 99} : () -> affineint
   // CHECK: for %i0 = 1 to 100 step 2 {
   for %m = 1 to 100 step 2 {
-    // CHECK: %4 = affine_apply #map0(%c1)
-    // CHECK-NEXT: %5 = affine_apply #map3(%c1)[%c99]
-    // CHECK-NEXT: %6 = affine_apply #map0(%c2)
-    // CHECK-NEXT: %7 = affine_apply #map3(%c2)[%c99]
-    // CHECK-NEXT: %8 = affine_apply #map0(%c3)
-    // CHECK-NEXT: %9 = affine_apply #map3(%c3)[%c99]
-    // CHECK-NEXT: %10 = affine_apply #map0(%c4)
-    // CHECK-NEXT: %11 = affine_apply #map3(%c4)[%c99]
+    // CHECK: %7 = affine_apply #map0(%c1)
+    // CHECK-NEXT: %8 = affine_apply #map6(%c1)[%c99]
+    // CHECK-NEXT: %9 = affine_apply #map0(%c1)
+    // CHECK-NEXT: %10 = affine_apply #map0(%9)
+    // CHECK-NEXT: %11 = affine_apply #map6(%9)[%c99]
+    // CHECK-NEXT: %12 = affine_apply #map1(%c1)
+    // CHECK-NEXT: %13 = affine_apply #map0(%12)
+    // CHECK-NEXT: %14 = affine_apply #map6(%12)[%c99]
+    // CHECK-NEXT: %15 = affine_apply #map2(%c1)
+    // CHECK-NEXT: %16 = affine_apply #map0(%15)
+    // CHECK-NEXT: %17 = affine_apply #map6(%15)[%c99]
     for %n = 1 to 4 {
       %y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
@@ -340,28 +376,53 @@
   return
 }
 
-// UNROLL-BY-3-LABEL: mlfunc @unroll_non_unit_stride_cleanup() {
+// UNROLL-BY-4-LABEL: mlfunc @unroll_non_unit_stride_cleanup() {
 mlfunc @unroll_non_unit_stride_cleanup() {
-  // UNROLL-BY-3: for %i0 = 1 to 100 {
+  // UNROLL-BY-4: for %i0 = 1 to 100 {
   for %i = 1 to 100 {
-    // UNROLL-BY-3: for [[L1:%i[0-9]+]] = 2 to 12 step 15 {
-    // UNROLL-BY-3-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
-    // UNROLL-BY-3-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
-    // UNROLL-BY-3-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
-    // UNROLL-BY-3-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
-    // UNROLL-BY-3-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
-    // UNROLL-BY-3-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
-    // UNROLL-BY-3-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
-    // UNROLL-BY-3-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
-    // UNROLL-BY-3-NEXT: }
-    // UNROLL-BY-3-NEXT: for [[L2:%i[0-9]+]] = 17 to 20 step 5 {
-    // UNROLL-BY-3-NEXT: %8 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
-    // UNROLL-BY-3-NEXT: %9 = "addi32"(%8, %8) : (i32, i32) -> i32
-    // UNROLL-BY-3-NEXT: }
-    for %j = 2 to 20 step 5 {
+    // UNROLL-BY-4: for [[L1:%i[0-9]+]] = 2 to 37 step 20 {
+    // UNROLL-BY-4-NEXT: %0 = "addi32"([[L1]], [[L1]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %1 = "addi32"(%0, %0) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %2 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %5 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %7 = "addi32"(%6, %6) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: %8 = affine_apply #map{{[0-9]+}}([[L1]])
+    // UNROLL-BY-4-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %10 = "addi32"(%9, %9) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: }
+    // UNROLL-BY-4-NEXT: for [[L2:%i[0-9]+]] = 42 to 48 step 5 {
+    // UNROLL-BY-4-NEXT: %11 = "addi32"([[L2]], [[L2]]) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %12 = "addi32"(%11, %11) : (i32, i32) -> i32
+    // UNROLL-BY-4-NEXT: }
+    for %j = 2 to 48 step 5 {
       %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
       %y = "addi32"(%x, %x) : (i32, i32) -> i32
     }
   }
   return
 }
+
+// Both the unrolled loop and the cleanup loop are single iteration loops.
+mlfunc @loop_nest_single_iteration_after_unroll(%N: affineint) {
+  // UNROLL-BY-4: %c0 = constant 0 : affineint
+  // UNROLL-BY-4: %c4 = constant 4 : affineint
+  // UNROLL-BY-4: for %i0 = 1 to %arg0 {
+  for %i = 1 to %N {
+    // UNROLL-BY-4: %0 = "addi32"(%c0, %c0) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %1 = affine_apply #map0(%c0)
+    // UNROLL-BY-4-NEXT: %2 = "addi32"(%1, %1) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %3 = affine_apply #map1(%c0)
+    // UNROLL-BY-4-NEXT: %4 = "addi32"(%3, %3) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %5 = affine_apply #map2(%c0)
+    // UNROLL-BY-4-NEXT: %6 = "addi32"(%5, %5) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NEXT: %7 = "addi32"(%c4, %c4) : (affineint, affineint) -> i32
+    // UNROLL-BY-4-NOT: for
+    for %j = 0 to 4 {
+      %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
+    } // UNROLL-BY-4-NOT: }
+  } // UNROLL-BY-4:  }
+  return
+}