Add utility to promote single iteration loops. Add methods for getting constant
loop counts. Improve / refactor loop unroll / loop unroll and jam.

- add utility to remove single iteration loops.
- use this utility to promote single iteration loops after unroll/unroll-and-jam
- use loopUnrollByFactor for loopUnrollFull and remove most of the latter.
- add methods for getting constant loop trip count

PiperOrigin-RevId: 212039569
diff --git a/lib/Transforms/LoopUnroll.cpp b/lib/Transforms/LoopUnroll.cpp
index 382b830..e663ce0 100644
--- a/lib/Transforms/LoopUnroll.cpp
+++ b/lib/Transforms/LoopUnroll.cpp
@@ -33,22 +33,21 @@
 #include "mlir/Transforms/Passes.h"
 #include "llvm/ADT/DenseMap.h"
 #include "llvm/Support/CommandLine.h"
-#include "llvm/Support/raw_ostream.h"
 
 using namespace mlir;
-using namespace llvm;
 
 // Loop unrolling factor.
 static llvm::cl::opt<unsigned>
-    clUnrollFactor("unroll-factor", cl::Hidden,
-                   cl::desc("Use this unroll factor for all loops"));
+    clUnrollFactor("unroll-factor", llvm::cl::Hidden,
+                   llvm::cl::desc("Use this unroll factor for all loops"));
 
-static llvm::cl::opt<bool> clUnrollFull("unroll-full", cl::Hidden,
-                                        cl::desc("Fully unroll loops"));
+static llvm::cl::opt<bool> clUnrollFull("unroll-full", llvm::cl::Hidden,
+                                        llvm::cl::desc("Fully unroll loops"));
 
 static llvm::cl::opt<unsigned> clUnrollFullThreshold(
-    "unroll-full-threshold", cl::Hidden,
-    cl::desc("Unroll all loops with trip count less than or equal to this"));
+    "unroll-full-threshold", llvm::cl::Hidden,
+    llvm::cl::desc(
+        "Unroll all loops with trip count less than or equal to this"));
 
 namespace {
 /// Loop unrolling pass. Unrolls all innermost loops unless full unrolling and a
@@ -67,7 +66,7 @@
   /// Unroll this for stmt. Returns false if nothing was done.
   bool runOnForStmt(ForStmt *forStmt);
   bool loopUnrollFull(ForStmt *forStmt);
-  bool loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor);
+  bool loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor);
 };
 } // end anonymous namespace
 
@@ -129,13 +128,8 @@
     ShortLoopGatherer(unsigned minTripCount) : minTripCount(minTripCount) {}
 
     void visitForStmt(ForStmt *forStmt) {
-      if (!forStmt->hasConstantBounds())
-        return;
-      auto lb = forStmt->getConstantLowerBound();
-      auto ub = forStmt->getConstantUpperBound();
-      auto step = forStmt->getStep();
-
-      if ((ub - lb) / step + 1 <= minTripCount)
+      Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+      if (tripCount.hasValue() && tripCount.getValue() <= minTripCount)
         loops.push_back(forStmt);
     }
   };
@@ -180,43 +174,14 @@
 // Unrolls this loop completely. Fails assertion if loop bounds are
 // non-constant.
 bool LoopUnroll::loopUnrollFull(ForStmt *forStmt) {
-  auto lb = forStmt->getConstantLowerBound();
-  auto ub = forStmt->getConstantUpperBound();
-  auto step = forStmt->getStep();
-
-  // Builder to add constants needed for the unrolled iterator.
-  auto *mlFunc = forStmt->findFunction();
-  MLFuncBuilder funcTopBuilder(&mlFunc->front());
-
-  // Builder to insert the unrolled bodies.  We insert right after the
-  // ForStmt we're unrolling.
-  MLFuncBuilder builder(forStmt->getBlock(), ++StmtBlock::iterator(forStmt));
-
-  // Unroll the contents of 'forStmt'.
-  for (int64_t i = lb; i <= ub; i += step) {
-    DenseMap<const MLValue *, MLValue *> operandMapping;
-
-    // If the induction variable is used, create a constant for this unrolled
-    // value and add an operand mapping for it.
-    if (!forStmt->use_empty()) {
-      auto *ivConst =
-          funcTopBuilder.create<ConstantAffineIntOp>(forStmt->getLoc(), i)
-              ->getResult();
-      operandMapping[forStmt] = cast<MLValue>(ivConst);
-    }
-
-    // Clone the body of the loop.
-    for (auto &childStmt : *forStmt) {
-      builder.clone(childStmt, operandMapping);
-    }
-  }
-  // Erase the original 'for' stmt from the block.
-  forStmt->eraseFromBlock();
-  return true;
+  Optional<uint64_t> tripCount = forStmt->getConstantTripCount();
+  if (tripCount.hasValue())
+    return loopUnrollByFactor(forStmt, tripCount.getValue());
+  return false;
 }
 
 /// Unrolls this loop by the specified unroll factor.
-bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, unsigned unrollFactor) {
+bool LoopUnroll::loopUnrollByFactor(ForStmt *forStmt, uint64_t unrollFactor) {
   assert(unrollFactor >= 1 && "unroll factor shoud be >= 1");
 
   if (unrollFactor == 1 || forStmt->getStatements().empty())
@@ -225,11 +190,9 @@
   if (!forStmt->hasConstantBounds())
     return false;
 
-  auto lb = forStmt->getConstantLowerBound();
-  auto ub = forStmt->getConstantUpperBound();
-  auto step = forStmt->getStep();
-
-  int64_t tripCount = (int64_t)ceilf((ub - lb + 1) / (float)step);
+  int64_t lb = forStmt->getConstantLowerBound();
+  int64_t step = forStmt->getStep();
+  uint64_t tripCount = forStmt->getConstantTripCount().getValue();
 
   // If the trip count is lower than the unroll factor, no unrolled body.
   // TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -243,6 +206,8 @@
     auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
     cleanupForStmt->setConstantLowerBound(
         lb + (tripCount - tripCount % unrollFactor) * step);
+    // Promote the loop body up if this has turned into a single iteration loop.
+    promoteIfSingleIteration(cleanupForStmt);
   }
 
   // Builder to insert unrolled bodies right after the last statement in the
@@ -281,5 +246,9 @@
     // Clone the last statement in the original body.
     builder.clone(*srcBlockEnd, operandMapping);
   }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  promoteIfSingleIteration(forStmt);
+
   return true;
 }