Fix exception propagation and some error handling related to the privatekey passphrase callback.
diff --git a/ChangeLog b/ChangeLog
index e52ad96..6d1440f 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,13 @@
+2011-09-14  Žiga Seilnacht <lp:ziga-seilnacht>
+
+	* OpenSSL/crypto/crypto.c: Allow exceptions from passphrase
+	  callbacks to propagate up out of load_privatekey
+	* OpenSSL/crypto/crypto.c: Raise an exception when a too-long
+	  passphrase is returned from a passphrase callback, instead of
+	  silently truncating it.
+	* OpenSSL/crypto/crypto.c: Fix a memory leak when a passphrase
+	  callback returns the wrong type.
+
 2011-09-13  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
 
 	* OpenSSL/crypto/crl.c: Add error handling for the use of
diff --git a/OpenSSL/crypto/crypto.c b/OpenSSL/crypto/crypto.c
index 1442b5a..ad35ce9 100644
--- a/OpenSSL/crypto/crypto.c
+++ b/OpenSSL/crypto/crypto.c
@@ -45,22 +45,72 @@
 
     func = (PyObject *)cb_arg;
     argv = Py_BuildValue("(i)", rwflag);
+    if (argv == NULL) {
+        return 0;
+    }
     ret = PyEval_CallObject(func, argv);
     Py_DECREF(argv);
-    if (ret == NULL)
+    if (ret == NULL) {
         return 0;
-    if (!PyBytes_Check(ret))
-    {
+    }
+    if (!PyBytes_Check(ret)) {
+        Py_DECREF(ret);
         PyErr_SetString(PyExc_ValueError, "String expected");
         return 0;
     }
     nchars = PyBytes_Size(ret);
-    if (nchars > len)
-        nchars = len;
+    if (nchars > len) {
+        Py_DECREF(ret);
+        PyErr_SetString(PyExc_ValueError,
+                        "passphrase returned by callback is too long");
+        return 0;
+    }
     strncpy(buf, PyBytes_AsString(ret), nchars);
+    Py_DECREF(ret);
     return nchars;
 }
 
+static PyObject *
+raise_current_error(void)
+{
+    if (PyErr_Occurred()) {
+        /*
+         * The python exception from callback is more informative than
+         * OpenSSL's error.
+         */
+        flush_error_queue();
+        return NULL;
+    }
+    exception_from_error_queue(crypto_Error);
+    return NULL;
+}
+
+static int
+setup_callback(int type, PyObject *pw, pem_password_cb **cb, void **cb_arg) {
+    if (pw == NULL) {
+        *cb = NULL;
+        *cb_arg = NULL;
+        return 1;
+    }
+    if (type != X509_FILETYPE_PEM) {
+        PyErr_SetString(PyExc_ValueError,
+                        "only FILETYPE_PEM key format supports encryption");
+        return 0;
+    }
+    if (PyBytes_Check(pw)) {
+        *cb = NULL;
+        *cb_arg = PyBytes_AsString(pw);
+    } else if (PyCallable_Check(pw)) {
+        *cb = global_passphrase_callback;
+        *cb_arg = pw;
+    } else {
+        PyErr_SetString(PyExc_TypeError,
+                        "Last argument must be string or callable");
+        return 0;
+    }
+    return 1;
+}
+
 static char crypto_load_privatekey_doc[] = "\n\
 Load a private key from a buffer\n\
 \n\
@@ -85,31 +135,20 @@
     BIO *bio;
     EVP_PKEY *pkey;
 
