Make NativeBN's error-handling more robust.

Functions in OpenSSL/BoringSSL return some result that signals success
or failure. NativeBN was using throwExceptionIfNecessary which is almost
the same, except:

- If BoringSSL has a failure path that forgets to push an exception, it
  would break. Ideally this wouldn't happen, but there may be cases.

- If some other BoringSSL consumer forgets to clear the error queue on
  failure, it would appear as if a NativeBN operation failed.

Instead, consider the result code as the source of truth and change
throwExceptionIfNecessary to throwException. Also remove some unused
BN_GENCB parameters from JNI functions and tidy error paths up with
early returns.

Bug: 30917411
Change-Id: I6f73f67d559e5f02f62ed0d7e63b8ae3bf1e56be
Test: run cts -c org.apache.harmony.tests.java.math.BigIntegerTest
diff --git a/luni/src/main/java/java/math/BigInt.java b/luni/src/main/java/java/math/BigInt.java
index 2cffee6..5e28a73 100644
--- a/luni/src/main/java/java/math/BigInt.java
+++ b/luni/src/main/java/java/math/BigInt.java
@@ -334,11 +334,11 @@
 
     static BigInt generatePrimeDefault(int bitLength) {
         BigInt r = newBigInt();
-        NativeBN.BN_generate_prime_ex(r.bignum, bitLength, false, 0, 0, 0);
+        NativeBN.BN_generate_prime_ex(r.bignum, bitLength, false, 0, 0);
         return r;
     }
 
     boolean isPrime(int certainty) {
-        return NativeBN.BN_is_prime_ex(bignum, certainty, 0);
+        return NativeBN.BN_primality_test(bignum, certainty, false);
     }
 }
diff --git a/luni/src/main/java/java/math/NativeBN.java b/luni/src/main/java/java/math/NativeBN.java
index 64b4468..d269f2e 100644
--- a/luni/src/main/java/java/math/NativeBN.java
+++ b/luni/src/main/java/java/math/NativeBN.java
@@ -120,12 +120,15 @@
 
 
     public static native void BN_generate_prime_ex(long ret, int bits, boolean safe,
-                                                   long add, long rem, long cb);
+                                                   long add, long rem);
     // int BN_generate_prime_ex(BIGNUM *ret, int bits, int safe,
     //         const BIGNUM *add, const BIGNUM *rem, BN_GENCB *cb);
 
-    public static native boolean BN_is_prime_ex(long p, int nchecks, long cb);
-    // int BN_is_prime_ex(const BIGNUM *p, int nchecks, BN_CTX *ctx, BN_GENCB *cb);
+    public static native boolean BN_primality_test(long candidate, int checks,
+                                                   boolean do_trial_division);
+    // int BN_primality_test(int *is_probably_prime, const BIGNUM *candidate, int checks,
+    //                       BN_CTX *ctx, int do_trial_division, BN_GENCB *cb);
+    // Returns *is_probably_prime on success and throws an exception on error.
 
     public static native long getNativeFinalizer();
     // &BN_free
diff --git a/luni/src/main/native/java_math_NativeBN.cpp b/luni/src/main/native/java_math_NativeBN.cpp
index e540942..45df4c5 100644
--- a/luni/src/main/native/java_math_NativeBN.cpp
+++ b/luni/src/main/native/java_math_NativeBN.cpp
@@ -51,11 +51,18 @@
   return reinterpret_cast<BIGNUM*>(static_cast<uintptr_t>(address));
 }
 
-static bool throwExceptionIfNecessary(JNIEnv* env) {
+static void throwException(JNIEnv* env) {
   long error = ERR_get_error();
+  // OpenSSL's error queue may contain multiple errors. Clean up after them.
+  ERR_clear_error();
+
   if (error == 0) {
-    return false;
+    // An operation failed but did not push to the error queue. Throw a default
+    // exception.
+    jniThrowException(env, "java/lang/ArithmeticException", "Operation failed");
+    return;
   }
+
   char message[256];
   ERR_error_string_n(error, message, sizeof(message));
   int reason = ERR_GET_REASON(error);
@@ -68,9 +75,6 @@
   } else {
     jniThrowException(env, "java/lang/ArithmeticException", message);
   }
-  // OpenSSL's error queue may contain multiple errors. Clean up after them.
-  ERR_clear_error();
-  return true;
 }
 
 static int isValidHandle(JNIEnv* env, jlong handle, const char* message) {
@@ -102,7 +106,9 @@
 
 static jlong NativeBN_BN_new(JNIEnv* env, jclass) {
   jlong result = static_cast<jlong>(reinterpret_cast<uintptr_t>(BN_new()));
-  throwExceptionIfNecessary(env);
+  if (!result) {
+    throwException(env);
+  }
   return result;
 }
 
@@ -122,8 +128,9 @@
 
 static void NativeBN_BN_copy(JNIEnv* env, jclass, jlong to, jlong from) {
   if (!twoValidHandles(env, to, from)) return;
-  BN_copy(toBigNum(to), toBigNum(from));
-  throwExceptionIfNecessary(env);
+  if (!BN_copy(toBigNum(to), toBigNum(from))) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_putULongInt(JNIEnv* env, jclass, jlong a0, jlong java_dw, jboolean neg) {
@@ -131,28 +138,27 @@
 
   uint64_t dw = java_dw;
   BIGNUM* a = toBigNum(a0);
-  int ok;
 
   static_assert(sizeof(dw) == sizeof(BN_ULONG) ||
                 sizeof(dw) == 2*sizeof(BN_ULONG), "Unknown BN configuration");
 
   if (sizeof(dw) == sizeof(BN_ULONG)) {
-    ok = BN_set_word(a, dw);
-  } else if (sizeof(dw) == 2 * sizeof(BN_ULONG)) {
-    ok = (bn_wexpand(a, 2) != NULL);
-    if (ok) {
-      a->d[0] = dw;
-      a->d[1] = dw >> 32;
-      a->top = 2;
-      bn_correct_top(a);
+    if (!BN_set_word(a, dw)) {
+      throwException(env);
+      return;
     }
+  } else if (sizeof(dw) == 2 * sizeof(BN_ULONG)) {
+    if (!bn_wexpand(a, 2)) {
+      throwException(env);
+      return;
+    }
+    a->d[0] = dw;
+    a->d[1] = dw >> 32;
+    a->top = 2;
+    bn_correct_top(a);
   }
 
   BN_set_negative(a, neg);
-
-  if (!ok) {
-    throwExceptionIfNecessary(env);
-  }
 }
 
 static void NativeBN_putLongInt(JNIEnv* env, jclass cls, jlong a, jlong dw) {
@@ -171,7 +177,9 @@
   }
   BIGNUM* a = toBigNum(a0);
   int result = BN_dec2bn(&a, chars.c_str());
-  throwExceptionIfNecessary(env);
+  if (result == 0) {
+    throwException(env);
+  }
   return result;
 }
 
@@ -183,7 +191,9 @@
   }
   BIGNUM* a = toBigNum(a0);
   int result = BN_hex2bn(&a, chars.c_str());
-  throwExceptionIfNecessary(env);
+  if (result == 0) {
+    throwException(env);
+  }
   return result;
 }
 
@@ -193,10 +203,12 @@
   if (bytes.get() == NULL) {
     return;
   }
