[3.7] bpo-34670: Add TLS 1.3 post handshake auth (GH-9460) (GH-9505)



Add SSLContext.post_handshake_auth and
SSLSocket.verify_client_post_handshake for TLS 1.3 post-handshake
authentication.

Signed-off-by: Christian Heimes <christian@python.org>q

https://bugs.python.org/issue34670.
(cherry picked from commit 9fb051f032c36b9f6086b79086b4d6b7755a3d70)

Co-authored-by: Christian Heimes <christian@python.org>



https://bugs.python.org/issue34670
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 4d5925a..d4132c5 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -234,7 +234,7 @@
 
     server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     server_context.load_cert_chain(server_cert)
-    client_context.load_verify_locations(SIGNING_CA)
+    server_context.load_verify_locations(SIGNING_CA)
 
     return client_context, server_context, hostname
 
@@ -2282,6 +2282,23 @@
                             sys.stdout.write(" server: read CB tls-unique from client, sending our CB data...\n")
                         data = self.sslconn.get_channel_binding("tls-unique")
                         self.write(repr(data).encode("us-ascii") + b"\n")
+                    elif stripped == b'PHA':
+                        if support.verbose and self.server.connectionchatty:
+                            sys.stdout.write(" server: initiating post handshake auth\n")
+                        try:
+                            self.sslconn.verify_client_post_handshake()
+                        except ssl.SSLError as e:
+                            self.write(repr(e).encode("us-ascii") + b"\n")
+                        else:
+                            self.write(b"OK\n")
+                    elif stripped == b'HASCERT':
+                        if self.sslconn.getpeercert() is not None:
+                            self.write(b'TRUE\n')
+                        else:
+                            self.write(b'FALSE\n')
+                    elif stripped == b'GETCERT':
+                        cert = self.sslconn.getpeercert()
+                        self.write(repr(cert).encode("us-ascii") + b"\n")
                     else:
                         if (support.verbose and
                             self.server.connectionchatty):
@@ -4175,6 +4192,179 @@
                                  'Session refers to a different SSLContext.')
 
 
