bpo-29883: Asyncio proactor udp (GH-13440)



Follow-up for #1067


https://bugs.python.org/issue29883
diff --git a/Lib/asyncio/proactor_events.py b/Lib/asyncio/proactor_events.py
index 6a53b2e..9b8ae06 100644
--- a/Lib/asyncio/proactor_events.py
+++ b/Lib/asyncio/proactor_events.py
@@ -11,6 +11,7 @@
 import socket
 import warnings
 import signal
+import collections
 
 from . import base_events
 from . import constants
@@ -23,6 +24,24 @@
 from .log import logger
 
 
+def _set_socket_extra(transport, sock):
+    transport._extra['socket'] = trsock.TransportSocket(sock)
+
+    try:
+        transport._extra['sockname'] = sock.getsockname()
+    except socket.error:
+        if transport._loop.get_debug():
+            logger.warning(
+                "getsockname() failed on %r", sock, exc_info=True)
+
+    if 'peername' not in transport._extra:
+        try:
+            transport._extra['peername'] = sock.getpeername()
+        except socket.error:
+            # UDP sockets may not have a peer name
+            transport._extra['peername'] = None
+
+
 class _ProactorBasePipeTransport(transports._FlowControlMixin,
                                  transports.BaseTransport):
     """Base class for pipe and socket transports."""
@@ -430,6 +449,134 @@
             self.close()
 
 
+class _ProactorDatagramTransport(_ProactorBasePipeTransport):
+    max_size = 256 * 1024
+    def __init__(self, loop, sock, protocol, address=None,
+                 waiter=None, extra=None):
+        self._address = address
+        self._empty_waiter = None
+        # We don't need to call _protocol.connection_made() since our base
+        # constructor does it for us.
+        super().__init__(loop, sock, protocol, waiter=waiter, extra=extra)
+
+        # The base constructor sets _buffer = None, so we set it here
+        self._buffer = collections.deque()
+        self._loop.call_soon(self._loop_reading)
+
+    def _set_extra(self, sock):
+        _set_socket_extra(self, sock)
+
+    def get_write_buffer_size(self):
+        return sum(len(data) for data, _ in self._buffer)
+
+    def abort(self):
+        self._force_close(None)
+
+    def sendto(self, data, addr=None):
+        if not isinstance(data, (bytes, bytearray, memoryview)):
+            raise TypeError('data argument must be bytes-like object (%r)',
+                            type(data))
+
+        if not data:
+            return
+
+        if self._address is not None and addr not in (None, self._address):
+            raise ValueError(
+                f'Invalid address: must be None or {self._address}')
+
+        if self._conn_lost and self._address:
+            if self._conn_lost >= constants.LOG_THRESHOLD_FOR_CONNLOST_WRITES:
+                logger.warning('socket.sendto() raised exception.')
+            self._conn_lost += 1
+            return
+
+        # Ensure that what we buffer is immutable.
+        self._buffer.append((bytes(data), addr))
+
+        if self._write_fut is None:
+            # No current write operations are active, kick one off
+            self._loop_writing()
+        # else: A write operation is already kicked off
+
+        self._maybe_pause_protocol()
+
+    def _loop_writing(self, fut=None):
+        try:
+            if self._conn_lost:
+                return
+
+            assert fut is self._write_fut
+            self._write_fut = None
+            if fut:
+                # We are in a _loop_writing() done callback, get the result
+                fut.result()
+
+            if not self._buffer or (self._conn_lost and self._address):
+                # The connection has been closed
+                if self._closing:
+                    self._loop.call_soon(self._call_connection_lost, None)
+                return
+
+            data, addr = self._buffer.popleft()
+            if self._address is not None:
+                self._write_fut = self._loop._proactor.send(self._sock,
+                                                            data)
+            else:
+                self._write_fut = self._loop._proactor.sendto(self._sock,
+                                                              data,
+                                                              addr=addr)
+        except OSError as exc:
+            self._protocol.error_received(exc)
+        except Exception as exc:
+            self._fatal_error(exc, 'Fatal write error on datagram transport')
+        else:
+            self._write_fut.add_done_callback(self._loop_writing)
+            self._maybe_resume_protocol()
+
+    def _loop_reading(self, fut=None):
+        data = None
+        try:
+            if self._conn_lost:
+                return
+
+            assert self._read_fut is fut or (self._read_fut is None and
+                                             self._closing)
+
+            self._read_fut = None
+            if fut is not None:
+                res = fut.result()
+
+                if self._closing:
+                    # since close() has been called we ignore any read data
+                    data = None
+                    return
+
+                if self._address is not None:
+                    data, addr = res, self._address
+                else:
+                    data, addr = res
+
+            if self._conn_lost:
+                return
+            if self._address is not None:
+                self._read_fut = self._loop._proactor.recv(self._sock,
+                                                           self.max_size)
+            else:
+                self._read_fut = self._loop._proactor.recvfrom(self._sock,
+                                                               self.max_size)
+        except OSError as exc:
+            self._protocol.error_received(exc)
+        except exceptions.CancelledError:
+            if not self._closing:
+                raise
+        else:
+            if self._read_fut is not None:
+                self._read_fut.add_done_callback(self._loop_reading)
+        finally:
+            if data:
+                self._protocol.datagram_received(data, addr)
+
+
 class _ProactorDuplexPipeTransport(_ProactorReadPipeTransport,
                                    _ProactorBaseWritePipeTransport,
                                    transports.Transport):
@@ -455,22 +602,7 @@
         base_events._set_nodelay(sock)
 
     def _set_extra(self, sock):
-        self._extra['socket'] = trsock.TransportSocket(sock)
-
-        try:
-            self._extra['sockname'] = sock.getsockname()
-        except (socket.error, AttributeError):
-            if self._loop.get_debug():
-                logger.warning(
-                    "getsockname() failed on %r", sock, exc_info=True)
-
-        if 'peername' not in self._extra:
-            try:
-                self._extra['peername'] = sock.getpeername()
-            except (socket.error, AttributeError):
-                if self._loop.get_debug():
-                    logger.warning("getpeername() failed on %r",
-                                   sock, exc_info=True)
+        _set_socket_extra(self, sock)
 
     def can_write_eof(self):
         return True
@@ -515,6 +647,11 @@
                                  extra=extra, server=server)
         return ssl_protocol._app_transport
 
+    def _make_datagram_transport(self, sock, protocol,
+                                 address=None, waiter=None, extra=None):
+        return _ProactorDatagramTransport(self, sock, protocol, address,
+                                          waiter, extra)
+
     def _make_duplex_pipe_transport(self, sock, protocol, waiter=None,
                                     extra=None):
         return _ProactorDuplexPipeTransport(self,
diff --git a/Lib/asyncio/windows_events.py b/Lib/asyncio/windows_events.py
index 61b40ba..ac51109 100644
--- a/Lib/asyncio/windows_events.py
+++ b/Lib/asyncio/windows_events.py
@@ -483,6 +483,44 @@
 
         return self._register(ov, conn, finish_recv)
 
+    def recvfrom(self, conn, nbytes, flags=0):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+        try:
+            ov.WSARecvFrom(conn.fileno(), nbytes, flags)
+        except BrokenPipeError:
+            return self._result((b'', None))
+
+        def finish_recv(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_recv)
+
+    def sendto(self, conn, buf, flags=0, addr=None):
+        self._register_with_iocp(conn)
+        ov = _overlapped.Overlapped(NULL)
+
+        ov.WSASendTo(conn.fileno(), buf, flags, addr)
+
+        def finish_send(trans, key, ov):
+            try:
+                return ov.getresult()
+            except OSError as exc:
+                if exc.winerror in (_overlapped.ERROR_NETNAME_DELETED,
+                                    _overlapped.ERROR_OPERATION_ABORTED):
+                    raise ConnectionResetError(*exc.args)
+                else:
+                    raise
+
+        return self._register(ov, conn, finish_send)
+
     def send(self, conn, buf, flags=0):
         self._register_with_iocp(conn)
         ov = _overlapped.Overlapped(NULL)
@@ -532,6 +570,14 @@
         return future
 
     def connect(self, conn, address):
+        if conn.type == socket.SOCK_DGRAM:
+            # WSAConnect will complete immediately for UDP sockets so we don't
+            # need to register any IOCP operation
+            _overlapped.WSAConnect(conn.fileno(), address)
+            fut = self._loop.create_future()
+            fut.set_result(None)
+            return fut
+
         self._register_with_iocp(conn)
         # The socket needs to be locally bound before we call ConnectEx().
         try:
diff --git a/Lib/test/test_asyncio/test_events.py b/Lib/test/test_asyncio/test_events.py
index e89db99..045654e 100644
--- a/Lib/test/test_asyncio/test_events.py
+++ b/Lib/test/test_asyncio/test_events.py
@@ -1249,11 +1249,6 @@
         server.transport.close()
 
     def test_create_datagram_endpoint_sock(self):
-        if (sys.platform == 'win32' and
-                isinstance(self.loop, proactor_events.BaseProactorEventLoop)):
-            raise unittest.SkipTest(
-                'UDP is not supported with proactor event loops')
-
         sock = None
         local_address = ('127.0.0.1', 0)
         infos = self.loop.run_until_complete(
@@ -2004,10 +1999,6 @@
         def test_writer_callback_cancel(self):
             raise unittest.SkipTest("IocpEventLoop does not have add_writer()")
 
-        def test_create_datagram_endpoint(self):
-            raise unittest.SkipTest(
-                "IocpEventLoop does not have create_datagram_endpoint()")
-
         def test_remove_fds_after_closing(self):
             raise unittest.SkipTest("IocpEventLoop does not have add_reader()")
 else:
diff --git a/Lib/test/test_asyncio/test_proactor_events.py b/Lib/test/test_asyncio/test_proactor_events.py
index 5952ccc..2e9995d 100644
--- a/Lib/test/test_asyncio/test_proactor_events.py
+++ b/Lib/test/test_asyncio/test_proactor_events.py
@@ -4,6 +4,7 @@
 import socket
 import unittest
 import sys
+from collections import deque
 from unittest import mock
 
 import asyncio
@@ -12,6 +13,7 @@
 from asyncio.proactor_events import _ProactorSocketTransport
 from asyncio.proactor_events import _ProactorWritePipeTransport
 from asyncio.proactor_events import _ProactorDuplexPipeTransport
+from asyncio.proactor_events import _ProactorDatagramTransport
 from test import support
 from test.test_asyncio import utils as test_utils
 
@@ -725,6 +727,208 @@
         self.assertFalse(tr.is_reading())
 
 
+class ProactorDatagramTransportTests(test_utils.TestCase):
+
+    def setUp(self):
+        super().setUp()
+        self.loop = self.new_test_loop()
+        self.proactor = mock.Mock()
+        self.loop._proactor = self.proactor
+        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+        self.sock = mock.Mock(spec_set=socket.socket)
+        self.sock.fileno.return_value = 7
+
+    def datagram_transport(self, address=None):
+        self.sock.getpeername.side_effect = None if address else OSError
+        transport = _ProactorDatagramTransport(self.loop, self.sock,
+                                               self.protocol,
+                                               address=address)
+        self.addCleanup(close_transport, transport)
+        return transport
+
+    def test_sendto(self):
+        data = b'data'
+        transport = self.datagram_transport()
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.proactor.sendto.called)
+        self.proactor.sendto.assert_called_with(
+            self.sock, data, addr=('0.0.0.0', 1234))
+
+    def test_sendto_bytearray(self):
+        data = bytearray(b'data')
+        transport = self.datagram_transport()
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.proactor.sendto.called)
+        self.proactor.sendto.assert_called_with(
+            self.sock, b'data', addr=('0.0.0.0', 1234))
+
+    def test_sendto_memoryview(self):
+        data = memoryview(b'data')
+        transport = self.datagram_transport()
+        transport.sendto(data, ('0.0.0.0', 1234))
+        self.assertTrue(self.proactor.sendto.called)
+        self.proactor.sendto.assert_called_with(
+            self.sock, b'data', addr=('0.0.0.0', 1234))
+
+    def test_sendto_no_data(self):
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data', ('0.0.0.0', 12345)))
+        transport.sendto(b'', ())
+        self.assertFalse(self.sock.sendto.called)
+        self.assertEqual(
+            [(b'data', ('0.0.0.0', 12345))], list(transport._buffer))
+
+    def test_sendto_buffer(self):
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport._write_fut = object()
+        transport.sendto(b'data2', ('0.0.0.0', 12345))
+        self.assertFalse(self.proactor.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+
+    def test_sendto_buffer_bytearray(self):
+        data2 = bytearray(b'data2')
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport._write_fut = object()
+        transport.sendto(data2, ('0.0.0.0', 12345))
+        self.assertFalse(self.proactor.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+        self.assertIsInstance(transport._buffer[1][0], bytes)
+
+    def test_sendto_buffer_memoryview(self):
+        data2 = memoryview(b'data2')
+        transport = self.datagram_transport()
+        transport._buffer.append((b'data1', ('0.0.0.0', 12345)))
+        transport._write_fut = object()
+        transport.sendto(data2, ('0.0.0.0', 12345))
+        self.assertFalse(self.proactor.sendto.called)
+        self.assertEqual(
+            [(b'data1', ('0.0.0.0', 12345)),
+             (b'data2', ('0.0.0.0', 12345))],
+            list(transport._buffer))
+        self.assertIsInstance(transport._buffer[1][0], bytes)
+
+    @mock.patch('asyncio.proactor_events.logger')
+    def test_sendto_exception(self, m_log):
+        data = b'data'
+        err = self.proactor.sendto.side_effect = RuntimeError()
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport.sendto(data, ())
+
+        self.assertTrue(transport._fatal_error.called)
+        transport._fatal_error.assert_called_with(
+                                   err,
+                                   'Fatal write error on datagram transport')
+        transport._conn_lost = 1
+
+        transport._address = ('123',)
+        transport.sendto(data)
+        transport.sendto(data)
+        transport.sendto(data)
+        transport.sendto(data)
+        transport.sendto(data)
+        m_log.warning.assert_called_with('socket.sendto() raised exception.')
+
+    def test_sendto_error_received(self):
+        data = b'data'
+
+        self.sock.sendto.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport.sendto(data, ())
+
+        self.assertEqual(transport._conn_lost, 0)
+        self.assertFalse(transport._fatal_error.called)
+
+    def test_sendto_error_received_connected(self):
+        data = b'data'
+
+        self.proactor.send.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        transport._fatal_error = mock.Mock()
+        transport.sendto(data)
+
+        self.assertFalse(transport._fatal_error.called)
+        self.assertTrue(self.protocol.error_received.called)
+
+    def test_sendto_str(self):
+        transport = self.datagram_transport()
+        self.assertRaises(TypeError, transport.sendto, 'str', ())
+
+    def test_sendto_connected_addr(self):
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        self.assertRaises(
+            ValueError, transport.sendto, b'str', ('0.0.0.0', 2))
+
+    def test_sendto_closing(self):
+        transport = self.datagram_transport(address=(1,))
+        transport.close()
+        self.assertEqual(transport._conn_lost, 1)
+        transport.sendto(b'data', (1,))
+        self.assertEqual(transport._conn_lost, 2)
+
+    def test__loop_writing_closing(self):
+        transport = self.datagram_transport()
+        transport._closing = True
+        transport._loop_writing()
+        self.assertIsNone(transport._write_fut)
+        test_utils.run_briefly(self.loop)
+        self.sock.close.assert_called_with()
+        self.protocol.connection_lost.assert_called_with(None)
+
+    def test__loop_writing_exception(self):
+        err = self.proactor.sendto.side_effect = RuntimeError()
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport._buffer.append((b'data', ()))
+        transport._loop_writing()
+
+        transport._fatal_error.assert_called_with(
+                                   err,
+                                   'Fatal write error on datagram transport')
+
+    def test__loop_writing_error_received(self):
+        self.proactor.sendto.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport()
+        transport._fatal_error = mock.Mock()
+        transport._buffer.append((b'data', ()))
+        transport._loop_writing()
+
+        self.assertFalse(transport._fatal_error.called)
+
+    def test__loop_writing_error_received_connection(self):
+        self.proactor.send.side_effect = ConnectionRefusedError
+
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        transport._fatal_error = mock.Mock()
+        transport._buffer.append((b'data', ()))
+        transport._loop_writing()
+
+        self.assertFalse(transport._fatal_error.called)
+        self.assertTrue(self.protocol.error_received.called)
+
+    @mock.patch('asyncio.base_events.logger.error')
+    def test_fatal_error_connected(self, m_exc):
+        transport = self.datagram_transport(address=('0.0.0.0', 1))
+        err = ConnectionRefusedError()
+        transport._fatal_error(err)
+        self.assertFalse(self.protocol.error_received.called)
+        m_exc.assert_not_called()
+
+
 class BaseProactorEventLoopTests(test_utils.TestCase):
 
     def setUp(self):
@@ -864,6 +1068,80 @@
         self.assertFalse(sock2.close.called)
         self.assertFalse(future2.cancel.called)
 
+    def datagram_transport(self):
+        self.protocol = test_utils.make_test_protocol(asyncio.DatagramProtocol)
+        return self.loop._make_datagram_transport(self.sock, self.protocol)
+
+    def test_make_datagram_transport(self):
+        tr = self.datagram_transport()
+        self.assertIsInstance(tr, _ProactorDatagramTransport)
+        close_transport(tr)
+
+    def test_datagram_loop_writing(self):
+        tr = self.datagram_transport()
+        tr._buffer.appendleft((b'data', ('127.0.0.1', 12068)))
+        tr._loop_writing()
+        self.loop._proactor.sendto.assert_called_with(self.sock, b'data', addr=('127.0.0.1', 12068))
+        self.loop._proactor.sendto.return_value.add_done_callback.\
+            assert_called_with(tr._loop_writing)
+
+        close_transport(tr)
+
+    def test_datagram_loop_reading(self):
+        tr = self.datagram_transport()
+        tr._loop_reading()
+        self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
+        self.assertFalse(self.protocol.datagram_received.called)
+        self.assertFalse(self.protocol.error_received.called)
+        close_transport(tr)
+
+    def test_datagram_loop_reading_data(self):
+        res = asyncio.Future(loop=self.loop)
+        res.set_result((b'data', ('127.0.0.1', 12068)))
+
+        tr = self.datagram_transport()
+        tr._read_fut = res
+        tr._loop_reading(res)
+        self.loop._proactor.recvfrom.assert_called_with(self.sock, 256 * 1024)
+        self.protocol.datagram_received.assert_called_with(b'data', ('127.0.0.1', 12068))
+        close_transport(tr)
+
+    def test_datagram_loop_reading_no_data(self):
+        res = asyncio.Future(loop=self.loop)
+        res.set_result((b'', ('127.0.0.1', 12068)))
+
+        tr = self.datagram_transport()
+        self.assertRaises(AssertionError, tr._loop_reading, res)
+
+        tr.close = mock.Mock()
+        tr._read_fut = res
+        tr._loop_reading(res)
+        self.assertTrue(self.loop._proactor.recvfrom.called)
+        self.assertFalse(self.protocol.error_received.called)
+        self.assertFalse(tr.close.called)
+        close_transport(tr)
+
+    def test_datagram_loop_reading_aborted(self):
+        err = self.loop._proactor.recvfrom.side_effect = ConnectionAbortedError()
+
+        tr = self.datagram_transport()
+        tr._fatal_error = mock.Mock()
+        tr._protocol.error_received = mock.Mock()
+        tr._loop_reading()
+        tr._protocol.error_received.assert_called_with(err)
+        close_transport(tr)
+
+    def test_datagram_loop_writing_aborted(self):
+        err = self.loop._proactor.sendto.side_effect = ConnectionAbortedError()
+
+        tr = self.datagram_transport()
+        tr._fatal_error = mock.Mock()
+        tr._protocol.error_received = mock.Mock()
+        tr._buffer.appendleft((b'Hello', ('127.0.0.1', 12068)))
+        tr._loop_writing()
+        tr._protocol.error_received.assert_called_with(err)
+        close_transport(tr)
+
 
 @unittest.skipIf(sys.platform != 'win32',
                  'Proactor is supported on Windows only')