Loop unrolling pass update

- fix/complete forStmt cloning for unrolling to work for outer loops
- create IV const's only when needed
- test outer loop unrolling by creating a short trip count unroll pass for
  loops with trip counts <= <parameter>
- add unrolling test cases for multiple op results, outer loop unrolling
- fix/clean up StmtWalker class while on this
- switch unroll loop iterator values from i32 to affineint

PiperOrigin-RevId: 207645967
diff --git a/include/mlir/IR/Builders.h b/include/mlir/IR/Builders.h
index b6cc934..59aca69 100644
--- a/include/mlir/IR/Builders.h
+++ b/include/mlir/IR/Builders.h
@@ -265,39 +265,16 @@
     return op;
   }
 
-  OperationStmt *clone(const OperationStmt &srcOpStmt) {
-    auto *op = srcOpStmt.clone();
-    block->getStatements().insert(insertPoint, op);
-    return op;
-  }
-
   // Create operation of specific op type at the current insertion point.
   template <typename OpTy, typename... Args>
   OpPointer<OpTy> create(Args... args) {
     return OpTy::build(this, args...);
   }
 
-  ForStmt *clone(const ForStmt &srcForStmt) {
-    auto *forStmt = srcForStmt.clone();
-    block->getStatements().insert(insertPoint, forStmt);
-    return forStmt;
-  }
-
-  IfStmt *clone(const IfStmt &srcIfStmt) {
-    auto *ifStmt = srcIfStmt.clone();
-    block->getStatements().insert(insertPoint, ifStmt);
-    return ifStmt;
-  }
-
   Statement *clone(const Statement &stmt) {
-    switch (stmt.getKind()) {
-    case Statement::Kind::Operation:
-      return clone(cast<const OperationStmt>(stmt));
-    case Statement::Kind::If:
-      return clone(cast<const IfStmt>(stmt));
-    case Statement::Kind::For:
-      return clone(cast<const ForStmt>(stmt));
-    }
+    Statement *cloneStmt = stmt.clone();
+    block->getStatements().insert(insertPoint, cloneStmt);
+    return cloneStmt;
   }
 
   // Creates for statement. When step is not specified, it is set to 1.
diff --git a/include/mlir/IR/StandardOps.h b/include/mlir/IR/StandardOps.h
index 8cfe432..f8701bf 100644
--- a/include/mlir/IR/StandardOps.h
+++ b/include/mlir/IR/StandardOps.h
@@ -158,12 +158,13 @@
 };
 
 /// This is a refinement of the "constant" op for the case where it is
-/// returning an integer value (either an IntegerType or AffineInt).
+/// returning an integer value of IntegerType.
 ///
 ///   %1 = "constant"(){value: 42}
 ///
 class ConstantIntOp : public ConstantOp {
 public:
+  /// Build a constant int op producing an integer of the specified width.
   template <class Builder>
   static OpPointer<ConstantIntOp> build(Builder *builder, int64_t value,
                                         unsigned width) {
@@ -186,6 +187,36 @@
   explicit ConstantIntOp(const Operation *state) : ConstantOp(state) {}
 };
 
+/// This is a refinement of the "constant" op for the case where it is
+/// returning an integer value of AffineInt type.
+///
+///   %1 = "constant"(){value: 99} : () -> affineint
+///
+class ConstantAffineIntOp : public ConstantOp {
+public:
+  /// Build a constant int op producing an affineint.
+  template <class Builder>
+  static OpPointer<ConstantAffineIntOp> build(Builder *builder, int64_t value) {
+    std::pair<Identifier, Attribute *> namedAttr(
+        builder->getIdentifier("value"), builder->getIntegerAttr(value));
+    auto *type = builder->getAffineIntType();
+
+    return OpPointer<ConstantAffineIntOp>(
+        ConstantAffineIntOp(builder->createOperation(
+            builder->getIdentifier("constant"), {}, type, {namedAttr})));
+  }
+
+  int64_t getValue() const {
+    return getAttrOfType<IntegerAttr>("value")->getValue();
+  }
+
+  static bool isClassFor(const Operation *op);
+
+private:
+  friend class Operation;
+  explicit ConstantAffineIntOp(const Operation *state) : ConstantOp(state) {}
+};
+
 /// The "dim" operation takes a memref or tensor operand and returns an
 /// "affineint".  It requires a single integer attribute named "index".  It
 /// returns the size of the specified dimension.  For example:
diff --git a/include/mlir/IR/Statement.h b/include/mlir/IR/Statement.h
index 2326d50..43bbd2c 100644
--- a/include/mlir/IR/Statement.h
+++ b/include/mlir/IR/Statement.h
@@ -31,6 +31,7 @@
 class StmtBlock;
 class ForStmt;
 class MLIRContext;
+class MLValue;
 
 /// Statement is a basic unit of execution within an ML function.
 /// Statements can be nested within for and if statements effectively
@@ -72,6 +73,9 @@
   void print(raw_ostream &os) const;
   void dump() const;
 
+  /// Replace all uses of 'oldVal' with 'newVal' in 'stmt'.
+  void replaceUses(MLValue *oldVal, MLValue *newVal);
+
 protected:
   Statement(Kind kind) : kind(kind) {}
   // Statements are deleted through the destroy() member because this class
diff --git a/include/mlir/IR/StmtBlock.h b/include/mlir/IR/StmtBlock.h
index a8a1f20..50463f0 100644
--- a/include/mlir/IR/StmtBlock.h
+++ b/include/mlir/IR/StmtBlock.h
@@ -24,11 +24,11 @@
 
 #include "mlir/IR/Statement.h"
 #include "mlir/Support/LLVM.h"
-#include "llvm/Support/raw_ostream.h"
 
 namespace mlir {
 class MLFunction;
 class IfStmt;
+class MLValue;
 
 /// Statement block represents an ordered list of statements, with the order
 /// being the contiguous lexical order in which the statements appear as
diff --git a/include/mlir/IR/StmtVisitor.h b/include/mlir/IR/StmtVisitor.h
index 73f17e3..3c10743 100644
--- a/include/mlir/IR/StmtVisitor.h
+++ b/include/mlir/IR/StmtVisitor.h
@@ -123,19 +123,24 @@
       walk(&(*Start++));
     }
   }
