Allow accessing a connection's verfied certificate chain (#894)

* Allow accessing a connection's verfied certificate chain

Add X509StoreContext.get_verified_chain using X509_STORE_CTX_get1_chain.
Add Connection.get_verified_chain using SSL_get0_verified_chain if
available (ie OpenSSL 1.1+) and X509StoreContext.get_verified_chain
otherwise.
Fixes #740.

* TLSv1_METHOD -> SSLv23_METHOD

* Use X509_up_ref instead of X509_dup

* Add _openssl_assert where appropriate

* SSL_get_peer_cert_chain should not be null

* Reformat with black

* Fix <OpenSSL.crypto.X509 object at 0x7fdbb59e8050> != <OpenSSL.crypto.X509 object at 0x7fdbb59daad0>

* Add Changelog entry

* Remove _add_chain
diff --git a/CHANGELOG.rst b/CHANGELOG.rst
index d2c92e3..9ceedd0 100644
--- a/CHANGELOG.rst
+++ b/CHANGELOG.rst
@@ -21,14 +21,14 @@
 
 - Deprecated ``OpenSSL.crypto.loads_pkcs7`` and ``OpenSSL.crypto.loads_pkcs12``.
 
-*none*
-
-
 Changes:
 ^^^^^^^^
 
 - Added ``Context.set_keylog_callback`` to log key material.
   `#910 <https://github.com/pyca/pyopenssl/pull/910>`_
+- Added ``OpenSSL.SSL.Connection.get_verified_chain`` to retrieve the
+  verified certificate chain of the peer.
+  `#894 <https://github.com/pyca/pyopenssl/pull/894>`_.
 
 
 19.1.0 (2019-11-18)
diff --git a/src/OpenSSL/SSL.py b/src/OpenSSL/SSL.py
index 8a54994..29e489a 100644
--- a/src/OpenSSL/SSL.py
+++ b/src/OpenSSL/SSL.py
@@ -28,6 +28,7 @@
     X509Name,
     X509,
     X509Store,
+    X509StoreContext,
 )
 
 __all__ = [
@@ -2126,6 +2127,22 @@
             return X509._from_raw_x509_ptr(cert)
         return None
 
+    @staticmethod
+    def _cert_stack_to_list(cert_stack):
+        """
+        Internal helper to convert a STACK_OF(X509) to a list of X509
+        instances.
+        """
+        result = []
+        for i in range(_lib.sk_X509_num(cert_stack)):
+            cert = _lib.sk_X509_value(cert_stack, i)
+            _openssl_assert(cert != _ffi.NULL)
+            res = _lib.X509_up_ref(cert)
+            _openssl_assert(res >= 1)
+            pycert = X509._from_raw_x509_ptr(cert)
+            result.append(pycert)
+        return result
+
     def get_peer_cert_chain(self):
         """
         Retrieve the other side's certificate (if any)
@@ -2137,13 +2154,43 @@
         if cert_stack == _ffi.NULL:
             return None
 
-        result = []
-        for i in range(_lib.sk_X509_num(cert_stack)):
-            # TODO could incref instead of dup here
-            cert = _lib.X509_dup(_lib.sk_X509_value(cert_stack, i))
-            pycert = X509._from_raw_x509_ptr(cert)
-            result.append(pycert)
-        return result
+        return self._cert_stack_to_list(cert_stack)
+
+    def get_verified_chain(self):
+        """
+        Retrieve the verified certificate chain of the peer including the
+        peer's end entity certificate. It must be called after a session has
+        been successfully established. If peer verification was not successful
+        the chain may be incomplete, invalid, or None.
+
+        :return: A list of X509 instances giving the peer's verified
+                 certificate chain, or None if it does not have one.
+
+        .. versionadded:: 20.0
+        """
+        if hasattr(_lib, "SSL_get0_verified_chain"):
+            # OpenSSL 1.1+
+            cert_stack = _lib.SSL_get0_verified_chain(self._ssl)
+            if cert_stack == _ffi.NULL:
+                return None
+
+            return self._cert_stack_to_list(cert_stack)
+
+        pycert = self.get_peer_certificate()
+        if pycert is None:
+            return None
+
+        # Should never be NULL because the peer presented a certificate.
+        cert_stack = _lib.SSL_get_peer_cert_chain(self._ssl)
+        _openssl_assert(cert_stack != _ffi.NULL)
+
+        pystore = self._context.get_cert_store()
+        if pystore is None:
+            return None
+
+        pystorectx = X509StoreContext(pystore, pycert)
+        pystorectx._chain = cert_stack
+        return pystorectx.get_verified_chain()
 
     def want_read(self):
         """
diff --git a/src/OpenSSL/crypto.py b/src/OpenSSL/crypto.py
index 1b1e93e..79307b8 100644
--- a/src/OpenSSL/crypto.py
+++ b/src/OpenSSL/crypto.py
@@ -1712,6 +1712,7 @@
         self._store_ctx = _ffi.gc(store_ctx, _lib.X509_STORE_CTX_free)
         self._store = store
         self._cert = certificate
+        self._chain = _ffi.NULL
         # Make the store context available for use after instantiating this
         # class by initializing it now. Per testing, subsequent calls to
         # :meth:`_init` have no adverse affect.
@@ -1725,7 +1726,7 @@
         :meth:`_cleanup` will leak memory.
         """
         ret = _lib.X509_STORE_CTX_init(
-            self._store_ctx, self._store._store, self._cert._x509, _ffi.NULL
+            self._store_ctx, self._store._store, self._cert._x509, self._chain
         )
         if ret <= 0:
             _raise_current_error()
@@ -1797,6 +1798,45 @@
         if ret <= 0:
             raise self._exception_from_context()
 
+    def get_verified_chain(self):
+        """
+        Verify a certificate in a context and return the complete validated
+        chain.
+
+        :raises X509StoreContextError: If an error occurred when validating a
+          certificate in the context. Sets ``certificate`` attribute to
+          indicate which certificate caused the error.
+
+        .. versionadded:: 20.0
+        """
+        # Always re-initialize the store context in case
+        # :meth:`verify_certificate` is called multiple times.
+        #
+        # :meth:`_init` is called in :meth:`__init__` so _cleanup is called
+        # before _init to ensure memory is not leaked.
+        self._cleanup()
+        self._init()
+        ret = _lib.X509_verify_cert(self._store_ctx)
+        if ret <= 0:
+            self._cleanup()
+            raise self._exception_from_context()
+
+        # Note: X509_STORE_CTX_get1_chain returns a deep copy of the chain.
+        cert_stack = _lib.X509_STORE_CTX_get1_chain(self._store_ctx)
+        _openssl_assert(cert_stack != _ffi.NULL)
+
+        result = []
+        for i in range(_lib.sk_X509_num(cert_stack)):
+            cert = _lib.sk_X509_value(cert_stack, i)
+            _openssl_assert(cert != _ffi.NULL)
+            pycert = X509._from_raw_x509_ptr(cert)
+            result.append(pycert)
+
+        # Free the stack but not the members which are freed by the X509 class.
+        _lib.sk_X509_free(cert_stack)
+        self._cleanup()
+        return result
+
 
 def load_certificate(type, buffer):
     """
diff --git a/tests/test_crypto.py b/tests/test_crypto.py
index 3802d9a..ac4e729 100644
--- a/tests/test_crypto.py
+++ b/tests/test_crypto.py
@@ -3849,6 +3849,41 @@
 
         assert exc.value.args[0][2] == "certificate has expired"
 
+    def test_get_verified_chain(self):
+        """
+        `get_verified_chain` returns the verified chain.
+        """
+        store = X509Store()
+        store.add_cert(self.root_cert)
+        store.add_cert(self.intermediate_cert)
+        store_ctx = X509StoreContext(store, self.intermediate_server_cert)
+        chain = store_ctx.get_verified_chain()
+        assert len(chain) == 3
+        intermediate_subject = self.intermediate_server_cert.get_subject()
+        assert chain[0].get_subject() == intermediate_subject
+        assert chain[1].get_subject() == self.intermediate_cert.get_subject()
+        assert chain[2].get_subject() == self.root_cert.get_subject()
+        # Test reuse
+        chain = store_ctx.get_verified_chain()
+        assert len(chain) == 3
+        assert chain[0].get_subject() == intermediate_subject
+        assert chain[1].get_subject() == self.intermediate_cert.get_subject()
+        assert chain[2].get_subject() == self.root_cert.get_subject()
+
+    def test_get_verified_chain_invalid_chain_no_root(self):
+        """
+        `get_verified_chain` raises error when cert verification fails.
+        """
+        store = X509Store()
+        store.add_cert(self.intermediate_cert)
+        store_ctx = X509StoreContext(store, self.intermediate_server_cert)
+
+        with pytest.raises(X509StoreContextError) as exc:
+            store_ctx.get_verified_chain()
+
+        assert exc.value.args[0][2] == "unable to get issuer certificate"
+        assert exc.value.certificate.get_subject().CN == "intermediate"
+
 
 class TestSignVerify(object):
     """
diff --git a/tests/test_ssl.py b/tests/test_ssl.py
index 7e28ab7..9f134b4 100644
--- a/tests/test_ssl.py
+++ b/tests/test_ssl.py
@@ -2445,6 +2445,63 @@
         interact_in_memory(client, server)
         assert None is server.get_peer_cert_chain()
 
+    def test_get_verified_chain(self):
+        """
+        `Connection.get_verified_chain` returns a list of certificates
+        which the connected server returned for the certification verification.
+        """
+        chain = _create_certificate_chain()
+        [(cakey, cacert), (ikey, icert), (skey, scert)] = chain
+
+        serverContext = Context(SSLv23_METHOD)
+        serverContext.use_privatekey(skey)
+        serverContext.use_certificate(scert)
+        serverContext.add_extra_chain_cert(icert)
+        serverContext.add_extra_chain_cert(cacert)
+        server = Connection(serverContext, None)
+        server.set_accept_state()
+
+        # Create the client
+        clientContext = Context(SSLv23_METHOD)
+        # cacert is self-signed so the client must trust it for verification
+        # to succeed.
+        clientContext.get_cert_store().add_cert(cacert)
+        clientContext.set_verify(VERIFY_PEER, verify_cb)
+        client = Connection(clientContext, None)
+        client.set_connect_state()
+
+        interact_in_memory(client, server)
+
+        chain = client.get_verified_chain()
+        assert len(chain) == 3
+        assert "Server Certificate" == chain[0].get_subject().CN
+        assert "Intermediate Certificate" == chain[1].get_subject().CN
+        assert "Authority Certificate" == chain[2].get_subject().CN
+
+    def test_get_verified_chain_none(self):
+        """
+        `Connection.get_verified_chain` returns `None` if the peer sends
+        no certificate chain.
+        """
+        ctx = Context(SSLv23_METHOD)
+        ctx.use_privatekey(load_privatekey(FILETYPE_PEM, server_key_pem))
+        ctx.use_certificate(load_certificate(FILETYPE_PEM, server_cert_pem))
+        server = Connection(ctx, None)
+        server.set_accept_state()
+        client = Connection(Context(SSLv23_METHOD), None)
+        client.set_connect_state()
+        interact_in_memory(client, server)
+        assert None is server.get_verified_chain()
+
+    def test_get_verified_chain_unconnected(self):
+        """
+        `Connection.get_verified_chain` returns `None` when used with an object
+        which has not been connected.
+        """
+        ctx = Context(SSLv23_METHOD)
+        server = Connection(ctx, None)
+        assert None is server.get_verified_chain()
+
     def test_get_session_unconnected(self):
         """
         `Connection.get_session` returns `None` when used with an object