blob: db14b0f66c2cbb347603dde14e7b491b52107fbf [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001import sys
2import types
3from copy import deepcopy
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004import inspect
5
6__all__ = ['dataclass',
7 'field',
8 'FrozenInstanceError',
9 'InitVar',
Eric V. Smith03220fd2017-12-29 13:59:58 -050010 'MISSING',
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011
12 # Helper functions.
13 'fields',
14 'asdict',
15 'astuple',
16 'make_dataclass',
17 'replace',
Eric V. Smithe7ba0132018-01-06 12:41:53 -050018 'is_dataclass',
Eric V. Smithf0db54a2017-12-04 16:58:55 -050019 ]
20
21# Raised when an attempt is made to modify a frozen class.
22class FrozenInstanceError(AttributeError): pass
23
24# A sentinel object for default values to signal that a
25# default-factory will be used.
26# This is given a nice repr() which will appear in the function
27# signature of dataclasses' constructors.
28class _HAS_DEFAULT_FACTORY_CLASS:
29 def __repr__(self):
30 return '<factory>'
31_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
32
Eric V. Smith03220fd2017-12-29 13:59:58 -050033# A sentinel object to detect if a parameter is supplied or not. Use
34# a class to give it a better repr.
35class _MISSING_TYPE:
36 pass
37MISSING = _MISSING_TYPE()
Eric V. Smithf0db54a2017-12-04 16:58:55 -050038
39# Since most per-field metadata will be unused, create an empty
40# read-only proxy that can be shared among all fields.
41_EMPTY_METADATA = types.MappingProxyType({})
42
43# Markers for the various kinds of fields and pseudo-fields.
44_FIELD = object() # An actual field.
45_FIELD_CLASSVAR = object() # Not a field, but a ClassVar.
46_FIELD_INITVAR = object() # Not a field, but an InitVar.
47
48# The name of an attribute on the class where we store the Field
49# objects. Also used to check if a class is a Data Class.
50_MARKER = '__dataclass_fields__'
51
52# The name of the function, that if it exists, is called at the end of
53# __init__.
54_POST_INIT_NAME = '__post_init__'
55
56
57class _InitVarMeta(type):
58 def __getitem__(self, params):
59 return self
60
61class InitVar(metaclass=_InitVarMeta):
62 pass
63
64
65# Instances of Field are only ever created from within this module,
66# and only from the field() function, although Field instances are
67# exposed externally as (conceptually) read-only objects.
68# name and type are filled in after the fact, not in __init__. They're
69# not known at the time this class is instantiated, but it's
70# convenient if they're available later.
71# When cls._MARKER is filled in with a list of Field objects, the name
72# and type fields will have been populated.
73class Field:
74 __slots__ = ('name',
75 'type',
76 'default',
77 'default_factory',
78 'repr',
79 'hash',
80 'init',
81 'compare',
82 'metadata',
83 '_field_type', # Private: not to be used by user code.
84 )
85
86 def __init__(self, default, default_factory, init, repr, hash, compare,
87 metadata):
88 self.name = None
89 self.type = None
90 self.default = default
91 self.default_factory = default_factory
92 self.init = init
93 self.repr = repr
94 self.hash = hash
95 self.compare = compare
96 self.metadata = (_EMPTY_METADATA
97 if metadata is None or len(metadata) == 0 else
98 types.MappingProxyType(metadata))
99 self._field_type = None
100
101 def __repr__(self):
102 return ('Field('
103 f'name={self.name!r},'
104 f'type={self.type},'
105 f'default={self.default},'
106 f'default_factory={self.default_factory},'
107 f'init={self.init},'
108 f'repr={self.repr},'
109 f'hash={self.hash},'
110 f'compare={self.compare},'
111 f'metadata={self.metadata}'
112 ')')
113
114
115# This function is used instead of exposing Field creation directly,
116# so that a type checker can be told (via overloads) that this is a
117# function whose type depends on its parameters.
Eric V. Smith03220fd2017-12-29 13:59:58 -0500118def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500119 hash=None, compare=True, metadata=None):
120 """Return an object to identify dataclass fields.
121
122 default is the default value of the field. default_factory is a
123 0-argument function called to initialize a field's value. If init
124 is True, the field will be a parameter to the class's __init__()
125 function. If repr is True, the field will be included in the
126 object's repr(). If hash is True, the field will be included in
127 the object's hash(). If compare is True, the field will be used in
128 comparison functions. metadata, if specified, must be a mapping
129 which is stored but not otherwise examined by dataclass.
130
131 It is an error to specify both default and default_factory.
132 """
133
Eric V. Smith03220fd2017-12-29 13:59:58 -0500134 if default is not MISSING and default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500135 raise ValueError('cannot specify both default and default_factory')
136 return Field(default, default_factory, init, repr, hash, compare,
137 metadata)
138
139
140def _tuple_str(obj_name, fields):
141 # Return a string representing each field of obj_name as a tuple
142 # member. So, if fields is ['x', 'y'] and obj_name is "self",
143 # return "(self.x,self.y)".
144
145 # Special case for the 0-tuple.
146 if len(fields) == 0:
147 return '()'
148 # Note the trailing comma, needed if this turns out to be a 1-tuple.
149 return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
150
151
152def _create_fn(name, args, body, globals=None, locals=None,
Eric V. Smith03220fd2017-12-29 13:59:58 -0500153 return_type=MISSING):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500154 # Note that we mutate locals when exec() is called. Caller beware!
155 if locals is None:
156 locals = {}
157 return_annotation = ''
Eric V. Smith03220fd2017-12-29 13:59:58 -0500158 if return_type is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500159 locals['_return_type'] = return_type
160 return_annotation = '->_return_type'
161 args = ','.join(args)
162 body = '\n'.join(f' {b}' for b in body)
163
164 txt = f'def {name}({args}){return_annotation}:\n{body}'
165
166 exec(txt, globals, locals)
167 return locals[name]
168
169
170def _field_assign(frozen, name, value, self_name):
171 # If we're a frozen class, then assign to our fields in __init__
172 # via object.__setattr__. Otherwise, just use a simple
173 # assignment.
174 # self_name is what "self" is called in this function: don't
175 # hard-code "self", since that might be a field name.
176 if frozen:
177 return f'object.__setattr__({self_name},{name!r},{value})'
178 return f'{self_name}.{name}={value}'
179
180
181def _field_init(f, frozen, globals, self_name):
182 # Return the text of the line in the body of __init__ that will
183 # initialize this field.
184
185 default_name = f'_dflt_{f.name}'
Eric V. Smith03220fd2017-12-29 13:59:58 -0500186 if f.default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500187 if f.init:
188 # This field has a default factory. If a parameter is
189 # given, use it. If not, call the factory.
190 globals[default_name] = f.default_factory
191 value = (f'{default_name}() '
192 f'if {f.name} is _HAS_DEFAULT_FACTORY '
193 f'else {f.name}')
194 else:
195 # This is a field that's not in the __init__ params, but
196 # has a default factory function. It needs to be
197 # initialized here by calling the factory function,
198 # because there's no other way to initialize it.
199
200 # For a field initialized with a default=defaultvalue, the
201 # class dict just has the default value
202 # (cls.fieldname=defaultvalue). But that won't work for a
203 # default factory, the factory must be called in __init__
204 # and we must assign that to self.fieldname. We can't
205 # fall back to the class dict's value, both because it's
206 # not set, and because it might be different per-class
207 # (which, after all, is why we have a factory function!).
208
209 globals[default_name] = f.default_factory
210 value = f'{default_name}()'
211 else:
212 # No default factory.
213 if f.init:
Eric V. Smith03220fd2017-12-29 13:59:58 -0500214 if f.default is MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500215 # There's no default, just do an assignment.
216 value = f.name
Eric V. Smith03220fd2017-12-29 13:59:58 -0500217 elif f.default is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500218 globals[default_name] = f.default
219 value = f.name
220 else:
221 # This field does not need initialization. Signify that to
222 # the caller by returning None.
223 return None
224
225 # Only test this now, so that we can create variables for the
226 # default. However, return None to signify that we're not going
227 # to actually do the assignment statement for InitVars.
228 if f._field_type == _FIELD_INITVAR:
229 return None
230
231 # Now, actually generate the field assignment.
232 return _field_assign(frozen, f.name, value, self_name)
233
234
235def _init_param(f):
236 # Return the __init__ parameter string for this field.
237 # For example, the equivalent of 'x:int=3' (except instead of 'int',
238 # reference a variable set to int, and instead of '3', reference a
239 # variable set to 3).
Eric V. Smith03220fd2017-12-29 13:59:58 -0500240 if f.default is MISSING and f.default_factory is MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500241 # There's no default, and no default_factory, just
242 # output the variable name and type.
243 default = ''
Eric V. Smith03220fd2017-12-29 13:59:58 -0500244 elif f.default is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500245 # There's a default, this will be the name that's used to look it up.
246 default = f'=_dflt_{f.name}'
Eric V. Smith03220fd2017-12-29 13:59:58 -0500247 elif f.default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500248 # There's a factory function. Set a marker.
249 default = '=_HAS_DEFAULT_FACTORY'
250 return f'{f.name}:_type_{f.name}{default}'
251
252
253def _init_fn(fields, frozen, has_post_init, self_name):
254 # fields contains both real fields and InitVar pseudo-fields.
255
256 # Make sure we don't have fields without defaults following fields
257 # with defaults. This actually would be caught when exec-ing the
258 # function source code, but catching it here gives a better error
259 # message, and future-proofs us in case we build up the function
260 # using ast.
261 seen_default = False
262 for f in fields:
263 # Only consider fields in the __init__ call.
264 if f.init:
Eric V. Smith03220fd2017-12-29 13:59:58 -0500265 if not (f.default is MISSING and f.default_factory is MISSING):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500266 seen_default = True
267 elif seen_default:
268 raise TypeError(f'non-default argument {f.name!r} '
269 'follows default argument')
270
Eric V. Smith03220fd2017-12-29 13:59:58 -0500271 globals = {'MISSING': MISSING,
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500272 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
273
274 body_lines = []
275 for f in fields:
276 # Do not initialize the pseudo-fields, only the real ones.
277 line = _field_init(f, frozen, globals, self_name)
278 if line is not None:
279 # line is None means that this field doesn't require
280 # initialization. Just skip it.
281 body_lines.append(line)
282
283 # Does this class have a post-init function?
284 if has_post_init:
285 params_str = ','.join(f.name for f in fields
286 if f._field_type is _FIELD_INITVAR)
287 body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
288
289 # If no body lines, use 'pass'.
290 if len(body_lines) == 0:
291 body_lines = ['pass']
292
293 locals = {f'_type_{f.name}': f.type for f in fields}
294 return _create_fn('__init__',
295 [self_name] +[_init_param(f) for f in fields if f.init],
296 body_lines,
297 locals=locals,
298 globals=globals,
299 return_type=None)
300
301
302def _repr_fn(fields):
303 return _create_fn('__repr__',
304 ['self'],
305 ['return self.__class__.__qualname__ + f"(' +
306 ', '.join([f"{f.name}={{self.{f.name}!r}}"
307 for f in fields]) +
308 ')"'])
309
310
311def _frozen_setattr(self, name, value):
312 raise FrozenInstanceError(f'cannot assign to field {name!r}')
313
314
315def _frozen_delattr(self, name):
316 raise FrozenInstanceError(f'cannot delete field {name!r}')
317
318
319def _cmp_fn(name, op, self_tuple, other_tuple):
320 # Create a comparison function. If the fields in the object are
321 # named 'x' and 'y', then self_tuple is the string
322 # '(self.x,self.y)' and other_tuple is the string
323 # '(other.x,other.y)'.
324
325 return _create_fn(name,
326 ['self', 'other'],
327 [ 'if other.__class__ is self.__class__:',
328 f' return {self_tuple}{op}{other_tuple}',
329 'return NotImplemented'])
330
331
332def _set_eq_fns(cls, fields):
333 # Create and set the equality comparison methods on cls.
334 # Pre-compute self_tuple and other_tuple, then re-use them for
335 # each function.
336 self_tuple = _tuple_str('self', fields)
337 other_tuple = _tuple_str('other', fields)
338 for name, op in [('__eq__', '=='),
339 ('__ne__', '!='),
340 ]:
341 _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
342
343
344def _set_order_fns(cls, fields):
345 # Create and set the ordering methods on cls.
346 # Pre-compute self_tuple and other_tuple, then re-use them for
347 # each function.
348 self_tuple = _tuple_str('self', fields)
349 other_tuple = _tuple_str('other', fields)
350 for name, op in [('__lt__', '<'),
351 ('__le__', '<='),
352 ('__gt__', '>'),
353 ('__ge__', '>='),
354 ]:
355 _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
356
357
358def _hash_fn(fields):
359 self_tuple = _tuple_str('self', fields)
360 return _create_fn('__hash__',
361 ['self'],
362 [f'return hash({self_tuple})'])
363
364
365def _get_field(cls, a_name, a_type):
366 # Return a Field object, for this field name and type. ClassVars
367 # and InitVars are also returned, but marked as such (see
368 # f._field_type).
369
370 # If the default value isn't derived from field, then it's
371 # only a normal default value. Convert it to a Field().
Eric V. Smith03220fd2017-12-29 13:59:58 -0500372 default = getattr(cls, a_name, MISSING)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500373 if isinstance(default, Field):
374 f = default
375 else:
376 f = field(default=default)
377
378 # Assume it's a normal field until proven otherwise.
379 f._field_type = _FIELD
380
381 # Only at this point do we know the name and the type. Set them.
382 f.name = a_name
383 f.type = a_type
384
385 # If typing has not been imported, then it's impossible for
386 # any annotation to be a ClassVar. So, only look for ClassVar
387 # if typing has been imported.
388 typing = sys.modules.get('typing')
389 if typing is not None:
390 # This test uses a typing internal class, but it's the best
391 # way to test if this is a ClassVar.
392 if type(a_type) is typing._ClassVar:
393 # This field is a ClassVar, so it's not a field.
394 f._field_type = _FIELD_CLASSVAR
395
396 if f._field_type is _FIELD:
397 # Check if this is an InitVar.
398 if a_type is InitVar:
399 # InitVars are not fields, either.
400 f._field_type = _FIELD_INITVAR
401
402 # Validations for fields. This is delayed until now, instead of
403 # in the Field() constructor, since only here do we know the field
404 # name, which allows better error reporting.
405
406 # Special restrictions for ClassVar and InitVar.
407 if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
Eric V. Smith03220fd2017-12-29 13:59:58 -0500408 if f.default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500409 raise TypeError(f'field {f.name} cannot have a '
410 'default factory')
411 # Should I check for other field settings? default_factory
412 # seems the most serious to check for. Maybe add others. For
413 # example, how about init=False (or really,
414 # init=<not-the-default-init-value>)? It makes no sense for
415 # ClassVar and InitVar to specify init=<anything>.
416
417 # For real fields, disallow mutable defaults for known types.
418 if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
419 raise ValueError(f'mutable default {type(f.default)} for field '
420 f'{f.name} is not allowed: use default_factory')
421
422 return f
423
424
425def _find_fields(cls):
426 # Return a list of Field objects, in order, for this class (and no
427 # base classes). Fields are found from __annotations__ (which is
428 # guaranteed to be ordered). Default values are from class
429 # attributes, if a field has a default. If the default value is
430 # a Field(), then it contains additional info beyond (and
431 # possibly including) the actual default value. Pseudo-fields
432 # ClassVars and InitVars are included, despite the fact that
433 # they're not real fields. That's deal with later.
434
435 annotations = getattr(cls, '__annotations__', {})
436
437 return [_get_field(cls, a_name, a_type)
438 for a_name, a_type in annotations.items()]
439
440
441def _set_attribute(cls, name, value):
442 # Raise TypeError if an attribute by this name already exists.
443 if name in cls.__dict__:
444 raise TypeError(f'Cannot overwrite attribute {name} '
445 f'in {cls.__name__}')
446 setattr(cls, name, value)
447
448
449def _process_class(cls, repr, eq, order, hash, init, frozen):
Eric V. Smithd1388922018-01-07 14:30:17 -0500450 # Now that dicts retain insertion order, there's no reason to use
451 # an ordered dict. I am leveraging that ordering here, because
452 # derived class fields overwrite base class fields, but the order
453 # is defined by the base class, which is found first.
454 fields = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500455
456 # Find our base classes in reverse MRO order, and exclude
457 # ourselves. In reversed order so that more derived classes
458 # override earlier field definitions in base classes.
459 for b in cls.__mro__[-1:0:-1]:
460 # Only process classes that have been processed by our
461 # decorator. That is, they have a _MARKER attribute.
462 base_fields = getattr(b, _MARKER, None)
463 if base_fields:
464 for f in base_fields.values():
465 fields[f.name] = f
466
467 # Now find fields in our class. While doing so, validate some
468 # things, and set the default values (as class attributes)
469 # where we can.
470 for f in _find_fields(cls):
471 fields[f.name] = f
472
473 # If the class attribute (which is the default value for
474 # this field) exists and is of type 'Field', replace it
475 # with the real default. This is so that normal class
476 # introspection sees a real default value, not a Field.
477 if isinstance(getattr(cls, f.name, None), Field):
Eric V. Smith03220fd2017-12-29 13:59:58 -0500478 if f.default is MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500479 # If there's no default, delete the class attribute.
480 # This happens if we specify field(repr=False), for
481 # example (that is, we specified a field object, but
482 # no default value). Also if we're using a default
483 # factory. The class attribute should not be set at
484 # all in the post-processed class.
485 delattr(cls, f.name)
486 else:
487 setattr(cls, f.name, f.default)
488
489 # Remember all of the fields on our class (including bases). This
490 # marks this class as being a dataclass.
491 setattr(cls, _MARKER, fields)
492
493 # We also need to check if a parent class is frozen: frozen has to
494 # be inherited down.
495 is_frozen = frozen or cls.__setattr__ is _frozen_setattr
496
497 # If we're generating ordering methods, we must be generating
498 # the eq methods.
499 if order and not eq:
500 raise ValueError('eq must be true if order is true')
501
502 if init:
503 # Does this class have a post-init function?
504 has_post_init = hasattr(cls, _POST_INIT_NAME)
505
506 # Include InitVars and regular fields (so, not ClassVars).
507 _set_attribute(cls, '__init__',
508 _init_fn(list(filter(lambda f: f._field_type
509 in (_FIELD, _FIELD_INITVAR),
510 fields.values())),
511 is_frozen,
512 has_post_init,
513 # The name to use for the "self" param
514 # in __init__. Use "self" if possible.
515 '__dataclass_self__' if 'self' in fields
516 else 'self',
517 ))
518
519 # Get the fields as a list, and include only real fields. This is
520 # used in all of the following methods.
521 field_list = list(filter(lambda f: f._field_type is _FIELD,
522 fields.values()))
523
524 if repr:
525 _set_attribute(cls, '__repr__',
526 _repr_fn(list(filter(lambda f: f.repr, field_list))))
527
528 if is_frozen:
529 _set_attribute(cls, '__setattr__', _frozen_setattr)
530 _set_attribute(cls, '__delattr__', _frozen_delattr)
531
532 generate_hash = False
533 if hash is None:
534 if eq and frozen:
535 # Generate a hash function.
536 generate_hash = True
537 elif eq and not frozen:
538 # Not hashable.
539 _set_attribute(cls, '__hash__', None)
540 elif not eq:
541 # Otherwise, use the base class definition of hash(). That is,
542 # don't set anything on this class.
543 pass
544 else:
545 assert "can't get here"
546 else:
547 generate_hash = hash
548 if generate_hash:
549 _set_attribute(cls, '__hash__',
550 _hash_fn(list(filter(lambda f: f.compare
551 if f.hash is None
552 else f.hash,
553 field_list))))
554
555 if eq:
556 # Create and __eq__ and __ne__ methods.
557 _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
558
559 if order:
560 # Create and __lt__, __le__, __gt__, and __ge__ methods.
561 # Create and set the comparison functions.
562 _set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
563
564 if not getattr(cls, '__doc__'):
565 # Create a class doc-string.
566 cls.__doc__ = (cls.__name__ +
567 str(inspect.signature(cls)).replace(' -> None', ''))
568
569 return cls
570
571
572# _cls should never be specified by keyword, so start it with an
Raymond Hettingerd55209d2018-01-10 20:56:41 -0800573# underscore. The presence of _cls is used to detect if this
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500574# decorator is being called with parameters or not.
575def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
576 hash=None, frozen=False):
577 """Returns the same class as was passed in, with dunder methods
578 added based on the fields defined in the class.
579
580 Examines PEP 526 __annotations__ to determine fields.
581
582 If init is true, an __init__() method is added to the class. If
583 repr is true, a __repr__() method is added. If order is true, rich
584 comparison dunder methods are added. If hash is true, a __hash__()
585 method function is added. If frozen is true, fields may not be
586 assigned to after instance creation.
587 """
588
589 def wrap(cls):
590 return _process_class(cls, repr, eq, order, hash, init, frozen)
591
592 # See if we're being called as @dataclass or @dataclass().
593 if _cls is None:
594 # We're called with parens.
595 return wrap
596
597 # We're called as @dataclass without parens.
598 return wrap(_cls)
599
600
601def fields(class_or_instance):
602 """Return a tuple describing the fields of this dataclass.
603
604 Accepts a dataclass or an instance of one. Tuple elements are of
605 type Field.
606 """
607
608 # Might it be worth caching this, per class?
609 try:
610 fields = getattr(class_or_instance, _MARKER)
611 except AttributeError:
612 raise TypeError('must be called with a dataclass type or instance')
613
Eric V. Smithd1388922018-01-07 14:30:17 -0500614 # Exclude pseudo-fields. Note that fields is sorted by insertion
615 # order, so the order of the tuple is as the fields were defined.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500616 return tuple(f for f in fields.values() if f._field_type is _FIELD)
617
618
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500619def _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500620 """Returns True if obj is an instance of a dataclass."""
621 return not isinstance(obj, type) and hasattr(obj, _MARKER)
622
623
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500624def is_dataclass(obj):
625 """Returns True if obj is a dataclass or an instance of a
626 dataclass."""
627 return hasattr(obj, _MARKER)
628
629
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500630def asdict(obj, *, dict_factory=dict):
631 """Return the fields of a dataclass instance as a new dictionary mapping
632 field names to field values.
633
634 Example usage:
635
636 @dataclass
637 class C:
638 x: int
639 y: int
640
641 c = C(1, 2)
642 assert asdict(c) == {'x': 1, 'y': 2}
643
644 If given, 'dict_factory' will be used instead of built-in dict.
645 The function applies recursively to field values that are
646 dataclass instances. This will also look into built-in containers:
647 tuples, lists, and dicts.
648 """
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500649 if not _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500650 raise TypeError("asdict() should be called on dataclass instances")
651 return _asdict_inner(obj, dict_factory)
652
653def _asdict_inner(obj, dict_factory):
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500654 if _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500655 result = []
656 for f in fields(obj):
657 value = _asdict_inner(getattr(obj, f.name), dict_factory)
658 result.append((f.name, value))
659 return dict_factory(result)
660 elif isinstance(obj, (list, tuple)):
661 return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
662 elif isinstance(obj, dict):
663 return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
664 for k, v in obj.items())
665 else:
666 return deepcopy(obj)
667
668
669def astuple(obj, *, tuple_factory=tuple):
670 """Return the fields of a dataclass instance as a new tuple of field values.
671
672 Example usage::
673
674 @dataclass
675 class C:
676 x: int
677 y: int
678
679 c = C(1, 2)
Raymond Hettingerd55209d2018-01-10 20:56:41 -0800680 assert astuple(c) == (1, 2)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500681
682 If given, 'tuple_factory' will be used instead of built-in tuple.
683 The function applies recursively to field values that are
684 dataclass instances. This will also look into built-in containers:
685 tuples, lists, and dicts.
686 """
687
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500688 if not _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500689 raise TypeError("astuple() should be called on dataclass instances")
690 return _astuple_inner(obj, tuple_factory)
691
692def _astuple_inner(obj, tuple_factory):
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500693 if _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 result = []
695 for f in fields(obj):
696 value = _astuple_inner(getattr(obj, f.name), tuple_factory)
697 result.append(value)
698 return tuple_factory(result)
699 elif isinstance(obj, (list, tuple)):
700 return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
701 elif isinstance(obj, dict):
702 return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
703 for k, v in obj.items())
704 else:
705 return deepcopy(obj)
706
707
Eric V. Smithd80b4432018-01-06 17:09:58 -0500708def make_dataclass(cls_name, fields, *, bases=(), namespace=None, init=True,
709 repr=True, eq=True, order=False, hash=None, frozen=False):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500710 """Return a new dynamically created dataclass.
711
Eric V. Smithed7d4292018-01-06 16:14:03 -0500712 The dataclass name will be 'cls_name'. 'fields' is an iterable
713 of either (name), (name, type) or (name, type, Field) objects. If type is
714 omitted, use the string 'typing.Any'. Field objects are created by
Eric V. Smithd327ae62018-01-07 08:19:45 -0500715 the equivalent of calling 'field(name, type [, Field-info])'.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716
Raymond Hettingerd55209d2018-01-10 20:56:41 -0800717 C = make_dataclass('C', ['x', ('y', int), ('z', int, field(init=False))], bases=(Base,))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500718
719 is equivalent to:
720
721 @dataclass
722 class C(Base):
Raymond Hettingerd55209d2018-01-10 20:56:41 -0800723 x: 'typing.Any'
724 y: int
725 z: int = field(init=False)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500726
Raymond Hettingerd55209d2018-01-10 20:56:41 -0800727 For the bases and namespace parameters, see the builtin type() function.
Eric V. Smithd80b4432018-01-06 17:09:58 -0500728
729 The parameters init, repr, eq, order, hash, and frozen are passed to
730 dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500731 """
732
733 if namespace is None:
734 namespace = {}
735 else:
736 # Copy namespace since we're going to mutate it.
737 namespace = namespace.copy()
738
Eric V. Smithd1388922018-01-07 14:30:17 -0500739 anns = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500740 for item in fields:
Eric V. Smithed7d4292018-01-06 16:14:03 -0500741 if isinstance(item, str):
742 name = item
743 tp = 'typing.Any'
744 elif len(item) == 2:
745 name, tp, = item
746 elif len(item) == 3:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500747 name, tp, spec = item
748 namespace[name] = spec
Eric V. Smithed7d4292018-01-06 16:14:03 -0500749 anns[name] = tp
750
751 namespace['__annotations__'] = anns
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500752 cls = type(cls_name, bases, namespace)
Eric V. Smithd80b4432018-01-06 17:09:58 -0500753 return dataclass(cls, init=init, repr=repr, eq=eq, order=order,
754 hash=hash, frozen=frozen)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500755
756def replace(obj, **changes):
757 """Return a new object replacing specified fields with new values.
758
759 This is especially useful for frozen classes. Example usage:
760
761 @dataclass(frozen=True)
762 class C:
763 x: int
764 y: int
765
766 c = C(1, 2)
767 c1 = replace(c, x=3)
768 assert c1.x == 3 and c1.y == 2
769 """
770
771 # We're going to mutate 'changes', but that's okay because it's a new
772 # dict, even if called with 'replace(obj, **my_changes)'.
773
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500774 if not _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500775 raise TypeError("replace() should be called on dataclass instances")
776
777 # It's an error to have init=False fields in 'changes'.
778 # If a field is not in 'changes', read its value from the provided obj.
779
780 for f in getattr(obj, _MARKER).values():
781 if not f.init:
782 # Error if this field is specified in changes.
783 if f.name in changes:
784 raise ValueError(f'field {f.name} is declared with '
785 'init=False, it cannot be specified with '
786 'replace()')
787 continue
788
789 if f.name not in changes:
790 changes[f.name] = getattr(obj, f.name)
791
792 # Create the new object, which calls __init__() and __post_init__
793 # (if defined), using all of the init fields we've added and/or
794 # left in 'changes'.
795 # If there are values supplied in changes that aren't fields, this
796 # will correctly raise a TypeError.
797 return obj.__class__(**changes)