Add support for Next Protocol Negotiation.
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index 7b1cbc1..e754a7e 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -293,6 +293,10 @@
self._info_callback = None
self._tlsext_servername_callback = None
self._app_data = None
+ self._npn_advertise_callback = None
+ self._npn_advertise_callback_args = None
+ self._npn_select_callback = None
+ self._npn_select_callback_args = None
# SSL_CTX_set_app_data(self->ctx, self);
# SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
@@ -809,6 +813,64 @@
_lib.SSL_CTX_set_tlsext_servername_callback(
self._context, self._tlsext_servername_callback)
+
+ def set_npn_advertise_callback(self, callback):
+ """
+ Specify a callback function that will be called when offering Next
+ Protocol Negotiation.
+
+ :param callback: The callback function. It will be invoked with one
+ argument, the Connection instance. It should return a Python
+ bytestring, like b'\\x08http/1.1\\x06spdy/2'.
+ """
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, arg):
+ outstr = callback(Connection._reverse_mapping[ssl])
+ self._npn_advertise_callback_args = [
+ _ffi.new("unsigned int *", len(outstr)),
+ _ffi.new("unsigned char[]", outstr),
+ ]
+ outlen[0] = self._npn_advertise_callback_args[0][0]
+ out[0] = self._npn_advertise_callback_args[1]
+ return 0
+
+ self._npn_advertise_callback = _ffi.callback(
+ "int (*)(SSL *, const unsigned char **, unsigned int *, void *)",
+ wrapper)
+ _lib.SSL_CTX_set_next_protos_advertised_cb(
+ self._context, self._npn_advertise_callback, _ffi.NULL)
+
+
+ def set_npn_select_callback(self, callback):
+ """
+ Specify a callback function that will be called when a server offers
+ Next Protocol Negotiation options.
+
+ :param callback: The callback function. It will be invoked with two
+ arguments: the Connection, and a list of offered protocols as
+ length-prefixed strings in a bytestring, e.g.
+ b'\\x08http/1.1\\x06spdy/2'. It should return one of those
+ bytestrings, the chosen protocol.
+ """
+ @wraps(callback)
+ def wrapper(ssl, out, outlen, in_, inlen, arg):
+ outstr = callback(
+ Connection._reverse_mapping[ssl], _ffi.string(in_))
+ self._npn_select_callback_args = [
+ _ffi.new("unsigned char *", len(outstr)),
+ _ffi.new("unsigned char[]", outstr),
+ ]
+ outlen[0] = self._npn_select_callback_args[0][0]
+ out[0] = self._npn_select_callback_args[1]
+ return 0
+
+ self._npn_select_callback = _ffi.callback(
+ "int (*)(SSL *, unsigned char **, unsigned char *, "
+ "const unsigned char *, unsigned int, void *)",
+ wrapper)
+ _lib.SSL_CTX_set_next_proto_select_cb(
+ self._context, self._npn_select_callback, _ffi.NULL)
+
ContextType = Context
@@ -1550,6 +1612,19 @@
version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher))
return version.decode("utf-8")
+ def get_next_proto_negotiated(self):
+ """
+ Get the protocol that was negotiated by NPN.
+ """
+ data = _ffi.new("unsigned char **")
+ data_len = _ffi.new("unsigned int *")
+
+ _lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len)
+
+ if not data_len[0]:
+ return ""
+ else:
+ return _ffi.string(data[0])
ConnectionType = Connection
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index 6409b8e..404f8b9 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -1434,6 +1434,84 @@
self.assertEqual([(server, b("foo1.example.com"))], args)
+class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
+ """
+ Test for Next Protocol Negotiation in PyOpenSSL.
+ """
+ def test_npn_success(self):
+ advertise_args =[]
+ select_args = []
+ def advertise(conn):
+ advertise_args.append((conn,))
+ return b('\x08http/1.1\x06spdy/2')
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b('spdy/2')
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_npn_advertise_callback(advertise)
+
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_npn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ self._interactInMemory(server, client)
+
+ self.assertEqual([(server,)], advertise_args)
+ self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args)
+
+ self.assertEqual(server.get_next_proto_negotiated(), b('spdy/2'))
+ self.assertEqual(client.get_next_proto_negotiated(), b('spdy/2'))
+
+
+ def test_npn_client_fail(self):
+ advertise_args =[]
+ select_args = []
+ def advertise(conn):
+ advertise_args.append((conn,))
+ return b('\x08http/1.1\x06spdy/2')
+ def select(conn, options):
+ select_args.append((conn, options))
+ return b('')
+
+ server_context = Context(TLSv1_METHOD)
+ server_context.set_npn_advertise_callback(advertise)
+
+ client_context = Context(TLSv1_METHOD)
+ client_context.set_npn_select_callback(select)
+
+ # Necessary to actually accept the connection
+ server_context.use_privatekey(
+ load_privatekey(FILETYPE_PEM, server_key_pem))
+ server_context.use_certificate(
+ load_certificate(FILETYPE_PEM, server_cert_pem))
+
+ # Do a little connection to trigger the logic
+ server = Connection(server_context, None)
+ server.set_accept_state()
+
+ client = Connection(client_context, None)
+ client.set_connect_state()
+
+ # If the client doesn't return anything, the connection will fail.
+ self.assertRaises(Error, self._interactInMemory, server, client)
+
+ self.assertEqual([(server,)], advertise_args)
+ self.assertEqual([(client, b('\x08http/1.1\x06spdy/2'))], select_args)
+
+
class SessionTests(TestCase):
"""