[bpo-28414] Make all hostnames in SSL module IDN A-labels (GH-5128) (GH-5843)

Previously, the ssl module stored international domain names (IDNs)
as U-labels. This is problematic for a number of reasons -- for
example, it made it impossible for users to use a different version
of IDNA than the one built into Python.

After this change, we always convert to A-labels as soon as possible,
and use them for all internal processing. In particular, server_hostname
attribute is now an A-label, and on the server side there's a new
sni_callback that receives the SNI servername as an A-label rather than
a U-label.
(cherry picked from commit 11a1493bc4198f1def5e572049485779cf54dc57)

Co-authored-by: Christian Heimes <christian@python.org>
diff --git a/Lib/ssl.py b/Lib/ssl.py
index b6161d0..f253769 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -355,13 +355,20 @@
         self = _SSLContext.__new__(cls, protocol)
         return self
 
-    def __init__(self, protocol=PROTOCOL_TLS):
-        self.protocol = protocol
+    def _encode_hostname(self, hostname):
+        if hostname is None:
+            return None
+        elif isinstance(hostname, str):
+            return hostname.encode('idna').decode('ascii')
+        else:
+            return hostname.decode('ascii')
 
     def wrap_socket(self, sock, server_side=False,
                     do_handshake_on_connect=True,
                     suppress_ragged_eofs=True,
                     server_hostname=None, session=None):
+        # SSLSocket class handles server_hostname encoding before it calls
+        # ctx._wrap_socket()
         return self.sslsocket_class(
             sock=sock,
             server_side=server_side,
@@ -374,8 +381,12 @@
 
     def wrap_bio(self, incoming, outgoing, server_side=False,
                  server_hostname=None, session=None):
-        sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
-                                server_hostname=server_hostname)
+        # Need to encode server_hostname here because _wrap_bio() can only
+        # handle ASCII str.
+        sslobj = self._wrap_bio(
+            incoming, outgoing, server_side=server_side,
+            server_hostname=self._encode_hostname(server_hostname)
+        )
         return self.sslobject_class(sslobj, session=session)
 
     def set_npn_protocols(self, npn_protocols):
@@ -389,6 +400,19 @@
 
         self._set_npn_protocols(protos)
 
+    def set_servername_callback(self, server_name_callback):
+        if server_name_callback is None:
+            self.sni_callback = None
+        else:
+            if not callable(server_name_callback):
+                raise TypeError("not a callable object")
+
+            def shim_cb(sslobj, servername, sslctx):
+                servername = self._encode_hostname(servername)
+                return server_name_callback(sslobj, servername, sslctx)
+
+            self.sni_callback = shim_cb
+
     def set_alpn_protocols(self, alpn_protocols):
         protos = bytearray()
         for protocol in alpn_protocols:
@@ -448,6 +472,10 @@
             return True
 
     @property
+    def protocol(self):
+        return _SSLMethod(super().protocol)
+
+    @property
     def verify_flags(self):
         return VerifyFlags(super().verify_flags)
 
@@ -749,7 +777,7 @@
             raise ValueError("check_hostname requires server_hostname")
         self._session = _session
         self.server_side = server_side
-        self.server_hostname = server_hostname
+        self.server_hostname = self._context._encode_hostname(server_hostname)
         self.do_handshake_on_connect = do_handshake_on_connect
         self.suppress_ragged_eofs = suppress_ragged_eofs
         if sock is not None:
@@ -781,7 +809,7 @@
             # create the SSL object
             try:
                 sslobj = self._context._wrap_socket(self, server_side,
-                                                    server_hostname)
+                                                    self.server_hostname)
                 self._sslobj = SSLObject(sslobj, owner=self,
                                          session=self._session)
                 if do_handshake_on_connect:
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index a253f51..a48eb89 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -1528,16 +1528,6 @@
                 # For compatibility
                 self.assertEqual(cm.exception.errno, ssl.SSL_ERROR_WANT_READ)
 
-    def test_bad_idna_in_server_hostname(self):
-        # Note: this test is testing some code that probably shouldn't exist
-        # in the first place, so if it starts failing at some point because
-        # you made the ssl module stop doing IDNA decoding then please feel
-        # free to remove it. The test was mainly added because this case used
-        # to cause memory corruption (see bpo-30594).
-        ctx = ssl.create_default_context()
-        with self.assertRaises(UnicodeError):
-            ctx.wrap_bio(ssl.MemoryBIO(), ssl.MemoryBIO(),
-                         server_hostname="xn--.com")
 
     def test_bad_server_hostname(self):
         ctx = ssl.create_default_context()
@@ -2634,10 +2624,10 @@
         if support.verbose:
             sys.stdout.write("\n")
 
-        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+        server_context = ssl.SSLContext(ssl.PROTOCOL_TLS_SERVER)
         server_context.load_cert_chain(IDNSANSFILE)
 
-        context = ssl.SSLContext(ssl.PROTOCOL_TLS)
+        context = ssl.SSLContext(ssl.PROTOCOL_TLS_CLIENT)
         context.verify_mode = ssl.CERT_REQUIRED
         context.check_hostname = True
         context.load_verify_locations(SIGNING_CA)
@@ -2646,18 +2636,26 @@
         # different ways
         idn_hostnames = [
             ('könig.idn.pythontest.net',
-             'könig.idn.pythontest.net',),
+             'xn--knig-5qa.idn.pythontest.net'),
             ('xn--knig-5qa.idn.pythontest.net',
              'xn--knig-5qa.idn.pythontest.net'),
             (b'xn--knig-5qa.idn.pythontest.net',
-             b'xn--knig-5qa.idn.pythontest.net'),
+             'xn--knig-5qa.idn.pythontest.net'),
 
             ('königsgäßchen.idna2003.pythontest.net',
-             'königsgäßchen.idna2003.pythontest.net'),
+             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
             ('xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
              'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
             (b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net',
-             b'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
+             'xn--knigsgsschen-lcb0w.idna2003.pythontest.net'),
+
+            # ('königsgäßchen.idna2008.pythontest.net',
+            #  'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
+            ('xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
+             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
+            (b'xn--knigsgchen-b4a3dun.idna2008.pythontest.net',
+             'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'),
+
         ]
         for server_hostname, expected_hostname in idn_hostnames:
             server = ThreadedEchoServer(context=server_context, chatty=True)
@@ -2676,16 +2674,6 @@
                     s.getpeercert()
                     self.assertEqual(s.server_hostname, expected_hostname)
 
-        # bug https://bugs.python.org/issue28414
-        # IDNA 2008 deviations are broken
-        idna2008 = 'xn--knigsgchen-b4a3dun.idna2008.pythontest.net'
-        server = ThreadedEchoServer(context=server_context, chatty=True)
-        with server:
-            with self.assertRaises(UnicodeError):
-                with context.wrap_socket(socket.socket(),
-                                         server_hostname=idna2008) as s:
-                    s.connect((HOST, server.port))
-
         # incorrect hostname should raise an exception
         server = ThreadedEchoServer(context=server_context, chatty=True)
         with server: