bpo-44011: New asyncio ssl implementation (#17975)

diff --git a/Lib/asyncio/sslproto.py b/Lib/asyncio/sslproto.py
index cad25b2..e71875b 100644
--- a/Lib/asyncio/sslproto.py
+++ b/Lib/asyncio/sslproto.py
@@ -1,4 +1,5 @@
 import collections
+import enum
 import warnings
 try:
     import ssl
@@ -6,10 +7,37 @@
     ssl = None
 
 from . import constants
+from . import exceptions
 from . import protocols
 from . import transports
 from .log import logger
 
+SSLAgainErrors = (ssl.SSLWantReadError, ssl.SSLSyscallError)
+
+
+class SSLProtocolState(enum.Enum):
+    UNWRAPPED = "UNWRAPPED"
+    DO_HANDSHAKE = "DO_HANDSHAKE"
+    WRAPPED = "WRAPPED"
+    FLUSHING = "FLUSHING"
+    SHUTDOWN = "SHUTDOWN"
+
+
+class AppProtocolState(enum.Enum):
+    # This tracks the state of app protocol (https://git.io/fj59P):
+    #
+    #     INIT -cm-> CON_MADE [-dr*->] [-er-> EOF?] -cl-> CON_LOST
+    #
+    # * cm: connection_made()
+    # * dr: data_received()
+    # * er: eof_received()
+    # * cl: connection_lost()
+
+    STATE_INIT = "STATE_INIT"
+    STATE_CON_MADE = "STATE_CON_MADE"
+    STATE_EOF = "STATE_EOF"
+    STATE_CON_LOST = "STATE_CON_LOST"
+
 
 def _create_transport_context(server_side, server_hostname):
     if server_side:
@@ -25,269 +53,35 @@ def _create_transport_context(server_side, server_hostname):
     return sslcontext
 
 
-# States of an _SSLPipe.
-_UNWRAPPED = "UNWRAPPED"
-_DO_HANDSHAKE = "DO_HANDSHAKE"
-_WRAPPED = "WRAPPED"
-_SHUTDOWN = "SHUTDOWN"
+def add_flowcontrol_defaults(high, low, kb):
+    if high is None:
+        if low is None:
+            hi = kb * 1024
+        else:
+            lo = low
+            hi = 4 * lo
+    else:
+        hi = high
+    if low is None:
+        lo = hi // 4
+    else:
+        lo = low
 
+    if not hi >= lo >= 0:
+        raise ValueError('high (%r) must be >= low (%r) must be >= 0' %
+                         (hi, lo))
 
-class _SSLPipe(object):
-    """An SSL "Pipe".
-
-    An SSL pipe allows you to communicate with an SSL/TLS protocol instance
-    through memory buffers. It can be used to implement a security layer for an
-    existing connection where you don't have access to the connection's file
-    descriptor, or for some reason you don't want to use it.
-
-    An SSL pipe can be in "wrapped" and "unwrapped" mode. In unwrapped mode,
-    data is passed through untransformed. In wrapped mode, application level
-    data is encrypted to SSL record level data and vice versa. The SSL record
-    level is the lowest level in the SSL protocol suite and is what travels
-    as-is over the wire.
-
-    An SslPipe initially is in "unwrapped" mode. To start SSL, call
-    do_handshake(). To shutdown SSL again, call unwrap().
-    """
-
-    max_size = 256 * 1024   # Buffer size passed to read()
-
-    def __init__(self, context, server_side, server_hostname=None):
-        """
-        The *context* argument specifies the ssl.SSLContext to use.
-
-        The *server_side* argument indicates whether this is a server side or
-        client side transport.
-
-        The optional *server_hostname* argument can be used to specify the
-        hostname you are connecting to. You may only specify this parameter if
-        the _ssl module supports Server Name Indication (SNI).
-        """
-        self._context = context
-        self._server_side = server_side
-        self._server_hostname = server_hostname
-        self._state = _UNWRAPPED
-        self._incoming = ssl.MemoryBIO()
-        self._outgoing = ssl.MemoryBIO()
-        self._sslobj = None
-        self._need_ssldata = False
-        self._handshake_cb = None
-        self._shutdown_cb = None
-
-    @property
-    def context(self):
-        """The SSL context passed to the constructor."""
-        return self._context
-
-    @property
-    def ssl_object(self):
-        """The internal ssl.SSLObject instance.
-
-        Return None if the pipe is not wrapped.
-        """
-        return self._sslobj
-
-    @property
-    def need_ssldata(self):
-        """Whether more record level data is needed to complete a handshake
-        that is currently in progress."""
-        return self._need_ssldata
-
-    @property
-    def wrapped(self):
-        """
-        Whether a security layer is currently in effect.
-
-        Return False during handshake.
-        """
-        return self._state == _WRAPPED
-
-    def do_handshake(self, callback=None):
-        """Start the SSL handshake.
-
-        Return a list of ssldata. A ssldata element is a list of buffers
-
-        The optional *callback* argument can be used to install a callback that
-        will be called when the handshake is complete. The callback will be
-        called with None if successful, else an exception instance.
-        """
-        if self._state != _UNWRAPPED:
-            raise RuntimeError('handshake in progress or completed')
-        self._sslobj = self._context.wrap_bio(
-            self._incoming, self._outgoing,
-            server_side=self._server_side,
-            server_hostname=self._server_hostname)
-        self._state = _DO_HANDSHAKE
-        self._handshake_cb = callback
-        ssldata, appdata = self.feed_ssldata(b'', only_handshake=True)
-        assert len(appdata) == 0
-        return ssldata
-
-    def shutdown(self, callback=None):
-        """Start the SSL shutdown sequence.
-
-        Return a list of ssldata. A ssldata element is a list of buffers
-
-        The optional *callback* argument can be used to install a callback that
-        will be called when the shutdown is complete. The callback will be
-        called without arguments.
-        """
-        if self._state == _UNWRAPPED:
-            raise RuntimeError('no security layer present')
-        if self._state == _SHUTDOWN:
-            raise RuntimeError('shutdown in progress')
-        assert self._state in (_WRAPPED, _DO_HANDSHAKE)
-        self._state = _SHUTDOWN
-        self._shutdown_cb = callback
-        ssldata, appdata = self.feed_ssldata(b'')
-        assert appdata == [] or appdata == [b'']
-        return ssldata
-
-    def feed_eof(self):
-        """Send a potentially "ragged" EOF.
-
-        This method will raise an SSL_ERROR_EOF exception if the EOF is
-        unexpected.
-        """
-        self._incoming.write_eof()
-        ssldata, appdata = self.feed_ssldata(b'')
-        assert appdata == [] or appdata == [b'']
-
-    def feed_ssldata(self, data, only_handshake=False):
-        """Feed SSL record level data into the pipe.
-
-        The data must be a bytes instance. It is OK to send an empty bytes
-        instance. This can be used to get ssldata for a handshake initiated by
-        this endpoint.
-
-        Return a (ssldata, appdata) tuple. The ssldata element is a list of
-        buffers containing SSL data that needs to be sent to the remote SSL.
-
-        The appdata element is a list of buffers containing plaintext data that
-        needs to be forwarded to the application. The appdata list may contain
-        an empty buffer indicating an SSL "close_notify" alert. This alert must
-        be acknowledged by calling shutdown().
-        """
-        if self._state == _UNWRAPPED:
-            # If unwrapped, pass plaintext data straight through.
-            if data:
-                appdata = [data]
-            else:
-                appdata = []
-            return ([], appdata)
-
-        self._need_ssldata = False
-        if data:
-            self._incoming.write(data)
-
-        ssldata = []
-        appdata = []
-        try:
-            if self._state == _DO_HANDSHAKE:
-                # Call do_handshake() until it doesn't raise anymore.
-                self._sslobj.do_handshake()
-                self._state = _WRAPPED
-                if self._handshake_cb:
-                    self._handshake_cb(None)
-                if only_handshake:
-                    return (ssldata, appdata)
-                # Handshake done: execute the wrapped block
-
-            if self._state == _WRAPPED:
-                # Main state: read data from SSL until close_notify
-                while True:
-                    chunk = self._sslobj.read(self.max_size)
-                    appdata.append(chunk)
-                    if not chunk:  # close_notify
-                        break
-
-            elif self._state == _SHUTDOWN:
-                # Call shutdown() until it doesn't raise anymore.
-                self._sslobj.unwrap()
-                self._sslobj = None
-                self._state = _UNWRAPPED
-                if self._shutdown_cb:
-                    self._shutdown_cb()
-
-            elif self._state == _UNWRAPPED:
-                # Drain possible plaintext data after close_notify.
-                appdata.append(self._incoming.read())
-        except (ssl.SSLError, ssl.CertificateError) as exc:
-            exc_errno = getattr(exc, 'errno', None)
-            if exc_errno not in (
-                    ssl.SSL_ERROR_WANT_READ, ssl.SSL_ERROR_WANT_WRITE,
-                    ssl.SSL_ERROR_SYSCALL):
-                if self._state == _DO_HANDSHAKE and self._handshake_cb:
-                    self._handshake_cb(exc)
-                raise
-            self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
-
-        # Check for record level data that needs to be sent back.
-        # Happens for the initial handshake and renegotiations.
-        if self._outgoing.pending:
-            ssldata.append(self._outgoing.read())
-        return (ssldata, appdata)
-
-    def feed_appdata(self, data, offset=0):
-        """Feed plaintext data into the pipe.
-
-        Return an (ssldata, offset) tuple. The ssldata element is a list of
-        buffers containing record level data that needs to be sent to the
-        remote SSL instance. The offset is the number of plaintext bytes that
-        were processed, which may be less than the length of data.
-
-        NOTE: In case of short writes, this call MUST be retried with the SAME
-        buffer passed into the *data* argument (i.e. the id() must be the
-        same). This is an OpenSSL requirement. A further particularity is that
-        a short write will always have offset == 0, because the _ssl module
-        does not enable partial writes. And even though the offset is zero,
-        there will still be encrypted data in ssldata.
-        """
-        assert 0 <= offset <= len(data)
-        if self._state == _UNWRAPPED:
-            # pass through data in unwrapped mode
-            if offset < len(data):
-                ssldata = [data[offset:]]
-            else:
-                ssldata = []
-            return (ssldata, len(data))
-
-        ssldata = []
-        view = memoryview(data)
-        while True:
-            self._need_ssldata = False
-            try:
-                if offset < len(view):
-                    offset += self._sslobj.write(view[offset:])
-            except ssl.SSLError as exc:
-                # It is not allowed to call write() after unwrap() until the
-                # close_notify is acknowledged. We return the condition to the
-                # caller as a short write.
-                exc_errno = getattr(exc, 'errno', None)
-                if exc.reason == 'PROTOCOL_IS_SHUTDOWN':
-                    exc_errno = exc.errno = ssl.SSL_ERROR_WANT_READ
-                if exc_errno not in (ssl.SSL_ERROR_WANT_READ,
-                                     ssl.SSL_ERROR_WANT_WRITE,
-                                     ssl.SSL_ERROR_SYSCALL):
-                    raise
-                self._need_ssldata = (exc_errno == ssl.SSL_ERROR_WANT_READ)
-
-            # See if there's any record level data back for us.
-            if self._outgoing.pending:
-                ssldata.append(self._outgoing.read())
-            if offset == len(view) or self._need_ssldata:
-                break
-        return (ssldata, offset)
+    return hi, lo
 
 
 class _SSLProtocolTransport(transports._FlowControlMixin,
                             transports.Transport):
 
+    _start_tls_compatible = True
     _sendfile_compatible = constants._SendfileMode.FALLBACK
 
     def __init__(self, loop, ssl_protocol):
         self._loop = loop
-        # SSLProtocol instance
         self._ssl_protocol = ssl_protocol
         self._closed = False
 
@@ -315,16 +109,15 @@ def close(self):
         self._closed = True
         self._ssl_protocol._start_shutdown()
 
-    def __del__(self, _warn=warnings.warn):
+    def __del__(self, _warnings=warnings):
         if not self._closed:
-            _warn(f"unclosed transport {self!r}", ResourceWarning, source=self)
-            self.close()
+            self._closed = True
+            _warnings.warn(
+                "unclosed transport <asyncio._SSLProtocolTransport "
+                "object>", ResourceWarning)
 
     def is_reading(self):
-        tr = self._ssl_protocol._transport
-        if tr is None:
-            raise RuntimeError('SSL transport has not been initialized yet')
-        return tr.is_reading()
+        return not self._ssl_protocol._app_reading_paused
 
     def pause_reading(self):
         """Pause the receiving end.
@@ -332,7 +125,7 @@ def pause_reading(self):
         No data will be passed to the protocol's data_received()
         method until resume_reading() is called.
         """
-        self._ssl_protocol._transport.pause_reading()
+        self._ssl_protocol._pause_reading()
 
     def resume_reading(self):
         """Resume the receiving end.
@@ -340,7 +133,7 @@ def resume_reading(self):
         Data received will once again be passed to the protocol's
         data_received() method.
         """
-        self._ssl_protocol._transport.resume_reading()
+        self._ssl_protocol._resume_reading()
 
     def set_write_buffer_limits(self, high=None, low=None):
         """Set the high- and low-water limits for write flow control.
@@ -361,16 +154,51 @@ def set_write_buffer_limits(self, high=None, low=None):
         reduces opportunities for doing I/O and computation
         concurrently.
         """
-        self._ssl_protocol._transport.set_write_buffer_limits(high, low)
+        self._ssl_protocol._set_write_buffer_limits(high, low)
+        self._ssl_protocol._control_app_writing()
+
+    def get_write_buffer_limits(self):
+        return (self._ssl_protocol._outgoing_low_water,
+                self._ssl_protocol._outgoing_high_water)
 
     def get_write_buffer_size(self):
-        """Return the current size of the write buffer."""
-        return self._ssl_protocol._transport.get_write_buffer_size()
+        """Return the current size of the write buffers."""
+        return self._ssl_protocol._get_write_buffer_size()
+
+    def set_read_buffer_limits(self, high=None, low=None):
+        """Set the high- and low-water limits for read flow control.
+
+        These two values control when to call the upstream transport's
+        pause_reading() and resume_reading() methods.  If specified,
+        the low-water limit must be less than or equal to the
+        high-water limit.  Neither value can be negative.
+
+        The defaults are implementation-specific.  If only the
+        high-water limit is given, the low-water limit defaults to an
+        implementation-specific value less than or equal to the
+        high-water limit.  Setting high to zero forces low to zero as
+        well, and causes pause_reading() to be called whenever the
+        buffer becomes non-empty.  Setting low to zero causes
+        resume_reading() to be called only once the buffer is empty.
+        Use of zero for either limit is generally sub-optimal as it
+        reduces opportunities for doing I/O and computation
+        concurrently.
+        """
+        self._ssl_protocol._set_read_buffer_limits(high, low)
+        self._ssl_protocol._control_ssl_reading()
+
+    def get_read_buffer_limits(self):
+        return (self._ssl_protocol._incoming_low_water,
+                self._ssl_protocol._incoming_high_water)
+
+    def get_read_buffer_size(self):
+        """Return the current size of the read buffer."""
+        return self._ssl_protocol._get_read_buffer_size()
 
     @property
     def _protocol_paused(self):
         # Required for sendfile fallback pause_writing/resume_writing logic
-        return self._ssl_protocol._transport._protocol_paused
+        return self._ssl_protocol._app_writing_paused
 
     def write(self, data):
         """Write some data bytes to the transport.
@@ -383,7 +211,22 @@ def write(self, data):
                             f"got {type(data).__name__}")
         if not data:
             return
-        self._ssl_protocol._write_appdata(data)
+        self._ssl_protocol._write_appdata((data,))
+
+    def writelines(self, list_of_data):
+        """Write a list (or any iterable) of data bytes to the transport.
+
+        The default implementation concatenates the arguments and
+        calls write() on the result.
+        """
+        self._ssl_protocol._write_appdata(list_of_data)
+
+    def write_eof(self):
+        """Close the write end after flushing buffered data.
+
+        This raises :exc:`NotImplementedError` right now.
+        """
+        raise NotImplementedError
 
     def can_write_eof(self):
         """Return True if this transport supports write_eof(), False if not."""
@@ -396,23 +239,36 @@ def abort(self):
         The protocol's connection_lost() method will (eventually) be
         called with None as its argument.
         """
-        self._ssl_protocol._abort()
         self._closed = True
+        self._ssl_protocol._abort()
+
+    def _force_close(self, exc):
+        self._closed = True
+        self._ssl_protocol._abort(exc)
+
+    def _test__append_write_backlog(self, data):
+        # for test only
+        self._ssl_protocol._write_backlog.append(data)
+        self._ssl_protocol._write_buffer_size += len(data)
 
 
-class SSLProtocol(protocols.Protocol):
-    """SSL protocol.
+class SSLProtocol(protocols.BufferedProtocol):
+    max_size = 256 * 1024   # Buffer size passed to read()
 
-    Implementation of SSL on top of a socket using incoming and outgoing
-    buffers which are ssl.MemoryBIO objects.
-    """
+    _handshake_start_time = None
+    _handshake_timeout_handle = None
+    _shutdown_timeout_handle = None
 
     def __init__(self, loop, app_protocol, sslcontext, waiter,
                  server_side=False, server_hostname=None,
                  call_connection_made=True,
-                 ssl_handshake_timeout=None):
+                 ssl_handshake_timeout=None,
+                 ssl_shutdown_timeout=None):
         if ssl is None:
-            raise RuntimeError('stdlib ssl module not available')
+            raise RuntimeError("stdlib ssl module not available")
+
+        self._ssl_buffer = bytearray(self.max_size)
+        self._ssl_buffer_view = memoryview(self._ssl_buffer)
 
         if ssl_handshake_timeout is None:
             ssl_handshake_timeout = constants.SSL_HANDSHAKE_TIMEOUT
@@ -420,6 +276,12 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
             raise ValueError(
                 f"ssl_handshake_timeout should be a positive number, "
                 f"got {ssl_handshake_timeout}")
+        if ssl_shutdown_timeout is None:
+            ssl_shutdown_timeout = constants.SSL_SHUTDOWN_TIMEOUT
+        elif ssl_shutdown_timeout <= 0:
+            raise ValueError(
+                f"ssl_shutdown_timeout should be a positive number, "
+                f"got {ssl_shutdown_timeout}")
 
         if not sslcontext:
             sslcontext = _create_transport_context(
@@ -442,21 +304,54 @@ def __init__(self, loop, app_protocol, sslcontext, waiter,
         self._waiter = waiter
         self._loop = loop
         self._set_app_protocol(app_protocol)
-        self._app_transport = _SSLProtocolTransport(self._loop, self)
-        # _SSLPipe instance (None until the connection is made)
-        self._sslpipe = None
-        self._session_established = False
-        self._in_handshake = False
-        self._in_shutdown = False
+        self._app_transport = None
+        self._app_transport_created = False
         # transport, ex: SelectorSocketTransport
         self._transport = None
-        self._call_connection_made = call_connection_made
         self._ssl_handshake_timeout = ssl_handshake_timeout
+        self._ssl_shutdown_timeout = ssl_shutdown_timeout
+        # SSL and state machine
+        self._incoming = ssl.MemoryBIO()
+        self._outgoing = ssl.MemoryBIO()
+        self._state = SSLProtocolState.UNWRAPPED
+        self._conn_lost = 0  # Set when connection_lost called
+        if call_connection_made:
+            self._app_state = AppProtocolState.STATE_INIT
+        else:
+            self._app_state = AppProtocolState.STATE_CON_MADE
+        self._sslobj = self._sslcontext.wrap_bio(
+            self._incoming, self._outgoing,
+            server_side=self._server_side,
+            server_hostname=self._server_hostname)
+
+        # Flow Control
+
+        self._ssl_writing_paused = False
+
+        self._app_reading_paused = False
+
+        self._ssl_reading_paused = False
+        self._incoming_high_water = 0
+        self._incoming_low_water = 0
+        self._set_read_buffer_limits()
+        self._eof_received = False
+
+        self._app_writing_paused = False
+        self._outgoing_high_water = 0
+        self._outgoing_low_water = 0
+        self._set_write_buffer_limits()
+        self._get_app_transport()
 
     def _set_app_protocol(self, app_protocol):
         self._app_protocol = app_protocol
-        self._app_protocol_is_buffer = \
-            isinstance(app_protocol, protocols.BufferedProtocol)
+        # Make fast hasattr check first
+        if (hasattr(app_protocol, 'get_buffer') and
+                isinstance(app_protocol, protocols.BufferedProtocol)):
+            self._app_protocol_get_buffer = app_protocol.get_buffer
+            self._app_protocol_buffer_updated = app_protocol.buffer_updated
+            self._app_protocol_is_buffer = True
+        else:
+            self._app_protocol_is_buffer = False
 
     def _wakeup_waiter(self, exc=None):
         if self._waiter is None:
@@ -468,15 +363,20 @@ def _wakeup_waiter(self, exc=None):
                 self._waiter.set_result(None)
         self._waiter = None
 
+    def _get_app_transport(self):
+        if self._app_transport is None:
+            if self._app_transport_created:
+                raise RuntimeError('Creating _SSLProtocolTransport twice')
+            self._app_transport = _SSLProtocolTransport(self._loop, self)
+            self._app_transport_created = True
+        return self._app_transport
+
     def connection_made(self, transport):
         """Called when the low-level connection is made.
 
         Start the SSL handshake.
         """
         self._transport = transport
-        self._sslpipe = _SSLPipe(self._sslcontext,
-                                 self._server_side,
-                                 self._server_hostname)
         self._start_handshake()
 
     def connection_lost(self, exc):
@@ -486,72 +386,58 @@ def connection_lost(self, exc):
         meaning a regular EOF is received or the connection was
         aborted or closed).
         """
-        if self._session_established:
-            self._session_established = False
-            self._loop.call_soon(self._app_protocol.connection_lost, exc)
-        else:
-            # Most likely an exception occurred while in SSL handshake.
-            # Just mark the app transport as closed so that its __del__
-            # doesn't complain.
-            if self._app_transport is not None:
-                self._app_transport._closed = True
+        self._write_backlog.clear()
+        self._outgoing.read()
+        self._conn_lost += 1
+
+        # Just mark the app transport as closed so that its __dealloc__
+        # doesn't complain.
+        if self._app_transport is not None:
+            self._app_transport._closed = True
+
+        if self._state != SSLProtocolState.DO_HANDSHAKE:
+            if (
+                self._app_state == AppProtocolState.STATE_CON_MADE or
+                self._app_state == AppProtocolState.STATE_EOF
+            ):
+                self._app_state = AppProtocolState.STATE_CON_LOST
+                self._loop.call_soon(self._app_protocol.connection_lost, exc)
+        self._set_state(SSLProtocolState.UNWRAPPED)
         self._transport = None
         self._app_transport = None
-        if getattr(self, '_handshake_timeout_handle', None):
-            self._handshake_timeout_handle.cancel()
-        self._wakeup_waiter(exc)
         self._app_protocol = None
-        self._sslpipe = None
+        self._wakeup_waiter(exc)
 
-    def pause_writing(self):
-        """Called when the low-level transport's buffer goes over
-        the high-water mark.
-        """
-        self._app_protocol.pause_writing()
+        if self._shutdown_timeout_handle:
+            self._shutdown_timeout_handle.cancel()
+            self._shutdown_timeout_handle = None
+        if self._handshake_timeout_handle:
+            self._handshake_timeout_handle.cancel()
+            self._handshake_timeout_handle = None
 
-    def resume_writing(self):
-        """Called when the low-level transport's buffer drains below
-        the low-water mark.
-        """
-        self._app_protocol.resume_writing()
+    def get_buffer(self, n):
+        want = n
+        if want <= 0 or want > self.max_size:
+            want = self.max_size
+        if len(self._ssl_buffer) < want:
+            self._ssl_buffer = bytearray(want)
+            self._ssl_buffer_view = memoryview(self._ssl_buffer)
+        return self._ssl_buffer_view
 
-    def data_received(self, data):
-        """Called when some SSL data is received.
+    def buffer_updated(self, nbytes):
+        self._incoming.write(self._ssl_buffer_view[:nbytes])
 
-        The argument is a bytes object.
-        """
-        if self._sslpipe is None:
-            # transport closing, sslpipe is destroyed
-            return
+        if self._state == SSLProtocolState.DO_HANDSHAKE:
+            self._do_handshake()
 
-        try:
-            ssldata, appdata = self._sslpipe.feed_ssldata(data)
-        except (SystemExit, KeyboardInterrupt):
-            raise
-        except BaseException as e:
-            self._fatal_error(e, 'SSL error in data received')
-            return
+        elif self._state == SSLProtocolState.WRAPPED:
+            self._do_read()
 
-        for chunk in ssldata:
-            self._transport.write(chunk)
+        elif self._state == SSLProtocolState.FLUSHING:
+            self._do_flush()
 
-        for chunk in appdata:
-            if chunk:
-                try:
-                    if self._app_protocol_is_buffer:
-                        protocols._feed_data_to_buffered_proto(
-                            self._app_protocol, chunk)
-                    else:
-                        self._app_protocol.data_received(chunk)
-                except (SystemExit, KeyboardInterrupt):
-                    raise
-                except BaseException as ex:
-                    self._fatal_error(
-                        ex, 'application protocol failed to receive SSL data')
-                    return
-            else:
-                self._start_shutdown()
-                break
+        elif self._state == SSLProtocolState.SHUTDOWN:
+            self._do_shutdown()
 
     def eof_received(self):
         """Called when the other end of the low-level stream
@@ -561,19 +447,32 @@ def eof_received(self):
         will close itself.  If it returns a true value, closing the
         transport is up to the protocol.
         """
+        self._eof_received = True
         try:
             if self._loop.get_debug():
                 logger.debug("%r received EOF", self)
 
-            self._wakeup_waiter(ConnectionResetError)
+            if self._state == SSLProtocolState.DO_HANDSHAKE:
+                self._on_handshake_complete(ConnectionResetError)
 
-            if not self._in_handshake:
-                keep_open = self._app_protocol.eof_received()
-                if keep_open:
-                    logger.warning('returning true from eof_received() '
-                                   'has no effect when using ssl')
-        finally:
+            elif self._state == SSLProtocolState.WRAPPED:
+                self._set_state(SSLProtocolState.FLUSHING)
+                if self._app_reading_paused:
+                    return True
+                else:
+                    self._do_flush()
+
+            elif self._state == SSLProtocolState.FLUSHING:
+                self._do_write()
+                self._set_state(SSLProtocolState.SHUTDOWN)
+                self._do_shutdown()
+
+            elif self._state == SSLProtocolState.SHUTDOWN:
+                self._do_shutdown()
+
+        except Exception:
             self._transport.close()
+            raise
 
     def _get_extra_info(self, name, default=None):
         if name in self._extra:
@@ -583,19 +482,45 @@ def _get_extra_info(self, name, default=None):
         else:
             return default
 
-    def _start_shutdown(self):
-        if self._in_shutdown:
-            return
-        if self._in_handshake:
-            self._abort()
-        else:
-            self._in_shutdown = True
-            self._write_appdata(b'')
+    def _set_state(self, new_state):
+        allowed = False
 
-    def _write_appdata(self, data):
-        self._write_backlog.append((data, 0))
-        self._write_buffer_size += len(data)
-        self._process_write_backlog()
+        if new_state == SSLProtocolState.UNWRAPPED:
+            allowed = True
+
+        elif (
+            self._state == SSLProtocolState.UNWRAPPED and
+            new_state == SSLProtocolState.DO_HANDSHAKE
+        ):
+            allowed = True
+
+        elif (
+            self._state == SSLProtocolState.DO_HANDSHAKE and
+            new_state == SSLProtocolState.WRAPPED
+        ):
+            allowed = True
+
+        elif (
+            self._state == SSLProtocolState.WRAPPED and
+            new_state == SSLProtocolState.FLUSHING
+        ):
+            allowed = True
+
+        elif (
+            self._state == SSLProtocolState.FLUSHING and
+            new_state == SSLProtocolState.SHUTDOWN
+        ):
+            allowed = True
+
+        if allowed:
+            self._state = new_state
+
+        else:
+            raise RuntimeError(
+                'cannot switch state from {} to {}'.format(
+                    self._state, new_state))
+
+    # Handshake flow
 
     def _start_handshake(self):
         if self._loop.get_debug():
@@ -603,17 +528,18 @@ def _start_handshake(self):
             self._handshake_start_time = self._loop.time()
         else:
             self._handshake_start_time = None
-        self._in_handshake = True
-        # (b'', 1) is a special value in _process_write_backlog() to do
-        # the SSL handshake
-        self._write_backlog.append((b'', 1))
+
+        self._set_state(SSLProtocolState.DO_HANDSHAKE)
+
+        # start handshake timeout count down
         self._handshake_timeout_handle = \
             self._loop.call_later(self._ssl_handshake_timeout,
-                                  self._check_handshake_timeout)
-        self._process_write_backlog()
+                                  lambda: self._check_handshake_timeout())
+
+        self._do_handshake()
 
     def _check_handshake_timeout(self):
-        if self._in_handshake is True:
+        if self._state == SSLProtocolState.DO_HANDSHAKE:
             msg = (
                 f"SSL handshake is taking longer than "
                 f"{self._ssl_handshake_timeout} seconds: "
@@ -621,24 +547,37 @@ def _check_handshake_timeout(self):
             )
             self._fatal_error(ConnectionAbortedError(msg))
 
-    def _on_handshake_complete(self, handshake_exc):
-        self._in_handshake = False
-        self._handshake_timeout_handle.cancel()
-
-        sslobj = self._sslpipe.ssl_object
+    def _do_handshake(self):
         try:
-            if handshake_exc is not None:
+            self._sslobj.do_handshake()
+        except SSLAgainErrors:
+            self._process_outgoing()
+        except ssl.SSLError as exc:
+            self._on_handshake_complete(exc)
+        else:
+            self._on_handshake_complete(None)
+
+    def _on_handshake_complete(self, handshake_exc):
+        if self._handshake_timeout_handle is not None:
+            self._handshake_timeout_handle.cancel()
+            self._handshake_timeout_handle = None
+
+        sslobj = self._sslobj
+        try:
+            if handshake_exc is None:
+                self._set_state(SSLProtocolState.WRAPPED)
+            else:
                 raise handshake_exc
 
             peercert = sslobj.getpeercert()
-        except (SystemExit, KeyboardInterrupt):
-            raise
-        except BaseException as exc:
+        except Exception as exc:
+            self._set_state(SSLProtocolState.UNWRAPPED)
             if isinstance(exc, ssl.CertificateError):
                 msg = 'SSL handshake failed on verifying the certificate'
             else:
                 msg = 'SSL handshake failed'
             self._fatal_error(exc, msg)
+            self._wakeup_waiter(exc)
             return
 
         if self._loop.get_debug():
@@ -649,85 +588,330 @@ def _on_handshake_complete(self, handshake_exc):
         self._extra.update(peercert=peercert,
                            cipher=sslobj.cipher(),
                            compression=sslobj.compression(),
-                           ssl_object=sslobj,
-                           )
-        if self._call_connection_made:
-            self._app_protocol.connection_made(self._app_transport)
+                           ssl_object=sslobj)
+        if self._app_state == AppProtocolState.STATE_INIT:
+            self._app_state = AppProtocolState.STATE_CON_MADE
+            self._app_protocol.connection_made(self._get_app_transport())
         self._wakeup_waiter()
-        self._session_established = True
-        # In case transport.write() was already called. Don't call
-        # immediately _process_write_backlog(), but schedule it:
-        # _on_handshake_complete() can be called indirectly from
-        # _process_write_backlog(), and _process_write_backlog() is not
-        # reentrant.
-        self._loop.call_soon(self._process_write_backlog)
+        self._do_read()
 
-    def _process_write_backlog(self):
-        # Try to make progress on the write backlog.
-        if self._transport is None or self._sslpipe is None:
+    # Shutdown flow
+
+    def _start_shutdown(self):
+        if (
+            self._state in (
+                SSLProtocolState.FLUSHING,
+                SSLProtocolState.SHUTDOWN,
+                SSLProtocolState.UNWRAPPED
+            )
+        ):
+            return
+        if self._app_transport is not None:
+            self._app_transport._closed = True
+        if self._state == SSLProtocolState.DO_HANDSHAKE:
+            self._abort()
+        else:
+            self._set_state(SSLProtocolState.FLUSHING)
+            self._shutdown_timeout_handle = self._loop.call_later(
+                self._ssl_shutdown_timeout,
+                lambda: self._check_shutdown_timeout()
+            )
+            self._do_flush()
+
+    def _check_shutdown_timeout(self):
+        if (
+            self._state in (
+                SSLProtocolState.FLUSHING,
+                SSLProtocolState.SHUTDOWN
+            )
+        ):
+            self._transport._force_close(
+                exceptions.TimeoutError('SSL shutdown timed out'))
+
+    def _do_flush(self):
+        self._do_read()
+        self._set_state(SSLProtocolState.SHUTDOWN)
+        self._do_shutdown()
+
+    def _do_shutdown(self):
+        try:
+            if not self._eof_received:
+                self._sslobj.unwrap()
+        except SSLAgainErrors:
+            self._process_outgoing()
+        except ssl.SSLError as exc:
+            self._on_shutdown_complete(exc)
+        else:
+            self._process_outgoing()
+            self._call_eof_received()
+            self._on_shutdown_complete(None)
+
+    def _on_shutdown_complete(self, shutdown_exc):
+        if self._shutdown_timeout_handle is not None:
+            self._shutdown_timeout_handle.cancel()
+            self._shutdown_timeout_handle = None
+
+        if shutdown_exc:
+            self._fatal_error(shutdown_exc)
+        else:
+            self._loop.call_soon(self._transport.close)
+
+    def _abort(self):
+        self._set_state(SSLProtocolState.UNWRAPPED)
+        if self._transport is not None:
+            self._transport.abort()
+
+    # Outgoing flow
+
+    def _write_appdata(self, list_of_data):
+        if (
+            self._state in (
+                SSLProtocolState.FLUSHING,
+                SSLProtocolState.SHUTDOWN,
+                SSLProtocolState.UNWRAPPED
+            )
+        ):
+            if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+                logger.warning('SSL connection is closed')
+            self._conn_lost += 1
             return
 
+        for data in list_of_data:
+            self._write_backlog.append(data)
+            self._write_buffer_size += len(data)
+
         try:
-            for i in range(len(self._write_backlog)):
-                data, offset = self._write_backlog[0]
-                if data:
-                    ssldata, offset = self._sslpipe.feed_appdata(data, offset)
-                elif offset:
-                    ssldata = self._sslpipe.do_handshake(
-                        self._on_handshake_complete)
-                    offset = 1
+            if self._state == SSLProtocolState.WRAPPED:
+                self._do_write()
+
+        except Exception as ex:
+            self._fatal_error(ex, 'Fatal error on SSL protocol')
+
+    def _do_write(self):
+        try:
+            while self._write_backlog:
+                data = self._write_backlog[0]
+                count = self._sslobj.write(data)
+                data_len = len(data)
+                if count < data_len:
+                    self._write_backlog[0] = data[count:]
+                    self._write_buffer_size -= count
                 else:
-                    ssldata = self._sslpipe.shutdown(self._finalize)
-                    offset = 1
+                    del self._write_backlog[0]
+                    self._write_buffer_size -= data_len
+        except SSLAgainErrors:
+            pass
+        self._process_outgoing()
 
-                for chunk in ssldata:
-                    self._transport.write(chunk)
+    def _process_outgoing(self):
+        if not self._ssl_writing_paused:
+            data = self._outgoing.read()
+            if len(data):
+                self._transport.write(data)
+        self._control_app_writing()
 
-                if offset < len(data):
-                    self._write_backlog[0] = (data, offset)
-                    # A short write means that a write is blocked on a read
-                    # We need to enable reading if it is paused!
-                    assert self._sslpipe.need_ssldata
-                    if self._transport._paused:
-                        self._transport.resume_reading()
+    # Incoming flow
+
+    def _do_read(self):
+        if (
+            self._state not in (
+                SSLProtocolState.WRAPPED,
+                SSLProtocolState.FLUSHING,
+            )
+        ):
+            return
+        try:
+            if not self._app_reading_paused:
+                if self._app_protocol_is_buffer:
+                    self._do_read__buffered()
+                else:
+                    self._do_read__copied()
+                if self._write_backlog:
+                    self._do_write()
+                else:
+                    self._process_outgoing()
+            self._control_ssl_reading()
+        except Exception as ex:
+            self._fatal_error(ex, 'Fatal error on SSL protocol')
+
+    def _do_read__buffered(self):
+        offset = 0
+        count = 1
+
+        buf = self._app_protocol_get_buffer(self._get_read_buffer_size())
+        wants = len(buf)
+
+        try:
+            count = self._sslobj.read(wants, buf)
+
+            if count > 0:
+                offset = count
+                while offset < wants:
+                    count = self._sslobj.read(wants - offset, buf[offset:])
+                    if count > 0:
+                        offset += count
+                    else:
+                        break
+                else:
+                    self._loop.call_soon(lambda: self._do_read())
+        except SSLAgainErrors:
+            pass
+        if offset > 0:
+            self._app_protocol_buffer_updated(offset)
+        if not count:
+            # close_notify
+            self._call_eof_received()
+            self._start_shutdown()
+
+    def _do_read__copied(self):
+        chunk = b'1'
+        zero = True
+        one = False
+
+        try:
+            while True:
+                chunk = self._sslobj.read(self.max_size)
+                if not chunk:
                     break
+                if zero:
+                    zero = False
+                    one = True
+                    first = chunk
+                elif one:
+                    one = False
+                    data = [first, chunk]
+                else:
+                    data.append(chunk)
+        except SSLAgainErrors:
+            pass
+        if one:
+            self._app_protocol.data_received(first)
+        elif not zero:
+            self._app_protocol.data_received(b''.join(data))
+        if not chunk:
+            # close_notify
+            self._call_eof_received()
+            self._start_shutdown()
 
-                # An entire chunk from the backlog was processed. We can
-                # delete it and reduce the outstanding buffer size.
-                del self._write_backlog[0]
-                self._write_buffer_size -= len(data)
-        except (SystemExit, KeyboardInterrupt):
+    def _call_eof_received(self):
+        try:
+            if self._app_state == AppProtocolState.STATE_CON_MADE:
+                self._app_state = AppProtocolState.STATE_EOF
+                keep_open = self._app_protocol.eof_received()
+                if keep_open:
+                    logger.warning('returning true from eof_received() '
+                                   'has no effect when using ssl')
+        except (KeyboardInterrupt, SystemExit):
             raise
-        except BaseException as exc:
-            if self._in_handshake:
-                # Exceptions will be re-raised in _on_handshake_complete.
-                self._on_handshake_complete(exc)
-            else:
-                self._fatal_error(exc, 'Fatal error on SSL transport')
+        except BaseException as ex:
+            self._fatal_error(ex, 'Error calling eof_received()')
+
+    # Flow control for writes from APP socket
+
+    def _control_app_writing(self):
+        size = self._get_write_buffer_size()
+        if size >= self._outgoing_high_water and not self._app_writing_paused:
+            self._app_writing_paused = True
+            try:
+                self._app_protocol.pause_writing()
+            except (KeyboardInterrupt, SystemExit):
+                raise
+            except BaseException as exc:
+                self._loop.call_exception_handler({
+                    'message': 'protocol.pause_writing() failed',
+                    'exception': exc,
+                    'transport': self._app_transport,
+                    'protocol': self,
+                })
+        elif size <= self._outgoing_low_water and self._app_writing_paused:
+            self._app_writing_paused = False
+            try:
+                self._app_protocol.resume_writing()
+            except (KeyboardInterrupt, SystemExit):
+                raise
+            except BaseException as exc:
+                self._loop.call_exception_handler({
+                    'message': 'protocol.resume_writing() failed',
+                    'exception': exc,
+                    'transport': self._app_transport,
+                    'protocol': self,
+                })
+
+    def _get_write_buffer_size(self):
+        return self._outgoing.pending + self._write_buffer_size
+
+    def _set_write_buffer_limits(self, high=None, low=None):
+        high, low = add_flowcontrol_defaults(
+            high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_WRITE)
+        self._outgoing_high_water = high
+        self._outgoing_low_water = low
+
+    # Flow control for reads to APP socket
+
+    def _pause_reading(self):
+        self._app_reading_paused = True
+
+    def _resume_reading(self):
+        if self._app_reading_paused:
+            self._app_reading_paused = False
+
+            def resume():
+                if self._state == SSLProtocolState.WRAPPED:
+                    self._do_read()
+                elif self._state == SSLProtocolState.FLUSHING:
+                    self._do_flush()
+                elif self._state == SSLProtocolState.SHUTDOWN:
+                    self._do_shutdown()
+            self._loop.call_soon(resume)
+
+    # Flow control for reads from SSL socket
+
+    def _control_ssl_reading(self):
+        size = self._get_read_buffer_size()
+        if size >= self._incoming_high_water and not self._ssl_reading_paused:
+            self._ssl_reading_paused = True
+            self._transport.pause_reading()
+        elif size <= self._incoming_low_water and self._ssl_reading_paused:
+            self._ssl_reading_paused = False
+            self._transport.resume_reading()
+
+    def _set_read_buffer_limits(self, high=None, low=None):
+        high, low = add_flowcontrol_defaults(
+            high, low, constants.FLOW_CONTROL_HIGH_WATER_SSL_READ)
+        self._incoming_high_water = high
+        self._incoming_low_water = low
+
+    def _get_read_buffer_size(self):
+        return self._incoming.pending
+
+    # Flow control for writes to SSL socket
+
+    def pause_writing(self):
+        """Called when the low-level transport's buffer goes over
+        the high-water mark.
+        """
+        assert not self._ssl_writing_paused
+        self._ssl_writing_paused = True
+
+    def resume_writing(self):
+        """Called when the low-level transport's buffer drains below
+        the low-water mark.
+        """
+        assert self._ssl_writing_paused
+        self._ssl_writing_paused = False
+        self._process_outgoing()
 
     def _fatal_error(self, exc, message='Fatal error on transport'):
+        if self._transport:
+            self._transport._force_close(exc)
+
         if isinstance(exc, OSError):
             if self._loop.get_debug():
                 logger.debug("%r: %s", self, message, exc_info=True)
-        else:
+        elif not isinstance(exc, exceptions.CancelledError):
             self._loop.call_exception_handler({
                 'message': message,
                 'exception': exc,
                 'transport': self._transport,
                 'protocol': self,
             })
-        if self._transport:
-            self._transport._force_close(exc)
-
-    def _finalize(self):
-        self._sslpipe = None
-
-        if self._transport is not None:
-            self._transport.close()
-
-    def _abort(self):
-        try:
-            if self._transport is not None:
-                self._transport.abort()
-        finally:
-            self._finalize()