[SCEV] Properly solve quadratic equations

Differential Revision: https://reviews.llvm.org/D48283

llvm-svn: 338758
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 1fae0e9..d994441 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -16,6 +16,7 @@
 #include "llvm/ADT/ArrayRef.h"
 #include "llvm/ADT/FoldingSet.h"
 #include "llvm/ADT/Hashing.h"
+#include "llvm/ADT/Optional.h"
 #include "llvm/ADT/SmallString.h"
 #include "llvm/ADT/StringRef.h"
 #include "llvm/Config/llvm-config.h"
@@ -2707,3 +2708,193 @@
   }
   llvm_unreachable("Unknown APInt::Rounding enum");
 }
+
+Optional<APInt>
+llvm::APIntOps::SolveQuadraticEquationWrap(APInt A, APInt B, APInt C,
+                                           unsigned RangeWidth) {
+  unsigned CoeffWidth = A.getBitWidth();
+  assert(CoeffWidth == B.getBitWidth() && CoeffWidth == C.getBitWidth());
+  assert(RangeWidth <= CoeffWidth &&
+         "Value range width should be less than coefficient width");
+  assert(RangeWidth > 1 && "Value range bit width should be > 1");
+
+  LLVM_DEBUG(dbgs() << __func__ << ": solving " << A << "x^2 + " << B
+                    << "x + " << C << ", rw:" << RangeWidth << '\n');
+
+  // Identify 0 as a (non)solution immediately.
+  if (C.sextOrTrunc(RangeWidth).isNullValue() ) {
+    LLVM_DEBUG(dbgs() << __func__ << ": zero solution\n");
+    return APInt(CoeffWidth, 0);
+  }
+
+  // The result of APInt arithmetic has the same bit width as the operands,
+  // so it can actually lose high bits. A product of two n-bit integers needs
+  // 2n-1 bits to represent the full value.
+  // The operation done below (on quadratic coefficients) that can produce
+  // the largest value is the evaluation of the equation during bisection,
+  // which needs 3 times the bitwidth of the coefficient, so the total number
+  // of required bits is 3n.
+  //
+  // The purpose of this extension is to simulate the set Z of all integers,
+  // where n+1 > n for all n in Z. In Z it makes sense to talk about positive
+  // and negative numbers (not so much in a modulo arithmetic). The method
+  // used to solve the equation is based on the standard formula for real
+  // numbers, and uses the concepts of "positive" and "negative" with their
+  // usual meanings.
+  CoeffWidth *= 3;
+  A = A.sext(CoeffWidth);
+  B = B.sext(CoeffWidth);
+  C = C.sext(CoeffWidth);
+
+  // Make A > 0 for simplicity. Negate cannot overflow at this point because
+  // the bit width has increased.
+  if (A.isNegative()) {
+    A.negate();
+    B.negate();
+    C.negate();
+  }
+
+  // Solving an equation q(x) = 0 with coefficients in modular arithmetic
+  // is really solving a set of equations q(x) = kR for k = 0, 1, 2, ...,
+  // and R = 2^BitWidth.
+  // Since we're trying not only to find exact solutions, but also values
+  // that "wrap around", such a set will always have a solution, i.e. an x
+  // that satisfies at least one of the equations, or such that |q(x)|
+  // exceeds kR, while |q(x-1)| for the same k does not.
+  //
+  // We need to find a value k, such that Ax^2 + Bx + C = kR will have a
+  // positive solution n (in the above sense), and also such that the n
+  // will be the least among all solutions corresponding to k = 0, 1, ...
+  // (more precisely, the least element in the set
+  //   { n(k) | k is such that a solution n(k) exists }).
+  //
+  // Consider the parabola (over real numbers) that corresponds to the
+  // quadratic equation. Since A > 0, the arms of the parabola will point
+  // up. Picking different values of k will shift it up and down by R.
+  //
+  // We want to shift the parabola in such a way as to reduce the problem
+  // of solving q(x) = kR to solving shifted_q(x) = 0.
+  // (The interesting solutions are the ceilings of the real number
+  // solutions.)
+  APInt R = APInt::getOneBitSet(CoeffWidth, RangeWidth);
+  APInt TwoA = 2 * A;
+  APInt SqrB = B * B;
+  bool PickLow;
+
+  auto RoundUp = [] (const APInt &V, const APInt &A) {
+    assert(A.isStrictlyPositive());
+    APInt T = V.abs().urem(A);
+    if (T.isNullValue())
+      return V;
+    return V.isNegative() ? V+T : V+(A-T);
+  };
+
+  // The vertex of the parabola is at -B/2A, but since A > 0, it's negative
+  // iff B is positive.
+  if (B.isNonNegative()) {
+    // If B >= 0, the vertex it at a negative location (or at 0), so in
+    // order to have a non-negative solution we need to pick k that makes
+    // C-kR negative. To satisfy all the requirements for the solution
+    // that we are looking for, it needs to be closest to 0 of all k.
+    C = C.srem(R);
+    if (C.isStrictlyPositive())
+      C -= R;
+    // Pick the greater solution.
+    PickLow = false;
+  } else {
+    // If B < 0, the vertex is at a positive location. For any solution
+    // to exist, the discriminant must be non-negative. This means that
+    // C-kR <= B^2/4A is a necessary condition for k, i.e. there is a
+    // lower bound on values of k: kR >= C - B^2/4A.
+    APInt LowkR = C - SqrB.udiv(2*TwoA); // udiv because all values > 0.
+    // Round LowkR up (towards +inf) to the nearest kR.
+    LowkR = RoundUp(LowkR, R);
+
+    // If there exists k meeting the condition above, and such that
+    // C-kR > 0, there will be two positive real number solutions of
+    // q(x) = kR. Out of all such values of k, pick the one that makes
+    // C-kR closest to 0, (i.e. pick maximum k such that C-kR > 0).
+    // In other words, find maximum k such that LowkR <= kR < C.
+    if (C.sgt(LowkR)) {
+      // If LowkR < C, then such a k is guaranteed to exist because
+      // LowkR itself is a multiple of R.
+      C -= -RoundUp(-C, R);      // C = C - RoundDown(C, R)
+      // Pick the smaller solution.
+      PickLow = true;
+    } else {
+      // If C-kR < 0 for all potential k's, it means that one solution
+      // will be negative, while the other will be positive. The positive
+      // solution will shift towards 0 if the parabola is moved up.
+      // Pick the kR closest to the lower bound (i.e. make C-kR closest
+      // to 0, or in other words, out of all parabolas that have solutions,
+      // pick the one that is the farthest "up").
+      // Since LowkR is itself a multiple of R, simply take C-LowkR.
+      C -= LowkR;
+      // Pick the greater solution.
+      PickLow = false;
+    }
+  }
+
+  LLVM_DEBUG(dbgs() << __func__ << ": updated coefficients " << A << "x^2 + "
+                    << B << "x + " << C << ", rw:" << RangeWidth << '\n');
+
+  APInt D = SqrB - 4*A*C;
+  assert(D.isNonNegative() && "Negative discriminant");
+  APInt SQ = D.sqrt();
+
+  APInt Q = SQ * SQ;
+  bool InexactSQ = Q != D;
+  // The calculated SQ may actually be greater than the exact (non-integer)
+  // value. If that's the case, decremement SQ to get a value that is lower.
+  if (Q.sgt(D))
+    SQ -= 1;
+
+  APInt X;
+  APInt Rem;
+
+  // SQ is rounded down (i.e SQ * SQ <= D), so the roots may be inexact.
+  // When using the quadratic formula directly, the calculated low root
+  // may be greater than the exact one, since we would be subtracting SQ.
+  // To make sure that the calculated root is not greater than the exact
+  // one, subtract SQ+1 when calculating the low root (for inexact value
+  // of SQ).
+  if (PickLow)
+    APInt::sdivrem(-B - (SQ+InexactSQ), TwoA, X, Rem);
+  else
+    APInt::sdivrem(-B + SQ, TwoA, X, Rem);
+
+  // The updated coefficients should be such that the (exact) solution is
+  // positive. Since APInt division rounds towards 0, the calculated one
+  // can be 0, but cannot be negative.
+  assert(X.isNonNegative() && "Solution should be non-negative");
+
+  if (!InexactSQ && Rem.isNullValue()) {
+    LLVM_DEBUG(dbgs() << __func__ << ": solution (root): " << X << '\n');
+    return X;
+  }
+
+  assert((SQ*SQ).sle(D) && "SQ = |_sqrt(D)_|, so SQ*SQ <= D");
+  // The exact value of the square root of D should be between SQ and SQ+1.
+  // This implies that the solution should be between that corresponding to
+  // SQ (i.e. X) and that corresponding to SQ+1.
+  //
+  // The calculated X cannot be greater than the exact (real) solution.
+  // Actually it must be strictly less than the exact solution, while
+  // X+1 will be greater than or equal to it.
+
+  APInt VX = (A*X + B)*X + C;
+  APInt VY = VX + TwoA*X + A + B;
+  bool SignChange = VX.isNegative() != VY.isNegative() ||
+                    VX.isNullValue() != VY.isNullValue();
+  // If the sign did not change between X and X+1, X is not a valid solution.
+  // This could happen when the actual (exact) roots don't have an integer
+  // between them, so they would both be contained between X and X+1.
+  if (!SignChange) {
+    LLVM_DEBUG(dbgs() << __func__ << ": no valid solution\n");
+    return None;
+  }
+
+  X += 1;
+  LLVM_DEBUG(dbgs() << __func__ << ": solution (wrap): " << X << '\n');
+  return X;
+}