add ssl_peek functionality
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 8c87c34..9b27013 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -1,3 +1,4 @@
+import socket
from sys import platform
from functools import wraps, partial
from itertools import count, chain
@@ -1311,12 +1312,15 @@
method again with the SAME buffer.
:param bufsiz: The maximum number of bytes to read
- :param flags: (optional) Included for compatibility with the socket
- API, the value is ignored
+ :param flags: (optional) The only supported flag is ``MSG_PEEK``,
+ all other flags are ignored.
:return: The string read from the Connection
"""
buf = _ffi.new("char[]", bufsiz)
- result = _lib.SSL_read(self._ssl, buf, bufsiz)
+ if flags is not None and flags & socket.MSG_PEEK:
+ result = _lib.SSL_peek(self._ssl, buf, bufsiz)
+ else:
+ result = _lib.SSL_read(self._ssl, buf, bufsiz)
self._raise_ssl_error(self._ssl, result)
return _ffi.buffer(buf, result)[:]
read = recv
@@ -1332,8 +1336,8 @@
buffer. If not present, defaults to the size of the buffer. If
larger than the size of the buffer, is reduced to the size of the
buffer.
- :param flags: (optional) Included for compatibility with the socket
- API, the value is ignored.
+ :param flags: (optional) The only supported flag is ``MSG_PEEK``,
+ all other flags are ignored.
:return: The number of bytes read into the buffer.
"""
if nbytes is None:
@@ -1345,7 +1349,10 @@
# better if we could pass memoryviews straight into the SSL_read call,
# but right now we can't. Revisit this if CFFI gets that ability.
buf = _ffi.new("char[]", nbytes)
- result = _lib.SSL_read(self._ssl, buf, nbytes)
+ if flags is not None and flags & socket.MSG_PEEK:
+ result = _lib.SSL_peek(self._ssl, buf, nbytes)
+ else:
+ result = _lib.SSL_read(self._ssl, buf, nbytes)
self._raise_ssl_error(self._ssl, result)
# This strange line is all to avoid a memory copy. The buffer protocol
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index e586537..787d636 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -8,7 +8,7 @@
from gc import collect, get_referrers
from errno import ECONNREFUSED, EINPROGRESS, EWOULDBLOCK, EPIPE, ESHUTDOWN
from sys import platform, getfilesystemencoding
-from socket import SHUT_RDWR, error, socket
+from socket import MSG_PEEK, SHUT_RDWR, error, socket
from os import makedirs
from os.path import join
from unittest import main
@@ -2172,6 +2172,17 @@
self.assertRaises(TypeError, connection.pending, None)
+ def test_peek(self):
+ """
+ :py:obj:`Connection.recv` peeks into the connection if :py:obj:`socket.MSG_PEEK` is passed.
+ """
+ server, client = self._loopback()
+ server.send(b('xy'))
+ self.assertEqual(client.recv(2, MSG_PEEK), b('xy'))
+ self.assertEqual(client.recv(2, MSG_PEEK), b('xy'))
+ self.assertEqual(client.recv(2), b('xy'))
+
+
def test_connect_wrong_args(self):
"""
:py:obj:`Connection.connect` raises :py:obj:`TypeError` if called with a non-address
@@ -2999,6 +3010,17 @@
self._doesnt_overfill_test(bytearray)
+ def test_peek(self):
+
+ server, client = self._loopback()
+ server.send(b('xy'))
+
+ for _ in range(2):
+ output_buffer = bytearray(5)
+ self.assertEqual(client.recv_into(output_buffer, flags=MSG_PEEK), 2)
+ self.assertEqual(output_buffer, bytearray(b('xy\x00\x00\x00')))
+
+
try:
memoryview
except NameError: