Add utility to promote single iteration loops. Add methods for getting constant
loop counts. Improve / refactor loop unroll / loop unroll and jam.

- add utility to remove single iteration loops.
- use this utility to promote single iteration loops after unroll/unroll-and-jam
- use loopUnrollByFactor for loopUnrollFull and remove most of the latter.
- add methods for getting constant loop trip count

PiperOrigin-RevId: 212039569
diff --git a/lib/Transforms/LoopUnrollJam.cpp b/lib/Transforms/LoopUnrollJam.cpp
index eeab87c..6fb6134 100644
--- a/lib/Transforms/LoopUnrollJam.cpp
+++ b/lib/Transforms/LoopUnrollJam.cpp
@@ -23,7 +23,7 @@
 // bounds of the loops inner to the loop being unroll-jammed do not depend on
 // the latter.
 //
-// Before      After unroll-jam of i by factor 2:
+// Before      After unroll and jam of i by factor 2:
 //
 //             for i, step = 2
 // for i         S1(i);
@@ -54,7 +54,6 @@
 #include "llvm/Support/CommandLine.h"
 
 using namespace mlir;
-using namespace llvm::cl;
 
 // Loop unroll jam factor.
 static llvm::cl::opt<unsigned>
@@ -74,7 +73,7 @@
 
   void runOnMLFunction(MLFunction *f) override;
   bool runOnForStmt(ForStmt *forStmt);
-  bool loopUnrollJamByFactor(ForStmt *forStmt, unsigned unrollJamFactor);
+  bool loopUnrollJamByFactor(ForStmt *forStmt, uint64_t unrollJamFactor);
 };
 } // end anonymous namespace
 
@@ -110,15 +109,7 @@
 
 /// Unrolls and jams this loop by the specified factor.
 bool LoopUnrollAndJam::loopUnrollJamByFactor(ForStmt *forStmt,
-                                             unsigned unrollJamFactor) {
-  assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
-
-  if (unrollJamFactor == 1 || forStmt->getStatements().empty())
-    return false;
-
-  if (!forStmt->hasConstantBounds())
-    return false;
-
+                                             uint64_t unrollJamFactor) {
   // Gathers all maximal sub-blocks of statements that do not themselves include
   // a for stmt (a statement could have a descendant for stmt though in its
   // tree).
@@ -146,12 +137,17 @@
     }
   };
 
-  auto lb = forStmt->getConstantLowerBound();
-  auto ub = forStmt->getConstantUpperBound();
-  auto step = forStmt->getStep();
+  assert(unrollJamFactor >= 1 && "unroll jam factor should be >= 1");
 
-  int64_t tripCount = (ub - lb + 1) % step == 0 ? (ub - lb + 1) / step
-                                                : (ub - lb + 1) / step + 1;
+  if (unrollJamFactor == 1 || forStmt->getStatements().empty())
+    return false;
+
+  if (!forStmt->hasConstantBounds())
+    return false;
+
+  int64_t lb = forStmt->getConstantLowerBound();
+  int64_t step = forStmt->getStep();
+  uint64_t tripCount = forStmt->getConstantTripCount().getValue();
 
   // If the trip count is lower than the unroll jam factor, no unrolled body.
   // TODO(bondhugula): option to specify cleanup loop unrolling.
@@ -172,6 +168,9 @@
     auto *cleanupForStmt = cast<ForStmt>(builder.clone(*forStmt, operandMap));
     cleanupForStmt->setConstantLowerBound(
         lb + (tripCount - tripCount % unrollJamFactor) * step);
+
+    // Promote the loop body up if this has turned into a single iteration loop.
+    promoteIfSingleIteration(cleanupForStmt);
   }
 
   MLFuncBuilder b(forStmt);
@@ -210,5 +209,9 @@
       builder.clone(*subBlock.second, operandMapping);
     }
   }
+
+  // Promote the loop body up if this has turned into a single iteration loop.
+  promoteIfSingleIteration(forStmt);
+
   return true;
 }