use _ffi.from_buffer() to support bytearray (#852)
* use _ffi.from_buffer(buf) in send, to support bytearray
* add bytearray test
* update CHANGELOG.rst
* move from_buffer before 'buffer too long' check
* context-managed from_buffer + black
* don't shadow buf in send()
* test return count for sendall
* test sending an array
* fix test
* also use from_buffer in bio_write
* de-format _util.py
* formatting
* add simple bio_write tests
* wrap line
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index 6b9422c..16767e9 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -2087,6 +2087,29 @@
with pytest.raises(TypeError):
Connection(bad_context)
+ @pytest.mark.parametrize('bad_bio', [object(), None, 1, [1, 2, 3]])
+ def test_bio_write_wrong_args(self, bad_bio):
+ """
+ `Connection.bio_write` raises `TypeError` if called with a non-bytes
+ (or text) argument.
+ """
+ context = Context(TLSv1_METHOD)
+ connection = Connection(context, None)
+ with pytest.raises(TypeError):
+ connection.bio_write(bad_bio)
+
+ def test_bio_write(self):
+ """
+ `Connection.bio_write` does not raise if called with bytes or
+ bytearray, warns if called with text.
+ """
+ context = Context(TLSv1_METHOD)
+ connection = Connection(context, None)
+ connection.bio_write(b'xy')
+ connection.bio_write(bytearray(b'za'))
+ with pytest.warns(DeprecationWarning):
+ connection.bio_write(u'deprecated')
+
def test_get_context(self):
"""
`Connection.get_context` returns the `Context` instance used to
@@ -2807,6 +2830,8 @@
connection = Connection(Context(TLSv1_METHOD), None)
with pytest.raises(TypeError):
connection.send(object())
+ with pytest.raises(TypeError):
+ connection.send([1, 2, 3])
def test_short_bytes(self):
"""
@@ -2845,6 +2870,16 @@
assert count == 2
assert client.recv(2) == b'xy'
+ def test_short_bytearray(self):
+ """
+ When passed a short bytearray, `Connection.send` transmits all of
+ it and returns the number of bytes sent.
+ """
+ server, client = loopback()
+ count = server.send(bytearray(b'xy'))
+ assert count == 2
+ assert client.recv(2) == b'xy'
+
@skip_if_py3
def test_short_buffer(self):
"""
@@ -3015,6 +3050,8 @@
connection = Connection(Context(TLSv1_METHOD), None)
with pytest.raises(TypeError):
connection.sendall(object())
+ with pytest.raises(TypeError):
+ connection.sendall([1, 2, 3])
def test_short(self):
"""
@@ -3056,8 +3093,9 @@
`Connection.sendall` transmits all of them.
"""
server, client = loopback()
- server.sendall(buffer(b'x'))
- assert client.recv(1) == b'x'
+ count = server.sendall(buffer(b'xy'))
+ assert count == 2
+ assert client.recv(2) == b'xy'
def test_long(self):
"""