-  BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.get()), len, toBigNum(ret));
-  if (!throwExceptionIfNecessary(env) && neg) {
-    BN_set_negative(toBigNum(ret), true);
+  if (!BN_bin2bn(reinterpret_cast<const unsigned char*>(bytes.get()), len, toBigNum(ret))) {
+    throwException(env);
+    return;
   }
+
+  BN_set_negative(toBigNum(ret), neg);
 }
 
 /**
@@ -221,28 +233,29 @@
     const int wlen = len;
 #endif
     const unsigned int* tmpInts = reinterpret_cast<const unsigned int*>(scopedArray.get());
-    if ((tmpInts != NULL) && (bn_wexpand(ret, wlen) != NULL)) {
-#ifdef __LP64__
-      if (len % 2) {
-        ret->d[wlen - 1] = tmpInts[--len];
-      }
-      if (len > 0) {
-        for (int i = len - 2; i >= 0; i -= 2) {
-          ret->d[i/2] = ((unsigned long long)tmpInts[i+1] << 32) | tmpInts[i];
-        }
-      }
-#else
-      int i = len; do { i--; ret->d[i] = tmpInts[i]; } while (i > 0);
-#endif
-      ret->top = wlen;
-      ret->neg = neg;
-      // need to call this due to clear byte at top if avoiding
-      // having the top bit set (-ve number)
-      // Basically get rid of top zero ints:
-      bn_correct_top(ret);
-    } else {
-      throwExceptionIfNecessary(env);
+    if (!bn_wexpand(ret, wlen)) {
+      throwException(env);
+      return;
     }
+
+#ifdef __LP64__
+    if (len % 2) {
+      ret->d[wlen - 1] = tmpInts[--len];
+    }
+    if (len > 0) {
+      for (int i = len - 2; i >= 0; i -= 2) {
+        ret->d[i/2] = ((unsigned long long)tmpInts[i+1] << 32) | tmpInts[i];
+      }
+    }
+#else
+    int i = len; do { i--; ret->d[i] = tmpInts[i]; } while (i > 0);
+#endif
+    ret->top = wlen;
+    ret->neg = neg;
+    // need to call this due to clear byte at top if avoiding
+    // having the top bit set (-ve number)
+    // Basically get rid of top zero ints:
+    bn_correct_top(ret);
   } else { // (len = 0) means value = 0 and sign will be 0, too.
     ret->top = 0;
   }
@@ -257,57 +270,63 @@
 #define BYTES2ULONG(bytes, k) \
     (((bytes)[(k) + 3] & 0xff) | ((bytes)[(k) + 2] & 0xff) << 8 | ((bytes)[(k) + 1] & 0xff) << 16 | ((bytes)[(k) + 0] & 0xff) << 24)
 #endif
-static void negBigEndianBytes2bn(JNIEnv*, jclass, const unsigned char* bytes, int bytesLen, jlong ret0) {
+
+// negBigEndianBytes2bn interprets |bytes| as a little-endian two's complement negative integer and
+// sets |ret0| to the result. It returns true on success and false on allocation failure.
+static bool negBigEndianBytes2bn(JNIEnv*, jclass, const unsigned char* bytes, int bytesLen, jlong ret0) {
   BIGNUM* ret = toBigNum(ret0);
 
   bn_check_top(ret);
   // FIXME: assert bytesLen > 0
   int wLen = (bytesLen + sizeof(BN_ULONG) - 1) / sizeof(BN_ULONG);
   int firstNonzeroDigit = -2;
-  if (bn_wexpand(ret, wLen) != NULL) {
-    BN_ULONG* d = ret->d;
-    BN_ULONG di;
-    ret->top = wLen;
-    int highBytes = bytesLen % sizeof(BN_ULONG);
-    int k = bytesLen;
-    // Put bytes to the int array starting from the end of the byte array
-    int i = 0;
-    while (k > highBytes) {
-      k -= sizeof(BN_ULONG);
-      di = BYTES2ULONG(bytes, k);
-      if (di != 0) {
-        d[i] = -di;
-        firstNonzeroDigit = i;
-        i++;
-        while (k > highBytes) {
-          k -= sizeof(BN_ULONG);
-          d[i] = ~BYTES2ULONG(bytes, k);
-          i++;
-        }
-        break;
-      } else {
-        d[i] = 0;
-        i++;
-      }
-    }
-    if (highBytes != 0) {
-      di = -1;
-      // Put the first bytes in the highest element of the int array
-      if (firstNonzeroDigit != -2) {
-        for (k = 0; k < highBytes; k++) {
-          di = (di << 8) | (bytes[k] & 0xFF);
-        }
-        d[i] = ~di;
-      } else {
-        for (k = 0; k < highBytes; k++) {
-          di = (di << 8) | (bytes[k] & 0xFF);
-        }
-        d[i] = -di;
-      }
-    }
-    // The top may have superfluous zeros, so fix it.
-    bn_correct_top(ret);
+  if (!bn_wexpand(ret, wLen)) {
+    return false;
   }
+
+  BN_ULONG* d = ret->d;
+  BN_ULONG di;
+  ret->top = wLen;
+  int highBytes = bytesLen % sizeof(BN_ULONG);
+  int k = bytesLen;
+  // Put bytes to the int array starting from the end of the byte array
+  int i = 0;
+  while (k > highBytes) {
+    k -= sizeof(BN_ULONG);
+    di = BYTES2ULONG(bytes, k);
+    if (di != 0) {
+      d[i] = -di;
+      firstNonzeroDigit = i;
+      i++;
+      while (k > highBytes) {
+        k -= sizeof(BN_ULONG);
+        d[i] = ~BYTES2ULONG(bytes, k);
+        i++;
+      }
+      break;
+    } else {
+      d[i] = 0;
+      i++;
+    }
+  }
+  if (highBytes != 0) {
+    di = -1;
+    // Put the first bytes in the highest element of the int array
+    if (firstNonzeroDigit != -2) {
+      for (k = 0; k < highBytes; k++) {
+        di = (di << 8) | (bytes[k] & 0xFF);
+      }
+      d[i] = ~di;
+    } else {
+      for (k = 0; k < highBytes; k++) {
+        di = (di << 8) | (bytes[k] & 0xFF);
+      }
+      d[i] = -di;
+    }
+  }
+  // The top may have superfluous zeros, so fix it.
+  bn_correct_top(ret);
+  return true;
 }
 
 static void NativeBN_twosComp2bn(JNIEnv* env, jclass cls, jbyteArray arr, int bytesLen, jlong ret0) {
@@ -323,16 +342,21 @@
     //
     // We can use the existing BN implementation for unsigned big endian bytes:
     //
-    BN_bin2bn(s, bytesLen, ret);
+    if (!BN_bin2bn(s, bytesLen, ret)) {
+      throwException(env);
+      return;
+    }
     BN_set_negative(ret, false);
   } else { // Negative value!
     //
     // We need to apply two's complement:
     //
-    negBigEndianBytes2bn(env, cls, s, bytesLen, ret0);
+    if (!negBigEndianBytes2bn(env, cls, s, bytesLen, ret0)) {
+      throwException(env);
+      return;
+    }
     BN_set_negative(ret, true);
   }
-  throwExceptionIfNecessary(env);
 }
 
 static jlong NativeBN_longInt(JNIEnv* env, jclass, jlong a0) {
@@ -373,6 +397,7 @@
   if (!oneValidHandle(env, a)) return NULL;
   char* tmpStr = BN_bn2dec(toBigNum(a));
   if (tmpStr == NULL) {
+    throwException(env);
     return NULL;
   }
   char* retStr = leadingZerosTrimmed(tmpStr);
@@ -385,6 +410,7 @@
   if (!oneValidHandle(env, a)) return NULL;
   char* tmpStr = BN_bn2hex(toBigNum(a));
   if (tmpStr == NULL) {
+    throwException(env);
     return NULL;
   }
   char* retStr = leadingZerosTrimmed(tmpStr);
@@ -472,111 +498,135 @@
 
 static jboolean NativeBN_BN_is_bit_set(JNIEnv* env, jclass, jlong a, int n) {
   if (!oneValidHandle(env, a)) return JNI_FALSE;
-  return BN_is_bit_set(toBigNum(a), n);
+  return BN_is_bit_set(toBigNum(a), n) ? JNI_TRUE : JNI_FALSE;
 }
 
 static void NativeBN_BN_shift(JNIEnv* env, jclass, jlong r, jlong a, int n) {
   if (!twoValidHandles(env, r, a)) return;
+  int ok;
   if (n >= 0) {
-    BN_lshift(toBigNum(r), toBigNum(a), n);
+    ok = BN_lshift(toBigNum(r), toBigNum(a), n);
   } else {
-    BN_rshift(toBigNum(r), toBigNum(a), -n);
+    ok = BN_rshift(toBigNum(r), toBigNum(a), -n);
   }
-  throwExceptionIfNecessary(env);
+  if (!ok) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_add_word(JNIEnv* env, jclass, jlong a, BN_ULONG w) {
   if (!oneValidHandle(env, a)) return;
-  BN_add_word(toBigNum(a), w);
-  throwExceptionIfNecessary(env);
+  if (!BN_add_word(toBigNum(a), w)) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_mul_word(JNIEnv* env, jclass, jlong a, BN_ULONG w) {
   if (!oneValidHandle(env, a)) return;
-  BN_mul_word(toBigNum(a), w);
-  throwExceptionIfNecessary(env);
+  if (!BN_mul_word(toBigNum(a), w)) {
+    throwException(env);
+  }
 }
 
 static BN_ULONG NativeBN_BN_mod_word(JNIEnv* env, jclass, jlong a, BN_ULONG w) {
   if (!oneValidHandle(env, a)) return 0;
-  int result = BN_mod_word(toBigNum(a), w);
-  throwExceptionIfNecessary(env);
+  BN_ULONG result = BN_mod_word(toBigNum(a), w);
+  if (result == (BN_ULONG)-1) {
+    throwException(env);
+  }
   return result;
 }
 
 static void NativeBN_BN_add(JNIEnv* env, jclass, jlong r, jlong a, jlong b) {
   if (!threeValidHandles(env, r, a, b)) return;
-  BN_add(toBigNum(r), toBigNum(a), toBigNum(b));
-  throwExceptionIfNecessary(env);
+  if (!BN_add(toBigNum(r), toBigNum(a), toBigNum(b))) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_sub(JNIEnv* env, jclass, jlong r, jlong a, jlong b) {
   if (!threeValidHandles(env, r, a, b)) return;
-  BN_sub(toBigNum(r), toBigNum(a), toBigNum(b));
-  throwExceptionIfNecessary(env);
+  if (!BN_sub(toBigNum(r), toBigNum(a), toBigNum(b))) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_gcd(JNIEnv* env, jclass, jlong r, jlong a, jlong b) {
   if (!threeValidHandles(env, r, a, b)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_gcd(toBigNum(r), toBigNum(a), toBigNum(b), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_gcd(toBigNum(r), toBigNum(a), toBigNum(b), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_mul(JNIEnv* env, jclass, jlong r, jlong a, jlong b) {
   if (!threeValidHandles(env, r, a, b)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_mul(toBigNum(r), toBigNum(a), toBigNum(b), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_mul(toBigNum(r), toBigNum(a), toBigNum(b), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_exp(JNIEnv* env, jclass, jlong r, jlong a, jlong p) {
   if (!threeValidHandles(env, r, a, p)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_exp(toBigNum(r), toBigNum(a), toBigNum(p), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_exp(toBigNum(r), toBigNum(a), toBigNum(p), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_div(JNIEnv* env, jclass, jlong dv, jlong rem, jlong m, jlong d) {
   if (!fourValidHandles(env, (rem ? rem : dv), (dv ? dv : rem), m, d)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_div(toBigNum(dv), toBigNum(rem), toBigNum(m), toBigNum(d), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_div(toBigNum(dv), toBigNum(rem), toBigNum(m), toBigNum(d), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_nnmod(JNIEnv* env, jclass, jlong r, jlong a, jlong m) {
   if (!threeValidHandles(env, r, a, m)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_nnmod(toBigNum(r), toBigNum(a), toBigNum(m), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_nnmod(toBigNum(r), toBigNum(a), toBigNum(m), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_mod_exp(JNIEnv* env, jclass, jlong r, jlong a, jlong p, jlong m) {
   if (!fourValidHandles(env, r, a, p, m)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_mod_exp(toBigNum(r), toBigNum(a), toBigNum(p), toBigNum(m), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_mod_exp(toBigNum(r), toBigNum(a), toBigNum(p), toBigNum(m), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_mod_inverse(JNIEnv* env, jclass, jlong ret, jlong a, jlong n) {
   if (!threeValidHandles(env, ret, a, n)) return;
   Unique_BN_CTX ctx(BN_CTX_new());
-  BN_mod_inverse(toBigNum(ret), toBigNum(a), toBigNum(n), ctx.get());
-  throwExceptionIfNecessary(env);
+  if (!BN_mod_inverse(toBigNum(ret), toBigNum(a), toBigNum(n), ctx.get())) {
+    throwException(env);
+  }
 }
 
 static void NativeBN_BN_generate_prime_ex(JNIEnv* env, jclass, jlong ret, int bits,
-                                          jboolean safe, jlong add, jlong rem, jlong cb) {
+                                          jboolean safe, jlong add, jlong rem) {
   if (!oneValidHandle(env, ret)) return;
-  BN_generate_prime_ex(toBigNum(ret), bits, safe, toBigNum(add), toBigNum(rem),
-                       reinterpret_cast<BN_GENCB*>(cb));
-  throwExceptionIfNecessary(env);
+  if (!BN_generate_prime_ex(toBigNum(ret), bits, safe, toBigNum(add), toBigNum(rem),
+                            NULL)) {
+    throwException(env);
+  }
 }
 
-static jboolean NativeBN_BN_is_prime_ex(JNIEnv* env, jclass, jlong p, int nchecks, jlong cb) {
-  if (!oneValidHandle(env, p)) return JNI_FALSE;
+static jboolean NativeBN_BN_primality_test(JNIEnv* env, jclass, jlong candidate, int checks,
+                                           jboolean do_trial_decryption) {
+  if (!oneValidHandle(env, candidate)) return JNI_FALSE;
   Unique_BN_CTX ctx(BN_CTX_new());
-  return BN_is_prime_ex(toBigNum(p), nchecks, ctx.get(), reinterpret_cast<BN_GENCB*>(cb));
+  int is_probably_prime;
+  if (!BN_primality_test(&is_probably_prime, toBigNum(candidate), checks, ctx.get(),
+                         do_trial_decryption, NULL)) {
+    throwException(env);
+    return JNI_FALSE;
+  }
+  return is_probably_prime ? JNI_TRUE : JNI_FALSE;
 }
 
 static JNINativeMethod gMethods[] = {
@@ -593,10 +643,10 @@
    NATIVE_METHOD(NativeBN, BN_exp, "(JJJ)V"),
    NATIVE_METHOD(NativeBN, BN_free, "(J)V"),
    NATIVE_METHOD(NativeBN, BN_gcd, "(JJJ)V"),
-   NATIVE_METHOD(NativeBN, BN_generate_prime_ex, "(JIZJJJ)V"),
+   NATIVE_METHOD(NativeBN, BN_generate_prime_ex, "(JIZJJ)V"),
    NATIVE_METHOD(NativeBN, BN_hex2bn, "(JLjava/lang/String;)I"),
    NATIVE_METHOD(NativeBN, BN_is_bit_set, "(JI)Z"),
-   NATIVE_METHOD(NativeBN, BN_is_prime_ex, "(JIJ)Z"),
+   NATIVE_METHOD(NativeBN, BN_primality_test, "(JIZ)Z"),
    NATIVE_METHOD(NativeBN, BN_mod_exp, "(JJJJ)V"),
    NATIVE_METHOD(NativeBN, BN_mod_inverse, "(JJJ)V"),
    NATIVE_METHOD(NativeBN, BN_mod_word, "(JI)I"),