blob: c44c53d039d5bc157a1102bb38a6a9ace8bf20b4 [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
Eric V. Smithe7ba0132018-01-06 12:41:53 -05003 make_dataclass, replace, InitVar, Field, MISSING, is_dataclass,
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
12
13# Just any custom exception we can catch.
14class CustomError(Exception): pass
15
16class TestCase(unittest.TestCase):
17 def test_no_fields(self):
18 @dataclass
19 class C:
20 pass
21
22 o = C()
23 self.assertEqual(len(fields(C)), 0)
24
25 def test_one_field_no_default(self):
26 @dataclass
27 class C:
28 x: int
29
30 o = C(42)
31 self.assertEqual(o.x, 42)
32
33 def test_named_init_params(self):
34 @dataclass
35 class C:
36 x: int
37
38 o = C(x=32)
39 self.assertEqual(o.x, 32)
40
41 def test_two_fields_one_default(self):
42 @dataclass
43 class C:
44 x: int
45 y: int = 0
46
47 o = C(3)
48 self.assertEqual((o.x, o.y), (3, 0))
49
50 # Non-defaults following defaults.
51 with self.assertRaisesRegex(TypeError,
52 "non-default argument 'y' follows "
53 "default argument"):
54 @dataclass
55 class C:
56 x: int = 0
57 y: int
58
59 # A derived class adds a non-default field after a default one.
60 with self.assertRaisesRegex(TypeError,
61 "non-default argument 'y' follows "
62 "default argument"):
63 @dataclass
64 class B:
65 x: int = 0
66
67 @dataclass
68 class C(B):
69 y: int
70
71 # Override a base class field and add a default to
72 # a field which didn't use to have a default.
73 with self.assertRaisesRegex(TypeError,
74 "non-default argument 'y' follows "
75 "default argument"):
76 @dataclass
77 class B:
78 x: int
79 y: int
80
81 @dataclass
82 class C(B):
83 x: int = 0
84
85 def test_overwriting_init(self):
86 with self.assertRaisesRegex(TypeError,
87 'Cannot overwrite attribute __init__ '
88 'in C'):
89 @dataclass
90 class C:
91 x: int
92 def __init__(self, x):
93 self.x = 2 * x
94
95 @dataclass(init=False)
96 class C:
97 x: int
98 def __init__(self, x):
99 self.x = 2 * x
100 self.assertEqual(C(5).x, 10)
101
102 def test_overwriting_repr(self):
103 with self.assertRaisesRegex(TypeError,
104 'Cannot overwrite attribute __repr__ '
105 'in C'):
106 @dataclass
107 class C:
108 x: int
109 def __repr__(self):
110 pass
111
112 @dataclass(repr=False)
113 class C:
114 x: int
115 def __repr__(self):
116 return 'x'
117 self.assertEqual(repr(C(0)), 'x')
118
119 def test_overwriting_cmp(self):
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __eq__ '
122 'in C'):
123 # This will generate the comparison functions, make sure we can't
124 # overwrite them.
125 @dataclass(hash=False, frozen=False)
126 class C:
127 x: int
128 def __eq__(self):
129 pass
130
131 @dataclass(order=False, eq=False)
132 class C:
133 x: int
134 def __eq__(self, other):
135 return True
136 self.assertEqual(C(0), 'x')
137
138 def test_overwriting_hash(self):
139 with self.assertRaisesRegex(TypeError,
140 'Cannot overwrite attribute __hash__ '
141 'in C'):
142 @dataclass(frozen=True)
143 class C:
144 x: int
145 def __hash__(self):
146 pass
147
148 @dataclass(frozen=True,hash=False)
149 class C:
150 x: int
151 def __hash__(self):
152 return 600
153 self.assertEqual(hash(C(0)), 600)
154
155 with self.assertRaisesRegex(TypeError,
156 'Cannot overwrite attribute __hash__ '
157 'in C'):
158 @dataclass(frozen=True)
159 class C:
160 x: int
161 def __hash__(self):
162 pass
163
164 @dataclass(frozen=True, hash=False)
165 class C:
166 x: int
167 def __hash__(self):
168 return 600
169 self.assertEqual(hash(C(0)), 600)
170
171 def test_overwriting_frozen(self):
172 # frozen uses __setattr__ and __delattr__
173 with self.assertRaisesRegex(TypeError,
174 'Cannot overwrite attribute __setattr__ '
175 'in C'):
176 @dataclass(frozen=True)
177 class C:
178 x: int
179 def __setattr__(self):
180 pass
181
182 with self.assertRaisesRegex(TypeError,
183 'Cannot overwrite attribute __delattr__ '
184 'in C'):
185 @dataclass(frozen=True)
186 class C:
187 x: int
188 def __delattr__(self):
189 pass
190
191 @dataclass(frozen=False)
192 class C:
193 x: int
194 def __setattr__(self, name, value):
195 self.__dict__['x'] = value * 2
196 self.assertEqual(C(10).x, 20)
197
198 def test_overwrite_fields_in_derived_class(self):
199 # Note that x from C1 replaces x in Base, but the order remains
200 # the same as defined in Base.
201 @dataclass
202 class Base:
203 x: Any = 15.0
204 y: int = 0
205
206 @dataclass
207 class C1(Base):
208 z: int = 10
209 x: int = 15
210
211 o = Base()
212 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
213
214 o = C1()
215 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
216
217 o = C1(x=5)
218 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
219
220 def test_field_named_self(self):
221 @dataclass
222 class C:
223 self: str
224 c=C('foo')
225 self.assertEqual(c.self, 'foo')
226
227 # Make sure the first parameter is not named 'self'.
228 sig = inspect.signature(C.__init__)
229 first = next(iter(sig.parameters))
230 self.assertNotEqual('self', first)
231
232 # But we do use 'self' if no field named self.
233 @dataclass
234 class C:
235 selfx: str
236
237 # Make sure the first parameter is named 'self'.
238 sig = inspect.signature(C.__init__)
239 first = next(iter(sig.parameters))
240 self.assertEqual('self', first)
241
242 def test_repr(self):
243 @dataclass
244 class B:
245 x: int
246
247 @dataclass
248 class C(B):
249 y: int = 10
250
251 o = C(4)
252 self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
253
254 @dataclass
255 class D(C):
256 x: int = 20
257 self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
258
259 @dataclass
260 class C:
261 @dataclass
262 class D:
263 i: int
264 @dataclass
265 class E:
266 pass
267 self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
268 self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
269
270 def test_0_field_compare(self):
271 # Ensure that order=False is the default.
272 @dataclass
273 class C0:
274 pass
275
276 @dataclass(order=False)
277 class C1:
278 pass
279
280 for cls in [C0, C1]:
281 with self.subTest(cls=cls):
282 self.assertEqual(cls(), cls())
283 for idx, fn in enumerate([lambda a, b: a < b,
284 lambda a, b: a <= b,
285 lambda a, b: a > b,
286 lambda a, b: a >= b]):
287 with self.subTest(idx=idx):
288 with self.assertRaisesRegex(TypeError,
289 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
290 fn(cls(), cls())
291
292 @dataclass(order=True)
293 class C:
294 pass
295 self.assertLessEqual(C(), C())
296 self.assertGreaterEqual(C(), C())
297
298 def test_1_field_compare(self):
299 # Ensure that order=False is the default.
300 @dataclass
301 class C0:
302 x: int
303
304 @dataclass(order=False)
305 class C1:
306 x: int
307
308 for cls in [C0, C1]:
309 with self.subTest(cls=cls):
310 self.assertEqual(cls(1), cls(1))
311 self.assertNotEqual(cls(0), cls(1))
312 for idx, fn in enumerate([lambda a, b: a < b,
313 lambda a, b: a <= b,
314 lambda a, b: a > b,
315 lambda a, b: a >= b]):
316 with self.subTest(idx=idx):
317 with self.assertRaisesRegex(TypeError,
318 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
319 fn(cls(0), cls(0))
320
321 @dataclass(order=True)
322 class C:
323 x: int
324 self.assertLess(C(0), C(1))
325 self.assertLessEqual(C(0), C(1))
326 self.assertLessEqual(C(1), C(1))
327 self.assertGreater(C(1), C(0))
328 self.assertGreaterEqual(C(1), C(0))
329 self.assertGreaterEqual(C(1), C(1))
330
331 def test_simple_compare(self):
332 # Ensure that order=False is the default.
333 @dataclass
334 class C0:
335 x: int
336 y: int
337
338 @dataclass(order=False)
339 class C1:
340 x: int
341 y: int
342
343 for cls in [C0, C1]:
344 with self.subTest(cls=cls):
345 self.assertEqual(cls(0, 0), cls(0, 0))
346 self.assertEqual(cls(1, 2), cls(1, 2))
347 self.assertNotEqual(cls(1, 0), cls(0, 0))
348 self.assertNotEqual(cls(1, 0), cls(1, 1))
349 for idx, fn in enumerate([lambda a, b: a < b,
350 lambda a, b: a <= b,
351 lambda a, b: a > b,
352 lambda a, b: a >= b]):
353 with self.subTest(idx=idx):
354 with self.assertRaisesRegex(TypeError,
355 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
356 fn(cls(0, 0), cls(0, 0))
357
358 @dataclass(order=True)
359 class C:
360 x: int
361 y: int
362
363 for idx, fn in enumerate([lambda a, b: a == b,
364 lambda a, b: a <= b,
365 lambda a, b: a >= b]):
366 with self.subTest(idx=idx):
367 self.assertTrue(fn(C(0, 0), C(0, 0)))
368
369 for idx, fn in enumerate([lambda a, b: a < b,
370 lambda a, b: a <= b,
371 lambda a, b: a != b]):
372 with self.subTest(idx=idx):
373 self.assertTrue(fn(C(0, 0), C(0, 1)))
374 self.assertTrue(fn(C(0, 1), C(1, 0)))
375 self.assertTrue(fn(C(1, 0), C(1, 1)))
376
377 for idx, fn in enumerate([lambda a, b: a > b,
378 lambda a, b: a >= b,
379 lambda a, b: a != b]):
380 with self.subTest(idx=idx):
381 self.assertTrue(fn(C(0, 1), C(0, 0)))
382 self.assertTrue(fn(C(1, 0), C(0, 1)))
383 self.assertTrue(fn(C(1, 1), C(1, 0)))
384
385 def test_compare_subclasses(self):
386 # Comparisons fail for subclasses, even if no fields
387 # are added.
388 @dataclass
389 class B:
390 i: int
391
392 @dataclass
393 class C(B):
394 pass
395
396 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
397 (lambda a, b: a != b, True)]):
398 with self.subTest(idx=idx):
399 self.assertEqual(fn(B(0), C(0)), expected)
400
401 for idx, fn in enumerate([lambda a, b: a < b,
402 lambda a, b: a <= b,
403 lambda a, b: a > b,
404 lambda a, b: a >= b]):
405 with self.subTest(idx=idx):
406 with self.assertRaisesRegex(TypeError,
407 "not supported between instances of 'B' and 'C'"):
408 fn(B(0), C(0))
409
410 def test_0_field_hash(self):
411 @dataclass(hash=True)
412 class C:
413 pass
414 self.assertEqual(hash(C()), hash(()))
415
416 def test_1_field_hash(self):
417 @dataclass(hash=True)
418 class C:
419 x: int
420 self.assertEqual(hash(C(4)), hash((4,)))
421 self.assertEqual(hash(C(42)), hash((42,)))
422
423 def test_hash(self):
424 @dataclass(hash=True)
425 class C:
426 x: int
427 y: str
428 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
429
430 def test_no_hash(self):
431 @dataclass(hash=None)
432 class C:
433 x: int
434 with self.assertRaisesRegex(TypeError,
435 "unhashable type: 'C'"):
436 hash(C(1))
437
438 def test_hash_rules(self):
439 # There are 24 cases of:
440 # hash=True/False/None
441 # eq=True/False
442 # order=True/False
443 # frozen=True/False
444 for (hash, eq, order, frozen, result ) in [
445 (False, False, False, False, 'absent'),
446 (False, False, False, True, 'absent'),
447 (False, False, True, False, 'exception'),
448 (False, False, True, True, 'exception'),
449 (False, True, False, False, 'absent'),
450 (False, True, False, True, 'absent'),
451 (False, True, True, False, 'absent'),
452 (False, True, True, True, 'absent'),
453 (True, False, False, False, 'fn'),
454 (True, False, False, True, 'fn'),
455 (True, False, True, False, 'exception'),
456 (True, False, True, True, 'exception'),
457 (True, True, False, False, 'fn'),
458 (True, True, False, True, 'fn'),
459 (True, True, True, False, 'fn'),
460 (True, True, True, True, 'fn'),
461 (None, False, False, False, 'absent'),
462 (None, False, False, True, 'absent'),
463 (None, False, True, False, 'exception'),
464 (None, False, True, True, 'exception'),
465 (None, True, False, False, 'none'),
466 (None, True, False, True, 'fn'),
467 (None, True, True, False, 'none'),
468 (None, True, True, True, 'fn'),
469 ]:
470 with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
471 if result == 'exception':
472 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
473 @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
474 class C:
475 pass
476 else:
477 @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
478 class C:
479 pass
480
481 # See if the result matches what's expected.
482 if result == 'fn':
483 # __hash__ contains the function we generated.
484 self.assertIn('__hash__', C.__dict__)
485 self.assertIsNotNone(C.__dict__['__hash__'])
486 elif result == 'absent':
487 # __hash__ is not present in our class.
488 self.assertNotIn('__hash__', C.__dict__)
489 elif result == 'none':
490 # __hash__ is set to None.
491 self.assertIn('__hash__', C.__dict__)
492 self.assertIsNone(C.__dict__['__hash__'])
493 else:
494 assert False, f'unknown result {result!r}'
495
496 def test_eq_order(self):
497 for (eq, order, result ) in [
498 (False, False, 'neither'),
499 (False, True, 'exception'),
500 (True, False, 'eq_only'),
501 (True, True, 'both'),
502 ]:
503 with self.subTest(eq=eq, order=order):
504 if result == 'exception':
505 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
506 @dataclass(eq=eq, order=order)
507 class C:
508 pass
509 else:
510 @dataclass(eq=eq, order=order)
511 class C:
512 pass
513
514 if result == 'neither':
515 self.assertNotIn('__eq__', C.__dict__)
516 self.assertNotIn('__ne__', C.__dict__)
517 self.assertNotIn('__lt__', C.__dict__)
518 self.assertNotIn('__le__', C.__dict__)
519 self.assertNotIn('__gt__', C.__dict__)
520 self.assertNotIn('__ge__', C.__dict__)
521 elif result == 'both':
522 self.assertIn('__eq__', C.__dict__)
523 self.assertIn('__ne__', C.__dict__)
524 self.assertIn('__lt__', C.__dict__)
525 self.assertIn('__le__', C.__dict__)
526 self.assertIn('__gt__', C.__dict__)
527 self.assertIn('__ge__', C.__dict__)
528 elif result == 'eq_only':
529 self.assertIn('__eq__', C.__dict__)
530 self.assertIn('__ne__', C.__dict__)
531 self.assertNotIn('__lt__', C.__dict__)
532 self.assertNotIn('__le__', C.__dict__)
533 self.assertNotIn('__gt__', C.__dict__)
534 self.assertNotIn('__ge__', C.__dict__)
535 else:
536 assert False, f'unknown result {result!r}'
537
538 def test_field_no_default(self):
539 @dataclass
540 class C:
541 x: int = field()
542
543 self.assertEqual(C(5).x, 5)
544
545 with self.assertRaisesRegex(TypeError,
546 r"__init__\(\) missing 1 required "
547 "positional argument: 'x'"):
548 C()
549
550 def test_field_default(self):
551 default = object()
552 @dataclass
553 class C:
554 x: object = field(default=default)
555
556 self.assertIs(C.x, default)
557 c = C(10)
558 self.assertEqual(c.x, 10)
559
560 # If we delete the instance attribute, we should then see the
561 # class attribute.
562 del c.x
563 self.assertIs(c.x, default)
564
565 self.assertIs(C().x, default)
566
567 def test_not_in_repr(self):
568 @dataclass
569 class C:
570 x: int = field(repr=False)
571 with self.assertRaises(TypeError):
572 C()
573 c = C(10)
574 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
575
576 @dataclass
577 class C:
578 x: int = field(repr=False)
579 y: int
580 c = C(10, 20)
581 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
582
583 def test_not_in_compare(self):
584 @dataclass
585 class C:
586 x: int = 0
587 y: int = field(compare=False, default=4)
588
589 self.assertEqual(C(), C(0, 20))
590 self.assertEqual(C(1, 10), C(1, 20))
591 self.assertNotEqual(C(3), C(4, 10))
592 self.assertNotEqual(C(3, 10), C(4, 10))
593
594 def test_hash_field_rules(self):
595 # Test all 6 cases of:
596 # hash=True/False/None
597 # compare=True/False
598 for (hash_val, compare, result ) in [
599 (True, False, 'field' ),
600 (True, True, 'field' ),
601 (False, False, 'absent'),
602 (False, True, 'absent'),
603 (None, False, 'absent'),
604 (None, True, 'field' ),
605 ]:
606 with self.subTest(hash_val=hash_val, compare=compare):
607 @dataclass(hash=True)
608 class C:
609 x: int = field(compare=compare, hash=hash_val, default=5)
610
611 if result == 'field':
612 # __hash__ contains the field.
613 self.assertEqual(C(5).__hash__(), hash((5,)))
614 elif result == 'absent':
615 # The field is not present in the hash.
616 self.assertEqual(C(5).__hash__(), hash(()))
617 else:
618 assert False, f'unknown result {result!r}'
619
620 def test_init_false_no_default(self):
621 # If init=False and no default value, then the field won't be
622 # present in the instance.
623 @dataclass
624 class C:
625 x: int = field(init=False)
626
627 self.assertNotIn('x', C().__dict__)
628
629 @dataclass
630 class C:
631 x: int
632 y: int = 0
633 z: int = field(init=False)
634 t: int = 10
635
636 self.assertNotIn('z', C(0).__dict__)
637 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
638
639 def test_class_marker(self):
640 @dataclass
641 class C:
642 x: int
643 y: str = field(init=False, default=None)
644 z: str = field(repr=False)
645
646 the_fields = fields(C)
647 # the_fields is a tuple of 3 items, each value
648 # is in __annotations__.
649 self.assertIsInstance(the_fields, tuple)
650 for f in the_fields:
651 self.assertIs(type(f), Field)
652 self.assertIn(f.name, C.__annotations__)
653
654 self.assertEqual(len(the_fields), 3)
655
656 self.assertEqual(the_fields[0].name, 'x')
657 self.assertEqual(the_fields[0].type, int)
658 self.assertFalse(hasattr(C, 'x'))
659 self.assertTrue (the_fields[0].init)
660 self.assertTrue (the_fields[0].repr)
661 self.assertEqual(the_fields[1].name, 'y')
662 self.assertEqual(the_fields[1].type, str)
663 self.assertIsNone(getattr(C, 'y'))
664 self.assertFalse(the_fields[1].init)
665 self.assertTrue (the_fields[1].repr)
666 self.assertEqual(the_fields[2].name, 'z')
667 self.assertEqual(the_fields[2].type, str)
668 self.assertFalse(hasattr(C, 'z'))
669 self.assertTrue (the_fields[2].init)
670 self.assertFalse(the_fields[2].repr)
671
672 def test_field_order(self):
673 @dataclass
674 class B:
675 a: str = 'B:a'
676 b: str = 'B:b'
677 c: str = 'B:c'
678
679 @dataclass
680 class C(B):
681 b: str = 'C:b'
682
683 self.assertEqual([(f.name, f.default) for f in fields(C)],
684 [('a', 'B:a'),
685 ('b', 'C:b'),
686 ('c', 'B:c')])
687
688 @dataclass
689 class D(B):
690 c: str = 'D:c'
691
692 self.assertEqual([(f.name, f.default) for f in fields(D)],
693 [('a', 'B:a'),
694 ('b', 'B:b'),
695 ('c', 'D:c')])
696
697 @dataclass
698 class E(D):
699 a: str = 'E:a'
700 d: str = 'E:d'
701
702 self.assertEqual([(f.name, f.default) for f in fields(E)],
703 [('a', 'E:a'),
704 ('b', 'B:b'),
705 ('c', 'D:c'),
706 ('d', 'E:d')])
707
708 def test_class_attrs(self):
709 # We only have a class attribute if a default value is
710 # specified, either directly or via a field with a default.
711 default = object()
712 @dataclass
713 class C:
714 x: int
715 y: int = field(repr=False)
716 z: object = default
717 t: int = field(default=100)
718
719 self.assertFalse(hasattr(C, 'x'))
720 self.assertFalse(hasattr(C, 'y'))
721 self.assertIs (C.z, default)
722 self.assertEqual(C.t, 100)
723
724 def test_disallowed_mutable_defaults(self):
725 # For the known types, don't allow mutable default values.
726 for typ, empty, non_empty in [(list, [], [1]),
727 (dict, {}, {0:1}),
728 (set, set(), set([1])),
729 ]:
730 with self.subTest(typ=typ):
731 # Can't use a zero-length value.
732 with self.assertRaisesRegex(ValueError,
733 f'mutable default {typ} for field '
734 'x is not allowed'):
735 @dataclass
736 class Point:
737 x: typ = empty
738
739
740 # Nor a non-zero-length value
741 with self.assertRaisesRegex(ValueError,
742 f'mutable default {typ} for field '
743 'y is not allowed'):
744 @dataclass
745 class Point:
746 y: typ = non_empty
747
748 # Check subtypes also fail.
749 class Subclass(typ): pass
750
751 with self.assertRaisesRegex(ValueError,
752 f"mutable default .*Subclass'>"
753 ' for field z is not allowed'
754 ):
755 @dataclass
756 class Point:
757 z: typ = Subclass()
758
759 # Because this is a ClassVar, it can be mutable.
760 @dataclass
761 class C:
762 z: ClassVar[typ] = typ()
763
764 # Because this is a ClassVar, it can be mutable.
765 @dataclass
766 class C:
767 x: ClassVar[typ] = Subclass()
768
769
770 def test_deliberately_mutable_defaults(self):
771 # If a mutable default isn't in the known list of
772 # (list, dict, set), then it's okay.
773 class Mutable:
774 def __init__(self):
775 self.l = []
776
777 @dataclass
778 class C:
779 x: Mutable
780
781 # These 2 instances will share this value of x.
782 lst = Mutable()
783 o1 = C(lst)
784 o2 = C(lst)
785 self.assertEqual(o1, o2)
786 o1.x.l.extend([1, 2])
787 self.assertEqual(o1, o2)
788 self.assertEqual(o1.x.l, [1, 2])
789 self.assertIs(o1.x, o2.x)
790
791 def test_no_options(self):
792 # call with dataclass()
793 @dataclass()
794 class C:
795 x: int
796
797 self.assertEqual(C(42).x, 42)
798
799 def test_not_tuple(self):
800 # Make sure we can't be compared to a tuple.
801 @dataclass
802 class Point:
803 x: int
804 y: int
805 self.assertNotEqual(Point(1, 2), (1, 2))
806
807 # And that we can't compare to another unrelated dataclass
808 @dataclass
809 class C:
810 x: int
811 y: int
812 self.assertNotEqual(Point(1, 3), C(1, 3))
813
814 def test_base_has_init(self):
815 class B:
816 def __init__(self):
817 pass
818
819 # Make sure that declaring this class doesn't raise an error.
820 # The issue is that we can't override __init__ in our class,
821 # but it should be okay to add __init__ to us if our base has
822 # an __init__.
823 @dataclass
824 class C(B):
825 x: int = 0
826
827 def test_frozen(self):
828 @dataclass(frozen=True)
829 class C:
830 i: int
831
832 c = C(10)
833 self.assertEqual(c.i, 10)
834 with self.assertRaises(FrozenInstanceError):
835 c.i = 5
836 self.assertEqual(c.i, 10)
837
838 # Check that a derived class is still frozen, even if not
839 # marked so.
840 @dataclass
841 class D(C):
842 pass
843
844 d = D(20)
845 self.assertEqual(d.i, 20)
846 with self.assertRaises(FrozenInstanceError):
847 d.i = 5
848 self.assertEqual(d.i, 20)
849
850 def test_not_tuple(self):
851 # Test that some of the problems with namedtuple don't happen
852 # here.
853 @dataclass
854 class Point3D:
855 x: int
856 y: int
857 z: int
858
859 @dataclass
860 class Date:
861 year: int
862 month: int
863 day: int
864
865 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
866 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
867
868 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200869 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500870 x, y, z = Point3D(4, 5, 6)
871
872 # Maka sure another class with the same field names isn't
873 # equal.
874 @dataclass
875 class Point3Dv1:
876 x: int = 0
877 y: int = 0
878 z: int = 0
879 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
880
881 def test_function_annotations(self):
882 # Some dummy class and instance to use as a default.
883 class F:
884 pass
885 f = F()
886
887 def validate_class(cls):
888 # First, check __annotations__, even though they're not
889 # function annotations.
890 self.assertEqual(cls.__annotations__['i'], int)
891 self.assertEqual(cls.__annotations__['j'], str)
892 self.assertEqual(cls.__annotations__['k'], F)
893 self.assertEqual(cls.__annotations__['l'], float)
894 self.assertEqual(cls.__annotations__['z'], complex)
895
896 # Verify __init__.
897
898 signature = inspect.signature(cls.__init__)
899 # Check the return type, should be None
900 self.assertIs(signature.return_annotation, None)
901
902 # Check each parameter.
903 params = iter(signature.parameters.values())
904 param = next(params)
905 # This is testing an internal name, and probably shouldn't be tested.
906 self.assertEqual(param.name, 'self')
907 param = next(params)
908 self.assertEqual(param.name, 'i')
909 self.assertIs (param.annotation, int)
910 self.assertEqual(param.default, inspect.Parameter.empty)
911 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
912 param = next(params)
913 self.assertEqual(param.name, 'j')
914 self.assertIs (param.annotation, str)
915 self.assertEqual(param.default, inspect.Parameter.empty)
916 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
917 param = next(params)
918 self.assertEqual(param.name, 'k')
919 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500920 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500921 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
922 param = next(params)
923 self.assertEqual(param.name, 'l')
924 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500925 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500926 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
927 self.assertRaises(StopIteration, next, params)
928
929
930 @dataclass
931 class C:
932 i: int
933 j: str
934 k: F = f
935 l: float=field(default=None)
936 z: complex=field(default=3+4j, init=False)
937
938 validate_class(C)
939
940 # Now repeat with __hash__.
941 @dataclass(frozen=True, hash=True)
942 class C:
943 i: int
944 j: str
945 k: F = f
946 l: float=field(default=None)
947 z: complex=field(default=3+4j, init=False)
948
949 validate_class(C)
950
Eric V. Smith03220fd2017-12-29 13:59:58 -0500951 def test_missing_default(self):
952 # Test that MISSING works the same as a default not being
953 # specified.
954 @dataclass
955 class C:
956 x: int=field(default=MISSING)
957 with self.assertRaisesRegex(TypeError,
958 r'__init__\(\) missing 1 required '
959 'positional argument'):
960 C()
961 self.assertNotIn('x', C.__dict__)
962
963 @dataclass
964 class D:
965 x: int
966 with self.assertRaisesRegex(TypeError,
967 r'__init__\(\) missing 1 required '
968 'positional argument'):
969 D()
970 self.assertNotIn('x', D.__dict__)
971
972 def test_missing_default_factory(self):
973 # Test that MISSING works the same as a default factory not
974 # being specified (which is really the same as a default not
975 # being specified, too).
976 @dataclass
977 class C:
978 x: int=field(default_factory=MISSING)
979 with self.assertRaisesRegex(TypeError,
980 r'__init__\(\) missing 1 required '
981 'positional argument'):
982 C()
983 self.assertNotIn('x', C.__dict__)
984
985 @dataclass
986 class D:
987 x: int=field(default=MISSING, default_factory=MISSING)
988 with self.assertRaisesRegex(TypeError,
989 r'__init__\(\) missing 1 required '
990 'positional argument'):
991 D()
992 self.assertNotIn('x', D.__dict__)
993
994 def test_missing_repr(self):
995 self.assertIn('MISSING_TYPE object', repr(MISSING))
996
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500997 def test_dont_include_other_annotations(self):
998 @dataclass
999 class C:
1000 i: int
1001 def foo(self) -> int:
1002 return 4
1003 @property
1004 def bar(self) -> int:
1005 return 5
1006 self.assertEqual(list(C.__annotations__), ['i'])
1007 self.assertEqual(C(10).foo(), 4)
1008 self.assertEqual(C(10).bar, 5)
1009
1010 def test_post_init(self):
1011 # Just make sure it gets called
1012 @dataclass
1013 class C:
1014 def __post_init__(self):
1015 raise CustomError()
1016 with self.assertRaises(CustomError):
1017 C()
1018
1019 @dataclass
1020 class C:
1021 i: int = 10
1022 def __post_init__(self):
1023 if self.i == 10:
1024 raise CustomError()
1025 with self.assertRaises(CustomError):
1026 C()
1027 # post-init gets called, but doesn't raise. This is just
1028 # checking that self is used correctly.
1029 C(5)
1030
1031 # If there's not an __init__, then post-init won't get called.
1032 @dataclass(init=False)
1033 class C:
1034 def __post_init__(self):
1035 raise CustomError()
1036 # Creating the class won't raise
1037 C()
1038
1039 @dataclass
1040 class C:
1041 x: int = 0
1042 def __post_init__(self):
1043 self.x *= 2
1044 self.assertEqual(C().x, 0)
1045 self.assertEqual(C(2).x, 4)
1046
Mike53f7a7c2017-12-14 14:04:53 +03001047 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001048 # attributes.
1049 @dataclass(frozen=True)
1050 class C:
1051 x: int = 0
1052 def __post_init__(self):
1053 self.x *= 2
1054 with self.assertRaises(FrozenInstanceError):
1055 C()
1056
1057 def test_post_init_super(self):
1058 # Make sure super() post-init isn't called by default.
1059 class B:
1060 def __post_init__(self):
1061 raise CustomError()
1062
1063 @dataclass
1064 class C(B):
1065 def __post_init__(self):
1066 self.x = 5
1067
1068 self.assertEqual(C().x, 5)
1069
1070 # Now call super(), and it will raise
1071 @dataclass
1072 class C(B):
1073 def __post_init__(self):
1074 super().__post_init__()
1075
1076 with self.assertRaises(CustomError):
1077 C()
1078
1079 # Make sure post-init is called, even if not defined in our
1080 # class.
1081 @dataclass
1082 class C(B):
1083 pass
1084
1085 with self.assertRaises(CustomError):
1086 C()
1087
1088 def test_post_init_staticmethod(self):
1089 flag = False
1090 @dataclass
1091 class C:
1092 x: int
1093 y: int
1094 @staticmethod
1095 def __post_init__():
1096 nonlocal flag
1097 flag = True
1098
1099 self.assertFalse(flag)
1100 c = C(3, 4)
1101 self.assertEqual((c.x, c.y), (3, 4))
1102 self.assertTrue(flag)
1103
1104 def test_post_init_classmethod(self):
1105 @dataclass
1106 class C:
1107 flag = False
1108 x: int
1109 y: int
1110 @classmethod
1111 def __post_init__(cls):
1112 cls.flag = True
1113
1114 self.assertFalse(C.flag)
1115 c = C(3, 4)
1116 self.assertEqual((c.x, c.y), (3, 4))
1117 self.assertTrue(C.flag)
1118
1119 def test_class_var(self):
1120 # Make sure ClassVars are ignored in __init__, __repr__, etc.
1121 @dataclass
1122 class C:
1123 x: int
1124 y: int = 10
1125 z: ClassVar[int] = 1000
1126 w: ClassVar[int] = 2000
1127 t: ClassVar[int] = 3000
1128
1129 c = C(5)
1130 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
1131 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1132 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1133 self.assertEqual(c.z, 1000)
1134 self.assertEqual(c.w, 2000)
1135 self.assertEqual(c.t, 3000)
1136 C.z += 1
1137 self.assertEqual(c.z, 1001)
1138 c = C(20)
1139 self.assertEqual((c.x, c.y), (20, 10))
1140 self.assertEqual(c.z, 1001)
1141 self.assertEqual(c.w, 2000)
1142 self.assertEqual(c.t, 3000)
1143
1144 def test_class_var_no_default(self):
1145 # If a ClassVar has no default value, it should not be set on the class.
1146 @dataclass
1147 class C:
1148 x: ClassVar[int]
1149
1150 self.assertNotIn('x', C.__dict__)
1151
1152 def test_class_var_default_factory(self):
1153 # It makes no sense for a ClassVar to have a default factory. When
1154 # would it be called? Call it yourself, since it's class-wide.
1155 with self.assertRaisesRegex(TypeError,
1156 'cannot have a default factory'):
1157 @dataclass
1158 class C:
1159 x: ClassVar[int] = field(default_factory=int)
1160
1161 self.assertNotIn('x', C.__dict__)
1162
1163 def test_class_var_with_default(self):
1164 # If a ClassVar has a default value, it should be set on the class.
1165 @dataclass
1166 class C:
1167 x: ClassVar[int] = 10
1168 self.assertEqual(C.x, 10)
1169
1170 @dataclass
1171 class C:
1172 x: ClassVar[int] = field(default=10)
1173 self.assertEqual(C.x, 10)
1174
1175 def test_class_var_frozen(self):
1176 # Make sure ClassVars work even if we're frozen.
1177 @dataclass(frozen=True)
1178 class C:
1179 x: int
1180 y: int = 10
1181 z: ClassVar[int] = 1000
1182 w: ClassVar[int] = 2000
1183 t: ClassVar[int] = 3000
1184
1185 c = C(5)
1186 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1187 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1188 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1189 self.assertEqual(c.z, 1000)
1190 self.assertEqual(c.w, 2000)
1191 self.assertEqual(c.t, 3000)
1192 # We can still modify the ClassVar, it's only instances that are
1193 # frozen.
1194 C.z += 1
1195 self.assertEqual(c.z, 1001)
1196 c = C(20)
1197 self.assertEqual((c.x, c.y), (20, 10))
1198 self.assertEqual(c.z, 1001)
1199 self.assertEqual(c.w, 2000)
1200 self.assertEqual(c.t, 3000)
1201
1202 def test_init_var_no_default(self):
1203 # If an InitVar has no default value, it should not be set on the class.
1204 @dataclass
1205 class C:
1206 x: InitVar[int]
1207
1208 self.assertNotIn('x', C.__dict__)
1209
1210 def test_init_var_default_factory(self):
1211 # It makes no sense for an InitVar to have a default factory. When
1212 # would it be called? Call it yourself, since it's class-wide.
1213 with self.assertRaisesRegex(TypeError,
1214 'cannot have a default factory'):
1215 @dataclass
1216 class C:
1217 x: InitVar[int] = field(default_factory=int)
1218
1219 self.assertNotIn('x', C.__dict__)
1220
1221 def test_init_var_with_default(self):
1222 # If an InitVar has a default value, it should be set on the class.
1223 @dataclass
1224 class C:
1225 x: InitVar[int] = 10
1226 self.assertEqual(C.x, 10)
1227
1228 @dataclass
1229 class C:
1230 x: InitVar[int] = field(default=10)
1231 self.assertEqual(C.x, 10)
1232
1233 def test_init_var(self):
1234 @dataclass
1235 class C:
1236 x: int = None
1237 init_param: InitVar[int] = None
1238
1239 def __post_init__(self, init_param):
1240 if self.x is None:
1241 self.x = init_param*2
1242
1243 c = C(init_param=10)
1244 self.assertEqual(c.x, 20)
1245
1246 def test_init_var_inheritance(self):
1247 # Note that this deliberately tests that a dataclass need not
1248 # have a __post_init__ function if it has an InitVar field.
1249 # It could just be used in a derived class, as shown here.
1250 @dataclass
1251 class Base:
1252 x: int
1253 init_base: InitVar[int]
1254
1255 # We can instantiate by passing the InitVar, even though
1256 # it's not used.
1257 b = Base(0, 10)
1258 self.assertEqual(vars(b), {'x': 0})
1259
1260 @dataclass
1261 class C(Base):
1262 y: int
1263 init_derived: InitVar[int]
1264
1265 def __post_init__(self, init_base, init_derived):
1266 self.x = self.x + init_base
1267 self.y = self.y + init_derived
1268
1269 c = C(10, 11, 50, 51)
1270 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1271
1272 def test_default_factory(self):
1273 # Test a factory that returns a new list.
1274 @dataclass
1275 class C:
1276 x: int
1277 y: list = field(default_factory=list)
1278
1279 c0 = C(3)
1280 c1 = C(3)
1281 self.assertEqual(c0.x, 3)
1282 self.assertEqual(c0.y, [])
1283 self.assertEqual(c0, c1)
1284 self.assertIsNot(c0.y, c1.y)
1285 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1286
1287 # Test a factory that returns a shared list.
1288 l = []
1289 @dataclass
1290 class C:
1291 x: int
1292 y: list = field(default_factory=lambda: l)
1293
1294 c0 = C(3)
1295 c1 = C(3)
1296 self.assertEqual(c0.x, 3)
1297 self.assertEqual(c0.y, [])
1298 self.assertEqual(c0, c1)
1299 self.assertIs(c0.y, c1.y)
1300 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1301
1302 # Test various other field flags.
1303 # repr
1304 @dataclass
1305 class C:
1306 x: list = field(default_factory=list, repr=False)
1307 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1308 self.assertEqual(C().x, [])
1309
1310 # hash
1311 @dataclass(hash=True)
1312 class C:
1313 x: list = field(default_factory=list, hash=False)
1314 self.assertEqual(astuple(C()), ([],))
1315 self.assertEqual(hash(C()), hash(()))
1316
1317 # init (see also test_default_factory_with_no_init)
1318 @dataclass
1319 class C:
1320 x: list = field(default_factory=list, init=False)
1321 self.assertEqual(astuple(C()), ([],))
1322
1323 # compare
1324 @dataclass
1325 class C:
1326 x: list = field(default_factory=list, compare=False)
1327 self.assertEqual(C(), C([1]))
1328
1329 def test_default_factory_with_no_init(self):
1330 # We need a factory with a side effect.
1331 factory = Mock()
1332
1333 @dataclass
1334 class C:
1335 x: list = field(default_factory=factory, init=False)
1336
1337 # Make sure the default factory is called for each new instance.
1338 C().x
1339 self.assertEqual(factory.call_count, 1)
1340 C().x
1341 self.assertEqual(factory.call_count, 2)
1342
1343 def test_default_factory_not_called_if_value_given(self):
1344 # We need a factory that we can test if it's been called.
1345 factory = Mock()
1346
1347 @dataclass
1348 class C:
1349 x: int = field(default_factory=factory)
1350
1351 # Make sure that if a field has a default factory function,
1352 # it's not called if a value is specified.
1353 C().x
1354 self.assertEqual(factory.call_count, 1)
1355 self.assertEqual(C(10).x, 10)
1356 self.assertEqual(factory.call_count, 1)
1357 C().x
1358 self.assertEqual(factory.call_count, 2)
1359
1360 def x_test_classvar_default_factory(self):
1361 # XXX: it's an error for a ClassVar to have a factory function
1362 @dataclass
1363 class C:
1364 x: ClassVar[int] = field(default_factory=int)
1365
1366 self.assertIs(C().x, int)
1367
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001368 def test_is_dataclass(self):
1369 class NotDataClass:
1370 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001371
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001372 self.assertFalse(is_dataclass(0))
1373 self.assertFalse(is_dataclass(int))
1374 self.assertFalse(is_dataclass(NotDataClass))
1375 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001376
1377 @dataclass
1378 class C:
1379 x: int
1380
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001381 @dataclass
1382 class D:
1383 d: C
1384 e: int
1385
1386 c = C(10)
1387 d = D(c, 4)
1388
1389 self.assertTrue(is_dataclass(C))
1390 self.assertTrue(is_dataclass(c))
1391 self.assertFalse(is_dataclass(c.x))
1392 self.assertTrue(is_dataclass(d.d))
1393 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001394
1395 def test_helper_fields_with_class_instance(self):
1396 # Check that we can call fields() on either a class or instance,
1397 # and get back the same thing.
1398 @dataclass
1399 class C:
1400 x: int
1401 y: float
1402
1403 self.assertEqual(fields(C), fields(C(0, 0.0)))
1404
1405 def test_helper_fields_exception(self):
1406 # Check that TypeError is raised if not passed a dataclass or
1407 # instance.
1408 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1409 fields(0)
1410
1411 class C: pass
1412 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1413 fields(C)
1414 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1415 fields(C())
1416
1417 def test_helper_asdict(self):
1418 # Basic tests for asdict(), it should return a new dictionary
1419 @dataclass
1420 class C:
1421 x: int
1422 y: int
1423 c = C(1, 2)
1424
1425 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1426 self.assertEqual(asdict(c), asdict(c))
1427 self.assertIsNot(asdict(c), asdict(c))
1428 c.x = 42
1429 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1430 self.assertIs(type(asdict(c)), dict)
1431
1432 def test_helper_asdict_raises_on_classes(self):
1433 # asdict() should raise on a class object
1434 @dataclass
1435 class C:
1436 x: int
1437 y: int
1438 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1439 asdict(C)
1440 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1441 asdict(int)
1442
1443 def test_helper_asdict_copy_values(self):
1444 @dataclass
1445 class C:
1446 x: int
1447 y: List[int] = field(default_factory=list)
1448 initial = []
1449 c = C(1, initial)
1450 d = asdict(c)
1451 self.assertEqual(d['y'], initial)
1452 self.assertIsNot(d['y'], initial)
1453 c = C(1)
1454 d = asdict(c)
1455 d['y'].append(1)
1456 self.assertEqual(c.y, [])
1457
1458 def test_helper_asdict_nested(self):
1459 @dataclass
1460 class UserId:
1461 token: int
1462 group: int
1463 @dataclass
1464 class User:
1465 name: str
1466 id: UserId
1467 u = User('Joe', UserId(123, 1))
1468 d = asdict(u)
1469 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1470 self.assertIsNot(asdict(u), asdict(u))
1471 u.id.group = 2
1472 self.assertEqual(asdict(u), {'name': 'Joe',
1473 'id': {'token': 123, 'group': 2}})
1474
1475 def test_helper_asdict_builtin_containers(self):
1476 @dataclass
1477 class User:
1478 name: str
1479 id: int
1480 @dataclass
1481 class GroupList:
1482 id: int
1483 users: List[User]
1484 @dataclass
1485 class GroupTuple:
1486 id: int
1487 users: Tuple[User, ...]
1488 @dataclass
1489 class GroupDict:
1490 id: int
1491 users: Dict[str, User]
1492 a = User('Alice', 1)
1493 b = User('Bob', 2)
1494 gl = GroupList(0, [a, b])
1495 gt = GroupTuple(0, (a, b))
1496 gd = GroupDict(0, {'first': a, 'second': b})
1497 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1498 {'name': 'Bob', 'id': 2}]})
1499 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1500 {'name': 'Bob', 'id': 2})})
1501 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1502 'second': {'name': 'Bob', 'id': 2}}})
1503
1504 def test_helper_asdict_builtin_containers(self):
1505 @dataclass
1506 class Child:
1507 d: object
1508
1509 @dataclass
1510 class Parent:
1511 child: Child
1512
1513 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1514 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1515
1516 def test_helper_asdict_factory(self):
1517 @dataclass
1518 class C:
1519 x: int
1520 y: int
1521 c = C(1, 2)
1522 d = asdict(c, dict_factory=OrderedDict)
1523 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1524 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1525 c.x = 42
1526 d = asdict(c, dict_factory=OrderedDict)
1527 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1528 self.assertIs(type(d), OrderedDict)
1529
1530 def test_helper_astuple(self):
1531 # Basic tests for astuple(), it should return a new tuple
1532 @dataclass
1533 class C:
1534 x: int
1535 y: int = 0
1536 c = C(1)
1537
1538 self.assertEqual(astuple(c), (1, 0))
1539 self.assertEqual(astuple(c), astuple(c))
1540 self.assertIsNot(astuple(c), astuple(c))
1541 c.y = 42
1542 self.assertEqual(astuple(c), (1, 42))
1543 self.assertIs(type(astuple(c)), tuple)
1544
1545 def test_helper_astuple_raises_on_classes(self):
1546 # astuple() should raise on a class object
1547 @dataclass
1548 class C:
1549 x: int
1550 y: int
1551 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1552 astuple(C)
1553 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1554 astuple(int)
1555
1556 def test_helper_astuple_copy_values(self):
1557 @dataclass
1558 class C:
1559 x: int
1560 y: List[int] = field(default_factory=list)
1561 initial = []
1562 c = C(1, initial)
1563 t = astuple(c)
1564 self.assertEqual(t[1], initial)
1565 self.assertIsNot(t[1], initial)
1566 c = C(1)
1567 t = astuple(c)
1568 t[1].append(1)
1569 self.assertEqual(c.y, [])
1570
1571 def test_helper_astuple_nested(self):
1572 @dataclass
1573 class UserId:
1574 token: int
1575 group: int
1576 @dataclass
1577 class User:
1578 name: str
1579 id: UserId
1580 u = User('Joe', UserId(123, 1))
1581 t = astuple(u)
1582 self.assertEqual(t, ('Joe', (123, 1)))
1583 self.assertIsNot(astuple(u), astuple(u))
1584 u.id.group = 2
1585 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1586
1587 def test_helper_astuple_builtin_containers(self):
1588 @dataclass
1589 class User:
1590 name: str
1591 id: int
1592 @dataclass
1593 class GroupList:
1594 id: int
1595 users: List[User]
1596 @dataclass
1597 class GroupTuple:
1598 id: int
1599 users: Tuple[User, ...]
1600 @dataclass
1601 class GroupDict:
1602 id: int
1603 users: Dict[str, User]
1604 a = User('Alice', 1)
1605 b = User('Bob', 2)
1606 gl = GroupList(0, [a, b])
1607 gt = GroupTuple(0, (a, b))
1608 gd = GroupDict(0, {'first': a, 'second': b})
1609 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1610 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1611 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1612
1613 def test_helper_astuple_builtin_containers(self):
1614 @dataclass
1615 class Child:
1616 d: object
1617
1618 @dataclass
1619 class Parent:
1620 child: Child
1621
1622 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1623 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1624
1625 def test_helper_astuple_factory(self):
1626 @dataclass
1627 class C:
1628 x: int
1629 y: int
1630 NT = namedtuple('NT', 'x y')
1631 def nt(lst):
1632 return NT(*lst)
1633 c = C(1, 2)
1634 t = astuple(c, tuple_factory=nt)
1635 self.assertEqual(t, NT(1, 2))
1636 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1637 c.x = 42
1638 t = astuple(c, tuple_factory=nt)
1639 self.assertEqual(t, NT(42, 2))
1640 self.assertIs(type(t), NT)
1641
1642 def test_dynamic_class_creation(self):
1643 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1644 }
1645
1646 # Create the class.
1647 cls = type('C', (), cls_dict)
1648
1649 # Make it a dataclass.
1650 cls1 = dataclass(cls)
1651
1652 self.assertEqual(cls1, cls)
1653 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1654
1655 def test_dynamic_class_creation_using_field(self):
1656 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1657 'y': field(default=5),
1658 }
1659
1660 # Create the class.
1661 cls = type('C', (), cls_dict)
1662
1663 # Make it a dataclass.
1664 cls1 = dataclass(cls)
1665
1666 self.assertEqual(cls1, cls)
1667 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1668
1669 def test_init_in_order(self):
1670 @dataclass
1671 class C:
1672 a: int
1673 b: int = field()
1674 c: list = field(default_factory=list, init=False)
1675 d: list = field(default_factory=list)
1676 e: int = field(default=4, init=False)
1677 f: int = 4
1678
1679 calls = []
1680 def setattr(self, name, value):
1681 calls.append((name, value))
1682
1683 C.__setattr__ = setattr
1684 c = C(0, 1)
1685 self.assertEqual(('a', 0), calls[0])
1686 self.assertEqual(('b', 1), calls[1])
1687 self.assertEqual(('c', []), calls[2])
1688 self.assertEqual(('d', []), calls[3])
1689 self.assertNotIn(('e', 4), calls)
1690 self.assertEqual(('f', 4), calls[4])
1691
1692 def test_items_in_dicts(self):
1693 @dataclass
1694 class C:
1695 a: int
1696 b: list = field(default_factory=list, init=False)
1697 c: list = field(default_factory=list)
1698 d: int = field(default=4, init=False)
1699 e: int = 0
1700
1701 c = C(0)
1702 # Class dict
1703 self.assertNotIn('a', C.__dict__)
1704 self.assertNotIn('b', C.__dict__)
1705 self.assertNotIn('c', C.__dict__)
1706 self.assertIn('d', C.__dict__)
1707 self.assertEqual(C.d, 4)
1708 self.assertIn('e', C.__dict__)
1709 self.assertEqual(C.e, 0)
1710 # Instance dict
1711 self.assertIn('a', c.__dict__)
1712 self.assertEqual(c.a, 0)
1713 self.assertIn('b', c.__dict__)
1714 self.assertEqual(c.b, [])
1715 self.assertIn('c', c.__dict__)
1716 self.assertEqual(c.c, [])
1717 self.assertNotIn('d', c.__dict__)
1718 self.assertIn('e', c.__dict__)
1719 self.assertEqual(c.e, 0)
1720
1721 def test_alternate_classmethod_constructor(self):
1722 # Since __post_init__ can't take params, use a classmethod
1723 # alternate constructor. This is mostly an example to show how
1724 # to use this technique.
1725 @dataclass
1726 class C:
1727 x: int
1728 @classmethod
1729 def from_file(cls, filename):
1730 # In a real example, create a new instance
1731 # and populate 'x' from contents of a file.
1732 value_in_file = 20
1733 return cls(value_in_file)
1734
1735 self.assertEqual(C.from_file('filename').x, 20)
1736
1737 def test_field_metadata_default(self):
1738 # Make sure the default metadata is read-only and of
1739 # zero length.
1740 @dataclass
1741 class C:
1742 i: int
1743
1744 self.assertFalse(fields(C)[0].metadata)
1745 self.assertEqual(len(fields(C)[0].metadata), 0)
1746 with self.assertRaisesRegex(TypeError,
1747 'does not support item assignment'):
1748 fields(C)[0].metadata['test'] = 3
1749
1750 def test_field_metadata_mapping(self):
1751 # Make sure only a mapping can be passed as metadata
1752 # zero length.
1753 with self.assertRaises(TypeError):
1754 @dataclass
1755 class C:
1756 i: int = field(metadata=0)
1757
1758 # Make sure an empty dict works
1759 @dataclass
1760 class C:
1761 i: int = field(metadata={})
1762 self.assertFalse(fields(C)[0].metadata)
1763 self.assertEqual(len(fields(C)[0].metadata), 0)
1764 with self.assertRaisesRegex(TypeError,
1765 'does not support item assignment'):
1766 fields(C)[0].metadata['test'] = 3
1767
1768 # Make sure a non-empty dict works.
1769 @dataclass
1770 class C:
1771 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1772 self.assertEqual(len(fields(C)[0].metadata), 3)
1773 self.assertEqual(fields(C)[0].metadata['test'], 10)
1774 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1775 self.assertEqual(fields(C)[0].metadata[3], 'three')
1776 with self.assertRaises(KeyError):
1777 # Non-existent key.
1778 fields(C)[0].metadata['baz']
1779 with self.assertRaisesRegex(TypeError,
1780 'does not support item assignment'):
1781 fields(C)[0].metadata['test'] = 3
1782
1783 def test_field_metadata_custom_mapping(self):
1784 # Try a custom mapping.
1785 class SimpleNameSpace:
1786 def __init__(self, **kw):
1787 self.__dict__.update(kw)
1788
1789 def __getitem__(self, item):
1790 if item == 'xyzzy':
1791 return 'plugh'
1792 return getattr(self, item)
1793
1794 def __len__(self):
1795 return self.__dict__.__len__()
1796
1797 @dataclass
1798 class C:
1799 i: int = field(metadata=SimpleNameSpace(a=10))
1800
1801 self.assertEqual(len(fields(C)[0].metadata), 1)
1802 self.assertEqual(fields(C)[0].metadata['a'], 10)
1803 with self.assertRaises(AttributeError):
1804 fields(C)[0].metadata['b']
1805 # Make sure we're still talking to our custom mapping.
1806 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1807
1808 def test_generic_dataclasses(self):
1809 T = TypeVar('T')
1810
1811 @dataclass
1812 class LabeledBox(Generic[T]):
1813 content: T
1814 label: str = '<unknown>'
1815
1816 box = LabeledBox(42)
1817 self.assertEqual(box.content, 42)
1818 self.assertEqual(box.label, '<unknown>')
1819
1820 # subscripting the resulting class should work, etc.
1821 Alias = List[LabeledBox[int]]
1822
1823 def test_generic_extending(self):
1824 S = TypeVar('S')
1825 T = TypeVar('T')
1826
1827 @dataclass
1828 class Base(Generic[T, S]):
1829 x: T
1830 y: S
1831
1832 @dataclass
1833 class DataDerived(Base[int, T]):
1834 new_field: str
1835 Alias = DataDerived[str]
1836 c = Alias(0, 'test1', 'test2')
1837 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1838
1839 class NonDataDerived(Base[int, T]):
1840 def new_method(self):
1841 return self.y
1842 Alias = NonDataDerived[float]
1843 c = Alias(10, 1.0)
1844 self.assertEqual(c.new_method(), 1.0)
1845
1846 def test_helper_replace(self):
1847 @dataclass(frozen=True)
1848 class C:
1849 x: int
1850 y: int
1851
1852 c = C(1, 2)
1853 c1 = replace(c, x=3)
1854 self.assertEqual(c1.x, 3)
1855 self.assertEqual(c1.y, 2)
1856
1857 def test_helper_replace_frozen(self):
1858 @dataclass(frozen=True)
1859 class C:
1860 x: int
1861 y: int
1862 z: int = field(init=False, default=10)
1863 t: int = field(init=False, default=100)
1864
1865 c = C(1, 2)
1866 c1 = replace(c, x=3)
1867 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1868 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1869
1870
1871 with self.assertRaisesRegex(ValueError, 'init=False'):
1872 replace(c, x=3, z=20, t=50)
1873 with self.assertRaisesRegex(ValueError, 'init=False'):
1874 replace(c, z=20)
1875 replace(c, x=3, z=20, t=50)
1876
1877 # Make sure the result is still frozen.
1878 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1879 c1.x = 3
1880
1881 # Make sure we can't replace an attribute that doesn't exist,
1882 # if we're also replacing one that does exist. Test this
1883 # here, because setting attributes on frozen instances is
1884 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001885 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001886 "keyword argument 'a'"):
1887 c1 = replace(c, x=20, a=5)
1888
1889 def test_helper_replace_invalid_field_name(self):
1890 @dataclass(frozen=True)
1891 class C:
1892 x: int
1893 y: int
1894
1895 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001896 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001897 "keyword argument 'z'"):
1898 c1 = replace(c, z=3)
1899
1900 def test_helper_replace_invalid_object(self):
1901 @dataclass(frozen=True)
1902 class C:
1903 x: int
1904 y: int
1905
1906 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1907 replace(C, x=3)
1908
1909 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1910 replace(0, x=3)
1911
1912 def test_helper_replace_no_init(self):
1913 @dataclass
1914 class C:
1915 x: int
1916 y: int = field(init=False, default=10)
1917
1918 c = C(1)
1919 c.y = 20
1920
1921 # Make sure y gets the default value.
1922 c1 = replace(c, x=5)
1923 self.assertEqual((c1.x, c1.y), (5, 10))
1924
1925 # Trying to replace y is an error.
1926 with self.assertRaisesRegex(ValueError, 'init=False'):
1927 replace(c, x=2, y=30)
1928 with self.assertRaisesRegex(ValueError, 'init=False'):
1929 replace(c, y=30)
1930
1931 def test_dataclassses_pickleable(self):
1932 global P, Q, R
1933 @dataclass
1934 class P:
1935 x: int
1936 y: int = 0
1937 @dataclass
1938 class Q:
1939 x: int
1940 y: int = field(default=0, init=False)
1941 @dataclass
1942 class R:
1943 x: int
1944 y: List[int] = field(default_factory=list)
1945 q = Q(1)
1946 q.y = 2
1947 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1948 for sample in samples:
1949 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1950 with self.subTest(sample=sample, proto=proto):
1951 new_sample = pickle.loads(pickle.dumps(sample, proto))
1952 self.assertEqual(sample.x, new_sample.x)
1953 self.assertEqual(sample.y, new_sample.y)
1954 self.assertIsNot(sample, new_sample)
1955 new_sample.x = 42
1956 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1957 self.assertEqual(new_sample.x, another_new_sample.x)
1958 self.assertEqual(sample.y, another_new_sample.y)
1959
1960 def test_helper_make_dataclass(self):
1961 C = make_dataclass('C',
1962 [('x', int),
1963 ('y', int, field(default=5))],
1964 namespace={'add_one': lambda self: self.x + 1})
1965 c = C(10)
1966 self.assertEqual((c.x, c.y), (10, 5))
1967 self.assertEqual(c.add_one(), 11)
1968
1969
1970 def test_helper_make_dataclass_no_mutate_namespace(self):
1971 # Make sure a provided namespace isn't mutated.
1972 ns = {}
1973 C = make_dataclass('C',
1974 [('x', int),
1975 ('y', int, field(default=5))],
1976 namespace=ns)
1977 self.assertEqual(ns, {})
1978
1979 def test_helper_make_dataclass_base(self):
1980 class Base1:
1981 pass
1982 class Base2:
1983 pass
1984 C = make_dataclass('C',
1985 [('x', int)],
1986 bases=(Base1, Base2))
1987 c = C(2)
1988 self.assertIsInstance(c, C)
1989 self.assertIsInstance(c, Base1)
1990 self.assertIsInstance(c, Base2)
1991
1992 def test_helper_make_dataclass_base_dataclass(self):
1993 @dataclass
1994 class Base1:
1995 x: int
1996 class Base2:
1997 pass
1998 C = make_dataclass('C',
1999 [('y', int)],
2000 bases=(Base1, Base2))
2001 with self.assertRaisesRegex(TypeError, 'required positional'):
2002 c = C(2)
2003 c = C(1, 2)
2004 self.assertIsInstance(c, C)
2005 self.assertIsInstance(c, Base1)
2006 self.assertIsInstance(c, Base2)
2007
2008 self.assertEqual((c.x, c.y), (1, 2))
2009
2010 def test_helper_make_dataclass_init_var(self):
2011 def post_init(self, y):
2012 self.x *= y
2013
2014 C = make_dataclass('C',
2015 [('x', int),
2016 ('y', InitVar[int]),
2017 ],
2018 namespace={'__post_init__': post_init},
2019 )
2020 c = C(2, 3)
2021 self.assertEqual(vars(c), {'x': 6})
2022 self.assertEqual(len(fields(c)), 1)
2023
2024 def test_helper_make_dataclass_class_var(self):
2025 C = make_dataclass('C',
2026 [('x', int),
2027 ('y', ClassVar[int], 10),
2028 ('z', ClassVar[int], field(default=20)),
2029 ])
2030 c = C(1)
2031 self.assertEqual(vars(c), {'x': 1})
2032 self.assertEqual(len(fields(c)), 1)
2033 self.assertEqual(C.y, 10)
2034 self.assertEqual(C.z, 20)
2035
Eric V. Smithed7d4292018-01-06 16:14:03 -05002036 def test_helper_make_dataclass_no_types(self):
2037 C = make_dataclass('Point', ['x', 'y', 'z'])
2038 c = C(1, 2, 3)
2039 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2040 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2041 'y': 'typing.Any',
2042 'z': 'typing.Any'})
2043
2044 C = make_dataclass('Point', ['x', ('y', int), 'z'])
2045 c = C(1, 2, 3)
2046 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2047 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2048 'y': int,
2049 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002050
2051class TestDocString(unittest.TestCase):
2052 def assertDocStrEqual(self, a, b):
2053 # Because 3.6 and 3.7 differ in how inspect.signature work
2054 # (see bpo #32108), for the time being just compare them with
2055 # whitespace stripped.
2056 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2057
2058 def test_existing_docstring_not_overridden(self):
2059 @dataclass
2060 class C:
2061 """Lorem ipsum"""
2062 x: int
2063
2064 self.assertEqual(C.__doc__, "Lorem ipsum")
2065
2066 def test_docstring_no_fields(self):
2067 @dataclass
2068 class C:
2069 pass
2070
2071 self.assertDocStrEqual(C.__doc__, "C()")
2072
2073 def test_docstring_one_field(self):
2074 @dataclass
2075 class C:
2076 x: int
2077
2078 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2079
2080 def test_docstring_two_fields(self):
2081 @dataclass
2082 class C:
2083 x: int
2084 y: int
2085
2086 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2087
2088 def test_docstring_three_fields(self):
2089 @dataclass
2090 class C:
2091 x: int
2092 y: int
2093 z: str
2094
2095 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2096
2097 def test_docstring_one_field_with_default(self):
2098 @dataclass
2099 class C:
2100 x: int = 3
2101
2102 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2103
2104 def test_docstring_one_field_with_default_none(self):
2105 @dataclass
2106 class C:
2107 x: Union[int, type(None)] = None
2108
2109 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2110
2111 def test_docstring_list_field(self):
2112 @dataclass
2113 class C:
2114 x: List[int]
2115
2116 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2117
2118 def test_docstring_list_field_with_default_factory(self):
2119 @dataclass
2120 class C:
2121 x: List[int] = field(default_factory=list)
2122
2123 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2124
2125 def test_docstring_deque_field(self):
2126 @dataclass
2127 class C:
2128 x: deque
2129
2130 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2131
2132 def test_docstring_deque_field_with_default_factory(self):
2133 @dataclass
2134 class C:
2135 x: deque = field(default_factory=deque)
2136
2137 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2138
2139
2140if __name__ == '__main__':
2141 unittest.main()