Merge pull request #200 from exarkun/npn

Add Lukasa's next protocol negtiation API/implementation.
diff --git a/ChangeLog b/ChangeLog
index 44a6944..e872803 100644
--- a/ChangeLog
+++ b/ChangeLog
@@ -1,3 +1,8 @@
+2015-03-23  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
+
+	* OpenSSL/SSL.py: Add Cory Benfield's next-protocol-negotiation
+	  (NPN) bindings.
+
 2015-03-15  Jean-Paul Calderone  <exarkun@twistedmatrix.com>
 
 	* OpenSSL/SSL.py: Add ``Connection.recv_into``, mirroring the
diff --git a/OpenSSL/SSL.py b/OpenSSL/SSL.py
index e86d855..6d00e13 100644
--- a/OpenSSL/SSL.py
+++ b/OpenSSL/SSL.py
@@ -1,11 +1,12 @@
 from sys import platform
 from functools import wraps, partial
-from itertools import count
+from itertools import count, chain
 from weakref import WeakValueDictionary
 from errno import errorcode
 
 from six import text_type as _text_type
 from six import integer_types as integer_types
+from six import int2byte, indexbytes
 
 from OpenSSL._util import (
     ffi as _ffi,
@@ -164,11 +165,42 @@
     pass
 
 
+class _CallbackExceptionHelper(object):
+    """
+    A base class for wrapper classes that allow for intelligent exception
+    handling in OpenSSL callbacks.
 
-class _VerifyHelper(object):
-    def __init__(self, callback):
+    :ivar list _problems: Any exceptions that occurred while executing in a
+        context where they could not be raised in the normal way.  Typically
+        this is because OpenSSL has called into some Python code and requires a
+        return value.  The exceptions are saved to be raised later when it is
+        possible to do so.
+    """
+    def __init__(self):
         self._problems = []
 
+
+    def raise_if_problem(self):
+        """
+        Raise an exception from the OpenSSL error queue or that was previously
+        captured whe running a callback.
+        """
+        if self._problems:
+            try:
+                _raise_current_error()
+            except Error:
+                pass
+            raise self._problems.pop(0)
+
+
+class _VerifyHelper(_CallbackExceptionHelper):
+    """
+    Wrap a callback such that it can be used as a certificate verification
+    callback.
+    """
+    def __init__(self, callback):
+        _CallbackExceptionHelper.__init__(self)
+
         @wraps(callback)
         def wrapper(ok, store_ctx):
             cert = X509.__new__(X509)
@@ -196,14 +228,92 @@
             "int (*)(int, X509_STORE_CTX *)", wrapper)
 
 
-    def raise_if_problem(self):
-        if self._problems:
-            try:
-                _raise_current_error()
-            except Error:
-                pass
-            raise self._problems.pop(0)
+class _NpnAdvertiseHelper(_CallbackExceptionHelper):
+    """
+    Wrap a callback such that it can be used as an NPN advertisement callback.
+    """
+    def __init__(self, callback):
+        _CallbackExceptionHelper.__init__(self)
 
+        @wraps(callback)
+        def wrapper(ssl, out, outlen, arg):
+            try:
+                conn = Connection._reverse_mapping[ssl]
+                protos = callback(conn)
+
+                # Join the protocols into a Python bytestring, length-prefixing
+                # each element.
+                protostr = b''.join(
+                    chain.from_iterable((int2byte(len(p)), p) for p in protos)
+                )
+
+                # Save our callback arguments on the connection object. This is
+                # done to make sure that they don't get freed before OpenSSL
+                # uses them. Then, return them appropriately in the output
+                # parameters.
+                conn._npn_advertise_callback_args = [
+                    _ffi.new("unsigned int *", len(protostr)),
+                    _ffi.new("unsigned char[]", protostr),
+                ]
+                outlen[0] = conn._npn_advertise_callback_args[0][0]
+                out[0] = conn._npn_advertise_callback_args[1]
+                return 0
+            except Exception as e:
+                self._problems.append(e)
+                return 2  # SSL_TLSEXT_ERR_ALERT_FATAL
+
+        self.callback = _ffi.callback(
+            "int (*)(SSL *, const unsigned char **, unsigned int *, void *)",
+            wrapper
+        )
+
+
+class _NpnSelectHelper(_CallbackExceptionHelper):
+    """
+    Wrap a callback such that it can be used as an NPN selection callback.
+    """
+    def __init__(self, callback):
+        _CallbackExceptionHelper.__init__(self)
+
+        @wraps(callback)
+        def wrapper(ssl, out, outlen, in_, inlen, arg):
+            try:
+                conn = Connection._reverse_mapping[ssl]
+
+                # The string passed to us is actually made up of multiple
+                # length-prefixed bytestrings. We need to split that into a
+                # list.
+                instr = _ffi.buffer(in_, inlen)[:]
+                protolist = []
+                while instr:
+                    l = indexbytes(instr, 0)
+                    proto = instr[1:l+1]
+                    protolist.append(proto)
+                    instr = instr[l+1:]
+
+                # Call the callback
+                outstr = callback(conn, protolist)
+
+                # Save our callback arguments on the connection object. This is
+                # done to make sure that they don't get freed before OpenSSL
+                # uses them. Then, return them appropriately in the output
+                # parameters.
+                conn._npn_select_callback_args = [
+                    _ffi.new("unsigned char *", len(outstr)),
+                    _ffi.new("unsigned char[]", outstr),
+                ]
+                outlen[0] = conn._npn_select_callback_args[0][0]
+                out[0] = conn._npn_select_callback_args[1]
+                return 0
+            except Exception as e:
+                self._problems.append(e)
+                return 2  # SSL_TLSEXT_ERR_ALERT_FATAL
+
+        self.callback = _ffi.callback(
+            "int (*)(SSL *, unsigned char **, unsigned char *, "
+                    "const unsigned char *, unsigned int, void *)",
+            wrapper
+        )
 
 
 def _asFileDescriptor(obj):
@@ -293,6 +403,10 @@
         self._info_callback = None
         self._tlsext_servername_callback = None
         self._app_data = None
+        self._npn_advertise_helper = None
+        self._npn_advertise_callback = None
+        self._npn_select_helper = None
+        self._npn_select_callback = None
 
         # SSL_CTX_set_app_data(self->ctx, self);
         # SSL_CTX_set_mode(self->ctx, SSL_MODE_ENABLE_PARTIAL_WRITE |
@@ -812,6 +926,39 @@
         _lib.SSL_CTX_set_tlsext_servername_callback(
             self._context, self._tlsext_servername_callback)
 
+
+    def set_npn_advertise_callback(self, callback):
+        """
+        Specify a callback function that will be called when offering `Next
+        Protocol Negotiation
+        <https://technotes.googlecode.com/git/nextprotoneg.html>`_ as a server.
+
+        :param callback: The callback function.  It will be invoked with one
+            argument, the Connection instance.  It should return a list of
+            bytestrings representing the advertised protocols, like
+            ``[b'http/1.1', b'spdy/2']``.
+        """
+        self._npn_advertise_helper = _NpnAdvertiseHelper(callback)
+        self._npn_advertise_callback = self._npn_advertise_helper.callback
+        _lib.SSL_CTX_set_next_protos_advertised_cb(
+            self._context, self._npn_advertise_callback, _ffi.NULL)
+
+
+    def set_npn_select_callback(self, callback):
+        """
+        Specify a callback function that will be called when a server offers
+        Next Protocol Negotiation options.
+
+        :param callback: The callback function.  It will be invoked with two
+            arguments: the Connection, and a list of offered protocols as
+            bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``.  It should return
+            one of those bytestrings, the chosen protocol.
+        """
+        self._npn_select_helper = _NpnSelectHelper(callback)
+        self._npn_select_callback = self._npn_select_helper.callback
+        _lib.SSL_CTX_set_next_proto_select_cb(
+            self._context, self._npn_select_callback, _ffi.NULL)
+
 ContextType = Context
 
 
@@ -836,6 +983,13 @@
         self._ssl = _ffi.gc(ssl, _lib.SSL_free)
         self._context = context
 
+        # References to strings used for Next Protocol Negotiation. OpenSSL's
+        # header files suggest that these might get copied at some point, but
+        # doesn't specify when, so we store them here to make sure they don't
+        # get freed before OpenSSL uses them.
+        self._npn_advertise_callback_args = None
+        self._npn_select_callback_args = None
+
         self._reverse_mapping[self._ssl] = self
 
         if socket is None:
