Fix a threading bug in the info callback support for context objects.

Also add some tests for Context.set_info_callback.
diff --git a/ChangeLog b/ChangeLog
index f5f89ff..2d0f388 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,11 +1,15 @@
 2008-04-26  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
 
-	* src/ssl/context.c: Change global_passphrase_callback so that it
-	  acquires the GIL before invoking any CPython APIs and does not
-	  release it until after it is finished invoking all of them.
+	* src/ssl/context.c: Change global_passphrase_callback and
+	  global_info_callback so that they acquire the GIL before
+	  invoking any CPython APIs and do not release it until after they
+	  are finished invoking all of them (based heavily on on patch
+	  from Dan Williams).
 	* test/test_crypto.py: Add tests for load_privatekey and
 	  dump_privatekey when a passphrase or a passphrase callback is
 	  supplied.
+	* test/test_ssl.py: Add tests for Context.set_passwd_cb and
+	  Context.set_info_callback.
 
 2008-04-11  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
 
diff --git a/leakcheck/context-info-callback.py b/leakcheck/context-info-callback.py
new file mode 100644
index 0000000..d4c9fa5
--- /dev/null
+++ b/leakcheck/context-info-callback.py
@@ -0,0 +1,96 @@
+# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
+#
+# Stress tester for thread-related bugs in global_info_callback in
+# src/ssl/context.c.  In 0.7 and earlier, this will somewhat reliably
+# segfault or abort after a few dozen to a few thousand iterations on an SMP
+# machine (generally not on a UP machine) due to uses of Python/C API
+# without holding the GIL.
+
+from itertools import count
+from threading import Thread
+from socket import socket
+
+from OpenSSL.SSL import Context, TLSv1_METHOD, Connection, WantReadError
+from OpenSSL.crypto import FILETYPE_PEM, load_certificate, load_privatekey
+
+cleartextPrivateKeyPEM = (
+    "-----BEGIN RSA PRIVATE KEY-----\n"
+    "MIICXAIBAAKBgQDaemNe1syksAbFFpF3aoOrZ18vB/IQNZrAjFqXPv9iieJm7+Tc\n"
+    "g+lA/v0qmoEKrpT2xfwxXmvZwBNM4ZhyRC3DPIFEyJV7/3IA1p5iuMY/GJI1VIgn\n"
+    "aikQCnrsyxtaRpsMBeZRniaVzcUJ+XnEdFGEjlo+k0xlwfVclDEMwgpXAQIDAQAB\n"
+    "AoGBALi0a7pMQqqgnriVAdpBVJveQtxSDVWi2/gZMKVZfzNheuSnv4amhtaKPKJ+\n"
+    "CMZtHkcazsE2IFvxRN/kgato9H3gJqq8nq2CkdpdLNVKBoxiCtkLfutdY4SQLtoY\n"
+    "USN7exk131pchsAJXYlR6mCW+ZP+E523cNwpPgsyKxVbmXSBAkEA9470fy2W0jFM\n"
+    "taZFslpntKSzbvn6JmdtjtvWrM1bBaeeqFiGBuQFYg46VaCUaeRWYw02jmYAsDYh\n"
+    "ZQavmXThaQJBAOHtlAQ0IJJEiMZr6vtVPH32fmbthSv1AUSYPzKqdlQrUnOXPQXu\n"
+    "z70cFoLG1TvPF5rBxbOkbQ/s8/ka5ZjPfdkCQCeC7YsO36+UpsWnUCBzRXITh4AC\n"
+    "7eYLQ/U1KUJTVF/GrQ/5cQrQgftwgecAxi9Qfmk4xqhbp2h4e0QAmS5I9WECQH02\n"
+    "0QwrX8nxFeTytr8pFGezj4a4KVCdb2B3CL+p3f70K7RIo9d/7b6frJI6ZL/LHQf2\n"
+    "UP4pKRDkgKsVDx7MELECQGm072/Z7vmb03h/uE95IYJOgY4nfmYs0QKA9Is18wUz\n"
+    "DpjfE33p0Ha6GO1VZRIQoqE24F8o5oimy3BEjryFuw4=\n"
+    "-----END RSA PRIVATE KEY-----\n")
+
+
+cleartextCertificatePEM = (
+    "-----BEGIN CERTIFICATE-----\n"
+    "MIICfTCCAeYCAQEwDQYJKoZIhvcNAQEEBQAwgYYxCzAJBgNVBAYTAlVTMRkwFwYD\n"
+    "VQQDExBweW9wZW5zc2wuc2YubmV0MREwDwYDVQQHEwhOZXcgWW9yazESMBAGA1UE\n"
+    "ChMJUHlPcGVuU1NMMREwDwYDVQQIEwhOZXcgWW9yazEQMA4GCSqGSIb3DQEJARYB\n"
+    "IDEQMA4GA1UECxMHVGVzdGluZzAeFw0wODAzMjUxOTA0MTNaFw0wOTAzMjUxOTA0\n"
+    "MTNaMIGGMQswCQYDVQQGEwJVUzEZMBcGA1UEAxMQcHlvcGVuc3NsLnNmLm5ldDER\n"
+    "MA8GA1UEBxMITmV3IFlvcmsxEjAQBgNVBAoTCVB5T3BlblNTTDERMA8GA1UECBMI\n"
+    "TmV3IFlvcmsxEDAOBgkqhkiG9w0BCQEWASAxEDAOBgNVBAsTB1Rlc3RpbmcwgZ8w\n"
+    "DQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANp6Y17WzKSwBsUWkXdqg6tnXy8H8hA1\n"
+    "msCMWpc+/2KJ4mbv5NyD6UD+/SqagQqulPbF/DFea9nAE0zhmHJELcM8gUTIlXv/\n"
+    "cgDWnmK4xj8YkjVUiCdqKRAKeuzLG1pGmwwF5lGeJpXNxQn5ecR0UYSOWj6TTGXB\n"
+    "9VyUMQzCClcBAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAmm0Vzvv1O91WLl2LnF2P\n"
+    "q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5\n"
+    "RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q\n"
+    "UBj849/xpszEM7BhwKE0GiQ=\n"
+    "-----END CERTIFICATE-----\n")
+
+count = count()
+def go():
+    port = socket()
+    port.bind(('', 0))
+    port.listen(1)
+
+    called = []
+    def info(conn, where, ret):
+        print count.next()
+        called.append(None)
+    context = Context(TLSv1_METHOD)
+    context.set_info_callback(info)
+    context.use_certificate(
+        load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
+    context.use_privatekey(
+        load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
+
+    while 1:
+        client = socket()
+        client.setblocking(False)
+        client.connect_ex(port.getsockname())
+
+        clientSSL = Connection(Context(TLSv1_METHOD), client)
+        clientSSL.set_connect_state()
+
+        server, ignored = port.accept()
+        server.setblocking(False)
+
+        serverSSL = Connection(context, server)
+        serverSSL.set_accept_state()
+
+        del called[:]
+        while not called:
+            for ssl in clientSSL, serverSSL:
+                try:
+                    ssl.do_handshake()
+                except WantReadError:
+                    pass
+
+
+threads = [Thread(target=go, args=()) for i in xrange(2)]
+for th in threads:
+    th.start()
+for th in threads:
+    th.join()
diff --git a/src/ssl/context.c b/src/ssl/context.c
index 1f9d512..6497304 100644
--- a/src/ssl/context.c
+++ b/src/ssl/context.c
@@ -75,9 +75,6 @@
     PyObject *argv, *ret = NULL;
     ssl_ContextObj *ctx = (ssl_ContextObj *)arg;
 
-    /* The Python callback is called with a (maxlen,verify,userdata) tuple */
-    argv = Py_BuildValue("(iiO)", maxlen, verify, ctx->passphrase_userdata);
-
     /*
      * GIL isn't held yet.  First things first - acquire it, or any Python API
      * we invoke might segfault or blow up the sun.  The reverse will be done
@@ -85,6 +82,9 @@
      */
     MY_END_ALLOW_THREADS(ctx->tstate);
 
+    /* The Python callback is called with a (maxlen,verify,userdata) tuple */
+    argv = Py_BuildValue("(iiO)", maxlen, verify, ctx->passphrase_userdata);
+
     /*
      * XXX Didn't check argv to see if it was NULL. -exarkun
      */
@@ -157,19 +157,19 @@
     ssl_ConnectionObj *conn;
     crypto_X509Obj *cert;
     int errnum, errdepth, c_ret, use_thread_state;
-        
+
     // Get Connection object to check thread state
     ssl = (SSL *)X509_STORE_CTX_get_app_data(x509_ctx);
     conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl);
-    
+
     use_thread_state = conn->tstate != NULL;
     if (use_thread_state)
         MY_END_ALLOW_THREADS(conn->tstate);
-    
+
     cert = crypto_X509_New(X509_STORE_CTX_get_current_cert(x509_ctx), 0);
     errnum = X509_STORE_CTX_get_error(x509_ctx);
     errdepth = X509_STORE_CTX_get_error_depth(x509_ctx);
-    
+
     argv = Py_BuildValue("(OOiii)", (PyObject *)conn, (PyObject *)cert,
                                     errnum, errdepth, ok);
     Py_DECREF(cert);
@@ -190,7 +190,9 @@
 }
 
 /*
- * Globally defined info callback
+ * Globally defined info callback.  This is called from OpenSSL internally.
+ * The GIL will not be held when this function is invoked.  It must not be held
+ * when the function returns.
  *
  * Arguments: ssl   - The Connection
  *            where - The part of the SSL code that called us
@@ -203,28 +205,30 @@
     ssl_ConnectionObj *conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl);
     PyObject *argv, *ret;
 
+    /*
+     * GIL isn't held yet.  First things first - acquire it, or any Python API
+     * we invoke might segfault or blow up the sun.  The reverse will be done
+     * before returning.
+     */
+    MY_END_ALLOW_THREADS(conn->tstate);
+
     argv = Py_BuildValue("(Oii)", (PyObject *)conn, where, _ret);