+  template <class Iterator> void walkPostOrder(Iterator Start, Iterator End) {
+    while (Start != End) {
+      walkPostOrder(&(*Start++));
+    }
+  }
 
   // Define walkers for MLFunction and all MLFunction statement kinds.
   void walk(MLFunction *f) {
     static_cast<SubClass *>(this)->visitMLFunction(f);
-    static_cast<SubClass *>(this)->walk(f->begin(), f->end());
+    walk(f->begin(), f->end());
   }
 
   void walkPostOrder(MLFunction *f) {
-    walk(f->begin(), f->end());
-    return static_cast<SubClass *>(this)->visitMLFunction(f);
+    walkPostOrder(f->begin(), f->end());
+    static_cast<SubClass *>(this)->visitMLFunction(f);
   }
 
-  void walkOpStmt(OperationStmt *opStmt) {
+  RetTy walkOpStmt(OperationStmt *opStmt) {
     return static_cast<SubClass *>(this)->visitOperationStmt(opStmt);
   }
 
@@ -204,7 +209,8 @@
 
   // When visiting a specific stmt directly during a walk, these  methods get
   // called. These are typically O(1) complexity and shouldn't be recursively
-  // processing their descendants in some way.
+  // processing their descendants in some way. When using RetTy, all of these
+  // need to be overridden.
   void visitMLFunction(MLFunction *f) {}
   void visitForStmt(ForStmt *forStmt) {}
   void visitIfStmt(IfStmt *ifStmt) {}
diff --git a/include/mlir/Transforms/Passes.h b/include/mlir/Transforms/Passes.h
index 7352ad0..3944c13 100644
--- a/include/mlir/Transforms/Passes.h
+++ b/include/mlir/Transforms/Passes.h
@@ -28,8 +28,9 @@
 class MLFunctionPass;
 class ModulePass;
 
-/// A loop unrolling pass.
+/// Loop unrolling passes.
 MLFunctionPass *createLoopUnrollPass();
+MLFunctionPass *createLoopUnrollPass(unsigned);
 
 /// Replaces all ML functions in the module with equivalent CFG functions.
 /// Function references are appropriately patched to refer to the newly
diff --git a/lib/IR/Statement.cpp b/lib/IR/Statement.cpp
index 0a76587..068e498 100644
--- a/lib/IR/Statement.cpp
+++ b/lib/IR/Statement.cpp
@@ -85,11 +85,35 @@
     llvm_unreachable("cloning for if's not implemented yet");
     return cast<IfStmt>(this)->clone();
   case Kind::For:
-    llvm_unreachable("cloning for loops not implemented yet");
     return cast<ForStmt>(this)->clone();
   }
 }
 
+/// Replaces all uses of oldVal with newVal.
+// TODO(bondhugula,clattner): do this more efficiently by walking those uses of
+// oldVal that fall within this statement.
+void Statement::replaceUses(MLValue *oldVal, MLValue *newVal) {
+  struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
+    // Value to be replaced.
+    MLValue *oldVal;
+    // Value to be replaced with.
+    MLValue *newVal;
+
+    ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
+        : oldVal(oldVal), newVal(newVal){};
+
+    void visitOperationStmt(OperationStmt *os) {
+      for (auto &operand : os->getStmtOperands()) {
+        if (operand.get() == oldVal)
+          operand.set(newVal);
+      }
+    }
+  };
+
+  ReplaceUseWalker ri(oldVal, newVal);
+  ri.walk(this);
+}
+
 //===----------------------------------------------------------------------===//
 // ilist_traits for Statement
 //===----------------------------------------------------------------------===//
@@ -233,12 +257,34 @@
       upperBound(upperBound), step(step) {}
 
 ForStmt *ForStmt::clone() const {
-  auto *stmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
-                           Statement::findFunction()->getContext());
+  auto *forStmt = new ForStmt(getLowerBound(), getUpperBound(), getStep(),
+                              Statement::findFunction()->getContext());
+
+  // Pairs of <old op stmt result whose uses need to be replaced,
+  // new result generated by the corresponding cloned op stmt>.
+  SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
   for (auto &s : getStatements()) {
-    stmt->getStatements().push_back(s.clone());
+    auto *cloneStmt = s.clone();
+    forStmt->getStatements().push_back(cloneStmt);
+    if (auto *opStmt = dyn_cast<OperationStmt>(&s)) {
+      auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+      for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+        oldNewResultPairs.push_back(
+            std::make_pair(const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
+                           &cloneOpStmt->getStmtResult(i)));
+      }
+    }
   }
