Make NPN markups.
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index e754a7e..bd68767 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -1,11 +1,12 @@
from sys import platform
from functools import wraps, partial
-from itertools import count
+from itertools import count, chain
from weakref import WeakValueDictionary
from errno import errorcode
from six import text_type as _text_type
from six import integer_types as integer_types
+from six import int2byte, byte2int
from OpenSSL._util import (
ffi as _ffi,
@@ -294,9 +295,7 @@
self._tlsext_servername_callback = None
self._app_data = None
self._npn_advertise_callback = None
- self._npn_advertise_callback_args = None
self._npn_select_callback = None
- self._npn_select_callback_args = None
# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
@@ -816,22 +815,35 @@
def set_npn_advertise_callback(self, callback):
"""
- Specify a callback function that will be called when offering Next
- Protocol Negotiation.
+ Specify a callback function that will be called when offering `Next
+ Protocol Negotiation
+ <https://technotes.googlecode.com/git/nextprotoneg.html>`_ as a server.
:param callback: The callback function. It will be invoked with one
- argument, the Connection instance. It should return a Python
- bytestring, like b'\\x08http/1.1\\x06spdy/2'.
+ argument, the Connection instance. It should return a list of
+ bytestrings representing the advertised protocols, like
+ ``[b'http/1.1', b'spdy/2']``.
"""
@wraps(callback)
def wrapper(ssl, out, outlen, arg):
- outstr = callback(Connection._reverse_mapping[ssl])
- self._npn_advertise_callback_args = [
- _ffi.new("unsigned int *", len(outstr)),
- _ffi.new("unsigned char[]", outstr),
+ conn = Connection._reverse_mapping[ssl]
+ protos = callback(conn)
+
+ # Join the protocols into a Python bytestring, length-prefixing
+ # each element.
+ protostr = b''.join(
+ chain.from_iterable((int2byte(len(p)), p) for p in protos)
+ )
+
+ # Save our callback arguments on the connection object. This is
+ # done to make sure that they don't get freed before OpenSSL uses
+ # them. Then, return them appropriately in the output parameters.
+ conn._npn_advertise_callback_args = [
+ _ffi.new("unsigned int *", len(protostr)),
+ _ffi.new("unsigned char[]", protostr),
]
- outlen[0] = self._npn_advertise_callback_args[0][0]
- out[0] = self._npn_advertise_callback_args[1]
+ outlen[0] = conn._npn_advertise_callback_args[0][0]
+ out[0] = conn._npn_advertise_callback_args[1]
return 0
self._npn_advertise_callback = _ffi.callback(
@@ -848,20 +860,38 @@
:param callback: The callback function. It will be invoked with two
arguments: the Connection, and a list of offered protocols as
- length-prefixed strings in a bytestring, e.g.
- b'\\x08http/1.1\\x06spdy/2'. It should return one of those
- bytestrings, the chosen protocol.
+ bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return
+ one of those bytestrings, the chosen protocol.
"""
@wraps(callback)
def wrapper(ssl, out, outlen, in_, inlen, arg):
- outstr = callback(
- Connection._reverse_mapping[ssl], _ffi.string(in_))
- self._npn_select_callback_args = [
+ conn = Connection._reverse_mapping[ssl]
+
+ # The string passed to us is actually made up of multiple
+ # length-prefixed bytestrings. We need to split that into a list.
+ instr = _ffi.buffer(in_, inlen)
+ protolist = []
+ while instr:
+ # This slightly insane syntax is to make sure we get a
+ # bytestring: on Python 3, instr[0] would return an int and
+ # this call would fail.
+ l = byte2int(instr[0:1])
+ proto = instr[1:l+1]
+ protolist.append(proto)
+ instr = instr[l+1:]
+
+ # Call the callback
+ outstr = callback(conn, protolist)
+
+ # Save our callback arguments on the connection object. This is
+ # done to make sure that they don't get freed before OpenSSL uses
+ # them. Then, return them appropriately in the output parameters.
+ conn._npn_select_callback_args = [
_ffi.new("unsigned char *", len(outstr)),
_ffi.new("unsigned char[]", outstr),
]
- outlen[0] = self._npn_select_callback_args[0][0]
- out[0] = self._npn_select_callback_args[1]
+ outlen[0] = conn._npn_select_callback_args[0][0]
+ out[0] = conn._npn_select_callback_args[1]
return 0
self._npn_select_callback = _ffi.callback(
@@ -895,6 +925,13 @@
self._ssl = _ffi.gc(ssl, _lib.SSL_free)
self._context = context
+ # References to strings used for Next Protocol Negotiation. OpenSSL's
+ # header files suggest that these might get copied at some point, but
+ # doesn't specify when, so we store them here to make sure they don't
+ # get freed before OpenSSL uses them.
+ self._npn_advertise_callback_args = None
+ self._npn_select_callback_args = None
+
self._reverse_mapping[self._ssl] = self
if socket is None:
@@ -1622,7 +1659,7 @@
_lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len)
if not data_len[0]:
- return ""
+ return b""
else:
return _ffi.string(data[0])
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 404f8b9..9fc8466 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -1439,14 +1439,19 @@
Test for Next Protocol Negotiation in PyOpenSSL.
"""
def test_npn_success(self):
- advertise_args =[]
+ """
+ Tests that clients and servers that agree on the negotiated next
+ protocol can correct establish a connection, and that the agreed
+ protocol is reported by the connections.
+ """
+ advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
- return b('\x08http/1.1\x06spdy/2')
+ return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
- return b('spdy/2')
+ return b'spdy/2'
server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
@@ -1470,21 +1475,25 @@
self._interactInMemory(server, client)
self.assertEqual([(server,)], advertise_args)
- self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args)
+ self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
- self.assertEqual(server.get_next_proto_negotiated(), b('spdy/2'))
- self.assertEqual(client.get_next_proto_negotiated(), b('spdy/2'))
+ self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
+ self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')
def test_npn_client_fail(self):
- advertise_args =[]
+ """
+ Tests that when clients and servers cannot agree on what protocol to
+ use next that the TLS connection does not get established.
+ """
+ advertise_args = []
select_args = []
def advertise(conn):
advertise_args.append((conn,))
- return b('\x08http/1.1\x06spdy/2')
+ return [b'http/1.1', b'spdy/2']
def select(conn, options):
select_args.append((conn, options))
- return b('')
+ return b''
server_context = Context(TLSv1_METHOD)
server_context.set_npn_advertise_callback(advertise)
@@ -1509,7 +1518,7 @@
self.assertRaises(Error, self._interactInMemory, server, client)
self.assertEqual([(server,)], advertise_args)
- self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args)
+ self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst
index a75af1f..fbee1fe 100644
--- a/doc/api/ssl.rst
+++ b/doc/api/ssl.rst
@@ -472,6 +472,33 @@
.. versionadded:: 0.13
+.. py:method:: Context.set_npn_advertise_callback(callback)
+
+ Specify a callback function that will be called when offering `Next
+ Protocol Negotiation
+ <https://technotes.googlecode.com/git/nextprotoneg.html>`_ as a server.
+
+ *callback* should be the callback function. It will be invoked with one
+ argument, the :py:class:`Connection` instance. It should return a list of
+ bytestrings representing the advertised protocols, like
+ ``[b'http/1.1', b'spdy/2']``.
+
+ .. versionadded:: 0.15
+
+
+.. py:method:: Context.set_npn_select_callback(callback):
+
+ Specify a callback function that will be called when a server offers Next
+ Protocol Negotiation options.
+
+ *callback* should be the callback function. It will be invoked with two
+ arguments: the :py:class:`Connection`, and a list of offered protocols as
+ bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``. It should return one of
+ those bytestrings, the chosen protocol.
+
+ .. versionadded:: 0.15
+
+
.. _openssl-session:
Session objects
@@ -806,6 +833,13 @@
.. versionadded:: 0.15
+.. py:method:: Connection.get_next_proto_negotiated():
+
+ Get the protocol that was negotiated by Next Protocol Negotiation.
+
+ .. versionadded:: 0.15
+
+
.. Rubric:: Footnotes
.. [#connection-context-socket] Actually, all that is required is an object that