A test for use_certificate_chain_file which fails just like add_extra_chain_cert
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 5eb18dd..99cd47b 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -288,6 +288,17 @@
return pemFile
+ def test_set_passwd_cb_wrong_args(self):
+ """
+ L{Context.set_passwd_cb} raises L{TypeError} if called with the
+ wrong arguments or with a non-callable first argument.
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.set_passwd_cb)
+ self.assertRaises(TypeError, context.set_passwd_cb, None)
+ self.assertRaises(TypeError, context.set_passwd_cb, lambda: None, None, None)
+
+
def test_set_passwd_cb(self):
"""
L{Context.set_passwd_cb} accepts a callable which will be invoked when
@@ -481,6 +492,18 @@
self._load_verify_locations_test(None, capath)
+ def test_load_verify_locations_wrong_args(self):
+ """
+ L{Context.load_verify_locations} raises L{TypeError} if called with
+ the wrong number of arguments or with non-C{str} arguments.
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.load_verify_locations)
+ self.assertRaises(TypeError, context.load_verify_locations, object())
+ self.assertRaises(TypeError, context.load_verify_locations, object(), object())
+ self.assertRaises(TypeError, context.load_verify_locations, None, None, None)
+
+
if platform == "win32":
"set_default_verify_paths appears not to work on Windows. "
"See LP#404343 and LP#404344."
@@ -524,6 +547,7 @@
self.assertRaises(TypeError, context.set_default_verify_paths, 1)
self.assertRaises(TypeError, context.set_default_verify_paths, "")
+
def test_add_extra_chain_cert_invalid_cert(self):
"""
L{Context.add_extra_chain_cert} raises L{TypeError} if called with
@@ -536,22 +560,14 @@
self.assertRaises(TypeError, context.add_extra_chain_cert, object(), object())
- def test_add_extra_chain_cert(self):
+ def _create_certificate_chain(self):
"""
- L{Context.add_extra_chain_cert} accepts an L{X509} instance to add to
- the certificate chain.
-
- Verify this by constructing::
+ Construct and return a chain of certificates.
1. A new self-signed certificate authority certificate (cacert)
2. A new intermediate certificate signed by cacert (icert)
3. A new server certificate signed by icert (scert)
-
- Then starting a server with scert and connecting to it with a client
- which trusts cacert and requires verification to succeed.
"""
- serverSocket, clientSocket = socket_pair()
-
caext = X509Extension('basicConstraints', True, 'CA:true')
# Step 1
@@ -613,16 +629,47 @@
# icert = icert.original
# skey = skey.original
# scert = scert.original
+ return [(cakey, cacert), (ikey, icert), (skey, scert)]
- # Create the server
- serverContext = Context(TLSv1_METHOD)
- serverContext.use_privatekey(skey)
- serverContext.use_certificate(scert)
- serverContext.add_extra_chain_cert(cacert)
- serverContext.add_extra_chain_cert(icert)
+
+ def _handshake_test(self, serverContext, clientContext):
+ """
+ Verify that a client and server created with the given contexts can
+ successfully handshake and communicate.
+ """
+ serverSocket, clientSocket = socket_pair()
+
server = Connection(serverContext, serverSocket)
server.set_accept_state()
+ client = Connection(clientContext, clientSocket)
+ client.set_connect_state()
+
+ # Make them talk to each other.
+ # self._interactInMemory(client, server)
+ for i in range(3):
+ for s in [client, server]:
+ try:
+ s.do_handshake()
+ except WantReadError:
+ pass
+
+
+ def test_add_extra_chain_cert(self):
+ """
+ L{Context.add_extra_chain_cert} accepts an L{X509} instance to add to
+ the certificate chain.
+
+ See L{_create_certificate_chain} for the details of the certificate
+ chain tested.
+
+ The chain is tested by starting a server with scert and connecting
+ to it with a client which trusts cacert and requires verification to
+ succeed.
+ """
+ chain = self._create_certificate_chain()
+ [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
+
# Dump the CA certificate to a file because that's the only way to load
# it as a trusted CA in the client context.
for cert, name in [(cacert, 'ca.pem'), (icert, 'i.pem'), (scert, 's.pem')]:
@@ -635,22 +682,61 @@
fObj.write(dump_privatekey(FILETYPE_PEM, key))
fObj.close()
+ # Create the server context
+ serverContext = Context(TLSv1_METHOD)
+ serverContext.use_privatekey(skey)
+ serverContext.use_certificate(scert)
+ serverContext.add_extra_chain_cert(cacert)
+ serverContext.add_extra_chain_cert(icert)
+
# Create the client
clientContext = Context(TLSv1_METHOD)
clientContext.set_verify(
VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
clientContext.load_verify_locations('ca.pem')
- client = Connection(clientContext, clientSocket)
- client.set_connect_state()
- # Make them talk to each other.
- # self._interactInMemory(client, server)
- for i in range(3):
- for s in [client, server]:
- try:
- s.do_handshake()
- except WantReadError:
- pass
+ # Try it out.
+ self._handshake_test(serverContext, clientContext)
+
+
+
+ def test_use_certificate_chain_file(self):
+ """
+ L{Context.use_certificate_chain_file} reads a certificate chain from
+ the specified file.
+
+ The chain is tested by starting a server with scert and connecting
+ to it with a client which trusts cacert and requires verification to
+ succeed.
+ """
+ chain = self._create_certificate_chain()
+ [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
+
+ # Write out the chain file.
+ chainFile = self.mktemp()
+ fObj = file(chainFile, 'w')
+ # Most specific to least general.
+ fObj.write(dump_certificate(FILETYPE_PEM, scert))
+ fObj.write(dump_certificate(FILETYPE_PEM, icert))
+ fObj.write(dump_certificate(FILETYPE_PEM, cacert))
+ fObj.close()
+
+ serverContext = Context(TLSv1_METHOD)
+ serverContext.use_certificate_chain_file(chainFile)
+ serverContext.use_privatekey(skey)
+
+ fObj = file('ca.pem', 'w')
+ fObj.write(dump_certificate(FILETYPE_PEM, cacert))
+ fObj.close()
+
+ clientContext = Context(TLSv1_METHOD)
+ clientContext.set_verify(
+ VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
+ clientContext.load_verify_locations('ca.pem')
+
+ self._handshake_test(serverContext, clientContext)
+
+