bpo-29883: Asyncio proactor udp (GH-13440)



Follow-up for #1067


https://bugs.python.org/issue29883
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 6a53b2e..9b8ae06 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -11,6 +11,7 @@
 import socket
 import warnings
 import signal
+import collections
 
 from . import base_events
 from . import constants
@@ -23,6 +24,24 @@
 from .log import logger
 
 
+def _set_socket_extra(transport, sock):
+    transport._extra['socket'] = trsock.TransportSocket(sock)
+
+    try:
+        transport._extra['sockname'] = sock.getsockname()
+    except socket.error:
+        if transport._loop.get_debug():
+            logger.warning(
+                "getsockname() failed on %r", sock, exc_info=True)
+
+    if 'peername' not in transport._extra:
+        try:
+            transport._extra['peername'] = sock.getpeername()
+        except socket.error:
+            # UDP sockets may not have a peer name
+            transport._extra['peername'] = None
+
+
 class _ProactorBasePipeTransport(transports._FlowControlMixin,
                                  transports.BaseTransport):
     """Base class for pipe and socket transports."""
@@ -430,6 +449,134 @@
             self.close()
 
 
+class _ProactorDatagramTransport(_ProactorBasePipeTransport):
+    max_size = 256 * 1024
+    def __init__(self, loop, sock, protocol, address=None,
+                 waiter=None, extra=None):
+        self._address = address
+        self._empty_waiter = None
+        # We don't need to call _protocol.connection_made() since our base
+        # constructor does it for us.
+        super().__init__(loop, sock, protocol, waiter=waiter, extra=extra)
+
+        # The base constructor sets _buffer = None, so we set it here
+        self._buffer = collections.deque()
+        self._loop.call_soon(self._loop_reading)
+
+    def _set_extra(self, sock):
+        _set_socket_extra(self, sock)
+
+    def get_write_buffer_size(self):
+        return sum(len(data) for data, _ in self._buffer)
+
+    def abort(self):
+        self._force_close(None)
+
+    def sendto(self, data, addr=None):
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be bytes-like object (%r)',
+                            type(data))
+
+        if not data:
+            return
+
+        if self._address is not None and addr not in (None, self._address):
+            raise ValueError(
+                f'Invalid address: must be None or {self._address}')
+
+        if self._conn_lost and self._address:
+            if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+                logger.warning('socket.sendto() raised exception.')
+            self._conn_lost += 1
+            return
+
+        # Ensure that what we buffer is immutable.
+        self._buffer.append((bytes(data), addr))
+
+        if self._write_fut is None:
+            # No current write operations are active, kick one off
+            self._loop_writing()
+        # else: A write operation is already kicked off
+
+        self._maybe_pause_protocol()
+
+    def _loop_writing(self, fut=None):
+        try:
+            if self._conn_lost:
+                return
+
+            assert fut is self._write_fut
+            self._write_fut = None
+            if fut:
+                # We are in a _loop_writing() done callback, get the result
+                fut.result()
+
+            if not self._buffer or (self._conn_lost and self._address):
+                # The connection has been closed
+                if self._closing:
+                    self._loop.call_soon(self._call_connection_lost, None)
+                return
+
+            data, addr = self._buffer.popleft()
+            if self._address is not None:
+                self._write_fut = self._loop._proactor.send(self._sock,
+                                                            data)
+            else:
+                self._write_fut = self._loop._proactor.sendto(self._sock,
+                                                              data,
+                                                              addr=addr)
+        except OSError as exc:
+            self._protocol.error_received(exc)
+        except Exception as exc:
+            self._fatal_error(exc, 'Fatal write error on datagram transport')
+        else:
+            self._write_fut.add_done_callback(self._loop_writing)
+            self._maybe_resume_protocol()
+
+    def _loop_reading(self, fut=None):
+        data = None
+        try:
+            if self._conn_lost:
+                return
+
+            assert self._read_fut is fut or (self._read_fut is None and
+                                             self._closing)
+
+            self._read_fut = None
+            if fut is not None:
+                res = fut.result()
+
+                if self._closing:
+                    # since close() has been called we ignore any read data
+                    data = None
+                    return
+
+                if self._address is not None:
+                    data, addr = res, self._address
+                else:
+                    data, addr = res
+
+            if self._conn_lost:
+                return
+            if self._address is not None:
+                self._read_fut = self._loop._proactor.recv(self._sock,
+                                                           self.max_size)
+            else:
+                self._read_fut = self._loop._proactor.recvfrom(self._sock,
+                                                               self.max_size)
+        except OSError as exc:
+            self._protocol.error_received(exc)
+        except exceptions.CancelledError:
+            if not self._closing:
+                raise
+        else:
+            if self._read_fut is not None:
+                self._read_fut.add_done_callback(self._loop_reading)
+        finally:
+            if data:
+                self._protocol.datagram_received(data, addr)
+
+
 class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport,
                                    _ProactorBaseWritePipeTransport,
                                    transports.Transport):
@@ -455,22 +602,7 @@
         base_events._set_nodelay(sock)
 
     def _set_extra(self, sock):
-        self._extra['socket'] = trsock.TransportSocket(sock)
-
-        try:
-            self._extra['sockname'] = sock.getsockname()
-        except (socket.error, AttributeError):
-            if self._loop.get_debug():
-                logger.warning(
-                    "getsockname() failed on %r", sock, exc_info=True)
-
-        if 'peername' not in self._extra:
-            try:
-                self._extra['peername'] = sock.getpeername()
-            except (socket.error, AttributeError):
-                if self._loop.get_debug():
-                    logger.warning("getpeername() failed on %r",
-                                   sock, exc_info=True)
+        _set_socket_extra(self, sock)
 
     def can_write_eof(self):
         return True
@@ -515,6 +647,11 @@
                                  extra=extra, server=server)
         return ssl_protocol._app_transport
 
+    def _make_datagram_transport(self, sock, protocol,
+                                 address=None, waiter=None, extra=None):
+        return _ProactorDatagramTransport(self, sock, protocol, address,
+                                          waiter, extra)
+
     def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
                                     extra=None):
         return _ProactorDuplexPipeTransport(self,
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 61b40ba..ac51109 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -483,6 +483,44 @@
 
         return self._register(ov, conn, finish_recv)
 
+    def recvfrom(self, conn, nbytes, flags=0):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+        try:
+            ov.WSARecvFrom(conn.fileno(), nbytes, flags)
+        except BrokenPipeError:
+            return self._result((b'', None))
+
+        def finish_recv(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_recv)
+
+    def sendto(self, conn, buf, flags=0, addr=None):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+
+        ov.WSASendTo(conn.fileno(), buf, flags, addr)
+
+        def finish_send(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_send)
+
     def send(self, conn, buf, flags=0):
         self._register_with_iocp(conn)
         ov = _overlapped.Overlapped(NULL)
@@ -532,6 +570,14 @@
         return future
 
     def connect(self, conn, address):
+        if conn.type == socket.SOCK_DGRAM:
+            # WSAConnect will complete immediately for UDP sockets so we don't
+            # need to register any IOCP operation
+            _overlapped.WSAConnect(conn.fileno(), address)
+            fut = self._loop.create_future()
+            fut.set_result(None)
+            return fut
+
         self._register_with_iocp(conn)
         # The socket needs to be locally bound before we call ConnectEx().
         try: