Fix a threading bug in the info callback support for context objects.
Also add some tests for Context.set_info_callback.
diff --git a/ChangeLog b/ChangeLog
index f5f89ff..2d0f388 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,11 +1,15 @@
2008-04-26 Jean-Paul Calderone <exarkun@twistedmatrix.com>
- * src/ssl/context.c: Change global_passphrase_callback so that it
- acquires the GIL before invoking any CPython APIs and does not
- release it until after it is finished invoking all of them.
+ * src/ssl/context.c: Change global_passphrase_callback and
+ global_info_callback so that they acquire the GIL before
+ invoking any CPython APIs and do not release it until after they
+ are finished invoking all of them (based heavily on on patch
+ from Dan Williams).
* test/test_crypto.py: Add tests for load_privatekey and
dump_privatekey when a passphrase or a passphrase callback is
supplied.
+ * test/test_ssl.py: Add tests for Context.set_passwd_cb and
+ Context.set_info_callback.
2008-04-11 Jean-Paul Calderone <exarkun@twistedmatrix.com>
diff --git a/leakcheck/context-info-callback.py b/leakcheck/context-info-callback.py
new file mode 100644
index 0000000..d4c9fa5
--- /dev/null
+++ b/leakcheck/context-info-callback.py
@@ -0,0 +1,96 @@
+# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
+#
+# Stress tester for thread-related bugs in global_info_callback in
+# src/ssl/context.c. In 0.7 and earlier, this will somewhat reliably
+# segfault or abort after a few dozen to a few thousand iterations on an SMP
+# machine (generally not on a UP machine) due to uses of Python/C API
+# without holding the GIL.
+
+from itertools import count
+from threading import Thread
+from socket import socket
+
+from OpenSSL.SSL import Context, TLSv1_METHOD, Connection, WantReadError
+from OpenSSL.crypto import FILETYPE_PEM, load_certificate, load_privatekey
+
+cleartextPrivateKeyPEM = (
+ "-----BEGIN RSA PRIVATE KEY-----\n"
+ "MIICXAIBAAKBgQDaemNe1syksAbFFpF3aoOrZ18vB/IQNZrAjFqXPv9iieJm7+Tc\n"
+ "g+lA/v0qmoEKrpT2xfwxXmvZwBNM4ZhyRC3DPIFEyJV7/3IA1p5iuMY/GJI1VIgn\n"
+ "aikQCnrsyxtaRpsMBeZRniaVzcUJ+XnEdFGEjlo+k0xlwfVclDEMwgpXAQIDAQAB\n"
+ "AoGBALi0a7pMQqqgnriVAdpBVJveQtxSDVWi2/gZMKVZfzNheuSnv4amhtaKPKJ+\n"
+ "CMZtHkcazsE2IFvxRN/kgato9H3gJqq8nq2CkdpdLNVKBoxiCtkLfutdY4SQLtoY\n"
+ "USN7exk131pchsAJXYlR6mCW+ZP+E523cNwpPgsyKxVbmXSBAkEA9470fy2W0jFM\n"
+ "taZFslpntKSzbvn6JmdtjtvWrM1bBaeeqFiGBuQFYg46VaCUaeRWYw02jmYAsDYh\n"
+ "ZQavmXThaQJBAOHtlAQ0IJJEiMZr6vtVPH32fmbthSv1AUSYPzKqdlQrUnOXPQXu\n"
+ "z70cFoLG1TvPF5rBxbOkbQ/s8/ka5ZjPfdkCQCeC7YsO36+UpsWnUCBzRXITh4AC\n"
+ "7eYLQ/U1KUJTVF/GrQ/5cQrQgftwgecAxi9Qfmk4xqhbp2h4e0QAmS5I9WECQH02\n"
+ "0QwrX8nxFeTytr8pFGezj4a4KVCdb2B3CL+p3f70K7RIo9d/7b6frJI6ZL/LHQf2\n"
+ "UP4pKRDkgKsVDx7MELECQGm072/Z7vmb03h/uE95IYJOgY4nfmYs0QKA9Is18wUz\n"
+ "DpjfE33p0Ha6GO1VZRIQoqE24F8o5oimy3BEjryFuw4=\n"
+ "-----END RSA PRIVATE KEY-----\n")
+
+
+cleartextCertificatePEM = (
+ "-----BEGIN CERTIFICATE-----\n"
+ "MIICfTCCAeYCAQEwDQYJKoZIhvcNAQEEBQAwgYYxCzAJBgNVBAYTAlVTMRkwFwYD\n"
+ "VQQDExBweW9wZW5zc2wuc2YubmV0MREwDwYDVQQHEwhOZXcgWW9yazESMBAGA1UE\n"
+ "ChMJUHlPcGVuU1NMMREwDwYDVQQIEwhOZXcgWW9yazEQMA4GCSqGSIb3DQEJARYB\n"
+ "IDEQMA4GA1UECxMHVGVzdGluZzAeFw0wODAzMjUxOTA0MTNaFw0wOTAzMjUxOTA0\n"
+ "MTNaMIGGMQswCQYDVQQGEwJVUzEZMBcGA1UEAxMQcHlvcGVuc3NsLnNmLm5ldDER\n"
+ "MA8GA1UEBxMITmV3IFlvcmsxEjAQBgNVBAoTCVB5T3BlblNTTDERMA8GA1UECBMI\n"
+ "TmV3IFlvcmsxEDAOBgkqhkiG9w0BCQEWASAxEDAOBgNVBAsTB1Rlc3RpbmcwgZ8w\n"
+ "DQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANp6Y17WzKSwBsUWkXdqg6tnXy8H8hA1\n"
+ "msCMWpc+/2KJ4mbv5NyD6UD+/SqagQqulPbF/DFea9nAE0zhmHJELcM8gUTIlXv/\n"
+ "cgDWnmK4xj8YkjVUiCdqKRAKeuzLG1pGmwwF5lGeJpXNxQn5ecR0UYSOWj6TTGXB\n"
+ "9VyUMQzCClcBAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAmm0Vzvv1O91WLl2LnF2P\n"
+ "q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5\n"
+ "RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q\n"
+ "UBj849/xpszEM7BhwKE0GiQ=\n"
+ "-----END CERTIFICATE-----\n")
+
+count = count()
+def go():
+ port = socket()
+ port.bind(('', 0))
+ port.listen(1)
+
+ called = []
+ def info(conn, where, ret):
+ print count.next()
+ called.append(None)
+ context = Context(TLSv1_METHOD)
+ context.set_info_callback(info)
+ context.use_certificate(
+ load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
+ context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
+
+ while 1:
+ client = socket()
+ client.setblocking(False)
+ client.connect_ex(port.getsockname())
+
+ clientSSL = Connection(Context(TLSv1_METHOD), client)
+ clientSSL.set_connect_state()
+
+ server, ignored = port.accept()
+ server.setblocking(False)
+
+ serverSSL = Connection(context, server)
+ serverSSL.set_accept_state()
+
+ del called[:]
+ while not called:
+ for ssl in clientSSL, serverSSL:
+ try:
+ ssl.do_handshake()
+ except WantReadError:
+ pass
+
+
+threads = [Thread(target=go, args=()) for i in xrange(2)]
+for th in threads:
+ th.start()
+for th in threads:
+ th.join()
diff --git a/src/ssl/context.c b/src/ssl/context.c
index 1f9d512..6497304 100644
--- a/src/ssl/context.c
+++ b/src/ssl/context.c
@@ -75,9 +75,6 @@
PyObject *argv, *ret = NULL;
ssl_ContextObj *ctx = (ssl_ContextObj *)arg;
- /* The Python callback is called with a (maxlen,verify,userdata) tuple */
- argv = Py_BuildValue("(iiO)", maxlen, verify, ctx->passphrase_userdata);
-
/*
* GIL isn't held yet. First things first - acquire it, or any Python API
* we invoke might segfault or blow up the sun. The reverse will be done
@@ -85,6 +82,9 @@
*/
MY_END_ALLOW_THREADS(ctx->tstate);
+ /* The Python callback is called with a (maxlen,verify,userdata) tuple */
+ argv = Py_BuildValue("(iiO)", maxlen, verify, ctx->passphrase_userdata);
+
/*
* XXX Didn't check argv to see if it was NULL. -exarkun
*/
@@ -157,19 +157,19 @@
ssl_ConnectionObj *conn;
crypto_X509Obj *cert;
int errnum, errdepth, c_ret, use_thread_state;
-
+
// Get Connection object to check thread state
ssl = (SSL *)X509_STORE_CTX_get_app_data(x509_ctx);
conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl);
-
+
use_thread_state = conn->tstate != NULL;
if (use_thread_state)
MY_END_ALLOW_THREADS(conn->tstate);
-
+
cert = crypto_X509_New(X509_STORE_CTX_get_current_cert(x509_ctx), 0);
errnum = X509_STORE_CTX_get_error(x509_ctx);
errdepth = X509_STORE_CTX_get_error_depth(x509_ctx);
-
+
argv = Py_BuildValue("(OOiii)", (PyObject *)conn, (PyObject *)cert,
errnum, errdepth, ok);
Py_DECREF(cert);
@@ -190,7 +190,9 @@
}
/*
- * Globally defined info callback
+ * Globally defined info callback. This is called from OpenSSL internally.
+ * The GIL will not be held when this function is invoked. It must not be held
+ * when the function returns.
*
* Arguments: ssl - The Connection
* where - The part of the SSL code that called us
@@ -203,28 +205,30 @@
ssl_ConnectionObj *conn = (ssl_ConnectionObj *)SSL_get_app_data(ssl);
PyObject *argv, *ret;
+ /*
+ * GIL isn't held yet. First things first - acquire it, or any Python API
+ * we invoke might segfault or blow up the sun. The reverse will be done
+ * before returning.
+ */
+ MY_END_ALLOW_THREADS(conn->tstate);
+
argv = Py_BuildValue("(Oii)", (PyObject *)conn, where, _ret);
- if (conn->tstate != NULL)
- {
- /* We need to get back our thread state before calling the callback */
- MY_END_ALLOW_THREADS(conn->tstate);
- ret = PyEval_CallObject(conn->context->info_callback, argv);
- if (ret == NULL)
- PyErr_Clear();
- else
- Py_DECREF(ret);
- MY_BEGIN_ALLOW_THREADS(conn->tstate);
- }
- else
- {
- ret = PyEval_CallObject(conn->context->info_callback, argv);
- if (ret == NULL)
- PyErr_Clear();
- else
- Py_DECREF(ret);
- }
+ ret = PyEval_CallObject(conn->context->info_callback, argv);
Py_DECREF(argv);
+ if (ret == NULL) {
+ /*
+ * XXX - This should be reported somehow. -exarkun
+ */
+ PyErr_Clear();
+ } else {
+ Py_DECREF(ret);
+ }
+
+ /*
+ * This function is returning into OpenSSL. Release the GIL again.
+ */
+ MY_BEGIN_ALLOW_THREADS(conn->tstate);
return;
}
diff --git a/test/test_crypto.py b/test/test_crypto.py
index 7d46f91..318f470 100644
--- a/test/test_crypto.py
+++ b/test/test_crypto.py
@@ -31,6 +31,25 @@
"-----END RSA PRIVATE KEY-----\n")
+cleartextCertificatePEM = (
+ "-----BEGIN CERTIFICATE-----\n"
+ "MIICfTCCAeYCAQEwDQYJKoZIhvcNAQEEBQAwgYYxCzAJBgNVBAYTAlVTMRkwFwYD\n"
+ "VQQDExBweW9wZW5zc2wuc2YubmV0MREwDwYDVQQHEwhOZXcgWW9yazESMBAGA1UE\n"
+ "ChMJUHlPcGVuU1NMMREwDwYDVQQIEwhOZXcgWW9yazEQMA4GCSqGSIb3DQEJARYB\n"
+ "IDEQMA4GA1UECxMHVGVzdGluZzAeFw0wODAzMjUxOTA0MTNaFw0wOTAzMjUxOTA0\n"
+ "MTNaMIGGMQswCQYDVQQGEwJVUzEZMBcGA1UEAxMQcHlvcGVuc3NsLnNmLm5ldDER\n"
+ "MA8GA1UEBxMITmV3IFlvcmsxEjAQBgNVBAoTCVB5T3BlblNTTDERMA8GA1UECBMI\n"
+ "TmV3IFlvcmsxEDAOBgkqhkiG9w0BCQEWASAxEDAOBgNVBAsTB1Rlc3RpbmcwgZ8w\n"
+ "DQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANp6Y17WzKSwBsUWkXdqg6tnXy8H8hA1\n"
+ "msCMWpc+/2KJ4mbv5NyD6UD+/SqagQqulPbF/DFea9nAE0zhmHJELcM8gUTIlXv/\n"
+ "cgDWnmK4xj8YkjVUiCdqKRAKeuzLG1pGmwwF5lGeJpXNxQn5ecR0UYSOWj6TTGXB\n"
+ "9VyUMQzCClcBAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAmm0Vzvv1O91WLl2LnF2P\n"
+ "q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5\n"
+ "RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q\n"
+ "UBj849/xpszEM7BhwKE0GiQ=\n"
+ "-----END CERTIFICATE-----\n")
+
+
encryptedPrivateKeyPEM = (
"-----BEGIN RSA PRIVATE KEY-----\n"
"Proc-Type: 4,ENCRYPTED\n"
@@ -53,7 +72,6 @@
encryptedPrivateKeyPEMPassphrase = "foobar"
-
class _Python23TestCaseHelper:
# Python 2.3 compatibility.
def assertTrue(self, *a, **kw):
@@ -424,24 +442,7 @@
"""
Tests for L{OpenSSL.crypto.X509}.
"""
- pemData = """
------BEGIN CERTIFICATE-----
-MIICfTCCAeYCAQEwDQYJKoZIhvcNAQEEBQAwgYYxCzAJBgNVBAYTAlVTMRkwFwYD
-VQQDExBweW9wZW5zc2wuc2YubmV0MREwDwYDVQQHEwhOZXcgWW9yazESMBAGA1UE
-ChMJUHlPcGVuU1NMMREwDwYDVQQIEwhOZXcgWW9yazEQMA4GCSqGSIb3DQEJARYB
-IDEQMA4GA1UECxMHVGVzdGluZzAeFw0wODAzMjUxOTA0MTNaFw0wOTAzMjUxOTA0
-MTNaMIGGMQswCQYDVQQGEwJVUzEZMBcGA1UEAxMQcHlvcGVuc3NsLnNmLm5ldDER
-MA8GA1UEBxMITmV3IFlvcmsxEjAQBgNVBAoTCVB5T3BlblNTTDERMA8GA1UECBMI
-TmV3IFlvcmsxEDAOBgkqhkiG9w0BCQEWASAxEDAOBgNVBAsTB1Rlc3RpbmcwgZ8w
-DQYJKoZIhvcNAQEBBQADgY0AMIGJAoGBANp6Y17WzKSwBsUWkXdqg6tnXy8H8hA1
-msCMWpc+/2KJ4mbv5NyD6UD+/SqagQqulPbF/DFea9nAE0zhmHJELcM8gUTIlXv/
-cgDWnmK4xj8YkjVUiCdqKRAKeuzLG1pGmwwF5lGeJpXNxQn5ecR0UYSOWj6TTGXB
-9VyUMQzCClcBAgMBAAEwDQYJKoZIhvcNAQEEBQADgYEAmm0Vzvv1O91WLl2LnF2P
-q55LJdOnJbCCXIgxLdoVmvYAz1ZJq1eGKgKWI5QLgxiSzJLEU7KK//aVfiZzoCd5
-RipBiEEMEV4eAY317bHPwPP+4Bj9t0l8AsDLseC5vLRHgxrLEu3bn08DYx6imB5Q
-UBj849/xpszEM7BhwKE0GiQ=
------END CERTIFICATE-----
-""" + cleartextPrivateKeyPEM
+ pemData = cleartextCertificatePEM + cleartextPrivateKeyPEM
def signable(self):
"""
diff --git a/test/test_ssl.py b/test/test_ssl.py
index cb3d1b3..33264e5 100644
--- a/test/test_ssl.py
+++ b/test/test_ssl.py
@@ -6,11 +6,14 @@
from unittest import TestCase
from tempfile import mktemp
+from socket import socket
-from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey
-from OpenSSL.SSL import Context
+from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey, load_certificate, load_privatekey
+from OpenSSL.SSL import WantReadError, Context, Connection
from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
+from OpenSSL.test.test_crypto import cleartextCertificatePEM, cleartextPrivateKeyPEM
+
class ContextTests(TestCase):
"""
@@ -69,3 +72,46 @@
self.assertTrue(isinstance(calledWith[0][0], int))
self.assertTrue(isinstance(calledWith[0][1], int))
self.assertEqual(calledWith[0][2], None)
+
+
+ def test_set_info_callback(self):
+ """
+ L{Context.set_info_callback} accepts a callable which will be invoked
+ when certain information about an SSL connection is available.
+ """
+ port = socket()
+ port.bind(('', 0))
+ port.listen(1)
+
+ client = socket()
+ client.setblocking(False)
+ client.connect_ex(port.getsockname())
+
+ clientSSL = Connection(Context(TLSv1_METHOD), client)
+ clientSSL.set_connect_state()
+
+ called = []
+ def info(conn, where, ret):
+ called.append((conn, where, ret))
+ context = Context(TLSv1_METHOD)
+ context.set_info_callback(info)
+ context.use_certificate(
+ load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
+ context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
+
+ server, ignored = port.accept()
+ server.setblocking(False)
+
+ serverSSL = Connection(context, server)
+ serverSSL.set_accept_state()
+
+ while not called:
+ for ssl in clientSSL, serverSSL:
+ try:
+ ssl.do_handshake()
+ except WantReadError:
+ pass
+
+ # Kind of lame. Just make sure it got called somehow.
+ self.assertTrue(called)