Refactor the existing memory bio tests a bit; add a new test which exercises the short-write case; change the error handling code in bio_read and bio_write to *not* use SSL_get_error, as these functions do nothing with SSL_* APIs
diff --git a/src/ssl/connection.c b/src/ssl/connection.c
index 8708e80..9abb736 100755
--- a/src/ssl/connection.c
+++ b/src/ssl/connection.c
@@ -125,6 +125,50 @@
}
/*
+ * Handle errors raised by BIO functions.
+ *
+ * Arguments: bio - The BIO object
+ * ret - The return value of the BIO_ function.
+ * Returns: None, the calling function should return NULL;
+ */
+static void
+handle_bio_errors(BIO* bio, int ret)
+{
+ if (BIO_should_retry(bio)) {
+ if (BIO_should_read(bio)) {
+ PyErr_SetNone(ssl_WantReadError);
+ } else if (BIO_should_write(bio)) {
+ PyErr_SetNone(ssl_WantWriteError);
+ } else if (BIO_should_io_special(bio)) {
+ /*
+ * It's somewhat unclear what this means. From the OpenSSL source,
+ * it seems like it should not be triggered by the memory BIO, so
+ * for the time being, this case shouldn't come up. The SSL BIO
+ * (which I think should be named the socket BIO) may trigger this
+ * case if its socket is not yet connected or it is busy doing
+ * something related to x509.
+ */
+ PyErr_SetString(PyExc_ValueError, "BIO_should_io_special");
+ } else {
+ /*
+ * I hope this is dead code. The BIO documentation suggests that
+ * one of the above three checks should always be true.
+ */
+ PyErr_SetString(PyExc_ValueError, "unknown bio failure");
+ }
+ } else {
+ /*
+ * If we aren't to retry, it's really an error, so fall back to the
+ * normal error reporting code. However, the BIO interface does not
+ * specify a uniform error reporting mechanism. We can only hope that
+ * the code which triggered the error also kindly pushed something onto
+ * the error stack.
+ */
+ exception_from_error_queue();
+ }
+}
+
+/*
* Handle errors raised by SSL I/O functions. NOTE: Not SSL_shutdown ;)
*
* Arguments: ssl - The SSL object
@@ -252,7 +296,7 @@
ssl_Connection_bio_write(ssl_ConnectionObj *self, PyObject *args)
{
char *buf;
- int len, ret, err;
+ int len, ret;
if(self->into_ssl == NULL)
{
@@ -271,16 +315,15 @@
return NULL;
}
- err = SSL_get_error(self->ssl, ret);
- if (err == SSL_ERROR_NONE)
- {
- return PyInt_FromLong((long)ret);
- }
- else
- {
- handle_ssl_errors(self->ssl, err, ret);
+ if (ret <= 0) {
+ /*
+ * There was a problem with the BIO_write of some sort.
+ */
+ handle_bio_errors(self->from_ssl, ret);
return NULL;
}
+
+ return PyInt_FromLong((long)ret);
}
static char ssl_Connection_send_doc[] = "\n\
@@ -440,7 +483,7 @@
static PyObject *
ssl_Connection_bio_read(ssl_ConnectionObj *self, PyObject *args)
{
- int bufsiz, ret, err;
+ int bufsiz, ret;
PyObject *buf;
if(self->from_ssl == NULL)
@@ -465,19 +508,24 @@
return NULL;
}
- err = SSL_get_error(self->ssl, ret);
- if (err == SSL_ERROR_NONE)
- {
- if (ret != bufsiz && _PyString_Resize(&buf, ret) < 0)
- return NULL;
- return buf;
- }
- else
- {
- handle_ssl_errors(self->ssl, err, ret);
+ if (ret <= 0) {
+ /*
+ * There was a problem with the BIO_read of some sort.
+ */
+ handle_bio_errors(self->from_ssl, ret);
Py_DECREF(buf);
return NULL;
}
+
+ /*
+ * Shrink the string to match the number of bytes we actually read.
+ */
+ if (ret != bufsiz && _PyString_Resize(&buf, ret) < 0)
+ {
+ Py_DECREF(buf);
+ return NULL;
+ }
+ return buf;
}
static char ssl_Connection_renegotiate_doc[] = "\n\
diff --git a/test/test_ssl.py b/test/test_ssl.py
index 8ade92e..f09819c 100644
--- a/test/test_ssl.py
+++ b/test/test_ssl.py
@@ -425,6 +425,40 @@
"""
Tests for L{OpenSSL.SSL.Connection} using a memory BIO.
"""
+ def _server(self):
+ # Create the server side Connection. This is mostly setup boilerplate
+ # - use TLSv1, use a particular certificate, etc.
+ server_ctx = Context(TLSv1_METHOD)
+ server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE )
+ server_ctx.set_verify(VERIFY_PEER|VERIFY_FAIL_IF_NO_PEER_CERT|VERIFY_CLIENT_ONCE, verify_cb)
+ server_store = server_ctx.get_cert_store()
+ server_ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+ server_ctx.check_privatekey()
+ server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
+ # Here the Connection is actually created. None is passed as the 2nd
+ # parameter, indicating a memory BIO should be created.
+ server_conn = Connection(server_ctx, None)
+ server_conn.set_accept_state()
+ return server_conn
+
+
+ def _client(self):
+ # Now create the client side Connection. Similar boilerplate to the above.
+ client_ctx = Context(TLSv1_METHOD)
+ client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE )
+ client_ctx.set_verify(VERIFY_PEER|VERIFY_FAIL_IF_NO_PEER_CERT|VERIFY_CLIENT_ONCE, verify_cb)
+ client_store = client_ctx.get_cert_store()
+ client_ctx.use_privatekey(load_privatekey(FILETYPE_PEM, client_key_pem))
+ client_ctx.use_certificate(load_certificate(FILETYPE_PEM, client_cert_pem))
+ client_ctx.check_privatekey()
+ client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
+ # Again, None to create a new memory BIO.
+ client_conn = Connection(client_ctx, None)
+ client_conn.set_connect_state()
+ return client_conn
+
+
def _loopback(self, client_conn, server_conn):
"""
Try to read application bytes from each of the two L{Connection}
@@ -447,7 +481,7 @@
# Give the side a chance to generate some more bytes, or
# succeed.
try:
- bytes = read.recv(1024)
+ bytes = read.recv(2 ** 16)
except WantReadError:
# It didn't succeed, so we'll hope it generated some
# output.
@@ -479,33 +513,8 @@
the other and in this way establish a connection and exchange
application-level bytes with each other.
"""
- # Create the server side Connection. This is mostly setup boilerplate
- # - use TLSv1, use a particular certificate, etc.
- server_ctx = Context(TLSv1_METHOD)
- server_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE )
- server_ctx.set_verify(VERIFY_PEER|VERIFY_FAIL_IF_NO_PEER_CERT|VERIFY_CLIENT_ONCE, verify_cb)
- server_store = server_ctx.get_cert_store()
- server_ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
- server_ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
- server_ctx.check_privatekey()
- server_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
- # Here the Connection is actually created. None is passed as the 2nd
- # parameter, indicating a memory BIO should be created.
- server_conn = Connection(server_ctx, None)
- server_conn.set_accept_state()
-
- # Now create the client side Connection. Similar boilerplate to the above.
- client_ctx = Context(TLSv1_METHOD)
- client_ctx.set_options(OP_NO_SSLv2 | OP_NO_SSLv3 | OP_SINGLE_DH_USE )
- client_ctx.set_verify(VERIFY_PEER|VERIFY_FAIL_IF_NO_PEER_CERT|VERIFY_CLIENT_ONCE, verify_cb)
- client_store = client_ctx.get_cert_store()
- client_ctx.use_privatekey(load_privatekey(FILETYPE_PEM, client_key_pem))
- client_ctx.use_certificate(load_certificate(FILETYPE_PEM, client_cert_pem))
- client_ctx.check_privatekey()
- client_store.add_cert(load_certificate(FILETYPE_PEM, root_cert_pem))
- # Again, None to create a new memory BIO.
- client_conn = Connection(client_ctx, None)
- client_conn.set_connect_state()
+ server_conn = self._server()
+ client_conn = self._client()
# There should be no key or nonces yet.
self.assertIdentical(server_conn.master_key(), None)
@@ -548,3 +557,30 @@
clientSSL = Connection(context, client)
self.assertRaises( TypeError, clientSSL.bio_read, 100)
self.assertRaises( TypeError, clientSSL.bio_write, "foo")
+
+
+ def test_outgoingOverflow(self):
+ """
+ If more bytes than can be written to the memory BIO are passed to
+ L{Connection.send} at once, the number of bytes which were written is
+ returned and that many bytes from the beginning of the input can be
+ read from the other end of the connection.
+ """
+ server = self._server()
+ client = self._client()
+
+ self._loopback(client, server)
+
+ size = 2 ** 15
+ sent = client.send("x" * size)
+ # Sanity check. We're trying to test what happens when the entire
+ # input can't be sent. If the entire input was sent, this test is
+ # meaningless.
+ self.assertTrue(sent < size)
+
+ receiver, received = self._loopback(client, server)
+ self.assertIdentical(receiver, server)
+
+ # We can rely on all of these bytes being received at once because
+ # _loopback passes 2 ** 16 to recv - more than 2 ** 15.
+ self.assertEquals(len(received), sent)