SF patch 936813: fast modular exponentiation

This checkin is adapted from part 2 (of 3) of Trevor Perrin's patch set.

BACKWARD INCOMPATIBILITY:  SHIFT must now be divisible by 5.  AFAIK,
nobody will care.  long_pow() could be complicated to worm around that,
if necessary.

long_pow():
  - BUGFIX:  This leaked the base and power when the power was negative
    (and so the computation delegated to float pow).
  - Instead of doing right-to-left exponentiation, do left-to-right.  This
    is more efficient for small bases, which is the common case.
  - In addition, if the exponent is large (more than FIVEARY_CUTOFF
    digits), precompute [a**i % c for i in range(32)], and go left to
    right 5 bits at a time.
l_divmod():
  - The signature changed so that callers who don't want the quotient,
    or don't want the remainder, can pass NULL in the slot they don't
    want.  This saves them from having to declare a vrbl for unwanted
    stuff, and remembering to decref it.
long_mod(), long_div(), long_classic_div():
  - Adjust to new l_divmod() signature, and simplified as a result.
diff --git a/Objects/longobject.c b/Objects/longobject.c
index 2f6d103..05c42ad 100644
--- a/Objects/longobject.c
+++ b/Objects/longobject.c
@@ -15,6 +15,13 @@
 #define KARATSUBA_CUTOFF 70
 #define KARATSUBA_SQUARE_CUTOFF (2 * KARATSUBA_CUTOFF)
 
+/* For exponentiation, use the binary left-to-right algorithm
+ * unless the exponent contains more than FIVEARY_CUTOFF digits.
+ * In that case, do 5 bits at a time.  The potential drawback is that
+ * a table of 2**5 intermediate results is computed.
+ */
+#define FIVEARY_CUTOFF 8
+
 #define ABS(x) ((x) < 0 ? -(x) : (x))
 
 #undef MIN
@@ -2136,6 +2143,12 @@
    have different signs.  We then subtract one from the 'div'
    part of the outcome to keep the invariant intact. */
 
+/* Compute
+ *     *pdiv, *pmod = divmod(v, w)
+ * NULL can be passed for pdiv or pmod, in which case that part of
+ * the result is simply thrown away.  The caller owns a reference to
+ * each of these it requests (does not pass NULL for).
+ */
 static int
 l_divmod(PyLongObject *v, PyLongObject *w,
 	 PyLongObject **pdiv, PyLongObject **pmod)
@@ -2167,44 +2180,43 @@
 		Py_DECREF(div);
 		div = temp;
 	}
-	*pdiv = div;
-	*pmod = mod;
+	if (pdiv != NULL)
+		*pdiv = div;
+	else
+		Py_DECREF(div);
+
+	if (pmod != NULL)
+		*pmod = mod;
+	else
+		Py_DECREF(mod);
+
 	return 0;
 }
 
 static PyObject *
 long_div(PyObject *v, PyObject *w)
 {
-	PyLongObject *a, *b, *div, *mod;
+	PyLongObject *a, *b, *div;
 
 	CONVERT_BINOP(v, w, &a, &b);
-
-	if (l_divmod(a, b, &div, &mod) < 0) {
-		Py_DECREF(a);
-		Py_DECREF(b);
-		return NULL;
-	}
+	if (l_divmod(a, b, &div, NULL) < 0)
+		div = NULL;
 	Py_DECREF(a);
 	Py_DECREF(b);
-	Py_DECREF(mod);
 	return (PyObject *)div;
 }
 
 static PyObject *
 long_classic_div(PyObject *v, PyObject *w)
 {
-	PyLongObject *a, *b, *div, *mod;
+	PyLongObject *a, *b, *div;
 
 	CONVERT_BINOP(v, w, &a, &b);
-
 	if (Py_DivisionWarningFlag &&
 	    PyErr_Warn(PyExc_DeprecationWarning, "classic long division") < 0)
 		div = NULL;
-	else if (l_divmod(a, b, &div, &mod) < 0)
+	else if (l_divmod(a, b, &div, NULL) < 0)
 		div = NULL;
-	else
-		Py_DECREF(mod);
-
 	Py_DECREF(a);
 	Py_DECREF(b);
 	return (PyObject *)div;
@@ -2255,18 +2267,14 @@
 static PyObject *
 long_mod(PyObject *v, PyObject *w)
 {
-	PyLongObject *a, *b, *div, *mod;
+	PyLongObject *a, *b, *mod;
 
 	CONVERT_BINOP(v, w, &a, &b);
 
-	if (l_divmod(a, b, &div, &mod) < 0) {
-		Py_DECREF(a);
-		Py_DECREF(b);
-		return NULL;
-	}
+	if (l_divmod(a, b, NULL, &mod) < 0)
+		mod = NULL;
 	Py_DECREF(a);
 	Py_DECREF(b);
-	Py_DECREF(div);
 	return (PyObject *)mod;
 }
 
@@ -2297,22 +2305,33 @@
 	return z;
 }
 
