[LoopReroll] Fix reroll root legality checking.

The code checked that the first root was an appropriate distance from
the base value, but skipped checking the other roots. This could lead to
rerolling a loop that can't be legally rerolled (at least, not without
rewriting the loop in a non-trivial way).

Differential Revision: https://reviews.llvm.org/D56812

llvm-svn: 353779
diff --git a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
index 17aa442..166b57f 100644
--- a/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
+++ b/llvm/lib/Transforms/Scalar/LoopRerollPass.cpp
@@ -891,12 +891,22 @@
   const auto *ADR = dyn_cast<SCEVAddRecExpr>(SE->getSCEV(DRS.BaseInst));
   if (!ADR)
     return false;
+
+  // Check that the first root is evenly spaced.
   unsigned N = DRS.Roots.size() + 1;
   const SCEV *StepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[0]), ADR);
   const SCEV *ScaleSCEV = SE->getConstant(StepSCEV->getType(), N);
   if (ADR->getStepRecurrence(*SE) != SE->getMulExpr(StepSCEV, ScaleSCEV))
     return false;
 
+  // Check that the remainling roots are evenly spaced.
+  for (unsigned i = 1; i < N - 1; ++i) {
+    const SCEV *NewStepSCEV = SE->getMinusSCEV(SE->getSCEV(DRS.Roots[i]),
+                                               SE->getSCEV(DRS.Roots[i-1]));
+    if (NewStepSCEV != StepSCEV)
+      return false;
+  }
+
   return true;
 }
 
diff --git a/llvm/test/Transforms/LoopReroll/basic.ll b/llvm/test/Transforms/LoopReroll/basic.ll
index 6e2f2fc..b415b26 100644
--- a/llvm/test/Transforms/LoopReroll/basic.ll
+++ b/llvm/test/Transforms/LoopReroll/basic.ll
@@ -785,6 +785,30 @@
   ret void
 }
 
+define void @bad_step(i32* nocapture readnone %x) #0 {
+entry:
+  br label %for.body
+
+for.body:                                         ; preds = %for.body, %entry
+  %i.08 = phi i32 [ 0, %entry ], [ %add3, %for.body ]
+  %call = tail call i32 @foo(i32 %i.08) #1
+  %add = add nsw i32 %i.08, 2
+  %call1 = tail call i32 @foo(i32 %add) #1
+  %add2 = add nsw i32 %i.08, 3
+  %call3 = tail call i32 @foo(i32 %add2) #1
+  %add3 = add nsw i32 %i.08, 6
+  %exitcond = icmp sge i32 %add3, 500
+  br i1 %exitcond, label %for.end, label %for.body
+
+; CHECK-LABEL: @bad_step
+; CHECK: %add = add nsw i32 %i.08, 2
+; CHECK: %add2 = add nsw i32 %i.08, 3
+; CHECK: %add3 = add nsw i32 %i.08, 6
+
+for.end:                                          ; preds = %for.body
+  ret void
+}
+
 attributes #0 = { nounwind uwtable }
 attributes #1 = { nounwind }