-  return stmt;
+  // Replace uses of old op results' with the newly created ones.
+  for (unsigned i = 0, e = oldNewResultPairs.size(); i < e; i++) {
+    for (auto &stmt : *forStmt) {
+      stmt.replaceUses(oldNewResultPairs[i].first, oldNewResultPairs[i].second);
+    }
+  }
+
+  // Replace uses of old loop IV with the new one.
+  forStmt->Statement::replaceUses(const_cast<ForStmt *>(this), forStmt);
+  return forStmt;
 }
 
 //===----------------------------------------------------------------------===//
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 27bb43f..eea3bf7 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -39,14 +39,22 @@
   void runOnMLFunction(MLFunction *f) override;
   void runOnForStmt(ForStmt *forStmt);
 };
+struct ShortLoopUnroll : public LoopUnroll {
+  const unsigned minTripCount;
+  void runOnMLFunction(MLFunction *f) override;
+  ShortLoopUnroll(unsigned minTripCount) : minTripCount(minTripCount) {}
+};
 } // end anonymous namespace
 
 MLFunctionPass *mlir::createLoopUnrollPass() { return new LoopUnroll(); }
 
+MLFunctionPass *mlir::createLoopUnrollPass(unsigned minTripCount) {
+  return new ShortLoopUnroll(minTripCount);
+}
+
 /// Unrolls all the innermost loops of this MLFunction.
 void LoopUnroll::runOnMLFunction(MLFunction *f) {
   // Gathers all innermost loops through a post order pruned walk.
-  // TODO: figure out the right reusable template here to better refactor code.
   class InnermostLoopGatherer : public StmtWalker<InnermostLoopGatherer, bool> {
   public:
     // Store innermost loops as we walk.
@@ -65,11 +73,6 @@
       return hasInnerLoops;
     }
 
-    // FIXME: can't use base class method for this because that in turn would
-    // need to use the derived class method above. CRTP doesn't allow it, and
-    // the compiler error resulting from it is also very misleading!
-    void walkPostOrder(MLFunction *f) { walkPostOrder(f->begin(), f->end()); }
-
     bool walkForStmtPostOrder(ForStmt *forStmt) {
       bool hasInnerLoops = walkPostOrder(forStmt->begin(), forStmt->end());
       if (!hasInnerLoops)
@@ -85,8 +88,11 @@
       return hasInnerLoops;
     }
 
-    bool walkOpStmt(OperationStmt *opStmt) { return false; }
+    bool visitOperationStmt(OperationStmt *opStmt) { return false; }
 
+    // FIXME: can't use base class method for this because that in turn would
+    // need to use the derived class method above. CRTP doesn't allow it, and
+    // the compiler error resulting from it is also misleading.
     using StmtWalker<InnermostLoopGatherer, bool>::walkPostOrder;
   };
 
@@ -97,28 +103,96 @@
     runOnForStmt(forStmt);
 }
 
