[3.9] bpo-42345: Fix three issues with typing.Literal parameters (GH-23294) (GH-23335)
Literal equality no longer depends on the order of arguments.
Fix issue related to `typing.Literal` caching by adding `typed` parameter to `typing._tp_cache` function.
Add deduplication of `typing.Literal` arguments.
(cherry picked from commit f03d318ca42578e45405717aedd4ac26ea52aaed)
Co-authored-by: Yurii Karabas <1998uriyyo@gmail.com>
diff --git a/Lib/typing.py b/Lib/typing.py
index 6fd67b0..14952ec 100644
--- a/Lib/typing.py
+++ b/Lib/typing.py
@@ -200,6 +200,20 @@
f" actual {alen}, expected {elen}")
+def _deduplicate(params):
+ # Weed out strict duplicates, preserving the first of each occurrence.
+ all_params = set(params)
+ if len(all_params) < len(params):
+ new_params = []
+ for t in params:
+ if t in all_params:
+ new_params.append(t)
+ all_params.remove(t)
+ params = new_params
+ assert not all_params, all_params
+ return params
+
+
def _remove_dups_flatten(parameters):
"""An internal helper for Union creation and substitution: flatten Unions
among parameters, then remove duplicates.
@@ -213,38 +227,45 @@
params.extend(p[1:])
else:
params.append(p)
- # Weed out strict duplicates, preserving the first of each occurrence.
- all_params = set(params)
- if len(all_params) < len(params):
- new_params = []
- for t in params:
- if t in all_params:
- new_params.append(t)
- all_params.remove(t)
- params = new_params
- assert not all_params, all_params
+
+ return tuple(_deduplicate(params))
+
+
+def _flatten_literal_params(parameters):
+ """An internal helper for Literal creation: flatten Literals among parameters"""
+ params = []
+ for p in parameters:
+ if isinstance(p, _LiteralGenericAlias):
+ params.extend(p.__args__)
+ else:
+ params.append(p)
return tuple(params)
_cleanups = []
-def _tp_cache(func):
+def _tp_cache(func=None, /, *, typed=False):
"""Internal wrapper caching __getitem__ of generic types with a fallback to
original function for non-hashable arguments.
"""
- cached = functools.lru_cache()(func)
- _cleanups.append(cached.cache_clear)
+ def decorator(func):
+ cached = functools.lru_cache(typed=typed)(func)
+ _cleanups.append(cached.cache_clear)
- @functools.wraps(func)
- def inner(*args, **kwds):
- try:
- return cached(*args, **kwds)
- except TypeError:
- pass # All real errors (not unhashable args) are raised below.
- return func(*args, **kwds)
- return inner
+ @functools.wraps(func)
+ def inner(*args, **kwds):
+ try:
+ return cached(*args, **kwds)
+ except TypeError:
+ pass # All real errors (not unhashable args) are raised below.
+ return func(*args, **kwds)
+ return inner
+ if func is not None:
+ return decorator(func)
+
+ return decorator
def _eval_type(t, globalns, localns, recursive_guard=frozenset()):
"""Evaluate all forward references in the given type t.
@@ -317,6 +338,13 @@
def __getitem__(self, parameters):
return self._getitem(self, parameters)
+
+class _LiteralSpecialForm(_SpecialForm, _root=True):
+ @_tp_cache(typed=True)
+ def __getitem__(self, parameters):
+ return self._getitem(self, parameters)
+
+
@_SpecialForm
def Any(self, parameters):
"""Special type indicating an unconstrained type.
@@ -434,7 +462,7 @@
arg = _type_check(parameters, f"{self} requires a single type.")
return Union[arg, type(None)]
-@_SpecialForm
+@_LiteralSpecialForm
def Literal(self, parameters):
"""Special typing form to define literal types (a.k.a. value types).
@@ -458,7 +486,17 @@
"""
# There is no '_type_check' call because arguments to Literal[...] are
# values, not types.
- return _GenericAlias(self, parameters)
+ if not isinstance(parameters, tuple):
+ parameters = (parameters,)
+
+ parameters = _flatten_literal_params(parameters)
+
+ try:
+ parameters = tuple(p for p, _ in _deduplicate(list(_value_and_type_iter(parameters))))
+ except TypeError: # unhashable parameters
+ pass
+
+ return _LiteralGenericAlias(self, parameters)
class ForwardRef(_Final, _root=True):
@@ -881,6 +919,22 @@
return super().__repr__()
+def _value_and_type_iter(parameters):
+ return ((p, type(p)) for p in parameters)
+
+
+class _LiteralGenericAlias(_GenericAlias, _root=True):
+
+ def __eq__(self, other):
+ if not isinstance(other, _LiteralGenericAlias):
+ return NotImplemented
+
+ return set(_value_and_type_iter(self.__args__)) == set(_value_and_type_iter(other.__args__))
+
+ def __hash__(self):
+ return hash(tuple(_value_and_type_iter(self.__args__)))
+
+
class Generic:
"""Abstract base class for generic types.