Enable non-power-of-2 #pragma unroll counts.

Patch by Evgeny Stupachenko.

Differential Revision: http://reviews.llvm.org/D18202

llvm-svn: 264407
diff --git a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
index e478722..c3761dcf 100644
--- a/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopUnrollPass.cpp
@@ -684,11 +684,6 @@
   }
 
   if (HasPragma) {
-    if (PragmaCount != 0)
-      // If loop has an unroll count pragma mark loop as unrolled to prevent
-      // unrolling beyond that requested by the pragma.
-      SetLoopAlreadyUnrolled(L);
-
     // Emit optimization remarks if we are unable to unroll the loop
     // as directed by a pragma.
     DebugLoc LoopLoc = L->getStartLoc();
@@ -738,6 +733,10 @@
                   TripMultiple, LI, SE, &DT, &AC, PreserveLCSSA))
     return false;
 
+  // If loop has an unroll count pragma mark loop as unrolled to prevent
+  // unrolling beyond that requested by the pragma.
+  if (HasPragma && PragmaCount != 0)
+    SetLoopAlreadyUnrolled(L);
   return true;
 }
 
diff --git a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
index 232707d..264988c 100644
--- a/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
+++ b/llvm/lib/Transforms/Utils/LoopUnrollRuntime.cpp
@@ -117,10 +117,10 @@
 
   assert(Count != 0 && "nonsensical Count!");
 
-  // If BECount <u (Count - 1) then (BECount + 1) & (Count - 1) == (BECount + 1)
-  // (since Count is a power of 2).  This means %xtraiter is (BECount + 1) and
-  // and all of the iterations of this loop were executed by the prologue.  Note
-  // that if BECount <u (Count - 1) then (BECount + 1) cannot unsigned-overflow.
+  // If BECount <u (Count - 1) then (BECount + 1) % Count == (BECount + 1)
+  // This means %xtraiter is (BECount + 1) and all of the iterations of this
+  // loop were executed by the prologue.  Note that if BECount <u (Count - 1)
+  // then (BECount + 1) cannot unsigned-overflow.
   Value *BrLoopExit =
       B.CreateICmpULT(BECount, ConstantInt::get(BECount->getType(), Count - 1));
   BasicBlock *Exit = L->getUniqueExitBlock();
@@ -319,11 +319,6 @@
       Expander.isHighCostExpansion(TripCountSC, L, PreHeaderBR))
     return false;
 
-  // We only handle cases when the unroll factor is a power of 2.
-  // Count is the loop unroll factor, the number of extra copies added + 1.
-  if (!isPowerOf2_32(Count))
-    return false;
-
   // This constraint lets us deal with an overflowing trip count easily; see the
   // comment on ModVal below.
   if (Log2_32(Count) > BEWidth)
@@ -349,18 +344,33 @@
                                           PreHeaderBR);
 
   IRBuilder<> B(PreHeaderBR);
-  Value *ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter");
-
-  // If ModVal is zero, we know that either
-  //  1. There are no iterations to be run in the prologue loop.
-  // OR
-  //  2. The addition computing TripCount overflowed.
-  //
-  // If (2) is true, we know that TripCount really is (1 << BEWidth) and so the
-  // number of iterations that remain to be run in the original loop is a
-  // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we
-  // explicitly check this above).
-
+  Value *ModVal;
+  // Calculate ModVal = (BECount + 1) % Count.
+  // Note that TripCount is BECount + 1.
+  if (isPowerOf2_32(Count)) {
+    ModVal = B.CreateAnd(TripCount, Count - 1, "xtraiter");
+    //  1. There are no iterations to be run in the prologue loop.
+    // OR
+    //  2. The addition computing TripCount overflowed.
+    //
+    // If (2) is true, we know that TripCount really is (1 << BEWidth) and so
+    // the number of iterations that remain to be run in the original loop is a
+    // multiple Count == (1 << Log2(Count)) because Log2(Count) <= BEWidth (we
+    // explicitly check this above).
+  } else {
+    // As (BECount + 1) can potentially unsigned overflow we count
+    // (BECount % Count) + 1 which is overflow safe as BECount % Count < Count.
+    Value *ModValTmp = B.CreateURem(BECount,
+                                    ConstantInt::get(BECount->getType(),
+                                                     Count));
+    Value *ModValAdd = B.CreateAdd(ModValTmp,
+                                   ConstantInt::get(ModValTmp->getType(), 1));
+    // At that point (BECount % Count) + 1 could be equal to Count.
+    // To handle this case we need to take mod by Count one more time.
+    ModVal = B.CreateURem(ModValAdd,
+                          ConstantInt::get(BECount->getType(), Count),
+                          "xtraiter");
+  }
   Value *BranchVal = B.CreateIsNotNull(ModVal, "lcmp.mod");
 
   // Branch to either the extra iterations or the cloned/unrolled loop.