-    if (!PyArg_ParseTuple(args, "is#|O:load_privatekey", &type, &buffer, &len, &pw))
+    if (!PyArg_ParseTuple(args, "is#|O:load_privatekey",
+                          &type, &buffer, &len, &pw)) {
         return NULL;
-
-    if (pw != NULL)
-    {
-        if (PyBytes_Check(pw))
-        {
-            cb = NULL;
-            cb_arg = PyBytes_AsString(pw);
-        }
-        else if (PyCallable_Check(pw))
-        {
-            cb = global_passphrase_callback;
-            cb_arg = pw;
-        }
-        else
-        {
-            PyErr_SetString(PyExc_TypeError, "Last argument must be string or callable");
-            return NULL;
-        }
+    }
+    if (!setup_callback(type, pw, &cb, &cb_arg)) {
+        return NULL;
     }
 
     bio = BIO_new_mem_buf(buffer, len);
-    switch (type)
-    {
+    if (bio == NULL) {
+        exception_from_error_queue(crypto_Error);
+        return NULL;
+    }
+    switch (type) {
         case X509_FILETYPE_PEM:
             pkey = PEM_read_bio_PrivateKey(bio, NULL, cb, cb_arg);
             break;
@@ -125,10 +164,8 @@
     }
     BIO_free(bio);
 
-    if (pkey == NULL)
-    {
-        exception_from_error_queue(crypto_Error);
-        return NULL;
+    if (pkey == NULL) {
+        return raise_current_error();
     }
 
     return (PyObject *)crypto_PKey_New(pkey, 1);
@@ -164,49 +201,32 @@
     crypto_PKeyObj *pkey;
 
     if (!PyArg_ParseTuple(args, "iO!|sO:dump_privatekey", &type,
-			  &crypto_PKey_Type, &pkey, &cipher_name, &pw))
+                          &crypto_PKey_Type, &pkey, &cipher_name, &pw)) {
         return NULL;
-
-    if (cipher_name != NULL && pw == NULL)
-    {
+    }
+    if (cipher_name != NULL && pw == NULL) {
         PyErr_SetString(PyExc_ValueError, "Illegal number of arguments");
         return NULL;
     }
-    if (cipher_name != NULL)
-    {
+    if (cipher_name != NULL) {
         cipher = EVP_get_cipherbyname(cipher_name);
-        if (cipher == NULL)
-        {
+        if (cipher == NULL) {
             PyErr_SetString(PyExc_ValueError, "Invalid cipher name");
             return NULL;
         }
-        if (PyBytes_Check(pw))
-        {
-            cb = NULL;
-            cb_arg = PyBytes_AsString(pw);
-        }
-        else if (PyCallable_Check(pw))
-        {
-            cb = global_passphrase_callback;
-            cb_arg = pw;
-        }
-        else
-        {
-            PyErr_SetString(PyExc_TypeError, "Last argument must be string or callable");
+        if (!setup_callback(type, pw, &cb, &cb_arg)) {
             return NULL;
         }
     }
 
     bio = BIO_new(BIO_s_mem());
-    switch (type)
-    {
+    if (bio == NULL) {
+        exception_from_error_queue(crypto_Error);
+        return NULL;
+    }
+    switch (type) {
         case X509_FILETYPE_PEM:
             ret = PEM_write_bio_PrivateKey(bio, pkey->pkey, cipher, NULL, 0, cb, cb_arg);
-            if (PyErr_Occurred())
-            {
-                BIO_free(bio);
-                return NULL;
-            }
             break;
 
         case X509_FILETYPE_ASN1:
@@ -215,8 +235,12 @@
 
         case X509_FILETYPE_TEXT:
             rsa = EVP_PKEY_get1_RSA(pkey->pkey);
+            if (rsa == NULL) {
+                ret = 0;
+                break;
+            }
             ret = RSA_print(bio, rsa, 0);
-            RSA_free(rsa); 
+            RSA_free(rsa);
             break;
 
         default:
@@ -225,11 +249,9 @@
             return NULL;
     }
 
-    if (ret == 0)
-    {
+    if (ret == 0) {
         BIO_free(bio);
-        exception_from_error_queue(crypto_Error);
-        return NULL;
+        return raise_current_error();
     }
 
     buf_len = BIO_get_mem_data(bio, &temp);
@@ -509,8 +531,8 @@
     if (!PyArg_ParseTuple(args, "is#:load_pkcs7_data", &type, &buffer, &len))
         return NULL;
 
-    /* 
-     * Try to read the pkcs7 data from the bio 
+    /*
+     * Try to read the pkcs7 data from the bio
      */
     bio = BIO_new_mem_buf(buffer, len);
     switch (type)
diff --git a/OpenSSL/test/test_crypto.py b/OpenSSL/test/test_crypto.py
index 56638e8..e0d7b27 100644
--- a/OpenSSL/test/test_crypto.py
+++ b/OpenSSL/test/test_crypto.py
@@ -7,7 +7,7 @@
 
 from unittest import main
 
-import os, re
+import os, re, sys
 from subprocess import PIPE, Popen
 from datetime import datetime, timedelta
 
@@ -1979,6 +1979,18 @@
             load_privatekey, FILETYPE_PEM, encryptedPrivateKeyPEM, b("quack"))
 
 
+    def test_load_privatekey_passphraseWrongType(self):
+        """
+        :py:obj:`load_privatekey` raises :py:obj:`ValueError` when it is passed a passphrase
+        with a private key encoded in a format, that doesn't support
+        encryption.
+        """
+        key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)
+        blob = dump_privatekey(FILETYPE_ASN1, key)
+        self.assertRaises(ValueError,
+            load_privatekey, FILETYPE_ASN1, blob, "secret")
+
+
     def test_load_privatekey_passphrase(self):
         """
         :py:obj:`load_privatekey` can create a :py:obj:`PKey` object from an encrypted PEM
@@ -1990,16 +2002,28 @@
         self.assertTrue(isinstance(key, PKeyType))
 
 
+    def test_load_privatekey_passphrase_exception(self):
+        """
+        If the passphrase callback raises an exception, that exception is raised
+        by :py:obj:`load_privatekey`.
+        """
+        def cb(ignored):
+            raise ArithmeticError
+
+        self.assertRaises(ArithmeticError,
+            load_privatekey, FILETYPE_PEM, encryptedPrivateKeyPEM, cb)
+
+
     def test_load_privatekey_wrongPassphraseCallback(self):
         """
-        :py:obj:`load_privatekey` raises :py:obj:`OpenSSL.crypto.Error` when it is passed an
-        encrypted PEM and a passphrase callback which returns an incorrect
-        passphrase.
+        :py:obj:`load_privatekey` raises :py:obj:`OpenSSL.crypto.Error` when it
+        is passed an encrypted PEM and a passphrase callback which returns an
+        incorrect passphrase.
         """
         called = []
         def cb(*a):
             called.append(None)
-            return "quack"
+            return b("quack")
         self.assertRaises(
             Error,
             load_privatekey, FILETYPE_PEM, encryptedPrivateKeyPEM, cb)
@@ -2021,20 +2045,15 @@
         self.assertEqual(called, [False])
 
 
-    def test_load_privatekey_passphrase_exception(self):
+    def test_load_privatekey_passphrase_wrong_return_type(self):
         """
-        An exception raised by the passphrase callback passed to
-        :py:obj:`load_privatekey` causes :py:obj:`OpenSSL.crypto.Error` to be raised.
-
-        This isn't as nice as just letting the exception pass through.  The
-        behavior might be changed to that eventually.
+        :py:obj:`load_privatekey` raises :py:obj:`ValueError` if the passphrase
+        callback returns something other than a byte string.
         """
-        def broken(ignored):
-            raise RuntimeError("This is not working.")
         self.assertRaises(
-            Error,
+            ValueError,
             load_privatekey,
-            FILETYPE_PEM, encryptedPrivateKeyPEM, broken)
+            FILETYPE_PEM, encryptedPrivateKeyPEM, lambda *args: 3)
 
 
     def test_dump_privatekey_wrong_args(self):
@@ -2043,6 +2062,9 @@
         of arguments.
         """
         self.assertRaises(TypeError, dump_privatekey)
+        # If cipher name is given, password is required.
+        self.assertRaises(
+            ValueError, dump_privatekey, FILETYPE_PEM, PKey(), "foo")
 
 
     def test_dump_privatekey_unknown_cipher(self):
@@ -2079,6 +2101,18 @@
         self.assertRaises(ValueError, dump_privatekey, 100, key)
 
 
+    def test_load_privatekey_passphraseCallbackLength(self):
+        """
+        :py:obj:`crypto.load_privatekey` should raise an error when the passphrase
+        provided by the callback is too long, not silently truncate it.
+        """
+        def cb(ignored):
+            return "a" * 1025
+
+        self.assertRaises(ValueError,
+            load_privatekey, FILETYPE_PEM, encryptedPrivateKeyPEM, cb)
+
+
     def test_dump_privatekey_passphrase(self):
         """
         :py:obj:`dump_privatekey` writes an encrypted PEM when given a passphrase.
@@ -2093,6 +2127,17 @@
         self.assertEqual(loadedKey.bits(), key.bits())
 
 
+    def test_dump_privatekey_passphraseWrongType(self):
+        """
+        :py:obj:`dump_privatekey` raises :py:obj:`ValueError` when it is passed a passphrase
+        with a private key encoded in a format, that doesn't support
+        encryption.
+        """
+        key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)
+        self.assertRaises(ValueError,
+            dump_privatekey, FILETYPE_ASN1, key, "blowfish", "secret")
+
+
     def test_dump_certificate(self):
         """
         :py:obj:`dump_certificate` writes PEM, DER, and text.
@@ -2171,6 +2216,32 @@
         self.assertEqual(loadedKey.bits(), key.bits())
 
 
+    def test_dump_privatekey_passphrase_exception(self):
+        """
+        :py:obj:`dump_privatekey` should not overwrite the exception raised
+        by the passphrase callback.
+        """
+        def cb(ignored):
+            raise ArithmeticError
+
+        key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)
+        self.assertRaises(ArithmeticError,
+            dump_privatekey, FILETYPE_PEM, key, "blowfish", cb)
+
+
+    def test_dump_privatekey_passphraseCallbackLength(self):
+        """
+        :py:obj:`crypto.dump_privatekey` should raise an error when the passphrase
+        provided by the callback is too long, not silently truncate it.
+        """
+        def cb(ignored):
+            return "a" * 1025
+
+        key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)
+        self.assertRaises(ValueError,
+            dump_privatekey, FILETYPE_PEM, key, "blowfish", cb)
+
+
     def test_load_pkcs7_data(self):
         """
         :py:obj:`load_pkcs7_data` accepts a PKCS#7 string and returns an instance of
diff --git a/leakcheck/crypto.py b/leakcheck/crypto.py
index 07b77e5..6a9af92 100644
--- a/leakcheck/crypto.py
+++ b/leakcheck/crypto.py
@@ -3,16 +3,21 @@
 
 import sys
 
-from OpenSSL.crypto import TYPE_DSA, Error, PKey, X509
+from OpenSSL.crypto import (
+    FILETYPE_PEM, TYPE_DSA, Error, PKey, X509, load_privatekey)
 
-class Checker_X509_get_pubkey(object):
-    """
-    Leak checks for L{X509.get_pubkey}.
-    """
+
+
+class BaseChecker(object):
     def __init__(self, iterations):
         self.iterations = iterations
 
 
+
+class Checker_X509_get_pubkey(BaseChecker):
+    """
+    Leak checks for L{X509.get_pubkey}.
+    """
     def check_exception(self):
         """
         Call the method repeatedly such that it will raise an exception.
@@ -40,6 +45,62 @@
                     cert.get_pubkey()
 
 
+
+class Checker_load_privatekey(BaseChecker):
+    """
+    Leak checks for :py:obj:`load_privatekey`.
+    """
+    ENCRYPTED_PEM = """\
+-----BEGIN RSA PRIVATE KEY-----
+Proc-Type: 4,ENCRYPTED
+DEK-Info: BF-CBC,3763C340F9B5A1D0
+
+a/DO10mLjHLCAOG8/Hc5Lbuh3pfjvcTZiCexShP+tupkp0VxW2YbZjML8uoXrpA6
+fSPUo7cEC+r96GjV03ZIVhjmsxxesdWMpfkzXRpG8rUbWEW2KcCJWdSX8bEkuNW3
+uvAXdXZwiOrm56ANDo/48gj27GcLwnlA8ld39+ylAzkUJ1tcMVzzTjfcyd6BMFpR
+Yjg23ikseug6iWEsZQormdl0ITdYzmFpM+YYsG7kmmmi4UjCEYfb9zFaqJn+WZT2
+qXxmo2ZPFzmEVkuB46mf5GCqMwLRN2QTbIZX2+Dljj1Hfo5erf5jROewE/yzcTwO
+FCB5K3c2kkTv2KjcCAimjxkE+SBKfHg35W0wB0AWkXpVFO5W/TbHg4tqtkpt/KMn
+/MPnSxvYr/vEqYMfW4Y83c45iqK0Cyr2pwY60lcn8Kk=
+-----END RSA PRIVATE KEY-----
+"""
+    def check_load_privatekey_callback(self):
+        """
+        Call the function with an encrypted PEM and a passphrase callback.
+        """
+        for i in xrange(self.iterations * 10):
+            load_privatekey(
+                FILETYPE_PEM, self.ENCRYPTED_PEM, lambda *args: "hello, secret")
+
+
+    def check_load_privatekey_callback_incorrect(self):
+        """
+        Call the function with an encrypted PEM and a passphrase callback which
+        returns the wrong passphrase.
+        """
+        for i in xrange(self.iterations * 10):
+            try:
+                load_privatekey(
+                    FILETYPE_PEM, self.ENCRYPTED_PEM,
+                    lambda *args: "hello, public")
+            except Error:
+                pass
+
+
+    def check_load_privatekey_callback_wrong_type(self):
+        """
+        Call the function with an encrypted PEM and a passphrase callback which
+        returns a non-string.
+        """
+        for i in xrange(self.iterations * 10):
+            try:
+                load_privatekey(
+                    FILETYPE_PEM, self.ENCRYPTED_PEM,
+                    lambda *args: {})
+            except ValueError:
+                pass
+
+
 def vmsize():
     return [x for x in file('/proc/self/status').readlines() if 'VmSize' in x]