All of ContextTests passes
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 6b8d78a..36605a8 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -1,12 +1,14 @@
from functools import wraps
+from itertools import count
from OpenSSL.xSSL import *
from tls.c import api as _api
from OpenSSL.crypto import (
- FILETYPE_PEM, _PassphraseHelper, PKey, X509, _raise_current_error)
+ FILETYPE_PEM, _PassphraseHelper, PKey, X509, _raise_current_error,
+ _new_mem_buf)
_unspecified = object()
@@ -51,6 +53,32 @@
+class WantReadError(Error):
+ pass
+
+
+
+def _asFileDescriptor(obj):
+ fd = None
+
+ if not isinstance(obj, int):
+ meth = getattr(obj, "fileno", None)
+ if meth is not None:
+ obj = meth()
+
+ if isinstance(obj, int):
+ fd = obj
+
+ if not isinstance(fd, int):
+ raise TypeError("argument must be an int, or have a fileno() method.")
+ elif fd < 0:
+ raise ValueError(
+ "file descriptor cannot be a negative integer (%i)" % (fd,))
+
+ return fd
+
+
+
def SSLeay_version(type):
"""
Return a string describing the version of OpenSSL in use.
@@ -133,10 +161,10 @@
def _wrap_callback(self, callback):
@wraps(callback)
- def wrapped(size, verify, userdata):
+ def wrapper(size, verify, userdata):
return callback(size, verify, self._passphrase_userdata)
return _PassphraseHelper(
- FILETYPE_PEM, wrapped, more_args=True, truncate=True)
+ FILETYPE_PEM, wrapper, more_args=True, truncate=True)
def set_passwd_cb(self, callback, userdata=None):
@@ -164,6 +192,11 @@
:return: None
"""
+ set_result = _api.SSL_CTX_set_default_verify_paths(self._context)
+ if not set_result:
+ 1/0
+ _raise_current_error(Error)
+
def use_certificate_chain_file(self, certfile):
"""
@@ -172,6 +205,13 @@
:param certfile: The name of the certificate chain file
:return: None
"""
+ if not isinstance(certfile, bytes):
+ raise TypeError("certfile must be a byte string")
+
+ result = _api.SSL_CTX_use_certificate_chain_file(self._context, certfile)
+ if not result:
+ _raise_current_error(Error)
+
def use_certificate_file(self, certfile, filetype=_unspecified):
"""
@@ -189,6 +229,13 @@
:param cert: The X509 object
:return: None
"""
+ if not isinstance(cert, X509):
+ raise TypeError("cert must be an X509 instance")
+
+ use_result = _api.SSL_CTX_use_certificate(self._context, cert._x509)
+ if not use_result:
+ 1/0
+
def add_extra_chain_cert(self, certobj):
"""
@@ -313,8 +360,27 @@
if not callable(callback):
raise TypeError("callback must be callable")
- callback = _api.ffi.callback("int(*)(int, X509_STORE_CTX*)", callback)
- _api.SSL_CTX_set_verify(self._context, mode, callback)
+ @wraps(callback)
+ def wrapper(ok, store_ctx):
+ cert = X509.__new__(X509)
+ cert._x509 = _api.X509_STORE_CTX_get_current_cert(store_ctx)
+ error_number = _api.X509_STORE_CTX_get_error(store_ctx)
+ error_depth = _api.X509_STORE_CTX_get_error_depth(store_ctx)
+
+ try:
+ result = callback(self, cert, error_number, error_depth, ok)
+ except Exception as e:
+ # TODO
+ pass
+ else:
+ if result:
+ _api.X509_STORE_CTX_set_error(store_ctx, _api.X509_V_OK)
+ return 1
+ else:
+ return 0
+
+ self._verify_callback = _api.ffi.callback("verify_callback", wrapper)
+ _api.SSL_CTX_set_verify(self._context, mode, self._verify_callback)
def set_verify_depth(self, depth):
@@ -375,6 +441,13 @@
:param cipher_list: A cipher list, see ciphers(1)
:return: None
"""
+ if not isinstance(cipher_list, bytes):
+ raise TypeError("cipher_list must be a byte string")
+
+ result = _api.SSL_CTX_set_cipher_list(self._context, cipher_list)
+ if not result:
+ _raise_current_error(Error)
+
def set_client_ca_list(self, certificate_authorities):
"""
@@ -427,6 +500,12 @@
:param callback: The Python callback to use
:return: None
"""
+ @wraps(callback)
+ def wrapper(ssl, where, return_code):
+ callback(self, where, return_code)
+ self._info_callback = _api.callback('info_callback', wrapper)
+ _api.SSL_CTX_set_info_callback(self._context, self._info_callback)
+
def get_app_data(self):
"""
@@ -489,3 +568,421 @@
"""
ContextType = Context
+
+
+
+class Connection(object):
+ """
+ """
+ def __init__(self, context, socket=None):
+ """
+ Create a new Connection object, using the given OpenSSL.SSL.Context
+ instance and socket.
+
+ :param context: An SSL Context to use for this connection
+ :param socket: The socket to use for transport layer
+ """
+ if not isinstance(context, Context):
+ raise TypeError("context must be a Context instance")
+
+ self._ssl = _api.SSL_new(context._context)
+
+ if socket is None:
+ self._socket = None
+ self._into_ssl = _new_mem_buf()
+ self._from_ssl = _new_mem_buf()
+
+ if self._into_ssl == _api.NULL or self._from_ssl == _api.NULL:
+ 1/0
+
+ _api.SSL_set_bio(self._ssl, self._into_ssl, self._from_ssl)
+ else:
+ self._socket = socket
+ set_result = _api.SSL_set_fd(self._ssl, _asFileDescriptor(self._socket))
+ if not set_result:
+ 1/0
+
+
+ def __getattr__(self, name):
+ """
+ Look up attributes on the wrapped socket object if they are not found on
+ the Connection object.
+ """
+ return getattr(self._socket, name)
+
+
+ def _raise_ssl_error(self, error, result):
+ if error == _api.SSL_ERROR_WANT_READ:
+ raise WantReadError()
+ else:
+ _raise_current_error(Error)
+
+
+ def get_context(self):
+ """
+ Get session context
+ """
+
+
+ def set_context(self):
+ """
+ Switch this connection to a new session context
+
+ :param context: A :py:class:`Context` instance giving the new session context to use.
+ """
+
+ def get_servername(self):
+ """
+ Retrieve the servername extension value if provided in the client hello
+ message, or None if there wasn't one.
+
+ :return: A byte string giving the server name or :py:data:`None`.
+ """
+
+ def set_tlsext_host_name(self):
+ """
+ Set the value of the servername extension to send in the client hello.
+
+ :param name: A byte string giving the name.
+ """
+
+ def pending(self):
+ """
+ Get the number of bytes that can be safely read from the connection
+
+ :return: The number of bytes available in the receive buffer.
+ """
+
+
+ def _handle_SSL_result(self, ssl, result):
+ error = _api.SSL_get_error(ssl, result)
+ if error != _api.SSL_ERROR_NONE:
+ self._raise_ssl_error(error, result)
+
+
+ def send(self, data, flags=None):
+ """
+ Send data on the connection. NOTE: If you get one of the WantRead,
+ WantWrite or WantX509Lookup exceptions on this, you have to call the
+ method again with the SAME buffer.
+
+ :param buf: The string to send
+ :param flags: (optional) Included for compatibility with the socket
+ API, the value is ignored
+ :return: The number of bytes written
+ """
+ result = _api.SSL_write(self._ssl, data, len(data))
+ self._handle_SSL_result(self._ssl, result)
+ return result
+ write = send
+
+
+ def sendall(self):
+ """
+ Send \"all\" data on the connection. This calls send() repeatedly until
+ all data is sent. If an error occurs, it's impossible to tell how much data
+ has been sent.
+
+ :param buf: The string to send
+ :param flags: (optional) Included for compatibility with the socket
+ API, the value is ignored
+ :return: The number of bytes written
+ """
+
+ def recv(self, bufsiz, flags=None):
+ """
+ Receive data on the connection. NOTE: If you get one of the WantRead,
+ WantWrite or WantX509Lookup exceptions on this, you have to call the
+ method again with the SAME buffer.
+
+ :param bufsiz: The maximum number of bytes to read
+ :param flags: (optional) Included for compatibility with the socket
+ API, the value is ignored
+ :return: The string read from the Connection
+ """
+ buf = _api.new("char[]", bufsiz)
+ result = _api.SSL_read(self._ssl, buf, bufsiz)
+ self._handle_SSL_result(self._ssl, result)
+ return _api.buffer(buf, result)
+ read = recv
+
+
+ def bio_read(self):
+ """
+ When using non-socket connections this function reads
+ the \"dirty\" data that would have traveled away on the network.
+
+ :param bufsiz: The maximum number of bytes to read
+ :return: The string read.
+ """
+
+ def bio_write(self):
+ """
+ When using non-socket connections this function sends
+ \"dirty\" data that would have traveled in on the network.
+
+ :param buf: The string to put into the memory BIO.
+ :return: The number of bytes written
+ """
+
+ def renegotiate(self):
+ """
+ Renegotiate the session
+
+ :return: True if the renegotiation can be started, false otherwise
+ """
+
+ def do_handshake(self):
+ """
+ Perform an SSL handshake (usually called after renegotiate() or one of
+ set_*_state()). This can raise the same exceptions as send and recv.
+
+ :return: None.
+ """
+ result = _api.SSL_do_handshake(self._ssl)
+ self._handle_SSL_result(self._ssl, result)
+
+
+ def renegotiate_pending(self):
+ """
+ Check if there's a renegotiation in progress, it will return false once
+ a renegotiation is finished.
+
+ :return: Whether there's a renegotiation in progress
+ """
+
+ def total_renegotiations(self):
+ """
+ Find out the total number of renegotiations.
+
+ :return: The number of renegotiations.
+ """
+
+ def connect(self):
+ """
+ Connect to remote host and set up client-side SSL
+
+ :param addr: A remote address
+ :return: What the socket's connect method returns
+ """
+
+ def connect_ex(self):
+ """
+ Connect to remote host and set up client-side SSL. Note that if the socket's
+ connect_ex method doesn't return 0, SSL won't be initialized.
+
+ :param addr: A remove address
+ :return: What the socket's connect_ex method returns
+ """
+
+ def accept(self):
+ """
+ Accept incoming connection and set up SSL on it
+
+ :return: A (conn,addr) pair where conn is a Connection and addr is an
+ address
+ """
+
+
+ def bio_shutdown(self):
+ """
+ When using non-socket connections this function signals end of
+ data on the input for this connection.
+
+ :return: None
+ """
+
+ def shutdown(self):
+ """
+ Send closure alert
+
+ :return: True if the shutdown completed successfully (i.e. both sides
+ have sent closure alerts), false otherwise (i.e. you have to
+ wait for a ZeroReturnError on a recv() method call
+ """
+ result = _api.SSL_shutdown(self._ssl)
+ if result < 0:
+ 1/0
+ elif result > 0:
+ 2/0
+ else:
+ return False
+
+
+ def get_cipher_list(self):
+ """
+ Get the session cipher list
+
+ :return: A list of cipher strings
+ """
+ ciphers = []
+ for i in count():
+ result = _api.SSL_get_cipher_list(self._ssl, i)
+ if result == _api.NULL:
+ break
+ ciphers.append(_api.string(result))
+ return ciphers
+
+
+ def get_client_ca_list(self):
+ """
+ Get CAs whose certificates are suggested for client authentication.
+
+ :return: If this is a server connection, a list of X509Names representing
+ the acceptable CAs as set by :py:meth:`OpenSSL.SSL.Context.set_client_ca_list` or
+ :py:meth:`OpenSSL.SSL.Context.add_client_ca`. If this is a client connection,
+ the list of such X509Names sent by the server, or an empty list if that
+ has not yet happened.
+ """
+
+ def makefile(self):
+ """
+ The makefile() method is not implemented, since there is no dup semantics
+ for SSL connections
+
+ :raise NotImplementedError
+ """
+
+
+ def get_app_data(self):
+ """
+ Get application data
+
+ :return: The application data
+ """
+
+ def set_app_data(self):
+ """
+ Set application data
+
+ :param data - The application data
+ :return: None
+ """
+
+ def get_shutdown(self):
+ """
+ Get shutdown state
+
+ :return: The shutdown state, a bitvector of SENT_SHUTDOWN, RECEIVED_SHUTDOWN.
+ """
+ return _api.SSL_get_shutdown(self._ssl)
+
+
+ def set_shutdown(self):
+ """
+ Set shutdown state
+
+ :param state - bitvector of SENT_SHUTDOWN, RECEIVED_SHUTDOWN.
+ :return: None
+ """
+
+ def state_string(self):
+ """
+ Get a verbose state description
+
+ :return: A string representing the state
+ """
+
+ def server_random(self):
+ """
+ Get a copy of the server hello nonce.
+
+ :return: A string representing the state
+ """
+
+ def client_random(self):
+ """
+ Get a copy of the client hello nonce.
+
+ :return: A string representing the state
+ """
+
+ def master_key(self):
+ """
+ Get a copy of the master key.
+
+ :return: A string representing the state
+ """
+
+ def sock_shutdown(self):
+ """
+ See shutdown(2)
+
+ :return: What the socket's shutdown() method returns
+ """
+
+ def get_peer_certificate(self):
+ """
+ Retrieve the other side's certificate (if any)
+
+ :return: The peer's certificate
+ """
+ cert = _api.SSL_get_peer_certificate(self._ssl)
+ if cert != _api.NULL:
+ pycert = X509.__new__(X509)
+ pycert._x509 = _api.ffi.gc(cert, _api.X509_free)
+ return pycert
+ return None
+
+
+ def get_peer_cert_chain(self):
+ """
+ Retrieve the other side's certificate (if any)
+
+ :return: A list of X509 instances giving the peer's certificate chain,
+ or None if it does not have one.
+ """
+
+ def want_read(self):
+ """
+ Checks if more data has to be read from the transport layer to complete an
+ operation.
+
+ :return: True iff more data has to be read
+ """
+
+ def want_write(self):
+ """
+ Checks if there is data to write to the transport layer to complete an
+ operation.
+
+ :return: True iff there is data to write
+ """
+
+ def set_accept_state(self):
+ """
+ Set the connection to work in server mode. The handshake will be handled
+ automatically by read/write.
+
+ :return: None
+ """
+ _api.SSL_set_accept_state(self._ssl)
+
+
+ def set_connect_state(self):
+ """
+ Set the connection to work in client mode. The handshake will be handled
+ automatically by read/write.
+
+ :return: None
+ """
+ _api.SSL_set_connect_state(self._ssl)
+
+
+ def get_session(self):
+ """
+ Returns the Session currently used.
+
+ @return: An instance of :py:class:`OpenSSL.SSL.Session` or :py:obj:`None` if
+ no session exists.
+ """
+
+ def set_session(self):
+ """
+ Set the session to be used when the TLS/SSL connection is established.
+
+ :param session: A Session instance representing the session to use.
+ :returns: None
+ """
+
+ConnectionType = Connection
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index c5a0222..80e54a7 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -859,6 +859,22 @@
self._handshake_test(serverContext, clientContext)
+
+ def test_use_certificate_chain_file_wrong_args(self):
+ """
+ :py:obj:`Context.use_certificate_chain_file` raises :py:obj:`TypeError`
+ if passed zero or more than one argument or when passed a non-byte
+ string single argument. It also raises :py:obj:`OpenSSL.SSL.Error` when
+ passed a bad chain file name (for example, the name of a file which does
+ not exist).
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.use_certificate_chain_file)
+ self.assertRaises(TypeError, context.use_certificate_chain_file, object())
+ self.assertRaises(TypeError, context.use_certificate_chain_file, b"foo", object())
+
+ self.assertRaises(Error, context.use_certificate_chain_file, self.mktemp())
+
# XXX load_client_ca
# XXX set_session_id
@@ -929,6 +945,21 @@
self.assertEquals(conn.get_cipher_list(), ["EXP-RC4-MD5"])
+ def test_set_cipher_list_wrong_args(self):
+ """
+ :py:obj:`Context.set_cipher_list` raises :py:obj:`TypeError` when passed
+ zero arguments or more than one argument or when passed a non-byte
+ string single argument and raises :py:obj:`OpenSSL.SSL.Error` when
+ passed an incorrect cipher list string.
+ """
+ context = Context(TLSv1_METHOD)
+ self.assertRaises(TypeError, context.set_cipher_list)
+ self.assertRaises(TypeError, context.set_cipher_list, object())
+ self.assertRaises(TypeError, context.set_cipher_list, b"EXP-RC4-MD5", object())
+
+ self.assertRaises(Error, context.set_cipher_list, b"imaginary-cipher")
+
+
def test_set_session_cache_mode_wrong_args(self):
"""
L{Context.set_session_cache_mode} raises L{TypeError} if called with