-/// Replace all uses of 'oldVal' with 'newVal' in 'stmt'
-static void replaceAllStmtUses(Statement *stmt, MLValue *oldVal,
-                               MLValue *newVal) {
-  struct ReplaceUseWalker : public StmtWalker<ReplaceUseWalker> {
-    // Value to be replaced.
-    MLValue *oldVal;
-    // Value to be replaced with.
-    MLValue *newVal;
+/// Unrolls all loops with trip count <= minTripCount.
+void ShortLoopUnroll::runOnMLFunction(MLFunction *f) {
+  // Gathers all loops with trip count <= minTripCount.
+  class ShortLoopGatherer : public StmtWalker<ShortLoopGatherer> {
+  public:
+    // Store short loops as we walk.
+    std::vector<ForStmt *> loops;
+    const unsigned minTripCount;
+    ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
 
-    ReplaceUseWalker(MLValue *oldVal, MLValue *newVal)
-        : oldVal(oldVal), newVal(newVal){};
+    void visitForStmt(ForStmt *forStmt) {
+      auto lb = forStmt->getLowerBound()->getValue();
+      auto ub = forStmt->getUpperBound()->getValue();
+      auto step = forStmt->getStep()->getValue();
 
-    void visitOperationStmt(OperationStmt *os) {
-      for (auto &operand : os->getStmtOperands()) {
-        if (operand.get() == oldVal)
-          operand.set(newVal);
-      }
+      if ((ub - lb) / step + 1 <= minTripCount)
+        loops.push_back(forStmt);
     }
   };
 
-  ReplaceUseWalker ri(oldVal, newVal);
-  ri.walk(stmt);
+  ShortLoopGatherer slg(minTripCount);
+  slg.walk(f);
+  auto &loops = slg.loops;
+  for (auto *forStmt : loops)
+    runOnForStmt(forStmt);
+}
+
+/// Replace all uses of oldVal with newVal from begin to end.
+static void replaceUses(StmtBlock::iterator begin, StmtBlock::iterator end,
+                        MLValue *oldVal, MLValue *newVal) {
+  // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+  // of oldVal that fall within this list of statements (instead of iterating
+  // through all statements / through all operands of operations found).
+  for (auto it = begin; it != end; it++) {
+    it->replaceUses(oldVal, newVal);
+  }
+}
+
+/// Replace all uses of oldVal with newVal.
+void replaceUses(StmtBlock *block, MLValue *oldVal, MLValue *newVal) {
+  // TODO(bondhugula,clattner): do this more efficiently by walking those uses
+  // of oldVal that fall within this StmtBlock (instead of iterating through
+  // all statements / through all operands of operations found).
+  for (auto it = block->begin(); it != block->end(); it++) {
+    it->replaceUses(oldVal, newVal);
+  }
+}
+
+/// Clone the list of stmt's from 'block' and insert into the current
+/// position of the builder.
+// TODO(bondhugula,clattner): replace this with a parameterizable clone.
+void cloneStmtListFromBlock(MLFuncBuilder *builder, const StmtBlock &block) {
+  // Pairs of <old op stmt result whose uses need to be replaced,
+  // new result generated by the corresponding cloned op stmt>.
+  SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+
+  // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+  StmtBlock::iterator beforeUnrolledBody = --builder->getInsertionPoint();
+
+  for (auto &stmt : block.getStatements()) {
+    auto *cloneStmt = builder->clone(stmt);
+    // Whenever we have an op stmt, we'll have a new ML Value defined: replace
+    // uses of the old result with this one.
+    if (auto *opStmt = dyn_cast<OperationStmt>(&stmt)) {
+      if (opStmt->getNumResults()) {
+        auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
+        for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
+          // Store old/new result pairs.
+          // TODO(bondhugula) *only* if needed later: storing of old/new
+          // results can be avoided by cloning the statement list in the
+          // reverse direction (and running the IR builder in the reverse
+          // (iplist.insertAfter()). That way, a newly created result can be
+          // immediately propagated to all its uses.
+          oldNewResultPairs.push_back(std::make_pair(
+              const_cast<StmtResult *>(&opStmt->getStmtResult(i)),
+              &cloneOpStmt->getStmtResult(i)));
+        }
+      }
+    }
+  }
+
+  // Replace uses of old op results' with the new results.
+  StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+  StmtBlock::iterator endOfUnrolledBody = builder->getInsertionPoint();
+
+  // Replace uses of old op results' with the newly created ones.
+  for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
+    replaceUses(startOfUnrolledBody, endOfUnrolledBody,
+                oldNewResultPairs[i].first, oldNewResultPairs[i].second);
+  }
 }
 
 /// Unroll this 'for stmt' / loop completely.
@@ -139,52 +213,25 @@
                             ++StmtBlock::iterator(forStmt));
 
   // Unroll the contents of 'forStmt'.
