Add a test that verifies the first argument passed to the verify callback. And make it pass.
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index fbb18f0..40fe5a6 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -161,7 +161,7 @@
class _VerifyHelper(object):
- def __init__(self, connection, callback):
+ def __init__(self, callback):
self._problems = []
@wraps(callback)
@@ -171,6 +171,10 @@
error_number = _lib.X509_STORE_CTX_get_error(store_ctx)
error_depth = _lib.X509_STORE_CTX_get_error_depth(store_ctx)
+ index = _lib.SSL_get_ex_data_X509_STORE_CTX_idx()
+ ssl = _lib.X509_STORE_CTX_get_ex_data(store_ctx, index)
+ connection = Connection._reverse_mapping[ssl]
+
try:
result = callback(connection, cert, error_number, error_depth, ok)
except Exception as e:
@@ -542,7 +546,7 @@
if not callable(callback):
raise TypeError("callback must be callable")
- self._verify_helper = _VerifyHelper(self, callback)
+ self._verify_helper = _VerifyHelper(callback)
self._verify_callback = self._verify_helper.callback
_lib.SSL_CTX_set_verify(self._context, mode, self._verify_callback)
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index bfe3114..369b1b6 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -277,6 +277,19 @@
write.bio_write(dirty)
+ def _handshakeInMemory(self, client_conn, server_conn):
+ client_conn.set_connect_state()
+ server_conn.set_accept_state()
+
+ for conn in [client_conn, server_conn]:
+ try:
+ conn.do_handshake()
+ except WantReadError:
+ pass
+
+ self._interactInMemory(client_conn, server_conn)
+
+
class VersionTests(TestCase):
"""
@@ -981,6 +994,34 @@
pass
+ def test_set_verify_callback_connection_argument(self):
+ """
+ The first argument passed to the verify callback is the
+ :py:class:`Connection` instance for which verification is taking place.
+ """
+ serverContext = Context(TLSv1_METHOD)
+ serverContext.use_privatekey(
+ load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
+ serverContext.use_certificate(
+ load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
+ serverConnection = Connection(serverContext, None)
+
+ class VerifyCallback(object):
+ def callback(self, connection, *args):
+ self.connection = connection
+ return 1
+
+ verify = VerifyCallback()
+ clientContext = Context(TLSv1_METHOD)
+ clientContext.set_verify(VERIFY_PEER, verify.callback)
+ clientConnection = Connection(clientContext, None)
+ clientConnection.set_connect_state()
+
+ self._handshakeInMemory(clientConnection, serverConnection)
+
+ self.assertIdentical(verify.connection, clientConnection)
+
+
def test_set_verify_callback_exception(self):
"""
If the verify callback passed to :py:obj:`Context.set_verify` raises an