[bpo-28414] Make all hostnames in SSL module IDN A-labels (GH-5128) (GH-5843)

Previously, the ssl module stored international domain names (IDNs)
as U-labels. This is problematic for a number of reasons -- for
example, it made it impossible for users to use a different version
of IDNA than the one built into Python.

After this change, we always convert to A-labels as soon as possible,
and use them for all internal processing. In particular, server_hostname
attribute is now an A-label, and on the server side there's a new
sni_callback that receives the SNI servername as an A-label rather than
a U-label.
(cherry picked from commit 11a1493bc4198f1def5e572049485779cf54dc57)

Co-authored-by: Christian Heimes <christian@python.org>
diff --git a/Lib/ssl.py b/Lib/ssl.py
index b6161d0..f253769 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -355,13 +355,20 @@
         self = _SSLContext.__new__(cls, protocol)
         return self
 
-    def __init__(self, protocol=PROTOCOL_TLS):
-        self.protocol = protocol
+    def _encode_hostname(self, hostname):
+        if hostname is None:
+            return None
+        elif isinstance(hostname, str):
+            return hostname.encode('idna').decode('ascii')
+        else:
+            return hostname.decode('ascii')
 
     def wrap_socket(self, sock, server_side=False,
                     do_handshake_on_connect=True,
                     suppress_ragged_eofs=True,
                     server_hostname=None, session=None):
+        # SSLSocket class handles server_hostname encoding before it calls
+        # ctx._wrap_socket()
         return self.sslsocket_class(
             sock=sock,
             server_side=server_side,
@@ -374,8 +381,12 @@
 
     def wrap_bio(self, incoming, outgoing, server_side=False,
                  server_hostname=None, session=None):
-        sslobj = self._wrap_bio(incoming, outgoing, server_side=server_side,
-                                server_hostname=server_hostname)
+        # Need to encode server_hostname here because _wrap_bio() can only
+        # handle ASCII str.
+        sslobj = self._wrap_bio(
+            incoming, outgoing, server_side=server_side,
+            server_hostname=self._encode_hostname(server_hostname)
+        )
         return self.sslobject_class(sslobj, session=session)
 
     def set_npn_protocols(self, npn_protocols):
@@ -389,6 +400,19 @@
 
         self._set_npn_protocols(protos)
 
+    def set_servername_callback(self, server_name_callback):
+        if server_name_callback is None:
+            self.sni_callback = None
+        else:
+            if not callable(server_name_callback):
+                raise TypeError("not a callable object")
+
+            def shim_cb(sslobj, servername, sslctx):
+                servername = self._encode_hostname(servername)
+                return server_name_callback(sslobj, servername, sslctx)
+
+            self.sni_callback = shim_cb
+
     def set_alpn_protocols(self, alpn_protocols):
         protos = bytearray()
         for protocol in alpn_protocols:
@@ -448,6 +472,10 @@
             return True
 
     @property
+    def protocol(self):
+        return _SSLMethod(super().protocol)
+
+    @property
     def verify_flags(self):
         return VerifyFlags(super().verify_flags)
 
@@ -749,7 +777,7 @@
             raise ValueError("check_hostname requires server_hostname")
         self._session = _session
         self.server_side = server_side
-        self.server_hostname = server_hostname
+        self.server_hostname = self._context._encode_hostname(server_hostname)
         self.do_handshake_on_connect = do_handshake_on_connect
         self.suppress_ragged_eofs = suppress_ragged_eofs
         if sock is not None:
@@ -781,7 +809,7 @@
             # create the SSL object
             try:
                 sslobj = self._context._wrap_socket(self, server_side,
-                                                    server_hostname)
+                                                    self.server_hostname)
                 self._sslobj = SSLObject(sslobj, owner=self,
                                          session=self._session)
                 if do_handshake_on_connect: