bpo-32947: OpenSSL 1.1.1-pre1 / TLS 1.3 fixes (GH-5663)


* bpo-32947: OpenSSL 1.1.1-pre1 / TLS 1.3 fixes

Misc fixes and workarounds for compatibility with OpenSSL 1.1.1-pre1 and
TLS 1.3 support. With OpenSSL 1.1.1, Python negotiates TLS 1.3 by
default. Some test cases only apply to TLS 1.2. Other tests currently
fail because the threaded or async test servers stop after failure.

I'm going to address these issues when OpenSSL 1.1.1 reaches beta.

OpenSSL 1.1.1 has added a new option OP_ENABLE_MIDDLEBOX_COMPAT for TLS
1.3. The feature is enabled by default for maximum compatibility with
broken middle boxes. Users should be able to disable the hack and CPython's test suite needs
it to verify default options.

Signed-off-by: Christian Heimes <christian@python.org>
(cherry picked from commit 05d9fe32a1245b9a798e49e0c1eb91f110935b69)

Co-authored-by: Christian Heimes <christian@python.org>
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 3b34fc0..d978a9e 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -30,7 +30,8 @@
 PROTOCOLS = sorted(ssl._PROTOCOL_NAMES)
 HOST = support.HOST
 IS_LIBRESSL = ssl.OPENSSL_VERSION.startswith('LibreSSL')
-IS_OPENSSL_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
+IS_OPENSSL_1_1_0 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 0)
+IS_OPENSSL_1_1_1 = not IS_LIBRESSL and ssl.OPENSSL_VERSION_INFO >= (1, 1, 1)
 PY_SSL_DEFAULT_CIPHERS = sysconfig.get_config_var('PY_SSL_DEFAULT_CIPHERS')
 
 def data_file(*name):
@@ -54,6 +55,7 @@
 BYTES_CAPATH = os.fsencode(CAPATH)
 CAFILE_NEURONIO = data_file("capath", "4e1295a3.0")
 CAFILE_CACERT = data_file("capath", "5ed36f99.0")
+WRONG_CERT = data_file("wrongcert.pem")
 
 CERTFILE_INFO = {
     'issuer': ((('countryName', 'XY'),),
@@ -124,6 +126,7 @@
 OP_SINGLE_DH_USE = getattr(ssl, "OP_SINGLE_DH_USE", 0)
 OP_SINGLE_ECDH_USE = getattr(ssl, "OP_SINGLE_ECDH_USE", 0)
 OP_CIPHER_SERVER_PREFERENCE = getattr(ssl, "OP_CIPHER_SERVER_PREFERENCE", 0)
+OP_ENABLE_MIDDLEBOX_COMPAT = getattr(ssl, "OP_ENABLE_MIDDLEBOX_COMPAT", 0)
 
 
 def handle_error(prefix):
@@ -232,6 +235,7 @@
 
     server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
     server_context.load_cert_chain(server_cert)
+    client_context.load_verify_locations(SIGNING_CA)
 
     return client_context, server_context, hostname
 
@@ -1016,7 +1020,8 @@
         default = (ssl.OP_ALL | ssl.OP_NO_SSLv2 | ssl.OP_NO_SSLv3)
         # SSLContext also enables these by default
         default |= (OP_NO_COMPRESSION | OP_CIPHER_SERVER_PREFERENCE |
-                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE)
+                    OP_SINGLE_DH_USE | OP_SINGLE_ECDH_USE |
+                    OP_ENABLE_MIDDLEBOX_COMPAT)
         self.assertEqual(default, ctx.options)
         ctx.options |= ssl.OP_NO_TLSv1
         self.assertEqual(default | ssl.OP_NO_TLSv1, ctx.options)
@@ -1778,6 +1783,8 @@
         der = ssl.PEM_cert_to_DER_cert(pem)
         ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
         ctx.verify_mode = ssl.CERT_REQUIRED
+        # TODO: fix TLSv1.3 support
+        ctx.options |= ssl.OP_NO_TLSv1_3
         ctx.load_verify_locations(cadata=pem)
         with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
             s.connect(self.server_addr)
@@ -1787,6 +1794,8 @@
         # same with DER
         ctx = ssl.SSLContext(ssl.PROTOCOL_TLS)
         ctx.verify_mode = ssl.CERT_REQUIRED
+        # TODO: fix TLSv1.3 support
+        ctx.options |= ssl.OP_NO_TLSv1_3
         ctx.load_verify_locations(cadata=der)
         with ctx.wrap_socket(socket.socket(socket.AF_INET)) as s:
             s.connect(self.server_addr)
@@ -2629,7 +2638,10 @@
     def test_ecc_cert(self):
         client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         client_context.load_verify_locations(SIGNING_CA)
-        client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
+        client_context.set_ciphers(
+            'TLS13-AES-128-GCM-SHA256:TLS13-CHACHA20-POLY1305-SHA256:'
+            'ECDHE:ECDSA:!NULL:!aRSA'
+        )
         hostname = SIGNED_CERTFILE_ECC_HOSTNAME
 
         server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
@@ -2650,6 +2662,9 @@
     def test_dual_rsa_ecc(self):
         client_context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         client_context.load_verify_locations(SIGNING_CA)
+        # TODO: fix TLSv1.3 once SSLContext can restrict signature
+        #       algorithms.
+        client_context.options |= ssl.OP_NO_TLSv1_3
         # only ECDSA certs
         client_context.set_ciphers('ECDHE:ECDSA:!NULL:!aRSA')
         hostname = SIGNED_CERTFILE_ECC_HOSTNAME
@@ -2676,6 +2691,8 @@
 
         server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
         server_context.load_cert_chain(IDNSANSFILE)
+        # TODO: fix TLSv1.3 support
+        server_context.options |= ssl.OP_NO_TLSv1_3
 
         context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         context.verify_mode = ssl.CERT_REQUIRED
@@ -2738,15 +2755,22 @@
         Launch a server with CERT_REQUIRED, and check that trying to
         connect to it with a wrong client certificate fails.
         """
-        certfile = os.path.join(os.path.dirname(__file__) or os.curdir,
-                                   "wrongcert.pem")
-        server = ThreadedEchoServer(CERTFILE,
-                                    certreqs=ssl.CERT_REQUIRED,
-                                    cacerts=CERTFILE, chatty=False,
-                                    connectionchatty=False)
+        client_context, server_context, hostname = testing_context()
+        # load client cert
+        client_context.load_cert_chain(WRONG_CERT)
+        # require TLS client authentication
+        server_context.verify_mode = ssl.CERT_REQUIRED
+        # TODO: fix TLSv1.3 support
+        # With TLS 1.3, test fails with exception in server thread
+        server_context.options |= ssl.OP_NO_TLSv1_3
+
+        server = ThreadedEchoServer(
+            context=server_context, chatty=True, connectionchatty=True,
+        )
+
         with server, \
-                socket.socket() as sock, \
-                test_wrap_socket(sock, certfile=certfile) as s:
+                client_context.wrap_socket(socket.socket(),
+                                           server_hostname=hostname) as s:
             try:
                 # Expect either an SSL error about the server rejecting
                 # the connection, or a low-level connection reset (which
@@ -3400,7 +3424,9 @@
                 self.assertIs(s.version(), None)
                 self.assertIs(s._sslobj, None)
                 s.connect((HOST, server.port))
-                if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
+                if ssl.OPENSSL_VERSION_INFO >= (1, 1, 1):
+                    self.assertEqual(s.version(), 'TLSv1.3')
+                elif ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
                     self.assertEqual(s.version(), 'TLSv1.2')
                 else:  # 0.9.8 to 1.0.1
                     self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
@@ -3412,18 +3438,18 @@
     def test_tls1_3(self):
         context = ssl.SSLContext(ssl.PROTOCOL_TLS)
         context.load_cert_chain(CERTFILE)
-        # disable all but TLS 1.3
         context.options |= (
             ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1 | ssl.OP_NO_TLSv1_2
         )
         with ThreadedEchoServer(context=context) as server:
             with context.wrap_socket(socket.socket()) as s:
                 s.connect((HOST, server.port))
-                self.assertIn(s.cipher()[0], [
+                self.assertIn(s.cipher()[0], {
                     'TLS13-AES-256-GCM-SHA384',
                     'TLS13-CHACHA20-POLY1305-SHA256',
                     'TLS13-AES-128-GCM-SHA256',
-                ])
+                })
+                self.assertEqual(s.version(), 'TLSv1.3')
 
     @unittest.skipUnless(ssl.HAS_ECDH, "test requires ECDH-enabled OpenSSL")
     def test_default_ecdh_curve(self):
@@ -3452,58 +3478,54 @@
         if support.verbose:
             sys.stdout.write("\n")
 
-        server = ThreadedEchoServer(CERTFILE,
-                                    certreqs=ssl.CERT_NONE,
-                                    ssl_version=ssl.PROTOCOL_TLS_SERVER,
-                                    cacerts=CERTFILE,
+        client_context, server_context, hostname = testing_context()
+        # TODO: fix TLSv1.3 support
+        client_context.options |= ssl.OP_NO_TLSv1_3
+
+        server = ThreadedEchoServer(context=server_context,
                                     chatty=True,
                                     connectionchatty=False)
+
         with server:
-            s = test_wrap_socket(socket.socket(),
-                                server_side=False,
-                                certfile=CERTFILE,
-                                ca_certs=CERTFILE,
-                                cert_reqs=ssl.CERT_NONE,
-                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
-            s.connect((HOST, server.port))
-            # get the data
-            cb_data = s.get_channel_binding("tls-unique")
-            if support.verbose:
-                sys.stdout.write(" got channel binding data: {0!r}\n"
-                                 .format(cb_data))
+            with client_context.wrap_socket(
+                    socket.socket(),
+                    server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                # get the data
+                cb_data = s.get_channel_binding("tls-unique")
+                if support.verbose:
+                    sys.stdout.write(
+                        " got channel binding data: {0!r}\n".format(cb_data))
 
-            # check if it is sane
-            self.assertIsNotNone(cb_data)
-            self.assertEqual(len(cb_data), 12) # True for TLSv1
+                # check if it is sane
+                self.assertIsNotNone(cb_data)
+                self.assertEqual(len(cb_data), 12) # True for TLSv1
 
-            # and compare with the peers version
-            s.write(b"CB tls-unique\n")
-            peer_data_repr = s.read().strip()
-            self.assertEqual(peer_data_repr,
-                             repr(cb_data).encode("us-ascii"))
-            s.close()
+                # and compare with the peers version
+                s.write(b"CB tls-unique\n")
+                peer_data_repr = s.read().strip()
+                self.assertEqual(peer_data_repr,
+                                 repr(cb_data).encode("us-ascii"))
 
             # now, again
-            s = test_wrap_socket(socket.socket(),
-                                server_side=False,
-                                certfile=CERTFILE,
-                                ca_certs=CERTFILE,
-                                cert_reqs=ssl.CERT_NONE,
-                                ssl_version=ssl.PROTOCOL_TLS_CLIENT)
-            s.connect((HOST, server.port))
-            new_cb_data = s.get_channel_binding("tls-unique")
-            if support.verbose:
-                sys.stdout.write(" got another channel binding data: {0!r}\n"
-                                 .format(new_cb_data))
-            # is it really unique
-            self.assertNotEqual(cb_data, new_cb_data)
-            self.assertIsNotNone(cb_data)
-            self.assertEqual(len(cb_data), 12) # True for TLSv1
-            s.write(b"CB tls-unique\n")
-            peer_data_repr = s.read().strip()
-            self.assertEqual(peer_data_repr,
-                             repr(new_cb_data).encode("us-ascii"))
-            s.close()
+            with client_context.wrap_socket(
+                    socket.socket(),
+                    server_hostname=hostname) as s:
+                s.connect((HOST, server.port))
+                new_cb_data = s.get_channel_binding("tls-unique")
+                if support.verbose:
+                    sys.stdout.write(
+                        "got another channel binding data: {0!r}\n".format(
+                            new_cb_data)
+                    )
+                # is it really unique
+                self.assertNotEqual(cb_data, new_cb_data)
+                self.assertIsNotNone(cb_data)
+                self.assertEqual(len(cb_data), 12) # True for TLSv1
+                s.write(b"CB tls-unique\n")
+                peer_data_repr = s.read().strip()
+                self.assertEqual(peer_data_repr,
+                                 repr(new_cb_data).encode("us-ascii"))
 
     def test_compression(self):
         client_context, server_context, hostname = testing_context()
@@ -3528,8 +3550,11 @@
     def test_dh_params(self):
         # Check we can get a connection with ephemeral Diffie-Hellman
         client_context, server_context, hostname = testing_context()
+        # test scenario needs TLS <= 1.2
+        client_context.options |= ssl.OP_NO_TLSv1_3
         server_context.load_dh_params(DHFILE)
         server_context.set_ciphers("kEDH")
+        server_context.options |= ssl.OP_NO_TLSv1_3
         stats = server_params_test(client_context, server_context,
                                    chatty=True, connectionchatty=True,
                                    sni_name=hostname)
@@ -3539,9 +3564,11 @@
             self.fail("Non-DH cipher: " + cipher[0])
 
     @unittest.skipUnless(HAVE_SECP_CURVES, "needs secp384r1 curve support")
+    @unittest.skipIf(IS_OPENSSL_1_1_1, "TODO: Test doesn't work on 1.1.1")
     def test_ecdh_curve(self):
         # server secp384r1, client auto
         client_context, server_context, hostname = testing_context()
+
         server_context.set_ecdh_curve("secp384r1")
         server_context.set_ciphers("ECDHE:!eNULL:!aNULL")
         server_context.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
@@ -3572,7 +3599,7 @@
             pass
         else:
             # OpenSSL 1.0.2 does not fail although it should.
-            if IS_OPENSSL_1_1:
+            if IS_OPENSSL_1_1_0:
                 self.fail("mismatch curve did not fail")
 
     def test_selected_alpn_protocol(self):
@@ -3616,7 +3643,7 @@
             except ssl.SSLError as e:
                 stats = e
 
-            if (expected is None and IS_OPENSSL_1_1
+            if (expected is None and IS_OPENSSL_1_1_0
                     and ssl.OPENSSL_VERSION_INFO < (1, 1, 0, 6)):
                 # OpenSSL 1.1.0 to 1.1.0e raises handshake error
                 self.assertIsInstance(stats, ssl.SSLError)
@@ -3823,6 +3850,8 @@
 
     def test_session(self):
         client_context, server_context, hostname = testing_context()
+        # TODO: sessions aren't compatible with TLSv1.3 yet
+        client_context.options |= ssl.OP_NO_TLSv1_3
 
         # first connection without session
         stats = server_params_test(client_context, server_context,
@@ -3881,7 +3910,7 @@
         client_context, server_context, hostname = testing_context()
         client_context2, _, _ = testing_context()
 
-        # TODO: session reuse does not work with TLS 1.3
+        # TODO: session reuse does not work with TLSv1.3
         client_context.options |= ssl.OP_NO_TLSv1_3
         client_context2.options |= ssl.OP_NO_TLSv1_3