bpo-33530: Implement Happy Eyeballs in asyncio, v2 (GH-7237)



Added two keyword arguments, `delay` and `interleave`, to
`BaseEventLoop.create_connection`. Happy eyeballs is activated if
`delay` is specified.

We now have documentation for the new arguments. `staggered_race()` is in its own module, but not exported to the main asyncio package.


https://bugs.python.org/issue33530
diff --git a/Lib/asyncio/base_events.py b/Lib/asyncio/base_events.py
index 9b4b846..c58906f 100644
--- a/Lib/asyncio/base_events.py
+++ b/Lib/asyncio/base_events.py
@@ -16,6 +16,7 @@
 import collections
 import collections.abc
 import concurrent.futures
+import functools
 import heapq
 import itertools
 import os
@@ -41,6 +42,7 @@
 from . import futures
 from . import protocols
 from . import sslproto
+from . import staggered
 from . import tasks
 from . import transports
 from .log import logger
@@ -159,6 +161,28 @@
     return None
 
 
+def _interleave_addrinfos(addrinfos, first_address_family_count=1):
+    """Interleave list of addrinfo tuples by family."""
+    # Group addresses by family
+    addrinfos_by_family = collections.OrderedDict()
+    for addr in addrinfos:
+        family = addr[0]
+        if family not in addrinfos_by_family:
+            addrinfos_by_family[family] = []
+        addrinfos_by_family[family].append(addr)
+    addrinfos_lists = list(addrinfos_by_family.values())
+
+    reordered = []
+    if first_address_family_count > 1:
+        reordered.extend(addrinfos_lists[0][:first_address_family_count - 1])
+        del addrinfos_lists[0][:first_address_family_count - 1]
+    reordered.extend(
+        a for a in itertools.chain.from_iterable(
+            itertools.zip_longest(*addrinfos_lists)
+        ) if a is not None)
+    return reordered
+
+
 def _run_until_complete_cb(fut):
     if not fut.cancelled():
         exc = fut.exception()
@@ -871,12 +895,49 @@
                 "offset must be a non-negative integer (got {!r})".format(
                     offset))
 
+    async def _connect_sock(self, exceptions, addr_info, local_addr_infos=None):
+        """Create, bind and connect one socket."""
+        my_exceptions = []
+        exceptions.append(my_exceptions)
+        family, type_, proto, _, address = addr_info
+        sock = None
+        try:
+            sock = socket.socket(family=family, type=type_, proto=proto)
+            sock.setblocking(False)
+            if local_addr_infos is not None:
+                for _, _, _, _, laddr in local_addr_infos:
+                    try:
+                        sock.bind(laddr)
+                        break
+                    except OSError as exc:
+                        msg = (
+                            f'error while attempting to bind on '
+                            f'address {laddr!r}: '
+                            f'{exc.strerror.lower()}'
+                        )
+                        exc = OSError(exc.errno, msg)
+                        my_exceptions.append(exc)
+                else:  # all bind attempts failed
+                    raise my_exceptions.pop()
+            await self.sock_connect(sock, address)
+            return sock
+        except OSError as exc:
+            my_exceptions.append(exc)
+            if sock is not None:
+                sock.close()
+            raise
+        except:
+            if sock is not None:
+                sock.close()
+            raise
+
     async def create_connection(
             self, protocol_factory, host=None, port=None,
             *, ssl=None, family=0,
             proto=0, flags=0, sock=None,
             local_addr=None, server_hostname=None,
-            ssl_handshake_timeout=None):
+            ssl_handshake_timeout=None,
+            happy_eyeballs_delay=None, interleave=None):
         """Connect to a TCP server.
 
         Create a streaming transport connection to a given Internet host and
@@ -911,6 +972,10 @@
             raise ValueError(
                 'ssl_handshake_timeout is only meaningful with ssl')
 
+        if happy_eyeballs_delay is not None and interleave is None:
+            # If using happy eyeballs, default to interleave addresses by family
+            interleave = 1
+
         if host is not None or port is not None:
             if sock is not None:
                 raise ValueError(
@@ -929,43 +994,31 @@
                     flags=flags, loop=self)
                 if not laddr_infos:
                     raise OSError('getaddrinfo() returned empty list')
+            else:
+                laddr_infos = None
+
+            if interleave:
+                infos = _interleave_addrinfos(infos, interleave)
 
             exceptions = []
-            for family, type, proto, cname, address in infos:
-                try:
-                    sock = socket.socket(family=family, type=type, proto=proto)
-                    sock.setblocking(False)
-                    if local_addr is not None:
-                        for _, _, _, _, laddr in laddr_infos:
-                            try:
-                                sock.bind(laddr)
-                                break
-                            except OSError as exc:
-                                msg = (
-                                    f'error while attempting to bind on '
-                                    f'address {laddr!r}: '
-                                    f'{exc.strerror.lower()}'
-                                )
-                                exc = OSError(exc.errno, msg)
-                                exceptions.append(exc)
-                        else:
-                            sock.close()
-                            sock = None
-                            continue
-                    if self._debug:
-                        logger.debug("connect %r to %r", sock, address)
-                    await self.sock_connect(sock, address)
-                except OSError as exc:
-                    if sock is not None:
-                        sock.close()
-                    exceptions.append(exc)
-                except:
-                    if sock is not None:
-                        sock.close()
-                    raise
-                else:
-                    break
-            else:
+            if happy_eyeballs_delay is None:
+                # not using happy eyeballs
+                for addrinfo in infos:
+                    try:
+                        sock = await self._connect_sock(
+                            exceptions, addrinfo, laddr_infos)
+                        break
+                    except OSError:
+                        continue
+            else:  # using happy eyeballs
+                sock, _, _ = await staggered.staggered_race(
+                    (functools.partial(self._connect_sock,
+                                       exceptions, addrinfo, laddr_infos)
+                     for addrinfo in infos),
+                    happy_eyeballs_delay, loop=self)
+
+            if sock is None:
+                exceptions = [exc for sub in exceptions for exc in sub]
                 if len(exceptions) == 1:
                     raise exceptions[0]
                 else: