[3.7] bpo-24334: Cleanup SSLSocket (GH-5252) (#5857)

* The SSLSocket is no longer implemented on top of SSLObject to
  avoid an extra level of indirection.
* Owner and session are now handled in the internal constructor.
* _ssl._SSLSocket now uses the same method names as SSLSocket and
  SSLObject.
* Channel binding type check is now handled in C code. Channel binding
  is always available.

The patch also changes the signature of SSLObject.__init__(). In my
opinion it's fine. A SSLObject is not a user-constructable object.
SSLContext.wrap_bio() is the only valid factory.
(cherry picked from commit 141c5e8c2437a9fed95a04c81e400ef725592a17)

Co-authored-by: Christian Heimes <christian@python.org>
diff --git a/Lib/ssl.py b/Lib/ssl.py
index ecdbb70..94ea35e 100644
--- a/Lib/ssl.py
+++ b/Lib/ssl.py
@@ -166,10 +166,7 @@
 
 socket_error = OSError  # keep that public name in module namespace
 
-if _ssl.HAS_TLS_UNIQUE:
-    CHANNEL_BINDING_TYPES = ['tls-unique']
-else:
-    CHANNEL_BINDING_TYPES = []
+CHANNEL_BINDING_TYPES = ['tls-unique']
 
 HAS_NEVER_CHECK_COMMON_NAME = hasattr(_ssl, 'HOSTFLAG_NEVER_CHECK_SUBJECT')
 
@@ -407,11 +404,11 @@
                  server_hostname=None, session=None):
         # Need to encode server_hostname here because _wrap_bio() can only
         # handle ASCII str.
-        sslobj = self._wrap_bio(
+        return self.sslobject_class(
             incoming, outgoing, server_side=server_side,
-            server_hostname=self._encode_hostname(server_hostname)
+            server_hostname=self._encode_hostname(server_hostname),
+            session=session, _context=self,
         )
-        return self.sslobject_class(sslobj, session=session)
 
     def set_npn_protocols(self, npn_protocols):
         protos = bytearray()
@@ -616,12 +613,13 @@
      * The ``do_handshake_on_connect`` and ``suppress_ragged_eofs`` machinery.
     """
 
-    def __init__(self, sslobj, owner=None, session=None):
-        self._sslobj = sslobj
-        # Note: _sslobj takes a weak reference to owner
-        self._sslobj.owner = owner or self
-        if session is not None:
-            self._sslobj.session = session
+    def __init__(self, incoming, outgoing, server_side=False,
+                 server_hostname=None, session=None, _context=None):
+        self._sslobj = _context._wrap_bio(
+            incoming, outgoing, server_side=server_side,
+            server_hostname=server_hostname,
+            owner=self, session=session
+        )
 
     @property
     def context(self):
@@ -684,7 +682,7 @@
         Return None if no certificate was provided, {} if a certificate was
         provided, but not validated.
         """
-        return self._sslobj.peer_certificate(binary_form)
+        return self._sslobj.getpeercert(binary_form)
 
     def selected_npn_protocol(self):
         """Return the currently selected NPN protocol as a string, or ``None``
@@ -732,13 +730,7 @@
         """Get channel binding data for current connection.  Raise ValueError
         if the requested `cb_type` is not supported.  Return bytes of the data
         or None if the data is not available (e.g. before the handshake)."""
-        if cb_type not in CHANNEL_BINDING_TYPES:
-            raise ValueError("Unsupported channel binding type")
-        if cb_type != "tls-unique":
-            raise NotImplementedError(
-                            "{0} channel binding type not implemented"
-                            .format(cb_type))
-        return self._sslobj.tls_unique_cb()
+        return self._sslobj.get_channel_binding(cb_type)
 
     def version(self):
         """Return a string identifying the protocol version used by the
@@ -832,10 +824,10 @@
         if connected:
             # create the SSL object
             try:
-                sslobj = self._context._wrap_socket(self, server_side,
-                                                    self.server_hostname)
-                self._sslobj = SSLObject(sslobj, owner=self,
-                                         session=self._session)
+                self._sslobj = self._context._wrap_socket(
+                    self, server_side, self.server_hostname,
+                    owner=self, session=self._session,
+                )
                 if do_handshake_on_connect:
                     timeout = self.gettimeout()
                     if timeout == 0.0:
@@ -895,10 +887,13 @@
         Return zero-length string on EOF."""
 
         self._checkClosed()
-        if not self._sslobj:
+        if self._sslobj is None:
             raise ValueError("Read on closed or unwrapped SSL socket.")
         try:
-            return self._sslobj.read(len, buffer)
+            if buffer is not None:
+                return self._sslobj.read(len, buffer)
+            else:
+                return self._sslobj.read(len)
         except SSLError as x:
             if x.args[0] == SSL_ERROR_EOF and self.suppress_ragged_eofs:
                 if buffer is not None:
@@ -913,7 +908,7 @@
         number of bytes of DATA actually transmitted."""
 
         self._checkClosed()
-        if not self._sslobj:
+        if self._sslobj is None:
             raise ValueError("Write on closed or unwrapped SSL socket.")
         return self._sslobj.write(data)
 
@@ -929,41 +924,42 @@
 
     def selected_npn_protocol(self):
         self._checkClosed()
-        if not self._sslobj or not _ssl.HAS_NPN:
+        if self._sslobj is None or not _ssl.HAS_NPN:
             return None
         else:
             return self._sslobj.selected_npn_protocol()
 
     def selected_alpn_protocol(self):
         self._checkClosed()
-        if not self._sslobj or not _ssl.HAS_ALPN:
+        if self._sslobj is None or not _ssl.HAS_ALPN:
             return None
         else:
             return self._sslobj.selected_alpn_protocol()
 
     def cipher(self):
         self._checkClosed()
-        if not self._sslobj:
+        if self._sslobj is None:
             return None
         else:
             return self._sslobj.cipher()
 
     def shared_ciphers(self):
         self._checkClosed()
-        if not self._sslobj:
+        if self._sslobj is None:
             return None
-        return self._sslobj.shared_ciphers()
+        else:
+            return self._sslobj.shared_ciphers()
 
     def compression(self):
         self._checkClosed()
-        if not self._sslobj:
+        if self._sslobj is None:
             return None
         else:
             return self._sslobj.compression()
 
     def send(self, data, flags=0):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             if flags != 0:
                 raise ValueError(
                     "non-zero flags not allowed in calls to send() on %s" %
@@ -974,7 +970,7 @@
 
     def sendto(self, data, flags_or_addr, addr=None):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             raise ValueError("sendto not allowed on instances of %s" %
                              self.__class__)
         elif addr is None:
@@ -990,7 +986,7 @@
 
     def sendall(self, data, flags=0):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             if flags != 0:
                 raise ValueError(
                     "non-zero flags not allowed in calls to sendall() on %s" %
@@ -1008,15 +1004,15 @@
         """Send a file, possibly by using os.sendfile() if this is a
         clear-text socket.  Return the total number of bytes sent.
         """
-        if self._sslobj is None:
+        if self._sslobj is not None:
+            return self._sendfile_use_send(file, offset, count)
+        else:
             # os.sendfile() works with plain sockets only
             return super().sendfile(file, offset, count)
-        else:
-            return self._sendfile_use_send(file, offset, count)
 
     def recv(self, buflen=1024, flags=0):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             if flags != 0:
                 raise ValueError(
                     "non-zero flags not allowed in calls to recv() on %s" %
@@ -1031,7 +1027,7 @@
             nbytes = len(buffer)
         elif nbytes is None:
             nbytes = 1024
-        if self._sslobj:
+        if self._sslobj is not None:
             if flags != 0:
                 raise ValueError(
                   "non-zero flags not allowed in calls to recv_into() on %s" %
@@ -1042,7 +1038,7 @@
 
     def recvfrom(self, buflen=1024, flags=0):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             raise ValueError("recvfrom not allowed on instances of %s" %
                              self.__class__)
         else:
@@ -1050,7 +1046,7 @@
 
     def recvfrom_into(self, buffer, nbytes=None, flags=0):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             raise ValueError("recvfrom_into not allowed on instances of %s" %
                              self.__class__)
         else:
@@ -1066,7 +1062,7 @@
 
     def pending(self):
         self._checkClosed()
-        if self._sslobj:
+        if self._sslobj is not None:
             return self._sslobj.pending()
         else:
             return 0
@@ -1078,7 +1074,7 @@
 
     def unwrap(self):
         if self._sslobj:
-            s = self._sslobj.unwrap()
+            s = self._sslobj.shutdown()
             self._sslobj = None
             return s
         else:
@@ -1096,6 +1092,11 @@
             if timeout == 0.0 and block:
                 self.settimeout(None)
             self._sslobj.do_handshake()
+            if self.context.check_hostname:
+                if not self.server_hostname:
+                    raise ValueError("check_hostname needs server_hostname "
+                                     "argument")
+                match_hostname(self.getpeercert(), self.server_hostname)
         finally:
             self.settimeout(timeout)
 
@@ -1104,11 +1105,12 @@
             raise ValueError("can't connect in server-side mode")
         # Here we assume that the socket is client-side, and not
         # connected at the time of the call.  We connect it, then wrap it.
-        if self._connected:
+        if self._connected or self._sslobj is not None:
             raise ValueError("attempt to connect already-connected SSLSocket!")
-        sslobj = self.context._wrap_socket(self, False, self.server_hostname)
-        self._sslobj = SSLObject(sslobj, owner=self,
-                                 session=self._session)
+        self._sslobj = self.context._wrap_socket(
+            self, False, self.server_hostname,
+            owner=self, session=self._session
+        )
         try:
             if connect_ex:
                 rc = super().connect_ex(addr)
@@ -1151,18 +1153,24 @@
         if the requested `cb_type` is not supported.  Return bytes of the data
         or None if the data is not available (e.g. before the handshake).
         """
-        if self._sslobj is None:
+        if self._sslobj is not None:
+            return self._sslobj.get_channel_binding(cb_type)
+        else:
+            if cb_type not in CHANNEL_BINDING_TYPES:
+                raise ValueError(
+                    "{0} channel binding type not implemented".format(cb_type)
+                )
             return None
-        return self._sslobj.get_channel_binding(cb_type)
 
     def version(self):
         """
         Return a string identifying the protocol version used by the
         current SSL channel, or None if there is no established channel.
         """
-        if self._sslobj is None:
+        if self._sslobj is not None:
+            return self._sslobj.version()
+        else:
             return None
-        return self._sslobj.version()
 
 
 # Python does not support forward declaration of types.
diff --git a/Lib/test/test_ssl.py b/Lib/test/test_ssl.py
index 7aa1123..3f2c50b 100644
--- a/Lib/test/test_ssl.py
+++ b/Lib/test/test_ssl.py
@@ -455,6 +455,8 @@
             self.assertRaises(OSError, ss.recvfrom_into, bytearray(b'x'), 1)
             self.assertRaises(OSError, ss.send, b'x')
             self.assertRaises(OSError, ss.sendto, b'x', ('0.0.0.0', 0))
+            self.assertRaises(NotImplementedError, ss.sendmsg,
+                              [b'x'], (), 0, ('0.0.0.0', 0))
 
     def test_timeout(self):
         # Issue #8524: when creating an SSL socket, the timeout of the
@@ -3381,11 +3383,13 @@
                                 chatty=False) as server:
             with context.wrap_socket(socket.socket()) as s:
                 self.assertIs(s.version(), None)
+                self.assertIs(s._sslobj, None)
                 s.connect((HOST, server.port))
                 if ssl.OPENSSL_VERSION_INFO >= (1, 0, 2):
                     self.assertEqual(s.version(), 'TLSv1.2')
                 else:  # 0.9.8 to 1.0.1
                     self.assertIn(s.version(), ('TLSv1', 'TLSv1.2'))
+            self.assertIs(s._sslobj, None)
             self.assertIs(s.version(), None)
 
     @unittest.skipUnless(ssl.HAS_TLSv1_3,