Merge send-memoryview
diff --git a/ChangeLog b/ChangeLog
index c286c03..006ee54 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,9 @@
+2011-01-22 Jean-Paul Calderone <exarkun@twistedmatrix.com>
+
+ * OpenSSL/ssl/connection.py: Add support for new-style
+ buffers (primarily memoryviews) to Connection.send and
+ Connection.sendall.
+
2010-11-01 Jean-Paul Calderone <exarkun@twistedmatrix.com>
* Release 0.11
diff --git a/OpenSSL/ssl/connection.c b/OpenSSL/ssl/connection.c
index 1d45926..f2881d3 100755
--- a/OpenSSL/ssl/connection.c
+++ b/OpenSSL/ssl/connection.c
@@ -331,18 +331,32 @@
@return: The number of bytes written\n\
";
static PyObject *
-ssl_Connection_send(ssl_ConnectionObj *self, PyObject *args)
-{
- char *buf;
+ssl_Connection_send(ssl_ConnectionObj *self, PyObject *args) {
int len, ret, err, flags;
+ char *buf;
+
+#if PY_VERSION_HEX >= 0x02060000
+ Py_buffer pbuf;
+
+ if (!PyArg_ParseTuple(args, "s*|i:send", &pbuf, &flags))
+ return NULL;
+
+ buf = pbuf.buf;
+ len = pbuf.len;
+#else
if (!PyArg_ParseTuple(args, "s#|i:send", &buf, &len, &flags))
return NULL;
+#endif
MY_BEGIN_ALLOW_THREADS(self->tstate)
ret = SSL_write(self->ssl, buf, len);
MY_END_ALLOW_THREADS(self->tstate)
+#if PY_VERSION_HEX >= 0x02060000
+ PyBuffer_Release(&pbuf);
+#endif
+
if (PyErr_Occurred())
{
flush_error_queue();
@@ -378,8 +392,18 @@
int len, ret, err, flags;
PyObject *pyret = Py_None;
+#if PY_VERSION_HEX >= 0x02060000
+ Py_buffer pbuf;
+
+ if (!PyArg_ParseTuple(args, "s*|i:sendall", &pbuf, &flags))
+ return NULL;
+
+ buf = pbuf.buf;
+ len = pbuf.len;
+#else
if (!PyArg_ParseTuple(args, "s#|i:sendall", &buf, &len, &flags))
return NULL;
+#endif
do {
MY_BEGIN_ALLOW_THREADS(self->tstate)
@@ -403,9 +427,13 @@
handle_ssl_errors(self->ssl, err, ret);
pyret = NULL;
break;
- }
+ }
} while (len > 0);
+#if PY_VERSION_HEX >= 0x02060000
+ PyBuffer_Release(&pbuf);
+#endif
+
Py_XINCREF(pyret);
return pyret;
}
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 6c8579b..bd5a92b 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -1074,6 +1074,49 @@
+class ConnectionSendTests(TestCase, _LoopbackMixin):
+ """
+ Tests for L{Connection.send}
+ """
+ def test_wrong_args(self):
+ """
+ When called with arguments other than a single string,
+ L{Connection.send} raises L{TypeError}.
+ """
+ connection = Connection(Context(TLSv1_METHOD), None)
+ self.assertRaises(TypeError, connection.send)
+ self.assertRaises(TypeError, connection.send, object())
+ self.assertRaises(TypeError, connection.send, "foo", "bar")
+
+
+ def test_short_bytes(self):
+ """
+ When passed a short byte string, L{Connection.send} transmits all of it
+ and returns the number of bytes sent.
+ """
+ server, client = self._loopback()
+ count = server.send(b('xy'))
+ self.assertEquals(count, 2)
+ self.assertEquals(client.recv(2), b('xy'))
+
+ try:
+ memoryview
+ except NameError:
+ "cannot test sending memoryview without memoryview"
+ else:
+ def test_short_memoryview(self):
+ """
+ When passed a memoryview onto a small number of bytes,
+ L{Connection.send} transmits all of them and returns the number of
+ bytes sent.
+ """
+ server, client = self._loopback()
+ count = server.send(memoryview(b('xy')))
+ self.assertEquals(count, 2)
+ self.assertEquals(client.recv(2), b('xy'))
+
+
+
class ConnectionSendallTests(TestCase, _LoopbackMixin):
"""
Tests for L{Connection.sendall}.
@@ -1099,6 +1142,21 @@
self.assertEquals(client.recv(1), b('x'))
+ try:
+ memoryview
+ except NameError:
+ "cannot test sending memoryview without memoryview"
+ else:
+ def test_short_memoryview(self):
+ """
+ When passed a memoryview onto a small number of bytes,
+ L{Connection.sendall} transmits all of them.
+ """
+ server, client = self._loopback()
+ server.sendall(memoryview(b('x')))
+ self.assertEquals(client.recv(1), b('x'))
+
+
def test_long(self):
"""
L{Connection.sendall} transmits all of the bytes in the string passed to