blob: c4e94b8f667381b3c9460ee595117faaf73c3b71 [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001import sys
2import types
3from copy import deepcopy
4import collections
5import inspect
6
7__all__ = ['dataclass',
8 'field',
9 'FrozenInstanceError',
10 'InitVar',
Eric V. Smith03220fd2017-12-29 13:59:58 -050011 'MISSING',
Eric V. Smithf0db54a2017-12-04 16:58:55 -050012
13 # Helper functions.
14 'fields',
15 'asdict',
16 'astuple',
17 'make_dataclass',
18 'replace',
Eric V. Smithe7ba0132018-01-06 12:41:53 -050019 'is_dataclass',
Eric V. Smithf0db54a2017-12-04 16:58:55 -050020 ]
21
22# Raised when an attempt is made to modify a frozen class.
23class FrozenInstanceError(AttributeError): pass
24
25# A sentinel object for default values to signal that a
26# default-factory will be used.
27# This is given a nice repr() which will appear in the function
28# signature of dataclasses' constructors.
29class _HAS_DEFAULT_FACTORY_CLASS:
30 def __repr__(self):
31 return '<factory>'
32_HAS_DEFAULT_FACTORY = _HAS_DEFAULT_FACTORY_CLASS()
33
Eric V. Smith03220fd2017-12-29 13:59:58 -050034# A sentinel object to detect if a parameter is supplied or not. Use
35# a class to give it a better repr.
36class _MISSING_TYPE:
37 pass
38MISSING = _MISSING_TYPE()
Eric V. Smithf0db54a2017-12-04 16:58:55 -050039
40# Since most per-field metadata will be unused, create an empty
41# read-only proxy that can be shared among all fields.
42_EMPTY_METADATA = types.MappingProxyType({})
43
44# Markers for the various kinds of fields and pseudo-fields.
45_FIELD = object() # An actual field.
46_FIELD_CLASSVAR = object() # Not a field, but a ClassVar.
47_FIELD_INITVAR = object() # Not a field, but an InitVar.
48
49# The name of an attribute on the class where we store the Field
50# objects. Also used to check if a class is a Data Class.
51_MARKER = '__dataclass_fields__'
52
53# The name of the function, that if it exists, is called at the end of
54# __init__.
55_POST_INIT_NAME = '__post_init__'
56
57
58class _InitVarMeta(type):
59 def __getitem__(self, params):
60 return self
61
62class InitVar(metaclass=_InitVarMeta):
63 pass
64
65
66# Instances of Field are only ever created from within this module,
67# and only from the field() function, although Field instances are
68# exposed externally as (conceptually) read-only objects.
69# name and type are filled in after the fact, not in __init__. They're
70# not known at the time this class is instantiated, but it's
71# convenient if they're available later.
72# When cls._MARKER is filled in with a list of Field objects, the name
73# and type fields will have been populated.
74class Field:
75 __slots__ = ('name',
76 'type',
77 'default',
78 'default_factory',
79 'repr',
80 'hash',
81 'init',
82 'compare',
83 'metadata',
84 '_field_type', # Private: not to be used by user code.
85 )
86
87 def __init__(self, default, default_factory, init, repr, hash, compare,
88 metadata):
89 self.name = None
90 self.type = None
91 self.default = default
92 self.default_factory = default_factory
93 self.init = init
94 self.repr = repr
95 self.hash = hash
96 self.compare = compare
97 self.metadata = (_EMPTY_METADATA
98 if metadata is None or len(metadata) == 0 else
99 types.MappingProxyType(metadata))
100 self._field_type = None
101
102 def __repr__(self):
103 return ('Field('
104 f'name={self.name!r},'
105 f'type={self.type},'
106 f'default={self.default},'
107 f'default_factory={self.default_factory},'
108 f'init={self.init},'
109 f'repr={self.repr},'
110 f'hash={self.hash},'
111 f'compare={self.compare},'
112 f'metadata={self.metadata}'
113 ')')
114
115
116# This function is used instead of exposing Field creation directly,
117# so that a type checker can be told (via overloads) that this is a
118# function whose type depends on its parameters.
Eric V. Smith03220fd2017-12-29 13:59:58 -0500119def field(*, default=MISSING, default_factory=MISSING, init=True, repr=True,
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500120 hash=None, compare=True, metadata=None):
121 """Return an object to identify dataclass fields.
122
123 default is the default value of the field. default_factory is a
124 0-argument function called to initialize a field's value. If init
125 is True, the field will be a parameter to the class's __init__()
126 function. If repr is True, the field will be included in the
127 object's repr(). If hash is True, the field will be included in
128 the object's hash(). If compare is True, the field will be used in
129 comparison functions. metadata, if specified, must be a mapping
130 which is stored but not otherwise examined by dataclass.
131
132 It is an error to specify both default and default_factory.
133 """
134
Eric V. Smith03220fd2017-12-29 13:59:58 -0500135 if default is not MISSING and default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500136 raise ValueError('cannot specify both default and default_factory')
137 return Field(default, default_factory, init, repr, hash, compare,
138 metadata)
139
140
141def _tuple_str(obj_name, fields):
142 # Return a string representing each field of obj_name as a tuple
143 # member. So, if fields is ['x', 'y'] and obj_name is "self",
144 # return "(self.x,self.y)".
145
146 # Special case for the 0-tuple.
147 if len(fields) == 0:
148 return '()'
149 # Note the trailing comma, needed if this turns out to be a 1-tuple.
150 return f'({",".join([f"{obj_name}.{f.name}" for f in fields])},)'
151
152
153def _create_fn(name, args, body, globals=None, locals=None,
Eric V. Smith03220fd2017-12-29 13:59:58 -0500154 return_type=MISSING):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500155 # Note that we mutate locals when exec() is called. Caller beware!
156 if locals is None:
157 locals = {}
158 return_annotation = ''
Eric V. Smith03220fd2017-12-29 13:59:58 -0500159 if return_type is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500160 locals['_return_type'] = return_type
161 return_annotation = '->_return_type'
162 args = ','.join(args)
163 body = '\n'.join(f' {b}' for b in body)
164
165 txt = f'def {name}({args}){return_annotation}:\n{body}'
166
167 exec(txt, globals, locals)
168 return locals[name]
169
170
171def _field_assign(frozen, name, value, self_name):
172 # If we're a frozen class, then assign to our fields in __init__
173 # via object.__setattr__. Otherwise, just use a simple
174 # assignment.
175 # self_name is what "self" is called in this function: don't
176 # hard-code "self", since that might be a field name.
177 if frozen:
178 return f'object.__setattr__({self_name},{name!r},{value})'
179 return f'{self_name}.{name}={value}'
180
181
182def _field_init(f, frozen, globals, self_name):
183 # Return the text of the line in the body of __init__ that will
184 # initialize this field.
185
186 default_name = f'_dflt_{f.name}'
Eric V. Smith03220fd2017-12-29 13:59:58 -0500187 if f.default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500188 if f.init:
189 # This field has a default factory. If a parameter is
190 # given, use it. If not, call the factory.
191 globals[default_name] = f.default_factory
192 value = (f'{default_name}() '
193 f'if {f.name} is _HAS_DEFAULT_FACTORY '
194 f'else {f.name}')
195 else:
196 # This is a field that's not in the __init__ params, but
197 # has a default factory function. It needs to be
198 # initialized here by calling the factory function,
199 # because there's no other way to initialize it.
200
201 # For a field initialized with a default=defaultvalue, the
202 # class dict just has the default value
203 # (cls.fieldname=defaultvalue). But that won't work for a
204 # default factory, the factory must be called in __init__
205 # and we must assign that to self.fieldname. We can't
206 # fall back to the class dict's value, both because it's
207 # not set, and because it might be different per-class
208 # (which, after all, is why we have a factory function!).
209
210 globals[default_name] = f.default_factory
211 value = f'{default_name}()'
212 else:
213 # No default factory.
214 if f.init:
Eric V. Smith03220fd2017-12-29 13:59:58 -0500215 if f.default is MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500216 # There's no default, just do an assignment.
217 value = f.name
Eric V. Smith03220fd2017-12-29 13:59:58 -0500218 elif f.default is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500219 globals[default_name] = f.default
220 value = f.name
221 else:
222 # This field does not need initialization. Signify that to
223 # the caller by returning None.
224 return None
225
226 # Only test this now, so that we can create variables for the
227 # default. However, return None to signify that we're not going
228 # to actually do the assignment statement for InitVars.
229 if f._field_type == _FIELD_INITVAR:
230 return None
231
232 # Now, actually generate the field assignment.
233 return _field_assign(frozen, f.name, value, self_name)
234
235
236def _init_param(f):
237 # Return the __init__ parameter string for this field.
238 # For example, the equivalent of 'x:int=3' (except instead of 'int',
239 # reference a variable set to int, and instead of '3', reference a
240 # variable set to 3).
Eric V. Smith03220fd2017-12-29 13:59:58 -0500241 if f.default is MISSING and f.default_factory is MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500242 # There's no default, and no default_factory, just
243 # output the variable name and type.
244 default = ''
Eric V. Smith03220fd2017-12-29 13:59:58 -0500245 elif f.default is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500246 # There's a default, this will be the name that's used to look it up.
247 default = f'=_dflt_{f.name}'
Eric V. Smith03220fd2017-12-29 13:59:58 -0500248 elif f.default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500249 # There's a factory function. Set a marker.
250 default = '=_HAS_DEFAULT_FACTORY'
251 return f'{f.name}:_type_{f.name}{default}'
252
253
254def _init_fn(fields, frozen, has_post_init, self_name):
255 # fields contains both real fields and InitVar pseudo-fields.
256
257 # Make sure we don't have fields without defaults following fields
258 # with defaults. This actually would be caught when exec-ing the
259 # function source code, but catching it here gives a better error
260 # message, and future-proofs us in case we build up the function
261 # using ast.
262 seen_default = False
263 for f in fields:
264 # Only consider fields in the __init__ call.
265 if f.init:
Eric V. Smith03220fd2017-12-29 13:59:58 -0500266 if not (f.default is MISSING and f.default_factory is MISSING):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500267 seen_default = True
268 elif seen_default:
269 raise TypeError(f'non-default argument {f.name!r} '
270 'follows default argument')
271
Eric V. Smith03220fd2017-12-29 13:59:58 -0500272 globals = {'MISSING': MISSING,
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500273 '_HAS_DEFAULT_FACTORY': _HAS_DEFAULT_FACTORY}
274
275 body_lines = []
276 for f in fields:
277 # Do not initialize the pseudo-fields, only the real ones.
278 line = _field_init(f, frozen, globals, self_name)
279 if line is not None:
280 # line is None means that this field doesn't require
281 # initialization. Just skip it.
282 body_lines.append(line)
283
284 # Does this class have a post-init function?
285 if has_post_init:
286 params_str = ','.join(f.name for f in fields
287 if f._field_type is _FIELD_INITVAR)
288 body_lines += [f'{self_name}.{_POST_INIT_NAME}({params_str})']
289
290 # If no body lines, use 'pass'.
291 if len(body_lines) == 0:
292 body_lines = ['pass']
293
294 locals = {f'_type_{f.name}': f.type for f in fields}
295 return _create_fn('__init__',
296 [self_name] +[_init_param(f) for f in fields if f.init],
297 body_lines,
298 locals=locals,
299 globals=globals,
300 return_type=None)
301
302
303def _repr_fn(fields):
304 return _create_fn('__repr__',
305 ['self'],
306 ['return self.__class__.__qualname__ + f"(' +
307 ', '.join([f"{f.name}={{self.{f.name}!r}}"
308 for f in fields]) +
309 ')"'])
310
311
312def _frozen_setattr(self, name, value):
313 raise FrozenInstanceError(f'cannot assign to field {name!r}')
314
315
316def _frozen_delattr(self, name):
317 raise FrozenInstanceError(f'cannot delete field {name!r}')
318
319
320def _cmp_fn(name, op, self_tuple, other_tuple):
321 # Create a comparison function. If the fields in the object are
322 # named 'x' and 'y', then self_tuple is the string
323 # '(self.x,self.y)' and other_tuple is the string
324 # '(other.x,other.y)'.
325
326 return _create_fn(name,
327 ['self', 'other'],
328 [ 'if other.__class__ is self.__class__:',
329 f' return {self_tuple}{op}{other_tuple}',
330 'return NotImplemented'])
331
332
333def _set_eq_fns(cls, fields):
334 # Create and set the equality comparison methods on cls.
335 # Pre-compute self_tuple and other_tuple, then re-use them for
336 # each function.
337 self_tuple = _tuple_str('self', fields)
338 other_tuple = _tuple_str('other', fields)
339 for name, op in [('__eq__', '=='),
340 ('__ne__', '!='),
341 ]:
342 _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
343
344
345def _set_order_fns(cls, fields):
346 # Create and set the ordering methods on cls.
347 # Pre-compute self_tuple and other_tuple, then re-use them for
348 # each function.
349 self_tuple = _tuple_str('self', fields)
350 other_tuple = _tuple_str('other', fields)
351 for name, op in [('__lt__', '<'),
352 ('__le__', '<='),
353 ('__gt__', '>'),
354 ('__ge__', '>='),
355 ]:
356 _set_attribute(cls, name, _cmp_fn(name, op, self_tuple, other_tuple))
357
358
359def _hash_fn(fields):
360 self_tuple = _tuple_str('self', fields)
361 return _create_fn('__hash__',
362 ['self'],
363 [f'return hash({self_tuple})'])
364
365
366def _get_field(cls, a_name, a_type):
367 # Return a Field object, for this field name and type. ClassVars
368 # and InitVars are also returned, but marked as such (see
369 # f._field_type).
370
371 # If the default value isn't derived from field, then it's
372 # only a normal default value. Convert it to a Field().
Eric V. Smith03220fd2017-12-29 13:59:58 -0500373 default = getattr(cls, a_name, MISSING)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500374 if isinstance(default, Field):
375 f = default
376 else:
377 f = field(default=default)
378
379 # Assume it's a normal field until proven otherwise.
380 f._field_type = _FIELD
381
382 # Only at this point do we know the name and the type. Set them.
383 f.name = a_name
384 f.type = a_type
385
386 # If typing has not been imported, then it's impossible for
387 # any annotation to be a ClassVar. So, only look for ClassVar
388 # if typing has been imported.
389 typing = sys.modules.get('typing')
390 if typing is not None:
391 # This test uses a typing internal class, but it's the best
392 # way to test if this is a ClassVar.
393 if type(a_type) is typing._ClassVar:
394 # This field is a ClassVar, so it's not a field.
395 f._field_type = _FIELD_CLASSVAR
396
397 if f._field_type is _FIELD:
398 # Check if this is an InitVar.
399 if a_type is InitVar:
400 # InitVars are not fields, either.
401 f._field_type = _FIELD_INITVAR
402
403 # Validations for fields. This is delayed until now, instead of
404 # in the Field() constructor, since only here do we know the field
405 # name, which allows better error reporting.
406
407 # Special restrictions for ClassVar and InitVar.
408 if f._field_type in (_FIELD_CLASSVAR, _FIELD_INITVAR):
Eric V. Smith03220fd2017-12-29 13:59:58 -0500409 if f.default_factory is not MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500410 raise TypeError(f'field {f.name} cannot have a '
411 'default factory')
412 # Should I check for other field settings? default_factory
413 # seems the most serious to check for. Maybe add others. For
414 # example, how about init=False (or really,
415 # init=<not-the-default-init-value>)? It makes no sense for
416 # ClassVar and InitVar to specify init=<anything>.
417
418 # For real fields, disallow mutable defaults for known types.
419 if f._field_type is _FIELD and isinstance(f.default, (list, dict, set)):
420 raise ValueError(f'mutable default {type(f.default)} for field '
421 f'{f.name} is not allowed: use default_factory')
422
423 return f
424
425
426def _find_fields(cls):
427 # Return a list of Field objects, in order, for this class (and no
428 # base classes). Fields are found from __annotations__ (which is
429 # guaranteed to be ordered). Default values are from class
430 # attributes, if a field has a default. If the default value is
431 # a Field(), then it contains additional info beyond (and
432 # possibly including) the actual default value. Pseudo-fields
433 # ClassVars and InitVars are included, despite the fact that
434 # they're not real fields. That's deal with later.
435
436 annotations = getattr(cls, '__annotations__', {})
437
438 return [_get_field(cls, a_name, a_type)
439 for a_name, a_type in annotations.items()]
440
441
442def _set_attribute(cls, name, value):
443 # Raise TypeError if an attribute by this name already exists.
444 if name in cls.__dict__:
445 raise TypeError(f'Cannot overwrite attribute {name} '
446 f'in {cls.__name__}')
447 setattr(cls, name, value)
448
449
450def _process_class(cls, repr, eq, order, hash, init, frozen):
451 # Use an OrderedDict because:
452 # - Order matters!
453 # - Derived class fields overwrite base class fields, but the
454 # order is defined by the base class, which is found first.
455 fields = collections.OrderedDict()
456
457 # Find our base classes in reverse MRO order, and exclude
458 # ourselves. In reversed order so that more derived classes
459 # override earlier field definitions in base classes.
460 for b in cls.__mro__[-1:0:-1]:
461 # Only process classes that have been processed by our
462 # decorator. That is, they have a _MARKER attribute.
463 base_fields = getattr(b, _MARKER, None)
464 if base_fields:
465 for f in base_fields.values():
466 fields[f.name] = f
467
468 # Now find fields in our class. While doing so, validate some
469 # things, and set the default values (as class attributes)
470 # where we can.
471 for f in _find_fields(cls):
472 fields[f.name] = f
473
474 # If the class attribute (which is the default value for
475 # this field) exists and is of type 'Field', replace it
476 # with the real default. This is so that normal class
477 # introspection sees a real default value, not a Field.
478 if isinstance(getattr(cls, f.name, None), Field):
Eric V. Smith03220fd2017-12-29 13:59:58 -0500479 if f.default is MISSING:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500480 # If there's no default, delete the class attribute.
481 # This happens if we specify field(repr=False), for
482 # example (that is, we specified a field object, but
483 # no default value). Also if we're using a default
484 # factory. The class attribute should not be set at
485 # all in the post-processed class.
486 delattr(cls, f.name)
487 else:
488 setattr(cls, f.name, f.default)
489
490 # Remember all of the fields on our class (including bases). This
491 # marks this class as being a dataclass.
492 setattr(cls, _MARKER, fields)
493
494 # We also need to check if a parent class is frozen: frozen has to
495 # be inherited down.
496 is_frozen = frozen or cls.__setattr__ is _frozen_setattr
497
498 # If we're generating ordering methods, we must be generating
499 # the eq methods.
500 if order and not eq:
501 raise ValueError('eq must be true if order is true')
502
503 if init:
504 # Does this class have a post-init function?
505 has_post_init = hasattr(cls, _POST_INIT_NAME)
506
507 # Include InitVars and regular fields (so, not ClassVars).
508 _set_attribute(cls, '__init__',
509 _init_fn(list(filter(lambda f: f._field_type
510 in (_FIELD, _FIELD_INITVAR),
511 fields.values())),
512 is_frozen,
513 has_post_init,
514 # The name to use for the "self" param
515 # in __init__. Use "self" if possible.
516 '__dataclass_self__' if 'self' in fields
517 else 'self',
518 ))
519
520 # Get the fields as a list, and include only real fields. This is
521 # used in all of the following methods.
522 field_list = list(filter(lambda f: f._field_type is _FIELD,
523 fields.values()))
524
525 if repr:
526 _set_attribute(cls, '__repr__',
527 _repr_fn(list(filter(lambda f: f.repr, field_list))))
528
529 if is_frozen:
530 _set_attribute(cls, '__setattr__', _frozen_setattr)
531 _set_attribute(cls, '__delattr__', _frozen_delattr)
532
533 generate_hash = False
534 if hash is None:
535 if eq and frozen:
536 # Generate a hash function.
537 generate_hash = True
538 elif eq and not frozen:
539 # Not hashable.
540 _set_attribute(cls, '__hash__', None)
541 elif not eq:
542 # Otherwise, use the base class definition of hash(). That is,
543 # don't set anything on this class.
544 pass
545 else:
546 assert "can't get here"
547 else:
548 generate_hash = hash
549 if generate_hash:
550 _set_attribute(cls, '__hash__',
551 _hash_fn(list(filter(lambda f: f.compare
552 if f.hash is None
553 else f.hash,
554 field_list))))
555
556 if eq:
557 # Create and __eq__ and __ne__ methods.
558 _set_eq_fns(cls, list(filter(lambda f: f.compare, field_list)))
559
560 if order:
561 # Create and __lt__, __le__, __gt__, and __ge__ methods.
562 # Create and set the comparison functions.
563 _set_order_fns(cls, list(filter(lambda f: f.compare, field_list)))
564
565 if not getattr(cls, '__doc__'):
566 # Create a class doc-string.
567 cls.__doc__ = (cls.__name__ +
568 str(inspect.signature(cls)).replace(' -> None', ''))
569
570 return cls
571
572
573# _cls should never be specified by keyword, so start it with an
574# underscore. The presense of _cls is used to detect if this
575# decorator is being called with parameters or not.
576def dataclass(_cls=None, *, init=True, repr=True, eq=True, order=False,
577 hash=None, frozen=False):
578 """Returns the same class as was passed in, with dunder methods
579 added based on the fields defined in the class.
580
581 Examines PEP 526 __annotations__ to determine fields.
582
583 If init is true, an __init__() method is added to the class. If
584 repr is true, a __repr__() method is added. If order is true, rich
585 comparison dunder methods are added. If hash is true, a __hash__()
586 method function is added. If frozen is true, fields may not be
587 assigned to after instance creation.
588 """
589
590 def wrap(cls):
591 return _process_class(cls, repr, eq, order, hash, init, frozen)
592
593 # See if we're being called as @dataclass or @dataclass().
594 if _cls is None:
595 # We're called with parens.
596 return wrap
597
598 # We're called as @dataclass without parens.
599 return wrap(_cls)
600
601
602def fields(class_or_instance):
603 """Return a tuple describing the fields of this dataclass.
604
605 Accepts a dataclass or an instance of one. Tuple elements are of
606 type Field.
607 """
608
609 # Might it be worth caching this, per class?
610 try:
611 fields = getattr(class_or_instance, _MARKER)
612 except AttributeError:
613 raise TypeError('must be called with a dataclass type or instance')
614
615 # Exclude pseudo-fields.
616 return tuple(f for f in fields.values() if f._field_type is _FIELD)
617
618
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500619def _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500620 """Returns True if obj is an instance of a dataclass."""
621 return not isinstance(obj, type) and hasattr(obj, _MARKER)
622
623
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500624def is_dataclass(obj):
625 """Returns True if obj is a dataclass or an instance of a
626 dataclass."""
627 return hasattr(obj, _MARKER)
628
629
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500630def asdict(obj, *, dict_factory=dict):
631 """Return the fields of a dataclass instance as a new dictionary mapping
632 field names to field values.
633
634 Example usage:
635
636 @dataclass
637 class C:
638 x: int
639 y: int
640
641 c = C(1, 2)
642 assert asdict(c) == {'x': 1, 'y': 2}
643
644 If given, 'dict_factory' will be used instead of built-in dict.
645 The function applies recursively to field values that are
646 dataclass instances. This will also look into built-in containers:
647 tuples, lists, and dicts.
648 """
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500649 if not _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500650 raise TypeError("asdict() should be called on dataclass instances")
651 return _asdict_inner(obj, dict_factory)
652
653def _asdict_inner(obj, dict_factory):
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500654 if _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500655 result = []
656 for f in fields(obj):
657 value = _asdict_inner(getattr(obj, f.name), dict_factory)
658 result.append((f.name, value))
659 return dict_factory(result)
660 elif isinstance(obj, (list, tuple)):
661 return type(obj)(_asdict_inner(v, dict_factory) for v in obj)
662 elif isinstance(obj, dict):
663 return type(obj)((_asdict_inner(k, dict_factory), _asdict_inner(v, dict_factory))
664 for k, v in obj.items())
665 else:
666 return deepcopy(obj)
667
668
669def astuple(obj, *, tuple_factory=tuple):
670 """Return the fields of a dataclass instance as a new tuple of field values.
671
672 Example usage::
673
674 @dataclass
675 class C:
676 x: int
677 y: int
678
679 c = C(1, 2)
680 assert asdtuple(c) == (1, 2)
681
682 If given, 'tuple_factory' will be used instead of built-in tuple.
683 The function applies recursively to field values that are
684 dataclass instances. This will also look into built-in containers:
685 tuples, lists, and dicts.
686 """
687
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500688 if not _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500689 raise TypeError("astuple() should be called on dataclass instances")
690 return _astuple_inner(obj, tuple_factory)
691
692def _astuple_inner(obj, tuple_factory):
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500693 if _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 result = []
695 for f in fields(obj):
696 value = _astuple_inner(getattr(obj, f.name), tuple_factory)
697 result.append(value)
698 return tuple_factory(result)
699 elif isinstance(obj, (list, tuple)):
700 return type(obj)(_astuple_inner(v, tuple_factory) for v in obj)
701 elif isinstance(obj, dict):
702 return type(obj)((_astuple_inner(k, tuple_factory), _astuple_inner(v, tuple_factory))
703 for k, v in obj.items())
704 else:
705 return deepcopy(obj)
706
707
708def make_dataclass(cls_name, fields, *, bases=(), namespace=None):
709 """Return a new dynamically created dataclass.
710
Eric V. Smithed7d4292018-01-06 16:14:03 -0500711 The dataclass name will be 'cls_name'. 'fields' is an iterable
712 of either (name), (name, type) or (name, type, Field) objects. If type is
713 omitted, use the string 'typing.Any'. Field objects are created by
714 calling 'field(name, type [, Field])'.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500715
716 C = make_class('C', [('a', int', ('b', int, Field(init=False))], bases=Base)
717
718 is equivalent to:
719
720 @dataclass
721 class C(Base):
722 a: int
723 b: int = field(init=False)
724
725 For the bases and namespace paremeters, see the builtin type() function.
726 """
727
728 if namespace is None:
729 namespace = {}
730 else:
731 # Copy namespace since we're going to mutate it.
732 namespace = namespace.copy()
733
Eric V. Smithed7d4292018-01-06 16:14:03 -0500734 anns = collections.OrderedDict()
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500735 for item in fields:
Eric V. Smithed7d4292018-01-06 16:14:03 -0500736 if isinstance(item, str):
737 name = item
738 tp = 'typing.Any'
739 elif len(item) == 2:
740 name, tp, = item
741 elif len(item) == 3:
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 name, tp, spec = item
743 namespace[name] = spec
Eric V. Smithed7d4292018-01-06 16:14:03 -0500744 anns[name] = tp
745
746 namespace['__annotations__'] = anns
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500747 cls = type(cls_name, bases, namespace)
748 return dataclass(cls)
749
750
751def replace(obj, **changes):
752 """Return a new object replacing specified fields with new values.
753
754 This is especially useful for frozen classes. Example usage:
755
756 @dataclass(frozen=True)
757 class C:
758 x: int
759 y: int
760
761 c = C(1, 2)
762 c1 = replace(c, x=3)
763 assert c1.x == 3 and c1.y == 2
764 """
765
766 # We're going to mutate 'changes', but that's okay because it's a new
767 # dict, even if called with 'replace(obj, **my_changes)'.
768
Eric V. Smithe7ba0132018-01-06 12:41:53 -0500769 if not _is_dataclass_instance(obj):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500770 raise TypeError("replace() should be called on dataclass instances")
771
772 # It's an error to have init=False fields in 'changes'.
773 # If a field is not in 'changes', read its value from the provided obj.
774
775 for f in getattr(obj, _MARKER).values():
776 if not f.init:
777 # Error if this field is specified in changes.
778 if f.name in changes:
779 raise ValueError(f'field {f.name} is declared with '
780 'init=False, it cannot be specified with '
781 'replace()')
782 continue
783
784 if f.name not in changes:
785 changes[f.name] = getattr(obj, f.name)
786
787 # Create the new object, which calls __init__() and __post_init__
788 # (if defined), using all of the init fields we've added and/or
789 # left in 'changes'.
790 # If there are values supplied in changes that aren't fields, this
791 # will correctly raise a TypeError.
792 return obj.__class__(**changes)