Add SSL_set_SSL_CTX wrapper, Connection.set_context.
diff --git a/OpenSSL/ssl/connection.c b/OpenSSL/ssl/connection.c
index a3ec0f0..a8dfa58 100755
--- a/OpenSSL/ssl/connection.c
+++ b/OpenSSL/ssl/connection.c
@@ -263,6 +263,45 @@
return (PyObject *)self->context;
}
+static char ssl_Connection_set_context_doc[] = "\n\
+Switch this connection to a new session context\n\
+\n\
+@param context: A L{Context} instance giving the new session context to use.\n\
+\n\
+";
+static PyObject *
+ssl_Connection_set_context(ssl_ConnectionObj *self, PyObject *args) {
+ ssl_ContextObj *ctx;
+ ssl_ContextObj *old;
+
+ if (!PyArg_ParseTuple(args, "O!:set_context", &ssl_Context_Type, &ctx)) {
+ return NULL;
+ }
+
+ /* This Connection will hold on to this context now. Make sure it stays
+ * alive.
+ */
+ Py_INCREF(ctx);
+
+ /* XXX The unit tests don't actually verify that this call is made.
+ * They're satisfied if self->context gets updated.
+ */
+ SSL_set_SSL_CTX(self->ssl, ctx->ctx);
+
+ /* Swap the old out and the new in.
+ */
+ old = self->context;
+ self->context = ctx;
+
+ /* XXX The unit tests don't verify that this reference is dropped.
+ */
+ Py_DECREF(old);
+
+ Py_INCREF(Py_None);
+ return Py_None;
+}
+
+
static char ssl_Connection_pending_doc[] = "\n\
Get the number of bytes that can be safely read from the connection\n\
\n\
@@ -1181,6 +1220,7 @@
static PyMethodDef ssl_Connection_methods[] =
{
ADD_METHOD(get_context),
+ ADD_METHOD(set_context),
ADD_METHOD(pending),
ADD_METHOD(send),
ADD_ALIAS (write, send),
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 2761cec..24a08b0 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -924,6 +924,41 @@
self.assertRaises(TypeError, connection.get_context, None)
+ def test_set_context_wrong_args(self):
+ """
+ L{Connection.set_context} raises L{TypeError} if called with a
+ non-L{Context} instance argument or with any number of arguments other
+ than 1.
+ """
+ ctx = Context(TLSv1_METHOD)
+ connection = Connection(ctx, None)
+ self.assertRaises(TypeError, connection.set_context)
+ self.assertRaises(TypeError, connection.set_context, object())
+ self.assertRaises(TypeError, connection.set_context, "hello")
+ self.assertRaises(TypeError, connection.set_context, 1)
+ self.assertRaises(TypeError, connection.set_context, 1, 2)
+ self.assertRaises(
+ TypeError, connection.set_context, Context(TLSv1_METHOD), 2)
+ self.assertIdentical(ctx, connection.get_context())
+
+
+ def test_set_context(self):
+ """
+ L{Connection.set_context} specifies a new L{Context} instance to be used
+ for the connection.
+ """
+ original = Context(SSLv23_METHOD)
+ replacement = Context(TLSv1_METHOD)
+ connection = Connection(original, None)
+ connection.set_context(replacement)
+ self.assertIdentical(replacement, connection.get_context())
+ # Lose our references to the contexts, just in case the Connection isn't
+ # properly managing its own contributions to their reference counts.
+ del original, replacement
+ import gc
+ gc.collect()
+
+
def test_pending(self):
"""
L{Connection.pending} returns the number of bytes available for