+/* pow(v, w, x) */
 static PyObject *
 long_pow(PyObject *v, PyObject *w, PyObject *x)
 {
-	PyLongObject *a, *b;
-	PyObject *c;
-	PyLongObject *z, *div, *mod;
-	int size_b, i;
+	PyLongObject *a, *b, *c; /* a,b,c = v,w,x */
+	int negativeOutput = 0;  /* if x<0 return negative output */
 
+	PyLongObject *z = NULL;  /* accumulated result */
+	int i, j, k;             /* counters */
+	PyLongObject *temp = NULL;
+
+	/* 5-ary values.  If the exponent is large enough, table is
+	 * precomputed so that table[i] == a**i % c for i in range(32).
+	 */
+	PyLongObject *table[32] = {0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,
+				   0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0};
+
+	/* a, b, c = v, w, x */
 	CONVERT_BINOP(v, w, &a, &b);
-	if (PyLong_Check(x) || Py_None == x) {
-		c = x;
+	if (PyLong_Check(x)) {
+		c = (PyLongObject *)x;
 		Py_INCREF(x);
 	}
-	else if (PyInt_Check(x)) {
-		c = PyLong_FromLong(PyInt_AS_LONG(x));
-	}
+	else if (PyInt_Check(x))
+		c = (PyLongObject *)PyLong_FromLong(PyInt_AS_LONG(x));
+	else if (x == Py_None)
+		c = NULL;
 	else {
 		Py_DECREF(a);
 		Py_DECREF(b);
@@ -2320,95 +2339,154 @@
 		return Py_NotImplemented;
 	}
 
-	if (c != Py_None && ((PyLongObject *)c)->ob_size == 0) {
-		PyErr_SetString(PyExc_ValueError,
-				"pow() 3rd argument cannot be 0");
-		z = NULL;
-		goto error;
-	}
-
-	size_b = b->ob_size;
-	if (size_b < 0) {
-		Py_DECREF(a);
-		Py_DECREF(b);
-		Py_DECREF(c);
-		if (x != Py_None) {
+	if (b->ob_size < 0) {  /* if exponent is negative */
+		if (c) {
 			PyErr_SetString(PyExc_TypeError, "pow() 2nd argument "
-			     "cannot be negative when 3rd argument specified");
+			    "cannot be negative when 3rd argument specified");
 			return NULL;
 		}
-		/* Return a float.  This works because we know that
-		   this calls float_pow() which converts its
-		   arguments to double. */
-		return PyFloat_Type.tp_as_number->nb_power(v, w, x);
-	}
-	z = (PyLongObject *)PyLong_FromLong(1L);
-	for (i = 0; i < size_b; ++i) {
-		digit bi = b->ob_digit[i];
-		int j;
-
-		for (j = 0; j < SHIFT; ++j) {
-			PyLongObject *temp;
-
-			if (bi & 1) {
-				temp = (PyLongObject *)long_mul(z, a);
-				Py_DECREF(z);
-			 	if (c!=Py_None && temp!=NULL) {
-			 		if (l_divmod(temp,(PyLongObject *)c,
-							&div,&mod) < 0) {
-						Py_DECREF(temp);
-						z = NULL;
-						goto error;
-					}
-				 	Py_XDECREF(div);
-				 	Py_DECREF(temp);
-				 	temp = mod;
-				}
-			 	z = temp;
-				if (z == NULL)
-					break;
-			}
-			bi >>= 1;
-			if (bi == 0 && i+1 == size_b)
-				break;
-			temp = (PyLongObject *)long_mul(a, a);
-			Py_DECREF(a);
-		 	if (c!=Py_None && temp!=NULL) {
-			 	if (l_divmod(temp, (PyLongObject *)c, &div,
-							&mod) < 0) {
-					Py_DECREF(temp);
-					z = NULL;
-					goto error;
-				}
-			 	Py_XDECREF(div);
-			 	Py_DECREF(temp);
-			 	temp = mod;
-			}
-			a = temp;
-			if (a == NULL) {
-				Py_DECREF(z);
-				z = NULL;
-				break;
-			}
-		}
-		if (a == NULL || z == NULL)
-			break;
-	}
-	if (c!=Py_None && z!=NULL) {
-		if (l_divmod(z, (PyLongObject *)c, &div, &mod) < 0) {
-			Py_DECREF(z);
-			z = NULL;
-		}
 		else {
-			Py_XDECREF(div);
-			Py_DECREF(z);
-			z = mod;
+			/* else return a float.  This works because we know
+			   that this calls float_pow() which converts its
+			   arguments to double. */
+			Py_DECREF(a);
+			Py_DECREF(b);
+			return PyFloat_Type.tp_as_number->nb_power(v, w, x);
 		}
 	}
-  error:
+
+	if (c) {
+		/* if modulus == 0:
+		       raise ValueError() */
+		if (c->ob_size == 0) {
+			PyErr_SetString(PyExc_ValueError,
+					"pow() 3rd argument cannot be 0");
+			goto Done;
+		}
+
+		/* if modulus < 0:
+		       negativeOutput = True
+		       modulus = -modulus */
+		if (c->ob_size < 0) {
+			negativeOutput = 1;
+			temp = (PyLongObject *)_PyLong_Copy(c);
+			if (temp == NULL)
+				goto Error;
+			Py_DECREF(c);
+			c = temp;
+			temp = NULL;
+			c->ob_size = - c->ob_size;
+		}
+
+		/* if modulus == 1:
+		       return 0 */
+		if ((c->ob_size == 1) && (c->ob_digit[0] == 1)) {
+			z = (PyLongObject *)PyLong_FromLong(0L);
+			goto Done;
+		}
+
+		/* if base < 0:
+		       base = base % modulus
+		   Having the base positive just makes things easier. */
+		if (a->ob_size < 0) {
+			if (l_divmod(a, c, NULL, &temp) < 0)
+				goto Done;
+			Py_DECREF(a);
+			a = temp;
+			temp = NULL;
+		}
+	}
+
+	/* At this point a, b, and c are guaranteed non-negative UNLESS
+	   c is NULL, in which case a may be negative. */
+
+	z = (PyLongObject *)PyLong_FromLong(1L);
+	if (z == NULL)
+		goto Error;
+
+	/* Perform a modular reduction, X = X % c, but leave X alone if c
+	 * is NULL.
+	 */
+#define REDUCE(X)					\
+	if (c != NULL) {				\
+		if (l_divmod(X, c, NULL, &temp) < 0)	\
+			goto Error;			\
+		Py_XDECREF(X);				\
+		X = temp;				\
+		temp = NULL;				\
+	}
+
+	/* Multiply two values, then reduce the result:
+	   result = X*Y % c.  If c is NULL, skip the mod. */
+#define MULT(X, Y, result)				\
+{							\
+	temp = (PyLongObject *)long_mul(X, Y);		\
+	if (temp == NULL)				\
+		goto Error;				\
+	Py_XDECREF(result);				\
+	result = temp;					\
+	temp = NULL;					\
+	REDUCE(result)					\
+}
+
+	if (b->ob_size <= FIVEARY_CUTOFF) {
+		/* Left-to-right binary exponentiation (HAC Algorithm 14.79) */
+		/* http://www.cacr.math.uwaterloo.ca/hac/about/chap14.pdf    */
+		for (i = b->ob_size - 1; i >= 0; --i) {
+			digit bi = b->ob_digit[i];
+
+			for (j = 1 << (SHIFT-1); j != 0; j >>= 1) {
+				MULT(z, z, z)
+				if (bi & j)
+					MULT(z, a, z)
+			}
+		}
+	}
+	else {
+		/* Left-to-right 5-ary exponentiation (HAC Algorithm 14.82) */
+		Py_INCREF(z);	/* still holds 1L */
+		table[0] = z;
+		for (i = 1; i < 32; ++i)
+			MULT(table[i-1], a, table[i])
+
+		for (i = b->ob_size - 1; i >= 0; --i) {
+			const digit bi = b->ob_digit[i];
+
+			for (j = SHIFT - 5; j >= 0; j -= 5) {
+				const int index = (bi >> j) & 0x1f;
+				for (k = 0; k < 5; ++k)
+					MULT(z, z, z)
+				if (index)
+					MULT(z, table[index], z)
+			}
+		}
+	}
+
+	if (negativeOutput && (z->ob_size != 0)) {
+		temp = (PyLongObject *)long_sub(z, c);
+		if (temp == NULL)
+			goto Error;
+		Py_DECREF(z);
+		z = temp;
+		temp = NULL;
+	}
+	goto Done;
+
+ Error:
+ 	if (z != NULL) {
+ 		Py_DECREF(z);
+ 		z = NULL;
+ 	}
+	/* fall through */
+ Done:
 	Py_XDECREF(a);
-	Py_DECREF(b);
-	Py_DECREF(c);
+	Py_XDECREF(b);
+	Py_XDECREF(c);
+	Py_XDECREF(temp);
+	if (b->ob_size > FIVEARY_CUTOFF) {
+		for (i = 0; i < 32; ++i)
+			Py_XDECREF(table[i]);
+	}
 	return (PyObject *)z;
 }