[SCEV] Properly solve quadratic equations

Differential Revision: https://reviews.llvm.org/D48283

llvm-svn: 338758
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 48f9119..731c464 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -10,6 +10,7 @@
 #include "llvm/ADT/APInt.h"
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/SmallString.h"
+#include "llvm/ADT/Twine.h"
 #include "gtest/gtest.h"
 #include <array>
 
@@ -2357,4 +2358,89 @@
   }
 }
 
+TEST(APIntTest, SolveQuadraticEquationWrap) {
+  // Verify that "Solution" is the first non-negative integer that solves
+  // Ax^2 + Bx + C = "0 or overflow", i.e. that it is a correct solution
+  // as calculated by SolveQuadraticEquationWrap.
+  auto Validate = [] (int A, int B, int C, unsigned Width, int Solution) {
+    int Mask = (1 << Width) - 1;
+
+    // Solution should be non-negative.
+    EXPECT_GE(Solution, 0);
+
+    auto OverflowBits = [] (int64_t V, unsigned W) {
+      return V & -(1 << W);
+    };
+
+    int64_t Over0 = OverflowBits(C, Width);
+
+    auto IsZeroOrOverflow = [&] (int X) {
+      int64_t ValueAtX = A*X*X + B*X + C;
+      int64_t OverX = OverflowBits(ValueAtX, Width);
+      return (ValueAtX & Mask) == 0 || OverX != Over0;
+    };
+
+    auto EquationToString = [&] (const char *X_str) {
+      return Twine(A) + Twine(X_str) + Twine("^2 + ") + Twine(B) +
+             Twine(X_str) + Twine(" + ") + Twine(C) + Twine(", bitwidth: ") +
+             Twine(Width);
+    };
+
+    auto IsSolution = [&] (const char *X_str, int X) {
+      if (IsZeroOrOverflow(X))
+        return ::testing::AssertionSuccess()
+                  << X << " is a solution of " << EquationToString(X_str);
+      return ::testing::AssertionFailure()
+                << X << " is not an expected solution of "
+                << EquationToString(X_str);
+    };
+
+    auto IsNotSolution = [&] (const char *X_str, int X) {
+      if (!IsZeroOrOverflow(X))
+        return ::testing::AssertionSuccess()
+                  << X << " is not a solution of " << EquationToString(X_str);
+      return ::testing::AssertionFailure()
+                << X << " is an unexpected solution of "
+                << EquationToString(X_str);
+    };
+
+    // This is the important part: make sure that there is no solution that
+    // is less than the calculated one.
+    if (Solution > 0) {
+      for (int X = 1; X < Solution-1; ++X)
+        EXPECT_PRED_FORMAT1(IsNotSolution, X);
+    }
+
+    // Verify that the calculated solution is indeed a solution.
+    EXPECT_PRED_FORMAT1(IsSolution, Solution);
+  };
+
+  // Generate all possible quadratic equations with Width-bit wide integer
+  // coefficients, get the solution from SolveQuadraticEquationWrap, and
+  // verify that the solution is correct.
+  auto Iterate = [&] (unsigned Width) {
+    assert(1 < Width && Width < 32);
+    int Low = -(1 << (Width-1));
+    int High = (1 << (Width-1));
+
+    for (int A = Low; A != High; ++A) {
+      if (A == 0)
+        continue;
+      for (int B = Low; B != High; ++B) {
+        for (int C = Low; C != High; ++C) {
+          Optional<APInt> S = APIntOps::SolveQuadraticEquationWrap(
+                                APInt(Width, A), APInt(Width, B),
+                                APInt(Width, C), Width);
+          if (S.hasValue())
+            Validate(A, B, C, Width, S->getSExtValue());
+        }
+      }
+    }
+  };
+
+  // Test all widths in [2..6].
+  for (unsigned i = 2; i <= 6; ++i)
+    Iterate(i);
+}
+
 } // end anonymous namespace