Add keyword-only fields to dataclasses. (GH=25608)
diff --git a/Lib/dataclasses.py b/Lib/dataclasses.py
index 036349b..3de50cf 100644
--- a/Lib/dataclasses.py
+++ b/Lib/dataclasses.py
@@ -16,6 +16,7 @@
'Field',
'FrozenInstanceError',
'InitVar',
+ 'KW_ONLY',
'MISSING',
# Helper functions.
@@ -163,8 +164,8 @@
# +-------+-------+-------+
# | True | add | | <- the default
# +=======+=======+=======+
-# __match_args__ is a tuple of __init__ parameter names; non-init fields must
-# be matched by keyword.
+# __match_args__ is always added unless the class already defines it. It is a
+# tuple of __init__ parameter names; non-init fields must be matched by keyword.
# Raised when an attempt is made to modify a frozen class.
@@ -184,6 +185,12 @@ class _MISSING_TYPE:
pass
MISSING = _MISSING_TYPE()
+# A sentinel object to indicate that following fields are keyword-only by
+# default. Use a class to give it a better repr.
+class _KW_ONLY_TYPE:
+ pass
+KW_ONLY = _KW_ONLY_TYPE()
+
# Since most per-field metadata will be unused, create an empty
# read-only proxy that can be shared among all fields.
_EMPTY_METADATA = types.MappingProxyType({})
@@ -232,7 +239,6 @@ def __repr__(self):
def __class_getitem__(cls, type):
return InitVar(type)
-
# Instances of Field are only ever created from within this module,
# and only from the field() function, although Field instances are
# exposed externally as (conceptually) read-only objects.
@@ -253,11 +259,12 @@ class Field:
'init',
'compare',
'metadata',
+ 'kw_only',
'_field_type', # Private: not to be used by user code.
)
def __init__(self, default, default_factory, init, repr, hash, compare,
- metadata):
+ metadata, kw_only):
self.name = None
self.type = None
self.default = default
@@ -269,6 +276,7 @@ def __init__(self, default, default_factory, init, repr, hash, compare,
self.metadata = (_EMPTY_METADATA
if metadata is None else
types.MappingProxyType(metadata))
+ self.kw_only = kw_only
self._field_type = None
def __repr__(self):
@@ -282,6 +290,7 @@ def __repr__(self):
f'hash={self.hash!r},'
f'compare={self.compare!r},'
f'metadata={self.metadata!r},'
+ f'kw_only={self.kw_only!r},'
f'_field_type={self._field_type}'
')')
@@ -335,17 +344,19 @@ def __repr__(self):
# so that a type checker can be told (via overloads) that this is a
# function whose type depends on its parameters.
def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
- hash=None, compare=True, metadata=None):
+ hash=None, compare=True, metadata=None, kw_only=MISSING):
"""Return an object to identify dataclass fields.
default is the default value of the field. default_factory is a
0-argument function called to initialize a field's value. If init
- is True, the field will be a parameter to the class's __init__()
- function. If repr is True, the field will be included in the
- object's repr(). If hash is True, the field will be included in
- the object's hash(). If compare is True, the field will be used
- in comparison functions. metadata, if specified, must be a
- mapping which is stored but not otherwise examined by dataclass.
+ is true, the field will be a parameter to the class's __init__()
+ function. If repr is true, the field will be included in the
+ object's repr(). If hash is true, the field will be included in the
+ object's hash(). If compare is true, the field will be used in
+ comparison functions. metadata, if specified, must be a mapping
+ which is stored but not otherwise examined by dataclass. If kw_only
+ is true, the field will become a keyword-only parameter to
+ __init__().
It is an error to specify both default and default_factory.
"""
@@ -353,7 +364,16 @@ def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
if default is not MISSING and default_factory is not MISSING:
raise ValueError('cannot specify both default and default_factory')
return Field(default, default_factory, init, repr, hash, compare,
- metadata)
+ metadata, kw_only)
+
+
+def _fields_in_init_order(fields):
+ # Returns the fields as __init__ will output them. It returns 2 tuples:
+ # the first for normal args, and the second for keyword args.
+
+ return (tuple(f for f in fields if f.init and not f.kw_only),
+ tuple(f for f in fields if f.init and f.kw_only)
+ )
def _tuple_str(obj_name, fields):
@@ -410,7 +430,6 @@ def _create_fn(name, args, body, *, globals=None, locals=None,
local_vars = ', '.join(locals.keys())
txt = f"def __create_fn__({local_vars}):\n{txt}\n return {name}"
-
ns = {}
exec(txt, globals, ns)
return ns['__create_fn__'](**locals)
@@ -501,7 +520,8 @@ def _init_param(f):
return f'{f.name}:_type_{f.name}{default}'
-def _init_fn(fields, frozen, has_post_init, self_name, globals):
+def _init_fn(fields, std_fields, kw_only_fields, frozen, has_post_init,
+ self_name, globals):
# fields contains both real fields and InitVar pseudo-fields.
# Make sure we don't have fields without defaults following fields
@@ -509,9 +529,10 @@ def _init_fn(fields, frozen, has_post_init, self_name, globals):
# function source code, but catching it here gives a better error
# message, and future-proofs us in case we build up the function
# using ast.
+
seen_default = False
- for f in fields:
- # Only consider fields in the __init__ call.
+ for f in std_fields:
+ # Only consider the non-kw-only fields in the __init__ call.
if f.init:
if not (f.default is MISSING and f.default_factory is MISSING):
seen_default = True
@@ -543,8 +564,15 @@ def _init_fn(fields, frozen, has_post_init, self_name, globals):
if not body_lines:
body_lines = ['pass']
+ _init_params = [_init_param(f) for f in std_fields]
+ if kw_only_fields:
+ # Add the keyword-only args. Because the * can only be added if
+ # there's at least one keyword-only arg, there needs to be a test here
+ # (instead of just concatenting the lists together).
+ _init_params += ['*']
+ _init_params += [_init_param(f) for f in kw_only_fields]
return _create_fn('__init__',
- [self_name] + [_init_param(f) for f in fields if f.init],
+ [self_name] + _init_params,
body_lines,
locals=locals,
globals=globals,
@@ -623,6 +651,9 @@ def _is_initvar(a_type, dataclasses):
return (a_type is dataclasses.InitVar
or type(a_type) is dataclasses.InitVar)
+def _is_kw_only(a_type, dataclasses):
+ return a_type is dataclasses.KW_ONLY
+
def _is_type(annotation, cls, a_module, a_type, is_type_predicate):
# Given a type annotation string, does it refer to a_type in
@@ -683,10 +714,11 @@ def _is_type(annotation, cls, a_module, a_type, is_type_predicate):
return False
-def _get_field(cls, a_name, a_type):
- # Return a Field object for this field name and type. ClassVars
- # and InitVars are also returned, but marked as such (see
- # f._field_type).
+def _get_field(cls, a_name, a_type, default_kw_only):
+ # Return a Field object for this field name and type. ClassVars and
+ # InitVars are also returned, but marked as such (see f._field_type).
+ # default_kw_only is the value of kw_only to use if there isn't a field()
+ # that defines it.
# If the default value isn't derived from Field, then it's only a
# normal default value. Convert it to a Field().
@@ -757,6 +789,19 @@ def _get_field(cls, a_name, a_type):
# init=<not-the-default-init-value>)? It makes no sense for
# ClassVar and InitVar to specify init=<anything>.
+ # kw_only validation and assignment.
+ if f._field_type in (_FIELD, _FIELD_INITVAR):
+ # For real and InitVar fields, if kw_only wasn't specified use the
+ # default value.
+ if f.kw_only is MISSING:
+ f.kw_only = default_kw_only
+ else:
+ # Make sure kw_only isn't set for ClassVars
+ assert f._field_type is _FIELD_CLASSVAR
+ if f.kw_only is not MISSING:
+ raise TypeError(f'field {f.name} is a ClassVar but specifies '
+ 'kw_only')
+
# For real fields, disallow mutable defaults for known types.
if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
raise ValueError(f'mutable default {type(f.default)} for field '
@@ -829,7 +874,7 @@ def _hash_exception(cls, fields, globals):
def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
- match_args):
+ match_args, kw_only):
# Now that dicts retain insertion order, there's no reason to use
# an ordered dict. I am leveraging that ordering here, because
# derived class fields overwrite base class fields, but the order
@@ -883,8 +928,22 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
# Now find fields in our class. While doing so, validate some
# things, and set the default values (as class attributes) where
# we can.
- cls_fields = [_get_field(cls, name, type)
- for name, type in cls_annotations.items()]
+ cls_fields = []
+ # Get a reference to this module for the _is_kw_only() test.
+ dataclasses = sys.modules[__name__]
+ for name, type in cls_annotations.items():
+ # See if this is a marker to change the value of kw_only.
+ if (_is_kw_only(type, dataclasses)
+ or (isinstance(type, str)
+ and _is_type(type, cls, dataclasses, dataclasses.KW_ONLY,
+ _is_kw_only))):
+ # Switch the default to kw_only=True, and ignore this
+ # annotation: it's not a real field.
+ kw_only = True
+ else:
+ # Otherwise it's a field of some type.
+ cls_fields.append(_get_field(cls, name, type, kw_only))
+
for f in cls_fields:
fields[f.name] = f
@@ -939,15 +998,22 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
if order and not eq:
raise ValueError('eq must be true if order is true')
+ # Include InitVars and regular fields (so, not ClassVars). This is
+ # initialized here, outside of the "if init:" test, because std_init_fields
+ # is used with match_args, below.
+ all_init_fields = [f for f in fields.values()
+ if f._field_type in (_FIELD, _FIELD_INITVAR)]
+ (std_init_fields,
+ kw_only_init_fields) = _fields_in_init_order(all_init_fields)
+
if init:
# Does this class have a post-init function?
has_post_init = hasattr(cls, _POST_INIT_NAME)
- # Include InitVars and regular fields (so, not ClassVars).
- flds = [f for f in fields.values()
- if f._field_type in (_FIELD, _FIELD_INITVAR)]
_set_new_attribute(cls, '__init__',
- _init_fn(flds,
+ _init_fn(all_init_fields,
+ std_init_fields,
+ kw_only_init_fields,
frozen,
has_post_init,
# The name to use for the "self"
@@ -1016,8 +1082,9 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
str(inspect.signature(cls)).replace(' -> None', ''))
if match_args:
+ # I could probably compute this once
_set_new_attribute(cls, '__match_args__',
- tuple(f.name for f in field_list if f.init))
+ tuple(f.name for f in std_init_fields))
abc.update_abstractmethods(cls)
@@ -1025,7 +1092,8 @@ def _process_class(cls, init, repr, eq, order, unsafe_hash, frozen,
def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
- unsafe_hash=False, frozen=False, match_args=True):
+ unsafe_hash=False, frozen=False, match_args=True,
+ kw_only=False):
"""Returns the same class as was passed in, with dunder methods
added based on the fields defined in the class.
@@ -1036,12 +1104,13 @@ def dataclass(cls=None, /, *, init=True, repr=True, eq=True, order=False,
comparison dunder methods are added. If unsafe_hash is true, a
__hash__() method function is added. If frozen is true, fields may
not be assigned to after instance creation. If match_args is true,
- the __match_args__ tuple is added.
+ the __match_args__ tuple is added. If kw_only is true, then by
+ default all fields are keyword-only.
"""
def wrap(cls):
return _process_class(cls, init, repr, eq, order, unsafe_hash,
- frozen, match_args)
+ frozen, match_args, kw_only)
# See if we're being called as @dataclass or @dataclass().
if cls is None:
diff --git a/Lib/test/test_dataclasses.py b/Lib/test/test_dataclasses.py
index f35f466..edb0848 100644
--- a/Lib/test/test_dataclasses.py
+++ b/Lib/test/test_dataclasses.py
@@ -61,6 +61,7 @@ def test_field_repr(self):
f"default=1,default_factory={MISSING!r}," \
"init=True,repr=False,hash=None," \
"compare=True,metadata=mappingproxy({})," \
+ f"kw_only={MISSING!r}," \
"_field_type=None)"
self.assertEqual(repr_output, expected_output)
@@ -3501,5 +3502,163 @@ def test_make_dataclasses(self):
self.assertEqual(C.__match_args__, ('z',))
+class TestKwArgs(unittest.TestCase):
+ def test_no_classvar_kwarg(self):
+ msg = 'field a is a ClassVar but specifies kw_only'
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: ClassVar[int] = field(kw_only=True)
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass
+ class A:
+ a: ClassVar[int] = field(kw_only=False)
+
+ with self.assertRaisesRegex(TypeError, msg):
+ @dataclass(kw_only=True)
+ class A:
+ a: ClassVar[int] = field(kw_only=False)
+
+ def test_field_marked_as_kwonly(self):
+ #######################
+ # Using dataclass(kw_only=True)
+ @dataclass(kw_only=True)
+ class A:
+ a: int
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=True)
+ class A:
+ a: int = field(kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=True)
+ class A:
+ a: int = field(kw_only=False)
+ self.assertFalse(fields(A)[0].kw_only)
+
+ #######################
+ # Using dataclass(kw_only=False)
+ @dataclass(kw_only=False)
+ class A:
+ a: int
+ self.assertFalse(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=False)
+ class A:
+ a: int = field(kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass(kw_only=False)
+ class A:
+ a: int = field(kw_only=False)
+ self.assertFalse(fields(A)[0].kw_only)
+
+ #######################
+ # Not specifying dataclass(kw_only)
+ @dataclass
+ class A:
+ a: int
+ self.assertFalse(fields(A)[0].kw_only)
+
+ @dataclass
+ class A:
+ a: int = field(kw_only=True)
+ self.assertTrue(fields(A)[0].kw_only)
+
+ @dataclass
+ class A:
+ a: int = field(kw_only=False)
+ self.assertFalse(fields(A)[0].kw_only)
+
+ def test_match_args(self):
+ # kw fields don't show up in __match_args__.
+ @dataclass(kw_only=True)
+ class C:
+ a: int
+ self.assertEqual(C(a=42).__match_args__, ())
+
+ @dataclass
+ class C:
+ a: int
+ b: int = field(kw_only=True)
+ self.assertEqual(C(42, b=10).__match_args__, ('a',))
+
+ def test_KW_ONLY(self):
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int
+ A(3, c=5, b=4)
+ msg = "takes 2 positional arguments but 4 were given"
+ with self.assertRaisesRegex(TypeError, msg):
+ A(3, 4, 5)
+
+
+ @dataclass(kw_only=True)
+ class B:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int
+ B(a=3, b=4, c=5)
+ msg = "takes 1 positional argument but 4 were given"
+ with self.assertRaisesRegex(TypeError, msg):
+ B(3, 4, 5)
+
+ # Explicitely make a field that follows KW_ONLY be non-keyword-only.
+ @dataclass
+ class C:
+ a: int
+ _: KW_ONLY
+ b: int
+ c: int = field(kw_only=False)
+ c = C(1, 2, b=3)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+ c = C(1, b=3, c=2)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+ c = C(1, b=3, c=2)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+ c = C(c=2, b=3, a=1)
+ self.assertEqual(c.a, 1)
+ self.assertEqual(c.b, 3)
+ self.assertEqual(c.c, 2)
+
+ def test_post_init(self):
+ @dataclass
+ class A:
+ a: int
+ _: KW_ONLY
+ b: InitVar[int]
+ c: int
+ d: InitVar[int]
+ def __post_init__(self, b, d):
+ raise CustomError(f'{b=} {d=}')
+ with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
+ A(1, c=2, b=3, d=4)
+
+ @dataclass
+ class B:
+ a: int
+ _: KW_ONLY
+ b: InitVar[int]
+ c: int
+ d: InitVar[int]
+ def __post_init__(self, b, d):
+ self.a = b
+ self.c = d
+ b = B(1, c=2, b=3, d=4)
+ self.assertEqual(asdict(b), {'a': 3, 'c': 4})
+
+
if __name__ == '__main__':
unittest.main()