Try for a real test of Context.add_extra_chain_cert; not working
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 5ca8964..5eb18dd 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -4,6 +4,8 @@
Unit tests for L{OpenSSL.SSL}.
"""
+from twisted.internet.ssl import *
+
from errno import ECONNREFUSED, EINPROGRESS
from sys import platform
from socket import error, socket
@@ -11,7 +13,7 @@
from os.path import join
from unittest import main
-from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey, load_certificate, load_privatekey
+from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, X509, X509Extension, dump_privatekey, load_certificate, dump_certificate, load_privatekey, X509_verify_cert_error_string
from OpenSSL.SSL import SysCallError, WantReadError, WantWriteError, ZeroReturnError, Context, ContextType, Connection, ConnectionType, Error
from OpenSSL.SSL import SENT_SHUTDOWN, RECEIVED_SHUTDOWN
from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
@@ -34,6 +36,10 @@
OP_NO_TICKET = None
+def verify_cb(conn, cert, errnum, depth, ok):
+ print conn, cert, X509_verify_cert_error_string(errnum), depth, ok
+ return ok
+
def socket_pair():
"""
Establish and return a pair of network sockets connected to each other.
@@ -64,7 +70,79 @@
-class ContextTests(TestCase):
+class _LoopbackMixin:
+ def _loopback(self):
+ (server, client) = socket_pair()
+
+ ctx = Context(TLSv1_METHOD)
+ ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+ server = Connection(ctx, server)
+ server.set_accept_state()
+ client = Connection(Context(TLSv1_METHOD), client)
+ client.set_connect_state()
+
+ for i in range(3):
+ for conn in [client, server]:
+ try:
+ conn.do_handshake()
+ except WantReadError:
+ pass
+
+ server.setblocking(True)
+ client.setblocking(True)
+ return server, client
+
+
+ def _interactInMemory(self, client_conn, server_conn):
+ """
+ Try to read application bytes from each of the two L{Connection}
+ objects. Copy bytes back and forth between their send/receive buffers
+ for as long as there is anything to copy. When there is nothing more
+ to copy, return C{None}. If one of them actually manages to deliver
+ some application bytes, return a two-tuple of the connection from which
+ the bytes were read and the bytes themselves.
+ """
+ wrote = True
+ while wrote:
+ # Loop until neither side has anything to say
+ wrote = False
+
+ # Copy stuff from each side's send buffer to the other side's
+ # receive buffer.
+ for (read, write) in [(client_conn, server_conn),
+ (server_conn, client_conn)]:
+
+ # Give the side a chance to generate some more bytes, or
+ # succeed.
+ try:
+ bytes = read.recv(2 ** 16)
+ except WantReadError:
+ # It didn't succeed, so we'll hope it generated some
+ # output.
+ pass
+ else:
+ # It did succeed, so we'll stop now and let the caller deal
+ # with it.
+ return (read, bytes)
+
+ while True:
+ # Keep copying as long as there's more stuff there.
+ try:
+ dirty = read.bio_read(4096)
+ except WantReadError:
+ # Okay, nothing more waiting to be sent. Stop
+ # processing this send buffer.
+ break
+ else:
+ # Keep track of the fact that someone generated some
+ # output.
+ wrote = True
+ write.bio_write(dirty)
+
+
+
+class ContextTests(TestCase, _LoopbackMixin):
"""
Unit tests for L{OpenSSL.SSL.Context}.
"""
@@ -462,38 +540,118 @@
"""
L{Context.add_extra_chain_cert} accepts an L{X509} instance to add to
the certificate chain.
+
+ Verify this by constructing::
+
+ 1. A new self-signed certificate authority certificate (cacert)
+ 2. A new intermediate certificate signed by cacert (icert)
+ 3. A new server certificate signed by icert (scert)
+
+ Then starting a server with scert and connecting to it with a client
+ which trusts cacert and requires verification to succeed.
"""
- context = Context(TLSv1_METHOD)
- context.add_extra_chain_cert(load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
- # XXX Oh no, actually asserting something about its behavior would be really hard.
- # See #477521.
+ serverSocket, clientSocket = socket_pair()
+ caext = X509Extension('basicConstraints', True, 'CA:true')
+ # Step 1
+ cakey = PKey()
+ cakey.generate_key(TYPE_RSA, 512)
+ cacert = X509()
+ cacert.get_subject().commonName = "CA Certificate"
+ cacert.set_issuer(cacert.get_subject())
+ cacert.set_pubkey(cakey)
+ cacert.gmtime_adj_notBefore(0)
+ cacert.gmtime_adj_notAfter(6000)
+ cacert.add_extensions([caext])
+ cacert.sign(cakey, "sha")
+ # Step 2
+ ikey = PKey()
+ ikey.generate_key(TYPE_RSA, 512)
+ icert = X509()
+ icert.get_subject().commonName = "Intermediate Certificate"
+ icert.set_issuer(cacert.get_subject())
+ icert.set_pubkey(ikey)
+ icert.gmtime_adj_notBefore(0)
+ icert.gmtime_adj_notAfter(6000)
+ icert.add_extensions([caext])
+ icert.sign(cakey, "sha")
-class _LoopbackMixin:
- def _loopback(self):
- (server, client) = socket_pair()
+ # Step 3
+ skey = PKey()
+ skey.generate_key(TYPE_RSA, 512)
+ scert = X509()
+ scert.get_subject().commonName = "Server Certificate"
+ scert.set_issuer(icert.get_subject())
+ scert.set_pubkey(skey)
+ scert.gmtime_adj_notBefore(0)
+ scert.gmtime_adj_notAfter(6000)
+ scert.sign(ikey, "sha")
- ctx = Context(TLSv1_METHOD)
- ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
- ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
- server = Connection(ctx, server)
+ # # Step 1
+ # cakey = KeyPair.generate()
+ # cacert = cakey.selfSignedCert(1, commonName="CA Certificate")
+
+ # # Step 2
+ # ikey = KeyPair.generate()
+ # ireq = ikey.certificateRequest(DN(commonName="Intermediate Certificate"))
+ # icert = PrivateCertificate.load(
+ # cacert.signCertificateRequest(ireq, lambda dn: True, 1),
+ # ikey)
+
+ # # Step 3
+ # skey = KeyPair.generate()
+ # sreq = skey.certificateRequest(DN(commonName="Server Certificate"))
+ # scert = PrivateCertificate.load(
+ # icert.signCertificateRequest(sreq, lambda dn: True, 2),
+ # skey)
+
+ # cakey = cakey.original
+ # cacert = cacert.original
+ # ikey = ikey.original
+ # icert = icert.original
+ # skey = skey.original
+ # scert = scert.original
+
+ # Create the server
+ serverContext = Context(TLSv1_METHOD)
+ serverContext.use_privatekey(skey)
+ serverContext.use_certificate(scert)
+ serverContext.add_extra_chain_cert(cacert)
+ serverContext.add_extra_chain_cert(icert)
+ server = Connection(serverContext, serverSocket)
server.set_accept_state()
- client = Connection(Context(TLSv1_METHOD), client)
+
+ # Dump the CA certificate to a file because that's the only way to load
+ # it as a trusted CA in the client context.
+ for cert, name in [(cacert, 'ca.pem'), (icert, 'i.pem'), (scert, 's.pem')]:
+ fObj = file(name, 'w')
+ fObj.write(dump_certificate(FILETYPE_PEM, cert))
+ fObj.close()
+
+ for key, name in [(cakey, 'ca.key'), (ikey, 'i.key'), (skey, 's.key')]:
+ fObj = file(name, 'w')
+ fObj.write(dump_privatekey(FILETYPE_PEM, key))
+ fObj.close()
+
+ # Create the client
+ clientContext = Context(TLSv1_METHOD)
+ clientContext.set_verify(
+ VERIFY_PEER | VERIFY_FAIL_IF_NO_PEER_CERT, verify_cb)
+ clientContext.load_verify_locations('ca.pem')
+ client = Connection(clientContext, clientSocket)
client.set_connect_state()
+ # Make them talk to each other.
+ # self._interactInMemory(client, server)
for i in range(3):
- for conn in [client, server]:
+ for s in [client, server]:
try:
- conn.do_handshake()
+ s.do_handshake()
except WantReadError:
pass
- server.setblocking(True)
- client.setblocking(True)
- return server, client
-
class ConnectionTests(TestCase, _LoopbackMixin):
@@ -805,10 +963,7 @@
-def verify_cb(conn, cert, errnum, depth, ok):
- return ok
-
-class MemoryBIOTests(TestCase):
+class MemoryBIOTests(TestCase, _LoopbackMixin):
"""
Tests for L{OpenSSL.SSL.Connection} using a memory BIO.
"""
@@ -854,53 +1009,6 @@
return client_conn
- def _loopback(self, client_conn, server_conn):
- """
- Try to read application bytes from each of the two L{Connection}
- objects. Copy bytes back and forth between their send/receive buffers
- for as long as there is anything to copy. When there is nothing more
- to copy, return C{None}. If one of them actually manages to deliver
- some application bytes, return a two-tuple of the connection from which
- the bytes were read and the bytes themselves.
- """
- wrote = True
- while wrote:
- # Loop until neither side has anything to say
- wrote = False
-
- # Copy stuff from each side's send buffer to the other side's
- # receive buffer.
- for (read, write) in [(client_conn, server_conn),
- (server_conn, client_conn)]:
-
- # Give the side a chance to generate some more bytes, or
- # succeed.
- try:
- bytes = read.recv(2 ** 16)
- except WantReadError:
- # It didn't succeed, so we'll hope it generated some
- # output.
- pass
- else:
- # It did succeed, so we'll stop now and let the caller deal
- # with it.
- return (read, bytes)
-
- while True:
- # Keep copying as long as there's more stuff there.
- try:
- dirty = read.bio_read(4096)
- except WantReadError:
- # Okay, nothing more waiting to be sent. Stop
- # processing this send buffer.
- break
- else:
- # Keep track of the fact that someone generated some
- # output.
- wrote = True
- write.bio_write(dirty)
-
-
def test_memoryConnect(self):
"""
Two L{Connection}s which use memory BIOs can be manually connected by
@@ -919,7 +1027,8 @@
# First, the handshake needs to happen. We'll deliver bytes back and
# forth between the client and server until neither of them feels like
# speaking any more.
- self.assertIdentical(self._loopback(client_conn, server_conn), None)
+ self.assertIdentical(
+ self._interactInMemory(client_conn, server_conn), None)
# Now that the handshake is done, there should be a key and nonces.
self.assertNotIdentical(server_conn.master_key(), None)
@@ -935,12 +1044,12 @@
server_conn.write(important_message)
self.assertEquals(
- self._loopback(client_conn, server_conn),
+ self._interactInMemory(client_conn, server_conn),
(client_conn, important_message))
client_conn.write(important_message[::-1])
self.assertEquals(
- self._loopback(client_conn, server_conn),
+ self._interactInMemory(client_conn, server_conn),
(server_conn, important_message[::-1]))
@@ -1008,7 +1117,7 @@
server = self._server(None)
client = self._client(None)
- self._loopback(client, server)
+ self._interactInMemory(client, server)
size = 2 ** 15
sent = client.send("x" * size)
@@ -1017,7 +1126,7 @@
# meaningless.
self.assertTrue(sent < size)
- receiver, received = self._loopback(client, server)
+ receiver, received = self._interactInMemory(client, server)
self.assertIdentical(receiver, server)
# We can rely on all of these bytes being received at once because
@@ -1057,7 +1166,7 @@
expected = func(ctx)
self.assertEqual(client.get_client_ca_list(), [])
self.assertEqual(server.get_client_ca_list(), expected)
- self._loopback(client, server)
+ self._interactInMemory(client, server)
self.assertEqual(client.get_client_ca_list(), expected)
self.assertEqual(server.get_client_ca_list(), expected)
diff --git a/OpenSSL/test/util.py b/OpenSSL/test/util.py
index d195d95..34585ee 100644
--- a/OpenSSL/test/util.py
+++ b/OpenSSL/test/util.py
@@ -27,7 +27,7 @@
Subclasses must invoke this method if they override it or the
cleanup will not occur.
"""
- if self._temporaryFiles is not None:
+ if False and self._temporaryFiles is not None:
for temp in self._temporaryFiles:
if os.path.isdir(temp):
shutil.rmtree(temp)