Fix a threading bug in passphrase callback support for context objects.

Also add a bunch of unit tests for loading and dumping private keys with passphrases.
diff --git a/ChangeLog b/ChangeLog
index e7cab13..f5f89ff 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,12 @@
+2008-04-26  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
+
+	* src/ssl/context.c: Change global_passphrase_callback so that it
+	  acquires the GIL before invoking any CPython APIs and does not
+	  release it until after it is finished invoking all of them.
+	* test/test_crypto.py: Add tests for load_privatekey and
+	  dump_privatekey when a passphrase or a passphrase callback is
+	  supplied.
+
 2008-04-11  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
 
 	* Release 0.7
@@ -102,7 +111,7 @@
 2004-07-19  Martin Sjögren  <msjogren@gmail.com>
 
 	* doc/pyOpenSSL.tex: Fix the errors regarding X509Name's field names.
-	
+
 2004-07-18  Martin Sjögren  <msjogren@gmail.com>
 
 	* examples/certgen.py: Fixed wrong attributes in doc string, thanks
diff --git a/leakcheck/context-passphrase-callback.py b/leakcheck/context-passphrase-callback.py
new file mode 100644
index 0000000..0f0933c
--- /dev/null
+++ b/leakcheck/context-passphrase-callback.py
@@ -0,0 +1,33 @@
+# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
+#
+# Stress tester for thread-related bugs in global_passphrase_callback in
+# src/ssl/context.c.  In 0.7 and earlier, this will somewhat reliably
+# segfault or abort after a few dozen to a few thousand iterations on an SMP
+# machine (generally not on a UP machine) due to uses of Python/C API
+# without holding the GIL.
+
+from itertools import count
+from threading import Thread
+
+from OpenSSL.SSL import Context, TLSv1_METHOD
+from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey
+
+k = PKey()
+k.generate_key(TYPE_RSA, 128)
+file('pkey.pem', 'w').write(dump_privatekey(FILETYPE_PEM, k, "blowfish", "foobar"))
+
+count = count()
+def go():
+    def cb(a, b, c):
+        print count.next()
+        return "foobar"
+    c = Context(TLSv1_METHOD)
+    c.set_passwd_cb(cb)
+    while 1:
+        c.use_privatekey_file('pkey.pem')
+
+threads = [Thread(target=go, args=()) for i in xrange(2)]
+for th in threads:
+    th.start()
+for th in threads:
+    th.join()
diff --git a/src/ssl/context.c b/src/ssl/context.c
index ceeb8fb..1f9d512 100644
--- a/src/ssl/context.c
+++ b/src/ssl/context.c
@@ -51,7 +51,9 @@
  */
 
 /*
- * Globally defined passphrase callback.
+ * Globally defined passphrase callback.  This is called from OpenSSL
+ * internally.  The GIL will not be held when this function is invoked.  It
+ * must not be held when the function returns.
  *
  * Arguments: buf    - Buffer to store the returned passphrase in
  *            maxlen - Maximum length of the passphrase
@@ -64,49 +66,77 @@
 static int
 global_passphrase_callback(char *buf, int maxlen, int verify, void *arg)
 {
-    int len;
+    /*
+     * Initialize len here because we're always going to return it, and we
+     * might jump to the return before it gets initialized in any other way.
+     */
+    int len = 0;
     char *str;
     PyObject *argv, *ret = NULL;
     ssl_ContextObj *ctx = (ssl_ContextObj *)arg;
 
     /* The Python callback is called with a (maxlen,verify,userdata) tuple */
     argv = Py_BuildValue("(iiO)", maxlen, verify, ctx->passphrase_userdata);
-    if (ctx->tstate != NULL)
-    {
-        /* We need to get back our thread state before calling the callback */
-        MY_END_ALLOW_THREADS(ctx->tstate);
-        ret = PyEval_CallObject(ctx->passphrase_callback, argv);
-        MY_BEGIN_ALLOW_THREADS(ctx->tstate);
-    }
-    else
-    {
-        ret = PyEval_CallObject(ctx->passphrase_callback, argv);
-    }
+
+    /*
+     * GIL isn't held yet.  First things first - acquire it, or any Python API
+     * we invoke might segfault or blow up the sun.  The reverse will be done
+     * before returning.
+     */
+    MY_END_ALLOW_THREADS(ctx->tstate);
+
+    /*
+     * XXX Didn't check argv to see if it was NULL. -exarkun
+     */
+    ret = PyEval_CallObject(ctx->passphrase_callback, argv);
     Py_DECREF(argv);
 
-    if (ret == NULL)
-        return 0;
-
-    if (!PyObject_IsTrue(ret))
-    {
-        Py_DECREF(ret);
-	return 0;
+    if (ret == NULL) {
+        /*
+         * XXX The callback raised an exception.  At the very least, it should
+         * be printed out here.  An *actual* solution would be to raise it up
+         * through OpenSSL.  That might be a bit tricky, but it's probably
+         * possible. -exarkun
+         */
+        goto out;
     }
 
-    if (!PyString_Check(ret))
-    {
+    if (!PyObject_IsTrue(ret)) {
+        /*
+         * Returned "", or None, or something.  Treat it as no passphrase.
+         */
         Py_DECREF(ret);
-        return 0;
+	goto out;
+    }
+
+    if (!PyString_Check(ret)) {
+        /*
+         * XXX Returned something that wasn't a string.  This is bogus.  We
+         * should report an error or raise an exception (again, through OpenSSL
+         * - tricky). -exarkun
+         */
+        Py_DECREF(ret);
+        goto out;
     }
 
     len = PyString_Size(ret);
-    if (len > maxlen)
+    if (len > maxlen) {
+        /*
+         * XXX Returned more than we said they were allowed to return.  Report
+         * an error or raise an exception (tricky blah blah). -exarkun
+         */
         len = maxlen;
+    }
 
     str = PyString_AsString(ret);
     strncpy(buf, str, len);
     Py_XDECREF(ret);
 
+  out:
+    /*
+     * This function is returning into OpenSSL.  Release the GIL again.
+     */
+    MY_BEGIN_ALLOW_THREADS(ctx->tstate);
     return len;
 }
 
diff --git a/test/test_crypto.py b/test/test_crypto.py
index b44f345..7d46f91 100644
--- a/test/test_crypto.py
+++ b/test/test_crypto.py
@@ -9,7 +9,50 @@
 from OpenSSL.crypto import TYPE_RSA, TYPE_DSA, Error, PKey, PKeyType
 from OpenSSL.crypto import X509, X509Type, X509Name, X509NameType
 from OpenSSL.crypto import X509Req, X509ReqType
-from OpenSSL.crypto import FILETYPE_PEM, load_certificate
+from OpenSSL.crypto import FILETYPE_PEM, load_certificate, load_privatekey
+from OpenSSL.crypto import dump_privatekey
+
+
+cleartextPrivateKeyPEM = (
+    "-----BEGIN RSA PRIVATE KEY-----\n"
+    "MIICXAIBAAKBgQDaemNe1syksAbFFpF3aoOrZ18vB/IQNZrAjFqXPv9iieJm7+Tc\n"
+    "g+lA/v0qmoEKrpT2xfwxXmvZwBNM4ZhyRC3DPIFEyJV7/3IA1p5iuMY/GJI1VIgn\n"
+    "aikQCnrsyxtaRpsMBeZRniaVzcUJ+XnEdFGEjlo+k0xlwfVclDEMwgpXAQIDAQAB\n"
+    "AoGBALi0a7pMQqqgnriVAdpBVJveQtxSDVWi2/gZMKVZfzNheuSnv4amhtaKPKJ+\n"
+    "CMZtHkcazsE2IFvxRN/kgato9H3gJqq8nq2CkdpdLNVKBoxiCtkLfutdY4SQLtoY\n"
+    "USN7exk131pchsAJXYlR6mCW+ZP+E523cNwpPgsyKxVbmXSBAkEA9470fy2W0jFM\n"
+    "taZFslpntKSzbvn6JmdtjtvWrM1bBaeeqFiGBuQFYg46VaCUaeRWYw02jmYAsDYh\n"
+    "ZQavmXThaQJBAOHtlAQ0IJJEiMZr6vtVPH32fmbthSv1AUSYPzKqdlQrUnOXPQXu\n"
+    "z70cFoLG1TvPF5rBxbOkbQ/s8/ka5ZjPfdkCQCeC7YsO36+UpsWnUCBzRXITh4AC\n"
+    "7eYLQ/U1KUJTVF/GrQ/5cQrQgftwgecAxi9Qfmk4xqhbp2h4e0QAmS5I9WECQH02\n"
+    "0QwrX8nxFeTytr8pFGezj4a4KVCdb2B3CL+p3f70K7RIo9d/7b6frJI6ZL/LHQf2\n"
+    "UP4pKRDkgKsVDx7MELECQGm072/Z7vmb03h/uE95IYJOgY4nfmYs0QKA9Is18wUz\n"
+    "DpjfE33p0Ha6GO1VZRIQoqE24F8o5oimy3BEjryFuw4=\n"
+    "-----END RSA PRIVATE KEY-----\n")
+
+
+encryptedPrivateKeyPEM = (
+    "-----BEGIN RSA PRIVATE KEY-----\n"
+    "Proc-Type: 4,ENCRYPTED\n"
+    "DEK-Info: BF-CBC,8306665233D056B1\n"
+    "\n"
+    "BwxghOcX1F+M108qRGBfpUBrfaeKOszDEV18OjEE55p0yGsiDxvdol3c4bwI5ITy\n"
+    "ltP8w9O33CDUCjr+Ymj8xLpPP60TTfr/aHq+2fEuG4TfkeHb5fVYm0mgVnaOhJs3\n"
+    "a2n5IL/KNCdP3zMZa0IaMJ0M+VK90SLpq5nzXOWkufLyZL1+n8srkk06gepmHS7L\n"
+    "rH3rALNboG8yTH1qjE8PwcMrJAQfRMd4/4RTQv+4pUuKj7I2en+YwSQ/gomy7qN1\n"
+    "3s/gMgV/2GUbEcTVch4thZ9l3WsX18V76rBQkiZ7yrJkxwNMv+Qc2GfHtBnsXAyA\n"
+    "0nIE4Mm/OQqX8h7EJ4c2s1DMGVS0YZGU+75HN0A3iD01h8C5utqSScWzBA45j/Vy\n"
+    "3aypQVqQeW7kBMQlpc6pHvJ1EsjiAJRCto7tZNLxRdjMKBV4w75JNLaAFSraqA+R\n"
+    "/WPcdcXAQuhmCeh31fzmVOHJGRF7/5pAR/b7AnFTD4YbYVcglNis/jpdiI9k2AYP\n"
+    "wZNwXOIh6Ibq5hMvyV4/pySyLbgDOrfrOGpi8N6lBbzewByYQKiXwUEZf+Y5499/\n"
+    "CckajBhgYynPpe6mgsSeklWGc845iIwAtzavBNZIkn1hKP1P+TFjbl2O75u/9JLJ\n"
+    "6i4IFYCyQmwiHX8nTR717SpCN2gyZ2HrX7z2mKP/KokkAX2yidwoKh9FMUV5lOGO\n"
+    "JPc4MfPo4lPB7SP30AtOh7y7zlS3x8Uo0+0wCg5Z5Fn/73x3W+p5nyI0G9n7RGzL\n"
+    "ZeCWLdG/Cm6ZyIpYZGbZ5m+U3Fr6/El9V6LSxrB1TB+8G1NTdLlbeA==\n"
+    "-----END RSA PRIVATE KEY-----\n")
+encryptedPrivateKeyPEMPassphrase = "foobar"
+
+
 
 class _Python23TestCaseHelper:
     # Python 2.3 compatibility.
@@ -398,22 +441,7 @@
 RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q
 UBj849/xpszEM7BhwKE0GiQ=
 -----END CERTIFICATE-----
------BEGIN RSA PRIVATE KEY-----
-MIICXAIBAAKBgQDaemNe1syksAbFFpF3aoOrZ18vB/IQNZrAjFqXPv9iieJm7+Tc
-g+lA/v0qmoEKrpT2xfwxXmvZwBNM4ZhyRC3DPIFEyJV7/3IA1p5iuMY/GJI1VIgn
-aikQCnrsyxtaRpsMBeZRniaVzcUJ+XnEdFGEjlo+k0xlwfVclDEMwgpXAQIDAQAB
-AoGBALi0a7pMQqqgnriVAdpBVJveQtxSDVWi2/gZMKVZfzNheuSnv4amhtaKPKJ+
-CMZtHkcazsE2IFvxRN/kgato9H3gJqq8nq2CkdpdLNVKBoxiCtkLfutdY4SQLtoY
-USN7exk131pchsAJXYlR6mCW+ZP+E523cNwpPgsyKxVbmXSBAkEA9470fy2W0jFM
-taZFslpntKSzbvn6JmdtjtvWrM1bBaeeqFiGBuQFYg46VaCUaeRWYw02jmYAsDYh
-ZQavmXThaQJBAOHtlAQ0IJJEiMZr6vtVPH32fmbthSv1AUSYPzKqdlQrUnOXPQXu
-z70cFoLG1TvPF5rBxbOkbQ/s8/ka5ZjPfdkCQCeC7YsO36+UpsWnUCBzRXITh4AC
-7eYLQ/U1KUJTVF/GrQ/5cQrQgftwgecAxi9Qfmk4xqhbp2h4e0QAmS5I9WECQH02
-0QwrX8nxFeTytr8pFGezj4a4KVCdb2B3CL+p3f70K7RIo9d/7b6frJI6ZL/LHQf2
-UP4pKRDkgKsVDx7MELECQGm072/Z7vmb03h/uE95IYJOgY4nfmYs0QKA9Is18wUz
-DpjfE33p0Ha6GO1VZRIQoqE24F8o5oimy3BEjryFuw4=
------END RSA PRIVATE KEY-----
-"""
+""" + cleartextPrivateKeyPEM
 
     def signable(self):
         """
@@ -534,3 +562,93 @@
         self.assertEqual(
             cert.digest("md5"),
             "A8:EB:07:F8:53:25:0A:F2:56:05:C5:A5:C4:C4:C7:15")
+
+
+
+class FunctionTests(TestCase):
+    """
+    Tests for free-functions in the L{OpenSSL.crypto} module.
+    """
+    def test_load_privatekey_wrongPassphrase(self):
+        """
+        L{load_privatekey} raises L{OpenSSL.crypto.Error} when it is passed an
+        encrypted PEM and an incorrect passphrase.
+        """
+        self.assertRaises(
+            Error,
+            load_privatekey, FILETYPE_PEM, encryptedPrivateKeyPEM, "quack")
+
+
+    def test_load_privatekey_passphrase(self):
+        """
+        L{load_privatekey} can create a L{PKey} object from an encrypted PEM
+        string if given the passphrase.
+        """
+        key = load_privatekey(
+            FILETYPE_PEM, encryptedPrivateKeyPEM,
+            encryptedPrivateKeyPEMPassphrase)
+        self.assertTrue(isinstance(key, PKeyType))
+
+
+    def test_load_privatekey_wrongPassphraseCallback(self):
+        """
+        L{load_privatekey} raises L{OpenSSL.crypto.Error} when it is passed an
+        encrypted PEM and a passphrase callback which returns an incorrect
+        passphrase.
+        """
+        called = []
+        def cb(*a):
+            called.append(None)
+            return "quack"
+        self.assertRaises(
+            Error,
+            load_privatekey, FILETYPE_PEM, encryptedPrivateKeyPEM, cb)
+        self.assertTrue(called)
+
+    def test_load_privatekey_passphraseCallback(self):
+        """
+        L{load_privatekey} can create a L{PKey} object from an encrypted PEM
+        string if given a passphrase callback which returns the correct
+        password.
+        """
+        called = []
+        def cb(writing):
+            called.append(writing)
+            return encryptedPrivateKeyPEMPassphrase
+        key = load_privatekey(FILETYPE_PEM, encryptedPrivateKeyPEM, cb)
+        self.assertTrue(isinstance(key, PKeyType))
+        self.assertEqual(called, [False])
+
+
+    def test_dump_privatekey_passphrase(self):
+        """
+        L{dump_privatekey} writes an encrypted PEM when given a passphrase.
+        """
+        passphrase = "foo"
+        key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)
+        pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase)
+        self.assertTrue(isinstance(pem, str))
+        loadedKey = load_privatekey(FILETYPE_PEM, pem, passphrase)
+        self.assertTrue(isinstance(loadedKey, PKeyType))
+        self.assertEqual(loadedKey.type(), key.type())
+        self.assertEqual(loadedKey.bits(), key.bits())
+
+
+    def test_dump_privatekey_passphraseCallback(self):
+        """
+        L{dump_privatekey} writes an encrypted PEM when given a callback which
+        returns the correct passphrase.
+        """
+        passphrase = "foo"
+        called = []
+        def cb(writing):
+            called.append(writing)
+            return passphrase
+        key = load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM)
+        pem = dump_privatekey(FILETYPE_PEM, key, "blowfish", cb)
+        self.assertTrue(isinstance(pem, str))
+        self.assertEqual(called, [True])
+        loadedKey = load_privatekey(FILETYPE_PEM, pem, passphrase)
+        self.assertTrue(isinstance(loadedKey, PKeyType))
+        self.assertEqual(loadedKey.type(), key.type())
+        self.assertEqual(loadedKey.bits(), key.bits())
diff --git a/test/test_ssl.py b/test/test_ssl.py
index 762aee5..cb3d1b3 100644
--- a/test/test_ssl.py
+++ b/test/test_ssl.py
@@ -5,8 +5,9 @@
 """
 
 from unittest import TestCase
+from tempfile import mktemp
 
-from OpenSSL.crypto import TYPE_RSA, PKey
+from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey
 from OpenSSL.SSL import Context
 from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
 
@@ -15,6 +16,13 @@
     """
     Unit tests for L{OpenSSL.SSL.Context}.
     """
+    def mktemp(self):
+        """
+        Pathetic substitute for twisted.trial.unittest.TestCase.mktemp.
+        """
+        return mktemp(dir=".")
+
+
     def test_method(self):
         """
         L{Context} can be instantiated with one of L{SSLv2_METHOD},
@@ -35,3 +43,29 @@
         ctx = Context(TLSv1_METHOD)
         ctx.use_privatekey(key)
         self.assertRaises(TypeError, ctx.use_privatekey, "")
+
+
+    def test_set_passwd_cb(self):
+        """
+        L{Context.set_passwd_cb} accepts a callable which will be invoked when
+        a private key is loaded from an encrypted PEM.
+        """
+        key = PKey()
+        key.generate_key(TYPE_RSA, 128)
+        pemFile = self.mktemp()
+        fObj = file(pemFile, 'w')
+        passphrase = "foobar"
+        fObj.write(dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase))
+        fObj.close()
+
+        calledWith = []
+        def passphraseCallback(maxlen, verify, extra):
+            calledWith.append((maxlen, verify, extra))
+            return passphrase
+        context = Context(TLSv1_METHOD)
+        context.set_passwd_cb(passphraseCallback)
+        context.use_privatekey_file(pemFile)
+        self.assertTrue(len(calledWith), 1)
+        self.assertTrue(isinstance(calledWith[0][0], int))
+        self.assertTrue(isinstance(calledWith[0][1], int))
+        self.assertEqual(calledWith[0][2], None)