-    if (conn->tstate != NULL)
-    {
-        /* We need to get back our thread state before calling the callback */
-        MY_END_ALLOW_THREADS(conn->tstate);
-        ret = PyEval_CallObject(conn->context->info_callback, argv);
-        if (ret == NULL)
-            PyErr_Clear();
-        else
-            Py_DECREF(ret);
-        MY_BEGIN_ALLOW_THREADS(conn->tstate);
-    }
-    else
-    {
-        ret = PyEval_CallObject(conn->context->info_callback, argv);
-        if (ret == NULL)
-            PyErr_Clear();
-        else
-            Py_DECREF(ret);
-    }
+    ret = PyEval_CallObject(conn->context->info_callback, argv);
     Py_DECREF(argv);
 
+    if (ret == NULL) {
+        /*
+         * XXX - This should be reported somehow. -exarkun
+         */
+        PyErr_Clear();
+    } else {
+        Py_DECREF(ret);
+    }
+
+    /*
+     * This function is returning into OpenSSL.  Release the GIL again.
+     */
+    MY_BEGIN_ALLOW_THREADS(conn->tstate);
     return;
 }
 
diff --git a/test/test_crypto.py b/test/test_crypto.py
index 7d46f91..318f470 100644
--- a/test/test_crypto.py
+++ b/test/test_crypto.py
@@ -31,6 +31,25 @@
     "-----END RSA PRIVATE KEY-----\n")
 
 
+cleartextCertificatePEM = (
+    "-----BEGIN CERTIFICATE-----\n"
+    "MIICfTCCAeYCAQEwDQYJKoZIhvcNAQEEBQAwgYYxCzAJBgNVBAYTAlVTMRkwFwYD\n"
+    "VQQDExBweW9wZW5zc2wuc2YubmV0MREwDwYDVQQHEwhOZXcgWW9yazESMBAGA1UE\n"
+    "ChMJUHlPcGVuU1NMMREwDwYDVQQIEwhOZXcgWW9yazEQMA4GCSqGSIb3DQEJARYB\n"
+    "IDEQMA4GA1UECxMHVGVzdGluZzAeFw0wODAzMjUxOTA0MTNaFw0wOTAzMjUxOTA0\n"
+    "MTNaMIGGMQswCQYDVQQGEwJVUzEZMBcGA1UEAxMQcHlvcGVuc3NsLnNmLm5ldDER\n"
+    "MA8GA1UEBxMITmV3IFlvcmsxEjAQBgNVBAoTCVB5T3BlblNTTDERMA8GA1UECBMI\n"
+    "TmV3IFlvcmsxEDAOBgkqhkiG9w0BCQEWASAxEDAOBgNVBAsTB1Rlc3RpbmcwgZ8w\n"
+    "DQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANp6Y17WzKSwBsUWkXdqg6tnXy8H8hA1\n"
+    "msCMWpc+/2KJ4mbv5NyD6UD+/SqagQqulPbF/DFea9nAE0zhmHJELcM8gUTIlXv/\n"
+    "cgDWnmK4xj8YkjVUiCdqKRAKeuzLG1pGmwwF5lGeJpXNxQn5ecR0UYSOWj6TTGXB\n"
+    "9VyUMQzCClcBAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAmm0Vzvv1O91WLl2LnF2P\n"
+    "q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5\n"
+    "RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q\n"
+    "UBj849/xpszEM7BhwKE0GiQ=\n"
+    "-----END CERTIFICATE-----\n")
+
+
 encryptedPrivateKeyPEM = (
     "-----BEGIN RSA PRIVATE KEY-----\n"
     "Proc-Type: 4,ENCRYPTED\n"
@@ -53,7 +72,6 @@
 encryptedPrivateKeyPEMPassphrase = "foobar"
 
 
-
 class _Python23TestCaseHelper:
     # Python 2.3 compatibility.
     def assertTrue(self, *a, **kw):
@@ -424,24 +442,7 @@
     """
     Tests for L{OpenSSL.crypto.X509}.
     """
-    pemData = """
------BEGIN CERTIFICATE-----
-MIICfTCCAeYCAQEwDQYJKoZIhvcNAQEEBQAwgYYxCzAJBgNVBAYTAlVTMRkwFwYD
-VQQDExBweW9wZW5zc2wuc2YubmV0MREwDwYDVQQHEwhOZXcgWW9yazESMBAGA1UE
-ChMJUHlPcGVuU1NMMREwDwYDVQQIEwhOZXcgWW9yazEQMA4GCSqGSIb3DQEJARYB
-IDEQMA4GA1UECxMHVGVzdGluZzAeFw0wODAzMjUxOTA0MTNaFw0wOTAzMjUxOTA0
-MTNaMIGGMQswCQYDVQQGEwJVUzEZMBcGA1UEAxMQcHlvcGVuc3NsLnNmLm5ldDER
-MA8GA1UEBxMITmV3IFlvcmsxEjAQBgNVBAoTCVB5T3BlblNTTDERMA8GA1UECBMI
-TmV3IFlvcmsxEDAOBgkqhkiG9w0BCQEWASAxEDAOBgNVBAsTB1Rlc3RpbmcwgZ8w
-DQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANp6Y17WzKSwBsUWkXdqg6tnXy8H8hA1
-msCMWpc+/2KJ4mbv5NyD6UD+/SqagQqulPbF/DFea9nAE0zhmHJELcM8gUTIlXv/
-cgDWnmK4xj8YkjVUiCdqKRAKeuzLG1pGmwwF5lGeJpXNxQn5ecR0UYSOWj6TTGXB
-9VyUMQzCClcBAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAmm0Vzvv1O91WLl2LnF2P
-q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5
-RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q
-UBj849/xpszEM7BhwKE0GiQ=
------END CERTIFICATE-----
-""" + cleartextPrivateKeyPEM
+    pemData = cleartextCertificatePEM + cleartextPrivateKeyPEM
 
     def signable(self):
         """
diff --git a/test/test_ssl.py b/test/test_ssl.py
index cb3d1b3..33264e5 100644
--- a/test/test_ssl.py
+++ b/test/test_ssl.py
@@ -6,11 +6,14 @@
 
 from unittest import TestCase
 from tempfile import mktemp
+from socket import socket
 
-from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey
-from OpenSSL.SSL import Context
+from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey, load_certificate, load_privatekey
+from OpenSSL.SSL import WantReadError, Context, Connection
 from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
 
+from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM
+
 
 class ContextTests(TestCase):
     """
@@ -69,3 +72,46 @@
         self.assertTrue(isinstance(calledWith[0][0], int))
         self.assertTrue(isinstance(calledWith[0][1], int))
         self.assertEqual(calledWith[0][2], None)
+
+
+    def test_set_info_callback(self):
+        """
+        L{Context.set_info_callback} accepts a callable which will be invoked
+        when certain information about an SSL connection is available.
+        """
+        port = socket()
+        port.bind(('', 0))
+        port.listen(1)
+
+        client = socket()
+        client.setblocking(False)
+        client.connect_ex(port.getsockname())
+
+        clientSSL = Connection(Context(TLSv1_METHOD), client)
+        clientSSL.set_connect_state()
+
+        called = []
+        def info(conn, where, ret):
+            called.append((conn, where, ret))
+        context = Context(TLSv1_METHOD)
+        context.set_info_callback(info)
+        context.use_certificate(
+            load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
+        context.use_privatekey(
+            load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
+
+        server, ignored = port.accept()
+        server.setblocking(False)
+
+        serverSSL = Connection(context, server)
+        serverSSL.set_accept_state()
+
+        while not called:
+            for ssl in clientSSL, serverSSL:
+                try:
+                    ssl.do_handshake()
+                except WantReadError:
+                    pass
+
+        # Kind of lame.  Just make sure it got called somehow.
+        self.assertTrue(called)