-  for (int i = lb; i <= ub; i += step) {
-    // TODO(bondhugula): generate constants only when IV actually appears.
-    auto constOp = funcTopBuilder.create<ConstantIntOp>(i, 32);
-    auto *ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
-
-    // Iterator pointing to just before 'this' (i^th) unrolled iteration.
+  for (int64_t i = lb; i <= ub; i += step) {
+    MLValue *ivConst = nullptr;
+    if (!forStmt->use_empty()) {
+      auto constOp = funcTopBuilder.create<ConstantAffineIntOp>(i);
+      ivConst = cast<OperationStmt>(constOp->getOperation())->getResult(0);
+    }
     StmtBlock::iterator beforeUnrolledBody = --builder.getInsertionPoint();
 
-    // Pairs of <old op stmt result whose uses need to be replaced,
-    // new result generated by the corresponding cloned op stmt>.
-    SmallVector<std::pair<MLValue *, MLValue *>, 8> oldNewResultPairs;
+    // Clone the loop body and insert it right after the loop - the latter will
+    // be erased after all unrolling has been done.
+    cloneStmtListFromBlock(&builder, *forStmt);
 
-    for (auto &loopBodyStmt : forStmt->getStatements()) {
-      auto *cloneStmt = builder.clone(loopBodyStmt);
-      // Replace all uses of the IV in the clone with constant iteration value.
-      replaceAllStmtUses(cloneStmt, forStmt, ivConst);
-
-      // Whenever we have an op stmt, we'll have a new ML Value defined: replace
-      // uses of the old result with this one.
-      if (auto *opStmt = dyn_cast<OperationStmt>(&loopBodyStmt)) {
-        if (opStmt->getNumResults()) {
-          auto *cloneOpStmt = cast<OperationStmt>(cloneStmt);
-          for (unsigned i = 0, e = opStmt->getNumResults(); i < e; i++) {
-            // Store old/new result pairs.
-            // TODO *only* if needed later: storing of old/new results can be
-            // avoided, by cloning the statement list in the reverse direction
-            // (and running the IR builder in the reverse
-            // (iplist.insertAfter()). That way, a newly created result can be
-            // immediately propagated to all its uses, which would already  been
-            // cloned/inserted.
-            oldNewResultPairs.push_back(std::make_pair(
-                &opStmt->getStmtResult(i), &cloneOpStmt->getStmtResult(i)));
-          }
-        }
-      }
-    }
-    // Replace uses of old op results' with the results in the just
-    // unrolled body.
-    StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
-    for (auto it = ++beforeUnrolledBody; it != endOfUnrolledBody; it++) {
-      for (unsigned i = 0; i < oldNewResultPairs.size(); i++) {
-        replaceAllStmtUses(&(*it), oldNewResultPairs[i].first,
-                           oldNewResultPairs[i].second);
-      }
+    // Replace unrolled loop IV with the unrolled constant.
+    if (ivConst) {
+      StmtBlock::iterator startOfUnrolledBody = ++beforeUnrolledBody;
+      StmtBlock::iterator endOfUnrolledBody = builder.getInsertionPoint();
+      replaceUses(startOfUnrolledBody, endOfUnrolledBody, forStmt, ivConst);
     }
   }
-  // Erase the original for stmt from the block.
+  // Erase the original 'for' stmt from the block.
   forStmt->eraseFromBlock();
 }
diff --git a/test/Transforms/unroll.mlir b/test/Transforms/unroll.mlir
index f091ba1..42f1c2a 100644
--- a/test/Transforms/unroll.mlir
+++ b/test/Transforms/unroll.mlir
@@ -1,17 +1,14 @@
 // RUN: %S/../../mlir-opt %s -o - -unroll-innermost-loops | FileCheck %s
+// RUN: %S/../../mlir-opt %s -o - -unroll-short-loops | FileCheck %s --check-prefix SHORT
 
 // CHECK-LABEL: mlfunc @loop_nest_simplest() {
 mlfunc @loop_nest_simplest() {
-  // CHECK: %c1_i32 = constant 1 : i32
-  // CHECK-NEXT: %c2_i32 = constant 2 : i32
-  // CHECK-NEXT: %c3_i32 = constant 3 : i32
-  // CHECK-NEXT: %c4_i32 = constant 4 : i32
-  // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
+  // CHECK: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
-    // CHECK: %c1_i32_0 = constant 1 : i32
+    // CHECK: %c1_i32 = constant 1 : i32
+    // CHECK-NEXT: %c1_i32_0 = constant 1 : i32
     // CHECK-NEXT: %c1_i32_1 = constant 1 : i32
     // CHECK-NEXT: %c1_i32_2 = constant 1 : i32
-    // CHECK-NEXT: %c1_i32_3 = constant 1 : i32
     for %j = 1 to 4 {
       %x = constant 1 : i32
     }
@@ -21,16 +18,16 @@
 
 // CHECK-LABEL: mlfunc @loop_nest_simple_iv_use() {
 mlfunc @loop_nest_simple_iv_use() {
-  // CHECK: %c1_i32 = constant 1 : i32
-  // CHECK-NEXT: %c2_i32 = constant 2 : i32
-  // CHECK-NEXT: %c3_i32 = constant 3 : i32
-  // CHECK-NEXT: %c4_i32 = constant 4 : i32
+  // CHECK: %c1 = constant 1 : affineint
+  // CHECK-NEXT: %c2 = constant 2 : affineint
+  // CHECK-NEXT: %c3 = constant 3 : affineint
+  // CHECK-NEXT: %c4 = constant 4 : affineint
   // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
   for %i = 1 to 100 step 2 {
-    // CHECK:       %0 = "addi32"(%c1_i32, %c1_i32) : (i32, i32) -> i32
-    // CHECK-NEXT:  %1 = "addi32"(%c2_i32, %c2_i32) : (i32, i32) -> i32
-    // CHECK-NEXT:  %2 = "addi32"(%c3_i32, %c3_i32) : (i32, i32) -> i32
-    // CHECK-NEXT:  %3 = "addi32"(%c4_i32, %c4_i32) : (i32, i32) -> i32
+    // CHECK:       %0 = "addi32"(%c1, %c1) : (affineint, affineint) -> i32
+    // CHECK-NEXT:  %1 = "addi32"(%c2, %c2) : (affineint, affineint) -> i32
+    // CHECK-NEXT:  %2 = "addi32"(%c3, %c3) : (affineint, affineint) -> i32
+    // CHECK-NEXT:  %3 = "addi32"(%c4, %c4) : (affineint, affineint) -> i32
     for %j = 1 to 4 {
       %x = "addi32"(%j, %j) : (affineint, affineint) -> i32
     }
@@ -38,29 +35,57 @@
   return  // CHECK:  return
 }         // CHECK }
 
+// Operations in the loop body have results that are used therein.
+// CHECK-LABEL: mlfunc @loop_nest_body_def_use() {
+mlfunc @loop_nest_body_def_use() {
+  // CHECK: %c0 = constant 0 : affineint
+  // CHECK-NEXT: %c1 = constant 1 : affineint
+  // CHECK-NEXT: %c2 = constant 2 : affineint
+  // CHECK-NEXT: %c3 = constant 3 : affineint
+  // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
+  for %i = 1 to 100 step 2 {
+    // CHECK: %c0_0 = constant 0 : affineint
+    %c0 = constant 0 : affineint
+    // CHECK:      %0 = affine_apply (d0) -> (d0 + 1)(%c0)
+    // CHECK-NEXT: %1 = "addi32"(%0, %c0_0) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c1)
+    // CHECK-NEXT: %3 = "addi32"(%2, %c0_0) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2)
+    // CHECK-NEXT: %5 = "addi32"(%4, %c0_0) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c3)
+    // CHECK-NEXT: %7 = "addi32"(%6, %c0_0) : (affineint, affineint) -> affineint
+    for %j = 0 to 3 {
+      %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
+        (affineint) -> (affineint)
+      %y = "addi32"(%x, %c0) : (affineint, affineint) -> affineint
+    }
+  }       // CHECK:  }
+  return  // CHECK:  return
+}         // CHECK }
+
 // CHECK-LABEL: mlfunc @loop_nest_strided() {
 mlfunc @loop_nest_strided() {
-  // CHECK: %c3_i32 = constant 3 : i32
-  // CHECK-NEXT: %c5_i32 = constant 5 : i32
-  // CHECK-NEXT: %c7_i32 = constant 7 : i32
-  // CHECK-NEXT: %c3_i32_0 = constant 3 : i32
-  // CHECK-NEXT: %c5_i32_1 = constant 5 : i32
+  // CHECK: %c3 = constant 3 : affineint
+  // CHECK-NEXT: %c5 = constant 5 : affineint
+  // CHECK-NEXT: %c7 = constant 7 : affineint
+  // CHECK-NEXT: %c3_0 = constant 3 : affineint
+  // CHECK-NEXT: %c5_1 = constant 5 : affineint
   // CHECK-NEXT: for %i0 = 1 to 100 {
   for %i = 1 to 100 {
-    // CHECK:      %0 = affine_apply (d0) -> (d0 + 1)(%c3_i32_0)
+    // CHECK:      %0 = affine_apply (d0) -> (d0 + 1)(%c3_0)
     // CHECK-NEXT: %1 = "addi32"(%0, %0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c5_i32_1)
+    // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c5_1)
     // CHECK-NEXT: %3 = "addi32"(%2, %2) : (affineint, affineint) -> affineint
     for %j = 3 to 6 step 2 {
       %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
       %y = "addi32"(%x, %x) : (affineint, affineint) -> affineint
     }
-    // CHECK:      %4 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
+    // CHECK:      %4 = affine_apply (d0) -> (d0 + 1)(%c3)
     // CHECK-NEXT: %5 = "addi32"(%4, %4) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c5_i32)
+    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c5)
     // CHECK-NEXT: %7 = "addi32"(%6, %6) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c7_i32)
+    // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c7)
     // CHECK-NEXT: %9 = "addi32"(%8, %8) : (affineint, affineint) -> affineint
     for %k = 3 to 7 step 2 {
       %z = "affine_apply" (%k) { map: (d0) -> (d0 + 1) } :
@@ -71,29 +96,26 @@
   return  // CHECK:  return
 }         // CHECK }
 
-// Operations in the loop body have results that are used therein.
-// CHECK-LABEL: mlfunc @loop_nest_body_def_use() {
-mlfunc @loop_nest_body_def_use() {
-  // CHECK: %c0_i32 = constant 0 : i32
-  // CHECK-NEXT: %c1_i32 = constant 1 : i32
-  // CHECK-NEXT: %c2_i32 = constant 2 : i32
-  // CHECK-NEXT: %c3_i32 = constant 3 : i32
-  // CHECK-NEXT: for %i0 = 1 to 100 step 2 {
-  for %i = 1 to 100 step 2 {
-    // CHECK: %c0 = constant 0 : affineint
-    %c0 = constant 0 : affineint
-    // CHECK:      %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
-    // CHECK-NEXT: %1 = "addi32"(%0, %c0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
-    // CHECK-NEXT: %3 = "addi32"(%2, %c0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
-    // CHECK-NEXT: %5 = "addi32"(%4, %c0) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
-    // CHECK-NEXT: %7 = "addi32"(%6, %c0) : (affineint, affineint) -> affineint
-    for %j = 0 to 3 {
-      %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
-        (affineint) -> (affineint)
-      %y = "addi32"(%x, %c0) : (affineint, affineint) -> affineint
+// CHECK-LABEL: mlfunc @loop_nest_multiple_results() {
+mlfunc @loop_nest_multiple_results() {
+  // CHECK: %c0 = constant 0 : affineint
+  // CHECK-NEXT: %c1 = constant 1 : affineint
+  for %i = 1 to 100 {
+    // CHECK: %0 = affine_apply (d0, d1) -> (d0 + 1, d1 + 2)(%i0, %c0)
+    // CHECK-NEXT: %1 = "addi32"(%0#0, %0#1) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %2 = affine_apply (d0, d1) -> (d0 + 3, d1 + 4)(%i0, %c0)
+    // CHECK-NEXT: %3 = "fma"(%2#0, %2#1, %0#0) : (affineint, affineint, affineint) -> (affineint, affineint)
+    // CHECK-NEXT: %4 = affine_apply (d0, d1) -> (d0 + 1, d1 + 2)(%i0, %c1)
+    // CHECK-NEXT: %5 = "addi32"(%4#0, %4#1) : (affineint, affineint) -> affineint
+    // CHECK-NEXT: %6 = affine_apply (d0, d1) -> (d0 + 3, d1 + 4)(%i0, %c1)
+    // CHECK-NEXT: %7 = "fma"(%6#0, %6#1, %4#0) : (affineint, affineint, affineint) -> (affineint, affineint)
+    for %j = 0 to 1 step 1 {
+      %x = "affine_apply" (%i, %j) { map: (d0, d1) -> (d0 + 1, d1 + 2) } :
+        (affineint, affineint) -> (affineint, affineint)
+      %y = "addi32"(%x#0, %x#1) : (affineint, affineint) -> affineint
+      %z = "affine_apply" (%i, %j) { map: (d0, d1) -> (d0 + 3, d1 + 4) } :
+        (affineint, affineint) -> (affineint, affineint)
+      %w = "fma"(%z#0, %z#1, %x#0) : (affineint, affineint, affineint) -> (affineint, affineint)
     }
   }       // CHECK:  }
   return  // CHECK:  return
@@ -103,27 +125,27 @@
 // Imperfect loop nest. Unrolling innermost here yields a perfect nest.
 // CHECK-LABEL: mlfunc @loop_nest_seq_imperfect(%arg0 : memref<128x128xf32>) {
 mlfunc @loop_nest_seq_imperfect(%a : memref<128x128xf32>) {
-  // CHECK: %c1_i32 = constant 1 : i32
-  // CHECK-NEXT: %c2_i32 = constant 2 : i32
-  // CHECK-NEXT: %c3_i32 = constant 3 : i32
-  // CHECK-NEXT: %c4_i32 = constant 4 : i32
+  // CHECK: %c1 = constant 1 : affineint
+  // CHECK-NEXT: %c2 = constant 2 : affineint
+  // CHECK-NEXT: %c3 = constant 3 : affineint
+  // CHECK-NEXT: %c4 = constant 4 : affineint
   // CHECK-NEXT: %c128 = constant 128 : affineint
   %c128 = constant 128 : affineint
   // CHECK: for %i0 = 1 to 100 {
   for %i = 1 to 100 {
     // CHECK: %0 = "vld"(%i0) : (affineint) -> i32
     %ld = "vld"(%i) : (affineint) -> i32
-    // CHECK: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
-    // CHECK-NEXT: %2 = "vmulf"(%c1_i32, %1) : (i32, affineint) -> affineint
+    // CHECK: %1 = affine_apply (d0) -> (d0 + 1)(%c1)
+    // CHECK-NEXT: %2 = "vmulf"(%c1, %1) : (affineint, affineint) -> affineint
     // CHECK-NEXT: %3 = "vaddf"(%2, %2) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
-    // CHECK-NEXT: %5 = "vmulf"(%c2_i32, %4) : (i32, affineint) -> affineint
+    // CHECK-NEXT: %4 = affine_apply (d0) -> (d0 + 1)(%c2)
+    // CHECK-NEXT: %5 = "vmulf"(%c2, %4) : (affineint, affineint) -> affineint
     // CHECK-NEXT: %6 = "vaddf"(%5, %5) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %7 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
-    // CHECK-NEXT: %8 = "vmulf"(%c3_i32, %7) : (i32, affineint) -> affineint
+    // CHECK-NEXT: %7 = affine_apply (d0) -> (d0 + 1)(%c3)
+    // CHECK-NEXT: %8 = "vmulf"(%c3, %7) : (affineint, affineint) -> affineint
     // CHECK-NEXT: %9 = "vaddf"(%8, %8) : (affineint, affineint) -> affineint
-    // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4_i32)
-    // CHECK-NEXT: %11 = "vmulf"(%c4_i32, %10) : (i32, affineint) -> affineint
+    // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4)
+    // CHECK-NEXT: %11 = "vmulf"(%c4, %10) : (affineint, affineint) -> affineint
     // CHECK-NEXT: %12 = "vaddf"(%11, %11) : (affineint, affineint) -> affineint
     for %j = 1 to 4 {
       %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
@@ -141,21 +163,21 @@
 
 // CHECK-LABEL: mlfunc @loop_nest_seq_multiple() {
 mlfunc @loop_nest_seq_multiple() {
-  // CHECK: %c1_i32 = constant 1 : i32
-  // CHECK-NEXT: %c2_i32 = constant 2 : i32
-  // CHECK-NEXT: %c3_i32 = constant 3 : i32
-  // CHECK-NEXT: %c4_i32 = constant 4 : i32
-  // CHECK-NEXT: %c0_i32 = constant 0 : i32
-  // CHECK-NEXT: %c1_i32_0 = constant 1 : i32
-  // CHECK-NEXT: %c2_i32_1 = constant 2 : i32
-  // CHECK-NEXT: %c3_i32_2 = constant 3 : i32
-  // CHECK-NEXT: %0 = affine_apply (d0) -> (d0 + 1)(%c0_i32)
+  // CHECK: %c1 = constant 1 : affineint
+  // CHECK-NEXT: %c2 = constant 2 : affineint
+  // CHECK-NEXT: %c3 = constant 3 : affineint
+  // CHECK-NEXT: %c4 = constant 4 : affineint
+  // CHECK-NEXT: %c0 = constant 0 : affineint
+  // CHECK-NEXT: %c1_0 = constant 1 : affineint
+  // CHECK-NEXT: %c2_1 = constant 2 : affineint
+  // CHECK-NEXT: %c3_2 = constant 3 : affineint
+  // CHECK-NEXT: %0 = affine_apply (d0) -> (d0 + 1)(%c0)
   // CHECK-NEXT: "mul"(%0, %0) : (affineint, affineint) -> ()
-  // CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_i32_0)
+  // CHECK-NEXT: %1 = affine_apply (d0) -> (d0 + 1)(%c1_0)
   // CHECK-NEXT: "mul"(%1, %1) : (affineint, affineint) -> ()
-  // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_i32_1)
+  // CHECK-NEXT: %2 = affine_apply (d0) -> (d0 + 1)(%c2_1)
   // CHECK-NEXT: "mul"(%2, %2) : (affineint, affineint) -> ()
-  // CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_i32_2)
+  // CHECK-NEXT: %3 = affine_apply (d0) -> (d0 + 1)(%c3_2)
   // CHECK-NEXT: "mul"(%3, %3) : (affineint, affineint) -> ()
   for %j = 0 to 3 {
     %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
@@ -167,14 +189,14 @@
   %k = "constant"(){value: 99} : () -> affineint
   // CHECK: for %i0 = 1 to 100 step 2 {
   for %m = 1 to 100 step 2 {
-    // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c1_i32)
-    // CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1_i32)[%c99]
-    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c2_i32)
-    // CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2_i32)[%c99]
-    // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c3_i32)
-    // CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3_i32)[%c99]
-    // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4_i32)
-    // CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c4_i32)[%c99]
+    // CHECK: %4 = affine_apply (d0) -> (d0 + 1)(%c1)
+    // CHECK-NEXT: %5 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c1)[%c99]
+    // CHECK-NEXT: %6 = affine_apply (d0) -> (d0 + 1)(%c2)
+    // CHECK-NEXT: %7 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c2)[%c99]
+    // CHECK-NEXT: %8 = affine_apply (d0) -> (d0 + 1)(%c3)
+    // CHECK-NEXT: %9 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c3)[%c99]
+    // CHECK-NEXT: %10 = affine_apply (d0) -> (d0 + 1)(%c4)
+    // CHECK-NEXT: %11 = affine_apply (d0)[s0] -> (d0 + s0 + 1)(%c4)[%c99]
     for %n = 1 to 4 {
       %y = "affine_apply" (%n) { map: (d0) -> (d0 + 1) } :
         (affineint) -> (affineint)
@@ -184,3 +206,23 @@
   }       // CHECK }
   return  // CHECK:  return
 }         // CHECK }
+
+// SHORT-LABEL: mlfunc @loop_nest_outer_unroll() {
+mlfunc @loop_nest_outer_unroll() {
+  // SHORT:      for %i0 = 1 to 4 {
+  // SHORT-NEXT:   %0 = affine_apply (d0) -> (d0 + 1)(%i0)
+  // SHORT-NEXT:   %1 = "addi32"(%0, %0) : (affineint, affineint) -> affineint
+  // SHORT-NEXT: }
+  // SHORT-NEXT: for %i1 = 1 to 4 {
+  // SHORT-NEXT:   %2 = affine_apply (d0) -> (d0 + 1)(%i1)
+  // SHORT-NEXT:   %3 = "addi32"(%2, %2) : (affineint, affineint) -> affineint
+  // SHORT-NEXT: }
+  for %i = 1 to 2 {
+    for %j = 1 to 4 {
+      %x = "affine_apply" (%j) { map: (d0) -> (d0 + 1) } :
+        (affineint) -> (affineint)
+      %y = "addi32"(%x, %x) : (affineint, affineint) -> affineint
+    }
+  }
+  return  // SHORT:  return
+}         // SHORT }
diff --git a/tools/mlir-opt/mlir-opt.cpp b/tools/mlir-opt/mlir-opt.cpp
index c95b838..26417e2 100644
--- a/tools/mlir-opt/mlir-opt.cpp
+++ b/tools/mlir-opt/mlir-opt.cpp
@@ -54,6 +54,7 @@
 enum Passes {
   ConvertToCFG,
   UnrollInnermostLoops,
+  UnrollShortLoops,
   TFRaiseControlFlow,
 };
 
@@ -63,6 +64,8 @@
                           "Convert all ML functions in the module to CFG ones"),
                clEnumValN(UnrollInnermostLoops, "unroll-innermost-loops",
                           "Unroll innermost loops"),
+               clEnumValN(UnrollShortLoops, "unroll-short-loops",
+                          "Unroll loops of trip count <= 2"),
                clEnumValN(TFRaiseControlFlow, "tf-raise-control-flow",
                           "Dynamic TensorFlow Switch/Match nodes to a CFG")));
 
@@ -111,6 +114,9 @@
     case UnrollInnermostLoops:
       pass = createLoopUnrollPass();
       break;
+    case UnrollShortLoops:
+      pass = createLoopUnrollPass(2);
+      break;
     case TFRaiseControlFlow:
       pass = createRaiseTFControlFlowPass();
       break;