@@ -870,6 +1024,10 @@
     def _raise_ssl_error(self, ssl, result):
         if self._context._verify_helper is not None:
             self._context._verify_helper.raise_if_problem()
+        if self._context._npn_advertise_helper is not None:
+            self._context._npn_advertise_helper.raise_if_problem()
+        if self._context._npn_select_helper is not None:
+            self._context._npn_select_helper.raise_if_problem()
 
         error = _lib.SSL_get_error(ssl, result)
         if error == _lib.SSL_ERROR_WANT_READ:
@@ -1591,6 +1749,16 @@
             version =_ffi.string(_lib.SSL_CIPHER_get_version(cipher))
             return version.decode("utf-8")
 
+    def get_next_proto_negotiated(self):
+        """
+        Get the protocol that was negotiated by NPN.
+        """
+        data = _ffi.new("unsigned char **")
+        data_len = _ffi.new("unsigned int *")
+
+        _lib.SSL_get0_next_proto_negotiated(self._ssl, data, data_len)
+
+        return _ffi.buffer(data[0], data_len[0])[:]
 
 
 ConnectionType = Connection
diff --git a/OpenSSL/test/test_ssl.py b/OpenSSL/test/test_ssl.py
index aa07e1b..188d559 100644
--- a/OpenSSL/test/test_ssl.py
+++ b/OpenSSL/test/test_ssl.py
@@ -1471,6 +1471,165 @@
         self.assertEqual([(server, b("foo1.example.com"))], args)
 
 
+class NextProtoNegotiationTests(TestCase, _LoopbackMixin):
+    """
+    Test for Next Protocol Negotiation in PyOpenSSL.
+    """
+    def test_npn_success(self):
+        """
+        Tests that clients and servers that agree on the negotiated next
+        protocol can correct establish a connection, and that the agreed
+        protocol is reported by the connections.
+        """
+        advertise_args = []
+        select_args = []
+        def advertise(conn):
+            advertise_args.append((conn,))
+            return [b'http/1.1', b'spdy/2']
+        def select(conn, options):
+            select_args.append((conn, options))
+            return b'spdy/2'
+
+        server_context = Context(TLSv1_METHOD)
+        server_context.set_npn_advertise_callback(advertise)
+
+        client_context = Context(TLSv1_METHOD)
+        client_context.set_npn_select_callback(select)
+
+        # Necessary to actually accept the connection
+        server_context.use_privatekey(
+            load_privatekey(FILETYPE_PEM, server_key_pem))
+        server_context.use_certificate(
+            load_certificate(FILETYPE_PEM, server_cert_pem))
+
+        # Do a little connection to trigger the logic
+        server = Connection(server_context, None)
+        server.set_accept_state()
+
+        client = Connection(client_context, None)
+        client.set_connect_state()
+
+        self._interactInMemory(server, client)
+
+        self.assertEqual([(server,)], advertise_args)
+        self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
+
+        self.assertEqual(server.get_next_proto_negotiated(), b'spdy/2')
+        self.assertEqual(client.get_next_proto_negotiated(), b'spdy/2')
+
+
+    def test_npn_client_fail(self):
+        """
+        Tests that when clients and servers cannot agree on what protocol to
+        use next that the TLS connection does not get established.
+        """
+        advertise_args = []
+        select_args = []
+        def advertise(conn):
+            advertise_args.append((conn,))
+            return [b'http/1.1', b'spdy/2']
+        def select(conn, options):
+            select_args.append((conn, options))
+            return b''
+
+        server_context = Context(TLSv1_METHOD)
+        server_context.set_npn_advertise_callback(advertise)
+
+        client_context = Context(TLSv1_METHOD)
+        client_context.set_npn_select_callback(select)
+
+        # Necessary to actually accept the connection
+        server_context.use_privatekey(
+            load_privatekey(FILETYPE_PEM, server_key_pem))
+        server_context.use_certificate(
+            load_certificate(FILETYPE_PEM, server_cert_pem))
+
+        # Do a little connection to trigger the logic
+        server = Connection(server_context, None)
+        server.set_accept_state()
+
+        client = Connection(client_context, None)
+        client.set_connect_state()
+
+        # If the client doesn't return anything, the connection will fail.
+        self.assertRaises(Error, self._interactInMemory, server, client)
+
+        self.assertEqual([(server,)], advertise_args)
+        self.assertEqual([(client, [b'http/1.1', b'spdy/2'])], select_args)
+
+
+    def test_npn_select_error(self):
+        """
+        Test that we can handle exceptions in the select callback. If select
+        fails it should be fatal to the connection.
+        """
+        advertise_args = []
+        def advertise(conn):
+            advertise_args.append((conn,))
+            return [b'http/1.1', b'spdy/2']
+        def select(conn, options):
+            raise TypeError
+
+        server_context = Context(TLSv1_METHOD)
+        server_context.set_npn_advertise_callback(advertise)
+
+        client_context = Context(TLSv1_METHOD)
+        client_context.set_npn_select_callback(select)
+
+        # Necessary to actually accept the connection
+        server_context.use_privatekey(
+            load_privatekey(FILETYPE_PEM, server_key_pem))
+        server_context.use_certificate(
+            load_certificate(FILETYPE_PEM, server_cert_pem))
+
+        # Do a little connection to trigger the logic
+        server = Connection(server_context, None)
+        server.set_accept_state()
+
+        client = Connection(client_context, None)
+        client.set_connect_state()
+
+        # If the callback throws an exception it should be raised here.
+        self.assertRaises(TypeError, self._interactInMemory, server, client)
+        self.assertEqual([(server,)], advertise_args)
+
+
+    def test_npn_advertise_error(self):
+        """
+        Test that we can handle exceptions in the advertise callback. If
+        advertise fails no NPN is advertised to the client.
+        """
+        select_args = []
+        def advertise(conn):
+            raise TypeError
+        def select(conn, options):
+            select_args.append((conn, options))
+            return b''
+
+        server_context = Context(TLSv1_METHOD)
+        server_context.set_npn_advertise_callback(advertise)
+
+        client_context = Context(TLSv1_METHOD)
+        client_context.set_npn_select_callback(select)
+
+        # Necessary to actually accept the connection
+        server_context.use_privatekey(
+            load_privatekey(FILETYPE_PEM, server_key_pem))
+        server_context.use_certificate(
+            load_certificate(FILETYPE_PEM, server_cert_pem))
+
+        # Do a little connection to trigger the logic
+        server = Connection(server_context, None)
+        server.set_accept_state()
+
+        client = Connection(client_context, None)
+        client.set_connect_state()
+
+        # If the client doesn't return anything, the connection will fail.
+        self.assertRaises(TypeError, self._interactInMemory, server, client)
+        self.assertEqual([], select_args)
+
+
 
 class SessionTests(TestCase):
     """
diff --git a/doc/api/ssl.rst b/doc/api/ssl.rst
index a3265b0..e6a0775 100644
--- a/doc/api/ssl.rst
+++ b/doc/api/ssl.rst
@@ -472,6 +472,33 @@
     .. versionadded:: 0.13
 
 
+.. py:method:: Context.set_npn_advertise_callback(callback)
+
+    Specify a callback function that will be called when offering `Next
+    Protocol Negotiation
+    <https://technotes.googlecode.com/git/nextprotoneg.html>`_ as a server.
+
+    *callback* should be the callback function.  It will be invoked with one
+    argument, the :py:class:`Connection` instance.  It should return a list of
+    bytestrings representing the advertised protocols, like
+    ``[b'http/1.1', b'spdy/2']``.
+
+    .. versionadded:: 0.15
+
+
+.. py:method:: Context.set_npn_select_callback(callback):
+
+    Specify a callback function that will be called when a server offers Next
+    Protocol Negotiation options.
+
+    *callback* should be the callback function.  It will be invoked with two
+    arguments: the :py:class:`Connection`, and a list of offered protocols as
+    bytestrings, e.g. ``[b'http/1.1', b'spdy/2']``.  It should return one of
+    those bytestrings, the chosen protocol.
+
+    .. versionadded:: 0.15
+
+
 .. _openssl-session:
 
 Session objects
@@ -814,6 +841,15 @@
     .. versionadded:: 0.15
 
 
+.. py:method:: Connection.get_next_proto_negotiated():
+
+    Get the protocol that was negotiated by Next Protocol Negotiation. Returns
+    a bytestring of the protocol name. If no protocol has been negotiated yet,
+    returns an empty string.
+
+    .. versionadded:: 0.15
+
+
 .. Rubric:: Footnotes
 
 .. [#connection-context-socket] Actually, all that is required is an object that