[APInt] Add helpers for rounding u/sdivs.

Reviewers: sanjoy, craig.topper

Subscribers: jlebar, hiraditya, bixia, llvm-commits

Differential Revision: https://reviews.llvm.org/D48498

llvm-svn: 335557
diff --git a/llvm/include/llvm/ADT/APInt.h b/llvm/include/llvm/ADT/APInt.h
index 1f3f7f5..6bf6b22 100644
--- a/llvm/include/llvm/ADT/APInt.h
+++ b/llvm/include/llvm/ADT/APInt.h
@@ -78,6 +78,12 @@
     APINT_BITS_PER_WORD = APINT_WORD_SIZE * CHAR_BIT
   };
 
+  enum class Rounding {
+    DOWN,
+    TOWARD_ZERO,
+    UP,
+  };
+
   static const WordType WORD_MAX = ~WordType(0);
 
 private:
@@ -1039,13 +1045,16 @@
   /// Perform an unsigned divide operation on this APInt by RHS. Both this and
   /// RHS are treated as unsigned quantities for purposes of this division.
   ///
-  /// \returns a new APInt value containing the division result
+  /// \returns a new APInt value containing the division result, rounded towards
+  /// zero.
   APInt udiv(const APInt &RHS) const;
   APInt udiv(uint64_t RHS) const;
 
   /// Signed division function for APInt.
   ///
   /// Signed divide this APInt by APInt RHS.
+  ///
+  /// The result is rounded towards zero.
   APInt sdiv(const APInt &RHS) const;
   APInt sdiv(int64_t RHS) const;
 
@@ -2151,6 +2160,12 @@
   return RoundDoubleToAPInt(double(Float), width);
 }
 
+/// Return A unsign-divided by B, rounded by the given rounding mode.
+APInt RoundingUDiv(const APInt &A, const APInt &B, APInt::Rounding RM);
+
+/// Return A sign-divided by B, rounded by the given rounding mode.
+APInt RoundingSDiv(const APInt &A, const APInt &B, APInt::Rounding RM);
+
 } // End of APIntOps namespace
 
 // See friend declaration above. This additional declaration is required in
diff --git a/llvm/lib/Support/APInt.cpp b/llvm/lib/Support/APInt.cpp
index 2c2ce95..8be903f 100644
--- a/llvm/lib/Support/APInt.cpp
+++ b/llvm/lib/Support/APInt.cpp
@@ -2658,3 +2658,49 @@
   while (i < parts)
     dst[i++] = 0;
 }
+
+APInt llvm::APIntOps::RoundingUDiv(const APInt &A, const APInt &B,
+                                   APInt::Rounding RM) {
+  // Currently udivrem always rounds down.
+  switch (RM) {
+  case APInt::Rounding::DOWN:
+  case APInt::Rounding::TOWARD_ZERO:
+    return A.udiv(B);
+  case APInt::Rounding::UP: {
+    APInt Quo, Rem;
+    APInt::udivrem(A, B, Quo, Rem);
+    if (Rem == 0)
+      return Quo;
+    return Quo + 1;
+  }
+  }
+}
+
+APInt llvm::APIntOps::RoundingSDiv(const APInt &A, const APInt &B,
+                                   APInt::Rounding RM) {
+  switch (RM) {
+  case APInt::Rounding::DOWN:
+  case APInt::Rounding::UP: {
+    APInt Quo, Rem;
+    APInt::sdivrem(A, B, Quo, Rem);
+    if (Rem == 0)
+      return Quo;
+    // This algorithm deals with arbitrary rounding mode used by sdivrem.
+    // We want to check whether the non-integer part of the mathematical value
+    // is negative or not. If the non-integer part is negative, we need to round
+    // down from Quo; otherwise, if it's positive or 0, we return Quo, as it's
+    // already rounded down.
+    if (RM == APInt::Rounding::DOWN) {
+      if (Rem.isNegative() != B.isNegative())
+        return Quo - 1;
+      return Quo;
+    }
+    if (Rem.isNegative() != B.isNegative())
+      return Quo;
+    return Quo + 1;
+  }
+  // Currently sdiv rounds twards zero.
+  case APInt::Rounding::TOWARD_ZERO:
+    return A.sdiv(B);
+  }
+}
diff --git a/llvm/unittests/ADT/APIntTest.cpp b/llvm/unittests/ADT/APIntTest.cpp
index 06b4d77..4eb6d67 100644
--- a/llvm/unittests/ADT/APIntTest.cpp
+++ b/llvm/unittests/ADT/APIntTest.cpp
@@ -2258,4 +2258,71 @@
   EXPECT_EQ(64U, i96.countTrailingZeros());
 }
 
+TEST(APIntTest, RoundingUDiv) {
+  for (uint64_t Ai = 1; Ai <= 255; Ai++) {
+    APInt A(8, Ai);
+    APInt Zero(8, 0);
+    EXPECT_EQ(0, APIntOps::RoundingUDiv(Zero, A, APInt::Rounding::UP));
+    EXPECT_EQ(0, APIntOps::RoundingUDiv(Zero, A, APInt::Rounding::DOWN));
+    EXPECT_EQ(0, APIntOps::RoundingUDiv(Zero, A, APInt::Rounding::TOWARD_ZERO));
+
+    for (uint64_t Bi = 1; Bi <= 255; Bi++) {
+      APInt B(8, Bi);
+      {
+        APInt Quo = APIntOps::RoundingUDiv(A, B, APInt::Rounding::UP);
+        auto Prod = Quo.zext(16) * B.zext(16);
+        EXPECT_TRUE(Prod.uge(Ai));
+        if (Prod.ugt(Ai)) {
+          EXPECT_TRUE(((Quo - 1).zext(16) * B.zext(16)).ult(Ai));
+        }
+      }
+      {
+        APInt Quo = A.udiv(B);
+        EXPECT_EQ(Quo, APIntOps::RoundingUDiv(A, B, APInt::Rounding::TOWARD_ZERO));
+        EXPECT_EQ(Quo, APIntOps::RoundingUDiv(A, B, APInt::Rounding::DOWN));
+      }
+    }
+  }
+}
+
+TEST(APIntTest, RoundingSDiv) {
+  for (int64_t Ai = -128; Ai <= 127; Ai++) {
+    APInt A(8, Ai);
+
+    if (Ai != 0) {
+      APInt Zero(8, 0);
+      EXPECT_EQ(0, APIntOps::RoundingSDiv(Zero, A, APInt::Rounding::UP));
+      EXPECT_EQ(0, APIntOps::RoundingSDiv(Zero, A, APInt::Rounding::DOWN));
+      EXPECT_EQ(0, APIntOps::RoundingSDiv(Zero, A, APInt::Rounding::TOWARD_ZERO));
+    }
+
+    for (uint64_t Bi = -128; Bi <= 127; Bi++) {
+      if (Bi == 0)
+        continue;
+
+      APInt B(8, Bi);
+      {
+        APInt Quo = APIntOps::RoundingSDiv(A, B, APInt::Rounding::UP);
+        auto Prod = Quo.sext(16) * B.sext(16);
+        EXPECT_TRUE(Prod.uge(A));
+        if (Prod.ugt(A)) {
+          EXPECT_TRUE(((Quo - 1).sext(16) * B.sext(16)).ult(A));
+        }
+      }
+      {
+        APInt Quo = APIntOps::RoundingSDiv(A, B, APInt::Rounding::DOWN);
+        auto Prod = Quo.sext(16) * B.sext(16);
+        EXPECT_TRUE(Prod.ule(A));
+        if (Prod.ult(A)) {
+          EXPECT_TRUE(((Quo + 1).sext(16) * B.sext(16)).ugt(A));
+        }
+      }
+      {
+        APInt Quo = A.sdiv(B);
+        EXPECT_EQ(Quo, APIntOps::RoundingSDiv(A, B, APInt::Rounding::TOWARD_ZERO));
+      }
+    }
+  }
+}
+
 } // end anonymous namespace