+@unittest.skipUnless(ssl.HAS_TLSv1_3, "Test needs TLS 1.3")
+class TestPostHandshakeAuth(unittest.TestCase):
+    def test_pha_setter(self):
+        protocols = [
+            ssl.PROTOCOL_TLS, ssl.PROTOCOL_TLS_SERVER, ssl.PROTOCOL_TLS_CLIENT
+        ]
+        for protocol in protocols:
+            ctx = ssl.SSLContext(protocol)
+            self.assertEqual(ctx.post_handshake_auth, False)
+
+            ctx.post_handshake_auth = True
+            self.assertEqual(ctx.post_handshake_auth, True)
+
+            ctx.verify_mode = ssl.CERT_REQUIRED
+            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
+            self.assertEqual(ctx.post_handshake_auth, True)
+
+            ctx.post_handshake_auth = False
+            self.assertEqual(ctx.verify_mode, ssl.CERT_REQUIRED)
+            self.assertEqual(ctx.post_handshake_auth, False)
+
+            ctx.verify_mode = ssl.CERT_OPTIONAL
+            ctx.post_handshake_auth = True
+            self.assertEqual(ctx.verify_mode, ssl.CERT_OPTIONAL)
+            self.assertEqual(ctx.post_handshake_auth, True)
+
+    def test_pha_required(self):
+        client_context, server_context, hostname = testing_context()
+        server_context.post_handshake_auth = True
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        client_context.post_handshake_auth = True
+        client_context.load_cert_chain(SIGNED_CERTFILE)
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'FALSE\n')
+                s.write(b'PHA')
+                self.assertEqual(s.recv(1024), b'OK\n')
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'TRUE\n')
+                # PHA method just returns true when cert is already available
+                s.write(b'PHA')
+                self.assertEqual(s.recv(1024), b'OK\n')
+                s.write(b'GETCERT')
+                cert_text = s.recv(4096).decode('us-ascii')
+                self.assertIn('Python Software Foundation CA', cert_text)
+
+    def test_pha_required_nocert(self):
+        client_context, server_context, hostname = testing_context()
+        server_context.post_handshake_auth = True
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        client_context.post_handshake_auth = True
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                s.write(b'PHA')
+                # receive CertificateRequest
+                self.assertEqual(s.recv(1024), b'OK\n')
+                # send empty Certificate + Finish
+                s.write(b'HASCERT')
+                # receive alert
+                with self.assertRaisesRegex(
+                        ssl.SSLError,
+                        'tlsv13 alert certificate required'):
+                    s.recv(1024)
+
+    def test_pha_optional(self):
+        if support.verbose:
+            sys.stdout.write("\n")
+
+        client_context, server_context, hostname = testing_context()
+        server_context.post_handshake_auth = True
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        client_context.post_handshake_auth = True
+        client_context.load_cert_chain(SIGNED_CERTFILE)
+
+        # check CERT_OPTIONAL
+        server_context.verify_mode = ssl.CERT_OPTIONAL
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'FALSE\n')
+                s.write(b'PHA')
+                self.assertEqual(s.recv(1024), b'OK\n')
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'TRUE\n')
+
+    def test_pha_optional_nocert(self):
+        if support.verbose:
+            sys.stdout.write("\n")
+
+        client_context, server_context, hostname = testing_context()
+        server_context.post_handshake_auth = True
+        server_context.verify_mode = ssl.CERT_OPTIONAL
+        client_context.post_handshake_auth = True
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'FALSE\n')
+                s.write(b'PHA')
+                self.assertEqual(s.recv(1024), b'OK\n')
+                # optional doens't fail when client does not have a cert
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'FALSE\n')
+
+    def test_pha_no_pha_client(self):
+        client_context, server_context, hostname = testing_context()
+        server_context.post_handshake_auth = True
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        client_context.load_cert_chain(SIGNED_CERTFILE)
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                with self.assertRaisesRegex(ssl.SSLError, 'not server'):
+                    s.verify_client_post_handshake()
+                s.write(b'PHA')
+                self.assertIn(b'extension not received', s.recv(1024))
+
+    def test_pha_no_pha_server(self):
+        # server doesn't have PHA enabled, cert is requested in handshake
+        client_context, server_context, hostname = testing_context()
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        client_context.post_handshake_auth = True
+        client_context.load_cert_chain(SIGNED_CERTFILE)
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'TRUE\n')
+                # PHA doesn't fail if there is already a cert
+                s.write(b'PHA')
+                self.assertEqual(s.recv(1024), b'OK\n')
+                s.write(b'HASCERT')
+                self.assertEqual(s.recv(1024), b'TRUE\n')
+
+    def test_pha_not_tls13(self):
+        # TLS 1.2
+        client_context, server_context, hostname = testing_context()
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        client_context.maximum_version = ssl.TLSVersion.TLSv1_2
+        client_context.post_handshake_auth = True
+        client_context.load_cert_chain(SIGNED_CERTFILE)
+
+        server = ThreadedEchoServer(context=server_context, chatty=False)
+        with server:
+            with client_context.wrap_socket(socket.socket(),
+                                            server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                # PHA fails for TLS != 1.3
+                s.write(b'PHA')
+                self.assertIn(b'WRONG_SSL_VERSION', s.recv(1024))
+
+
 def test_main(verbose=False):
     if support.verbose:
         import warnings
@@ -4218,6 +4408,7 @@
     tests = [
         ContextTests, BasicSocketTests, SSLErrorTests, MemoryBIOTests,
         SSLObjectTests, SimpleBackgroundTests, ThreadedTests,
+        TestPostHandshakeAuth
     ]
 
     if support.is_resource_enabled('network'):