blob: dd42a10d25a438822d7f49d4f30eee2d5172f9dc [file] [log] [blame]
# Copyright (C) Jean-Paul Calderone 2008, All rights reserved
"""
Unit tests for L{OpenSSL.SSL}.
"""
from unittest import TestCase
from tempfile import mktemp
from socket import socket
from os import makedirs, symlink
from os.path import join
from OpenSSL.crypto import TYPE_RSA, FILETYPE_PEM, PKey, dump_privatekey, load_certificate, load_privatekey
from OpenSSL.SSL import WantReadError, Context, Connection, Error
from OpenSSL.SSL import SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD
from OpenSSL.SSL import VERIFY_PEER
from OpenSSL.test.test_crypto import _Python23TestCaseHelper, cleartextCertificatePEM, cleartextPrivateKeyPEM
class ContextTests(TestCase, _Python23TestCaseHelper):
"""
Unit tests for L{OpenSSL.SSL.Context}.
"""
def mktemp(self):
"""
Pathetic substitute for twisted.trial.unittest.TestCase.mktemp.
"""
return mktemp(dir=".")
def test_method(self):
"""
L{Context} can be instantiated with one of L{SSLv2_METHOD},
L{SSLv3_METHOD}, L{SSLv23_METHOD}, or L{TLSv1_METHOD}.
"""
for meth in [SSLv2_METHOD, SSLv3_METHOD, SSLv23_METHOD, TLSv1_METHOD]:
Context(meth)
self.assertRaises(TypeError, Context, "")
self.assertRaises(ValueError, Context, 10)
def test_use_privatekey(self):
"""
L{Context.use_privatekey} takes an L{OpenSSL.crypto.PKey} instance.
"""
key = PKey()
key.generate_key(TYPE_RSA, 128)
ctx = Context(TLSv1_METHOD)
ctx.use_privatekey(key)
self.assertRaises(TypeError, ctx.use_privatekey, "")
def test_set_passwd_cb(self):
"""
L{Context.set_passwd_cb} accepts a callable which will be invoked when
a private key is loaded from an encrypted PEM.
"""
key = PKey()
key.generate_key(TYPE_RSA, 128)
pemFile = self.mktemp()
fObj = file(pemFile, 'w')
passphrase = "foobar"
fObj.write(dump_privatekey(FILETYPE_PEM, key, "blowfish", passphrase))
fObj.close()
calledWith = []
def passphraseCallback(maxlen, verify, extra):
calledWith.append((maxlen, verify, extra))
return passphrase
context = Context(TLSv1_METHOD)
context.set_passwd_cb(passphraseCallback)
context.use_privatekey_file(pemFile)
self.assertTrue(len(calledWith), 1)
self.assertTrue(isinstance(calledWith[0][0], int))
self.assertTrue(isinstance(calledWith[0][1], int))
self.assertEqual(calledWith[0][2], None)
def test_set_info_callback(self):
"""
L{Context.set_info_callback} accepts a callable which will be invoked
when certain information about an SSL connection is available.
"""
port = socket()
port.bind(('', 0))
port.listen(1)
client = socket()
client.setblocking(False)
client.connect_ex(port.getsockname())
clientSSL = Connection(Context(TLSv1_METHOD), client)
clientSSL.set_connect_state()
called = []
def info(conn, where, ret):
called.append((conn, where, ret))
context = Context(TLSv1_METHOD)
context.set_info_callback(info)
context.use_certificate(
load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
context.use_privatekey(
load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
server, ignored = port.accept()
server.setblocking(False)
serverSSL = Connection(context, server)
serverSSL.set_accept_state()
while not called:
for ssl in clientSSL, serverSSL:
try:
ssl.do_handshake()
except WantReadError:
pass
# Kind of lame. Just make sure it got called somehow.
self.assertTrue(called)
def _load_verify_locations_test(self, *args):
port = socket()
port.bind(('', 0))
port.listen(1)
client = socket()
client.setblocking(False)
client.connect_ex(port.getsockname())
clientContext = Context(TLSv1_METHOD)
clientContext.load_verify_locations(*args)
# Require that the server certificate verify properly or the
# connection will fail.
clientContext.set_verify(
VERIFY_PEER,
lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
clientSSL = Connection(clientContext, client)
clientSSL.set_connect_state()
server, _ = port.accept()
server.setblocking(False)
serverContext = Context(TLSv1_METHOD)
serverContext.use_certificate(
load_certificate(FILETYPE_PEM, cleartextCertificatePEM))
serverContext.use_privatekey(
load_privatekey(FILETYPE_PEM, cleartextPrivateKeyPEM))
serverSSL = Connection(serverContext, server)
serverSSL.set_accept_state()
for i in range(3):
for ssl in clientSSL, serverSSL:
try:
# Without load_verify_locations above, the handshake
# will fail:
# Error: [('SSL routines', 'SSL3_GET_SERVER_CERTIFICATE',
# 'certificate verify failed')]
ssl.do_handshake()
except WantReadError:
pass
cert = clientSSL.get_peer_certificate()
self.assertEqual(cert.get_subject().CN, 'pyopenssl.sf.net')
def test_load_verify_file(self):
"""
L{Context.load_verify_locations} accepts a file name and uses the
certificates within for verification purposes.
"""
cafile = self.mktemp()
fObj = file(cafile, 'w')
fObj.write(cleartextCertificatePEM)
fObj.close()
self._load_verify_locations_test(cafile)
def test_load_verify_invalid_file(self):
"""
L{Context.load_verify_locations} raises L{Error} when passed a
non-existent cafile.
"""
clientContext = Context(TLSv1_METHOD)
self.assertRaises(
Error, clientContext.load_verify_locations, self.mktemp())
def test_load_verify_directory(self):
"""
L{Context.load_verify_locations} accepts a directory name and uses
the certificates within for verification purposes.
"""
capath = self.mktemp()
makedirs(capath)
cafile = join(capath, 'cert.pem')
fObj = file(cafile, 'w')
fObj.write(cleartextCertificatePEM)
fObj.close()
# Hash value computed manually with c_rehash to avoid depending on
# c_rehash in the test suite.
symlink('cert.pem', join(capath, '07497d9e.0'))
self._load_verify_locations_test(None, capath)
def test_set_default_verify_paths(self):
"""
L{Context.set_default_verify_paths} causes the platform-specific CA
certificate locations to be used for verification purposes.
"""
# Testing this requires a server with a certificate signed by one of
# the CAs in the platform CA location. Getting one of those costs
# money. Fortunately (or unfortunately, depending on your
# perspective), it's easy to think of a public server on the
# internet which has such a certificate. Connecting to the network
# in a unit test is bad, but it's the only way I can think of to
# really test this. -exarkun
# Arg, verisign.com doesn't speak TLSv1
context = Context(SSLv3_METHOD)
context.set_default_verify_paths()
context.set_verify(
VERIFY_PEER,
lambda conn, cert, errno, depth, preverify_ok: preverify_ok)
client = socket()
client.connect(('verisign.com', 443))
clientSSL = Connection(context, client)
clientSSL.set_connect_state()
clientSSL.do_handshake()
clientSSL.send('GET / HTTP/1.0\r\n\r\n')
self.assertTrue(clientSSL.recv(1024))