Use TLS for the PyThreadState pointer instead of a field on Connection; this allows a Connection to be used safely in multiple threads concurrently
diff --git a/leakcheck/dhparam.pem b/leakcheck/dhparam.pem
new file mode 100644
index 0000000..9d33a4a
--- /dev/null
+++ b/leakcheck/dhparam.pem
@@ -0,0 +1,4 @@
+-----BEGIN DH PARAMETERS-----
+MEYCQQDM2LbvAjF5ahXHOUdDR09Vw/7kxjF/euWhNKBqUQQYT7FDSAMCCMq+Jhno
+BKxWEDhlxR1Q1VZ4H/NVTAGtWai7AgEC
+-----END DH PARAMETERS-----
diff --git a/leakcheck/thread-crash.py b/leakcheck/thread-crash.py
new file mode 100644
index 0000000..26048a5
--- /dev/null
+++ b/leakcheck/thread-crash.py
@@ -0,0 +1,70 @@
+# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
+#
+# Stress tester for thread-related bugs in ssl_Connection_send and
+# ssl_Connection_recv in src/ssl/connection.c for usage of a single
+# Connection object simultaneously in multiple threads. In 0.7 and earlier,
+# this will somewhat reliably cause Python to abort with a "tstate mix-up"
+# almost immediately, due to the incorrect sharing between threads of the
+# `tstate` field of the connection object.
+
+
+from socket import socket
+from threading import Thread
+
+from OpenSSL.SSL import Connection, Context, TLSv1_METHOD
+
+def send(conn):
+ while 1:
+ for i in xrange(1024 * 32):
+ conn.send('x')
+ print 'Sent 32KB on', hex(id(conn))
+
+
+def recv(conn):
+ while 1:
+ for i in xrange(1024 * 64):
+ conn.recv(1)
+ print 'Received 64KB on', hex(id(conn))
+
+
+def main():
+ port = socket()
+ port.bind(('', 0))
+ port.listen(5)
+
+ client = socket()
+ client.setblocking(False)
+ client.connect_ex(port.getsockname())
+ client.setblocking(True)
+
+ server = port.accept()[0]
+
+ clientCtx = Context(TLSv1_METHOD)
+ clientCtx.set_cipher_list('ALL:ADH')
+ clientCtx.load_tmp_dh('dhparam.pem')
+
+ sslClient = Connection(clientCtx, client)
+ sslClient.set_connect_state()
+
+ serverCtx = Context(TLSv1_METHOD)
+ serverCtx.set_cipher_list('ALL:ADH')
+ serverCtx.load_tmp_dh('dhparam.pem')
+
+ sslServer = Connection(serverCtx, server)
+ sslServer.set_accept_state()
+
+ t1 = Thread(target=send, args=(sslClient,))
+ t2 = Thread(target=send, args=(sslServer,))
+ t3 = Thread(target=recv, args=(sslClient,))
+ t4 = Thread(target=recv, args=(sslServer,))
+
+ t1.start()
+ t2.start()
+ t3.start()
+ t4.start()
+ t1.join()
+ t2.join()
+ t3.join()
+ t4.join()
+
+main()
diff --git a/src/ssl/connection.h b/src/ssl/connection.h
index dedb73e..13f42f0 100644
--- a/src/ssl/connection.h
+++ b/src/ssl/connection.h
@@ -42,7 +42,7 @@
SSL *ssl;
ssl_ContextObj *context;
PyObject *socket;
- PyThreadState *tstate;
+ PyThreadState *tstate; /* This field is no longer used. */
PyObject *app_data;
} ssl_ConnectionObj;
diff --git a/src/ssl/ssl.c b/src/ssl/ssl.c
index 7f58771..1f8cbcc 100644
--- a/src/ssl/ssl.c
+++ b/src/ssl/ssl.c
@@ -32,6 +32,8 @@
void **crypto_API;
+int _pyOpenSSL_tstate_key;
+
/* Exceptions defined by the SSL submodule */
PyObject *ssl_Error, /* Base class */
*ssl_ZeroReturnError, /* Used with SSL_get_error */
@@ -201,6 +203,13 @@
if (!init_ssl_connection(dict))
goto error;
-error:
+#ifdef WITH_THREAD
+ /*
+ * Initialize this module's threading support structures.
+ */
+ _pyOpenSSL_tstate_key = PyThread_create_key();
+#endif
+
+ error:
;
}
diff --git a/src/ssl/ssl.h b/src/ssl/ssl.h
index e8d3e93..9cf0186 100644
--- a/src/ssl/ssl.h
+++ b/src/ssl/ssl.h
@@ -14,6 +14,7 @@
#define PyOpenSSL_SSL_H_
#include <Python.h>
+#include <pythread.h>
#include "context.h"
#include "connection.h"
#include "../util.h"
@@ -45,6 +46,10 @@
#define ssl_API_pointers 2
+#ifdef WITH_THREAD
+extern int _pyOpenSSL_tstate_key;
+#endif /* WITH_THREAD */
+
#ifdef SSL_MODULE
extern ssl_Context_New_RETURN ssl_Context_New ssl_Context_New_PROTO;
diff --git a/src/util.h b/src/util.h
index 592660e..b95e75b 100644
--- a/src/util.h
+++ b/src/util.h
@@ -30,10 +30,21 @@
* WHERE to save the thread state.
*/
#ifdef WITH_THREAD
-# define MY_BEGIN_ALLOW_THREADS(st) \
- { st = PyEval_SaveThread(); }
-# define MY_END_ALLOW_THREADS(st) \
- { PyEval_RestoreThread(st); st = NULL; }
+
+/*
+ * Get the current Python threadstate and put it somewhere any code running
+ * in this thread can get it, if it needs to restore the threadstate to run
+ * some Python.
+ */
+# define MY_BEGIN_ALLOW_THREADS(ignored) \
+ PyThread_set_key_value(_pyOpenSSL_tstate_key, PyEval_SaveThread());
+
+/*
+ * Get the previous Python threadstate and restore it.
+ */
+# define MY_END_ALLOW_THREADS(ignored) \
+ PyEval_RestoreThread(PyThread_get_key_value(_pyOpenSSL_tstate_key));
+
#else
# define MY_BEGIN_ALLOW_THREADS(st)
# define MY_END_ALLOW_THREADS(st) { st = NULL; }