blob: caea98a13b06b7efecfdbda0015d59e1a8fb32ab [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
3 make_dataclass, replace, InitVar, Field
4)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
12
13# Just any custom exception we can catch.
14class CustomError(Exception): pass
15
16class TestCase(unittest.TestCase):
17 def test_no_fields(self):
18 @dataclass
19 class C:
20 pass
21
22 o = C()
23 self.assertEqual(len(fields(C)), 0)
24
25 def test_one_field_no_default(self):
26 @dataclass
27 class C:
28 x: int
29
30 o = C(42)
31 self.assertEqual(o.x, 42)
32
33 def test_named_init_params(self):
34 @dataclass
35 class C:
36 x: int
37
38 o = C(x=32)
39 self.assertEqual(o.x, 32)
40
41 def test_two_fields_one_default(self):
42 @dataclass
43 class C:
44 x: int
45 y: int = 0
46
47 o = C(3)
48 self.assertEqual((o.x, o.y), (3, 0))
49
50 # Non-defaults following defaults.
51 with self.assertRaisesRegex(TypeError,
52 "non-default argument 'y' follows "
53 "default argument"):
54 @dataclass
55 class C:
56 x: int = 0
57 y: int
58
59 # A derived class adds a non-default field after a default one.
60 with self.assertRaisesRegex(TypeError,
61 "non-default argument 'y' follows "
62 "default argument"):
63 @dataclass
64 class B:
65 x: int = 0
66
67 @dataclass
68 class C(B):
69 y: int
70
71 # Override a base class field and add a default to
72 # a field which didn't use to have a default.
73 with self.assertRaisesRegex(TypeError,
74 "non-default argument 'y' follows "
75 "default argument"):
76 @dataclass
77 class B:
78 x: int
79 y: int
80
81 @dataclass
82 class C(B):
83 x: int = 0
84
85 def test_overwriting_init(self):
86 with self.assertRaisesRegex(TypeError,
87 'Cannot overwrite attribute __init__ '
88 'in C'):
89 @dataclass
90 class C:
91 x: int
92 def __init__(self, x):
93 self.x = 2 * x
94
95 @dataclass(init=False)
96 class C:
97 x: int
98 def __init__(self, x):
99 self.x = 2 * x
100 self.assertEqual(C(5).x, 10)
101
102 def test_overwriting_repr(self):
103 with self.assertRaisesRegex(TypeError,
104 'Cannot overwrite attribute __repr__ '
105 'in C'):
106 @dataclass
107 class C:
108 x: int
109 def __repr__(self):
110 pass
111
112 @dataclass(repr=False)
113 class C:
114 x: int
115 def __repr__(self):
116 return 'x'
117 self.assertEqual(repr(C(0)), 'x')
118
119 def test_overwriting_cmp(self):
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __eq__ '
122 'in C'):
123 # This will generate the comparison functions, make sure we can't
124 # overwrite them.
125 @dataclass(hash=False, frozen=False)
126 class C:
127 x: int
128 def __eq__(self):
129 pass
130
131 @dataclass(order=False, eq=False)
132 class C:
133 x: int
134 def __eq__(self, other):
135 return True
136 self.assertEqual(C(0), 'x')
137
138 def test_overwriting_hash(self):
139 with self.assertRaisesRegex(TypeError,
140 'Cannot overwrite attribute __hash__ '
141 'in C'):
142 @dataclass(frozen=True)
143 class C:
144 x: int
145 def __hash__(self):
146 pass
147
148 @dataclass(frozen=True,hash=False)
149 class C:
150 x: int
151 def __hash__(self):
152 return 600
153 self.assertEqual(hash(C(0)), 600)
154
155 with self.assertRaisesRegex(TypeError,
156 'Cannot overwrite attribute __hash__ '
157 'in C'):
158 @dataclass(frozen=True)
159 class C:
160 x: int
161 def __hash__(self):
162 pass
163
164 @dataclass(frozen=True, hash=False)
165 class C:
166 x: int
167 def __hash__(self):
168 return 600
169 self.assertEqual(hash(C(0)), 600)
170
171 def test_overwriting_frozen(self):
172 # frozen uses __setattr__ and __delattr__
173 with self.assertRaisesRegex(TypeError,
174 'Cannot overwrite attribute __setattr__ '
175 'in C'):
176 @dataclass(frozen=True)
177 class C:
178 x: int
179 def __setattr__(self):
180 pass
181
182 with self.assertRaisesRegex(TypeError,
183 'Cannot overwrite attribute __delattr__ '
184 'in C'):
185 @dataclass(frozen=True)
186 class C:
187 x: int
188 def __delattr__(self):
189 pass
190
191 @dataclass(frozen=False)
192 class C:
193 x: int
194 def __setattr__(self, name, value):
195 self.__dict__['x'] = value * 2
196 self.assertEqual(C(10).x, 20)
197
198 def test_overwrite_fields_in_derived_class(self):
199 # Note that x from C1 replaces x in Base, but the order remains
200 # the same as defined in Base.
201 @dataclass
202 class Base:
203 x: Any = 15.0
204 y: int = 0
205
206 @dataclass
207 class C1(Base):
208 z: int = 10
209 x: int = 15
210
211 o = Base()
212 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
213
214 o = C1()
215 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
216
217 o = C1(x=5)
218 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
219
220 def test_field_named_self(self):
221 @dataclass
222 class C:
223 self: str
224 c=C('foo')
225 self.assertEqual(c.self, 'foo')
226
227 # Make sure the first parameter is not named 'self'.
228 sig = inspect.signature(C.__init__)
229 first = next(iter(sig.parameters))
230 self.assertNotEqual('self', first)
231
232 # But we do use 'self' if no field named self.
233 @dataclass
234 class C:
235 selfx: str
236
237 # Make sure the first parameter is named 'self'.
238 sig = inspect.signature(C.__init__)
239 first = next(iter(sig.parameters))
240 self.assertEqual('self', first)
241
242 def test_repr(self):
243 @dataclass
244 class B:
245 x: int
246
247 @dataclass
248 class C(B):
249 y: int = 10
250
251 o = C(4)
252 self.assertEqual(repr(o), 'TestCase.test_repr.<locals>.C(x=4, y=10)')
253
254 @dataclass
255 class D(C):
256 x: int = 20
257 self.assertEqual(repr(D()), 'TestCase.test_repr.<locals>.D(x=20, y=10)')
258
259 @dataclass
260 class C:
261 @dataclass
262 class D:
263 i: int
264 @dataclass
265 class E:
266 pass
267 self.assertEqual(repr(C.D(0)), 'TestCase.test_repr.<locals>.C.D(i=0)')
268 self.assertEqual(repr(C.E()), 'TestCase.test_repr.<locals>.C.E()')
269
270 def test_0_field_compare(self):
271 # Ensure that order=False is the default.
272 @dataclass
273 class C0:
274 pass
275
276 @dataclass(order=False)
277 class C1:
278 pass
279
280 for cls in [C0, C1]:
281 with self.subTest(cls=cls):
282 self.assertEqual(cls(), cls())
283 for idx, fn in enumerate([lambda a, b: a < b,
284 lambda a, b: a <= b,
285 lambda a, b: a > b,
286 lambda a, b: a >= b]):
287 with self.subTest(idx=idx):
288 with self.assertRaisesRegex(TypeError,
289 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
290 fn(cls(), cls())
291
292 @dataclass(order=True)
293 class C:
294 pass
295 self.assertLessEqual(C(), C())
296 self.assertGreaterEqual(C(), C())
297
298 def test_1_field_compare(self):
299 # Ensure that order=False is the default.
300 @dataclass
301 class C0:
302 x: int
303
304 @dataclass(order=False)
305 class C1:
306 x: int
307
308 for cls in [C0, C1]:
309 with self.subTest(cls=cls):
310 self.assertEqual(cls(1), cls(1))
311 self.assertNotEqual(cls(0), cls(1))
312 for idx, fn in enumerate([lambda a, b: a < b,
313 lambda a, b: a <= b,
314 lambda a, b: a > b,
315 lambda a, b: a >= b]):
316 with self.subTest(idx=idx):
317 with self.assertRaisesRegex(TypeError,
318 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
319 fn(cls(0), cls(0))
320
321 @dataclass(order=True)
322 class C:
323 x: int
324 self.assertLess(C(0), C(1))
325 self.assertLessEqual(C(0), C(1))
326 self.assertLessEqual(C(1), C(1))
327 self.assertGreater(C(1), C(0))
328 self.assertGreaterEqual(C(1), C(0))
329 self.assertGreaterEqual(C(1), C(1))
330
331 def test_simple_compare(self):
332 # Ensure that order=False is the default.
333 @dataclass
334 class C0:
335 x: int
336 y: int
337
338 @dataclass(order=False)
339 class C1:
340 x: int
341 y: int
342
343 for cls in [C0, C1]:
344 with self.subTest(cls=cls):
345 self.assertEqual(cls(0, 0), cls(0, 0))
346 self.assertEqual(cls(1, 2), cls(1, 2))
347 self.assertNotEqual(cls(1, 0), cls(0, 0))
348 self.assertNotEqual(cls(1, 0), cls(1, 1))
349 for idx, fn in enumerate([lambda a, b: a < b,
350 lambda a, b: a <= b,
351 lambda a, b: a > b,
352 lambda a, b: a >= b]):
353 with self.subTest(idx=idx):
354 with self.assertRaisesRegex(TypeError,
355 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
356 fn(cls(0, 0), cls(0, 0))
357
358 @dataclass(order=True)
359 class C:
360 x: int
361 y: int
362
363 for idx, fn in enumerate([lambda a, b: a == b,
364 lambda a, b: a <= b,
365 lambda a, b: a >= b]):
366 with self.subTest(idx=idx):
367 self.assertTrue(fn(C(0, 0), C(0, 0)))
368
369 for idx, fn in enumerate([lambda a, b: a < b,
370 lambda a, b: a <= b,
371 lambda a, b: a != b]):
372 with self.subTest(idx=idx):
373 self.assertTrue(fn(C(0, 0), C(0, 1)))
374 self.assertTrue(fn(C(0, 1), C(1, 0)))
375 self.assertTrue(fn(C(1, 0), C(1, 1)))
376
377 for idx, fn in enumerate([lambda a, b: a > b,
378 lambda a, b: a >= b,
379 lambda a, b: a != b]):
380 with self.subTest(idx=idx):
381 self.assertTrue(fn(C(0, 1), C(0, 0)))
382 self.assertTrue(fn(C(1, 0), C(0, 1)))
383 self.assertTrue(fn(C(1, 1), C(1, 0)))
384
385 def test_compare_subclasses(self):
386 # Comparisons fail for subclasses, even if no fields
387 # are added.
388 @dataclass
389 class B:
390 i: int
391
392 @dataclass
393 class C(B):
394 pass
395
396 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
397 (lambda a, b: a != b, True)]):
398 with self.subTest(idx=idx):
399 self.assertEqual(fn(B(0), C(0)), expected)
400
401 for idx, fn in enumerate([lambda a, b: a < b,
402 lambda a, b: a <= b,
403 lambda a, b: a > b,
404 lambda a, b: a >= b]):
405 with self.subTest(idx=idx):
406 with self.assertRaisesRegex(TypeError,
407 "not supported between instances of 'B' and 'C'"):
408 fn(B(0), C(0))
409
410 def test_0_field_hash(self):
411 @dataclass(hash=True)
412 class C:
413 pass
414 self.assertEqual(hash(C()), hash(()))
415
416 def test_1_field_hash(self):
417 @dataclass(hash=True)
418 class C:
419 x: int
420 self.assertEqual(hash(C(4)), hash((4,)))
421 self.assertEqual(hash(C(42)), hash((42,)))
422
423 def test_hash(self):
424 @dataclass(hash=True)
425 class C:
426 x: int
427 y: str
428 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
429
430 def test_no_hash(self):
431 @dataclass(hash=None)
432 class C:
433 x: int
434 with self.assertRaisesRegex(TypeError,
435 "unhashable type: 'C'"):
436 hash(C(1))
437
438 def test_hash_rules(self):
439 # There are 24 cases of:
440 # hash=True/False/None
441 # eq=True/False
442 # order=True/False
443 # frozen=True/False
444 for (hash, eq, order, frozen, result ) in [
445 (False, False, False, False, 'absent'),
446 (False, False, False, True, 'absent'),
447 (False, False, True, False, 'exception'),
448 (False, False, True, True, 'exception'),
449 (False, True, False, False, 'absent'),
450 (False, True, False, True, 'absent'),
451 (False, True, True, False, 'absent'),
452 (False, True, True, True, 'absent'),
453 (True, False, False, False, 'fn'),
454 (True, False, False, True, 'fn'),
455 (True, False, True, False, 'exception'),
456 (True, False, True, True, 'exception'),
457 (True, True, False, False, 'fn'),
458 (True, True, False, True, 'fn'),
459 (True, True, True, False, 'fn'),
460 (True, True, True, True, 'fn'),
461 (None, False, False, False, 'absent'),
462 (None, False, False, True, 'absent'),
463 (None, False, True, False, 'exception'),
464 (None, False, True, True, 'exception'),
465 (None, True, False, False, 'none'),
466 (None, True, False, True, 'fn'),
467 (None, True, True, False, 'none'),
468 (None, True, True, True, 'fn'),
469 ]:
470 with self.subTest(hash=hash, eq=eq, order=order, frozen=frozen):
471 if result == 'exception':
472 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
473 @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
474 class C:
475 pass
476 else:
477 @dataclass(hash=hash, eq=eq, order=order, frozen=frozen)
478 class C:
479 pass
480
481 # See if the result matches what's expected.
482 if result == 'fn':
483 # __hash__ contains the function we generated.
484 self.assertIn('__hash__', C.__dict__)
485 self.assertIsNotNone(C.__dict__['__hash__'])
486 elif result == 'absent':
487 # __hash__ is not present in our class.
488 self.assertNotIn('__hash__', C.__dict__)
489 elif result == 'none':
490 # __hash__ is set to None.
491 self.assertIn('__hash__', C.__dict__)
492 self.assertIsNone(C.__dict__['__hash__'])
493 else:
494 assert False, f'unknown result {result!r}'
495
496 def test_eq_order(self):
497 for (eq, order, result ) in [
498 (False, False, 'neither'),
499 (False, True, 'exception'),
500 (True, False, 'eq_only'),
501 (True, True, 'both'),
502 ]:
503 with self.subTest(eq=eq, order=order):
504 if result == 'exception':
505 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
506 @dataclass(eq=eq, order=order)
507 class C:
508 pass
509 else:
510 @dataclass(eq=eq, order=order)
511 class C:
512 pass
513
514 if result == 'neither':
515 self.assertNotIn('__eq__', C.__dict__)
516 self.assertNotIn('__ne__', C.__dict__)
517 self.assertNotIn('__lt__', C.__dict__)
518 self.assertNotIn('__le__', C.__dict__)
519 self.assertNotIn('__gt__', C.__dict__)
520 self.assertNotIn('__ge__', C.__dict__)
521 elif result == 'both':
522 self.assertIn('__eq__', C.__dict__)
523 self.assertIn('__ne__', C.__dict__)
524 self.assertIn('__lt__', C.__dict__)
525 self.assertIn('__le__', C.__dict__)
526 self.assertIn('__gt__', C.__dict__)
527 self.assertIn('__ge__', C.__dict__)
528 elif result == 'eq_only':
529 self.assertIn('__eq__', C.__dict__)
530 self.assertIn('__ne__', C.__dict__)
531 self.assertNotIn('__lt__', C.__dict__)
532 self.assertNotIn('__le__', C.__dict__)
533 self.assertNotIn('__gt__', C.__dict__)
534 self.assertNotIn('__ge__', C.__dict__)
535 else:
536 assert False, f'unknown result {result!r}'
537
538 def test_field_no_default(self):
539 @dataclass
540 class C:
541 x: int = field()
542
543 self.assertEqual(C(5).x, 5)
544
545 with self.assertRaisesRegex(TypeError,
546 r"__init__\(\) missing 1 required "
547 "positional argument: 'x'"):
548 C()
549
550 def test_field_default(self):
551 default = object()
552 @dataclass
553 class C:
554 x: object = field(default=default)
555
556 self.assertIs(C.x, default)
557 c = C(10)
558 self.assertEqual(c.x, 10)
559
560 # If we delete the instance attribute, we should then see the
561 # class attribute.
562 del c.x
563 self.assertIs(c.x, default)
564
565 self.assertIs(C().x, default)
566
567 def test_not_in_repr(self):
568 @dataclass
569 class C:
570 x: int = field(repr=False)
571 with self.assertRaises(TypeError):
572 C()
573 c = C(10)
574 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
575
576 @dataclass
577 class C:
578 x: int = field(repr=False)
579 y: int
580 c = C(10, 20)
581 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
582
583 def test_not_in_compare(self):
584 @dataclass
585 class C:
586 x: int = 0
587 y: int = field(compare=False, default=4)
588
589 self.assertEqual(C(), C(0, 20))
590 self.assertEqual(C(1, 10), C(1, 20))
591 self.assertNotEqual(C(3), C(4, 10))
592 self.assertNotEqual(C(3, 10), C(4, 10))
593
594 def test_hash_field_rules(self):
595 # Test all 6 cases of:
596 # hash=True/False/None
597 # compare=True/False
598 for (hash_val, compare, result ) in [
599 (True, False, 'field' ),
600 (True, True, 'field' ),
601 (False, False, 'absent'),
602 (False, True, 'absent'),
603 (None, False, 'absent'),
604 (None, True, 'field' ),
605 ]:
606 with self.subTest(hash_val=hash_val, compare=compare):
607 @dataclass(hash=True)
608 class C:
609 x: int = field(compare=compare, hash=hash_val, default=5)
610
611 if result == 'field':
612 # __hash__ contains the field.
613 self.assertEqual(C(5).__hash__(), hash((5,)))
614 elif result == 'absent':
615 # The field is not present in the hash.
616 self.assertEqual(C(5).__hash__(), hash(()))
617 else:
618 assert False, f'unknown result {result!r}'
619
620 def test_init_false_no_default(self):
621 # If init=False and no default value, then the field won't be
622 # present in the instance.
623 @dataclass
624 class C:
625 x: int = field(init=False)
626
627 self.assertNotIn('x', C().__dict__)
628
629 @dataclass
630 class C:
631 x: int
632 y: int = 0
633 z: int = field(init=False)
634 t: int = 10
635
636 self.assertNotIn('z', C(0).__dict__)
637 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
638
639 def test_class_marker(self):
640 @dataclass
641 class C:
642 x: int
643 y: str = field(init=False, default=None)
644 z: str = field(repr=False)
645
646 the_fields = fields(C)
647 # the_fields is a tuple of 3 items, each value
648 # is in __annotations__.
649 self.assertIsInstance(the_fields, tuple)
650 for f in the_fields:
651 self.assertIs(type(f), Field)
652 self.assertIn(f.name, C.__annotations__)
653
654 self.assertEqual(len(the_fields), 3)
655
656 self.assertEqual(the_fields[0].name, 'x')
657 self.assertEqual(the_fields[0].type, int)
658 self.assertFalse(hasattr(C, 'x'))
659 self.assertTrue (the_fields[0].init)
660 self.assertTrue (the_fields[0].repr)
661 self.assertEqual(the_fields[1].name, 'y')
662 self.assertEqual(the_fields[1].type, str)
663 self.assertIsNone(getattr(C, 'y'))
664 self.assertFalse(the_fields[1].init)
665 self.assertTrue (the_fields[1].repr)
666 self.assertEqual(the_fields[2].name, 'z')
667 self.assertEqual(the_fields[2].type, str)
668 self.assertFalse(hasattr(C, 'z'))
669 self.assertTrue (the_fields[2].init)
670 self.assertFalse(the_fields[2].repr)
671
672 def test_field_order(self):
673 @dataclass
674 class B:
675 a: str = 'B:a'
676 b: str = 'B:b'
677 c: str = 'B:c'
678
679 @dataclass
680 class C(B):
681 b: str = 'C:b'
682
683 self.assertEqual([(f.name, f.default) for f in fields(C)],
684 [('a', 'B:a'),
685 ('b', 'C:b'),
686 ('c', 'B:c')])
687
688 @dataclass
689 class D(B):
690 c: str = 'D:c'
691
692 self.assertEqual([(f.name, f.default) for f in fields(D)],
693 [('a', 'B:a'),
694 ('b', 'B:b'),
695 ('c', 'D:c')])
696
697 @dataclass
698 class E(D):
699 a: str = 'E:a'
700 d: str = 'E:d'
701
702 self.assertEqual([(f.name, f.default) for f in fields(E)],
703 [('a', 'E:a'),
704 ('b', 'B:b'),
705 ('c', 'D:c'),
706 ('d', 'E:d')])
707
708 def test_class_attrs(self):
709 # We only have a class attribute if a default value is
710 # specified, either directly or via a field with a default.
711 default = object()
712 @dataclass
713 class C:
714 x: int
715 y: int = field(repr=False)
716 z: object = default
717 t: int = field(default=100)
718
719 self.assertFalse(hasattr(C, 'x'))
720 self.assertFalse(hasattr(C, 'y'))
721 self.assertIs (C.z, default)
722 self.assertEqual(C.t, 100)
723
724 def test_disallowed_mutable_defaults(self):
725 # For the known types, don't allow mutable default values.
726 for typ, empty, non_empty in [(list, [], [1]),
727 (dict, {}, {0:1}),
728 (set, set(), set([1])),
729 ]:
730 with self.subTest(typ=typ):
731 # Can't use a zero-length value.
732 with self.assertRaisesRegex(ValueError,
733 f'mutable default {typ} for field '
734 'x is not allowed'):
735 @dataclass
736 class Point:
737 x: typ = empty
738
739
740 # Nor a non-zero-length value
741 with self.assertRaisesRegex(ValueError,
742 f'mutable default {typ} for field '
743 'y is not allowed'):
744 @dataclass
745 class Point:
746 y: typ = non_empty
747
748 # Check subtypes also fail.
749 class Subclass(typ): pass
750
751 with self.assertRaisesRegex(ValueError,
752 f"mutable default .*Subclass'>"
753 ' for field z is not allowed'
754 ):
755 @dataclass
756 class Point:
757 z: typ = Subclass()
758
759 # Because this is a ClassVar, it can be mutable.
760 @dataclass
761 class C:
762 z: ClassVar[typ] = typ()
763
764 # Because this is a ClassVar, it can be mutable.
765 @dataclass
766 class C:
767 x: ClassVar[typ] = Subclass()
768
769
770 def test_deliberately_mutable_defaults(self):
771 # If a mutable default isn't in the known list of
772 # (list, dict, set), then it's okay.
773 class Mutable:
774 def __init__(self):
775 self.l = []
776
777 @dataclass
778 class C:
779 x: Mutable
780
781 # These 2 instances will share this value of x.
782 lst = Mutable()
783 o1 = C(lst)
784 o2 = C(lst)
785 self.assertEqual(o1, o2)
786 o1.x.l.extend([1, 2])
787 self.assertEqual(o1, o2)
788 self.assertEqual(o1.x.l, [1, 2])
789 self.assertIs(o1.x, o2.x)
790
791 def test_no_options(self):
792 # call with dataclass()
793 @dataclass()
794 class C:
795 x: int
796
797 self.assertEqual(C(42).x, 42)
798
799 def test_not_tuple(self):
800 # Make sure we can't be compared to a tuple.
801 @dataclass
802 class Point:
803 x: int
804 y: int
805 self.assertNotEqual(Point(1, 2), (1, 2))
806
807 # And that we can't compare to another unrelated dataclass
808 @dataclass
809 class C:
810 x: int
811 y: int
812 self.assertNotEqual(Point(1, 3), C(1, 3))
813
814 def test_base_has_init(self):
815 class B:
816 def __init__(self):
817 pass
818
819 # Make sure that declaring this class doesn't raise an error.
820 # The issue is that we can't override __init__ in our class,
821 # but it should be okay to add __init__ to us if our base has
822 # an __init__.
823 @dataclass
824 class C(B):
825 x: int = 0
826
827 def test_frozen(self):
828 @dataclass(frozen=True)
829 class C:
830 i: int
831
832 c = C(10)
833 self.assertEqual(c.i, 10)
834 with self.assertRaises(FrozenInstanceError):
835 c.i = 5
836 self.assertEqual(c.i, 10)
837
838 # Check that a derived class is still frozen, even if not
839 # marked so.
840 @dataclass
841 class D(C):
842 pass
843
844 d = D(20)
845 self.assertEqual(d.i, 20)
846 with self.assertRaises(FrozenInstanceError):
847 d.i = 5
848 self.assertEqual(d.i, 20)
849
850 def test_not_tuple(self):
851 # Test that some of the problems with namedtuple don't happen
852 # here.
853 @dataclass
854 class Point3D:
855 x: int
856 y: int
857 z: int
858
859 @dataclass
860 class Date:
861 year: int
862 month: int
863 day: int
864
865 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
866 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
867
868 # Make sure we can't unpack
869 with self.assertRaisesRegex(TypeError, 'is not iterable'):
870 x, y, z = Point3D(4, 5, 6)
871
872 # Maka sure another class with the same field names isn't
873 # equal.
874 @dataclass
875 class Point3Dv1:
876 x: int = 0
877 y: int = 0
878 z: int = 0
879 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
880
881 def test_function_annotations(self):
882 # Some dummy class and instance to use as a default.
883 class F:
884 pass
885 f = F()
886
887 def validate_class(cls):
888 # First, check __annotations__, even though they're not
889 # function annotations.
890 self.assertEqual(cls.__annotations__['i'], int)
891 self.assertEqual(cls.__annotations__['j'], str)
892 self.assertEqual(cls.__annotations__['k'], F)
893 self.assertEqual(cls.__annotations__['l'], float)
894 self.assertEqual(cls.__annotations__['z'], complex)
895
896 # Verify __init__.
897
898 signature = inspect.signature(cls.__init__)
899 # Check the return type, should be None
900 self.assertIs(signature.return_annotation, None)
901
902 # Check each parameter.
903 params = iter(signature.parameters.values())
904 param = next(params)
905 # This is testing an internal name, and probably shouldn't be tested.
906 self.assertEqual(param.name, 'self')
907 param = next(params)
908 self.assertEqual(param.name, 'i')
909 self.assertIs (param.annotation, int)
910 self.assertEqual(param.default, inspect.Parameter.empty)
911 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
912 param = next(params)
913 self.assertEqual(param.name, 'j')
914 self.assertIs (param.annotation, str)
915 self.assertEqual(param.default, inspect.Parameter.empty)
916 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
917 param = next(params)
918 self.assertEqual(param.name, 'k')
919 self.assertIs (param.annotation, F)
920 # Don't test for the default, since it's set to _MISSING
921 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
922 param = next(params)
923 self.assertEqual(param.name, 'l')
924 self.assertIs (param.annotation, float)
925 # Don't test for the default, since it's set to _MISSING
926 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
927 self.assertRaises(StopIteration, next, params)
928
929
930 @dataclass
931 class C:
932 i: int
933 j: str
934 k: F = f
935 l: float=field(default=None)
936 z: complex=field(default=3+4j, init=False)
937
938 validate_class(C)
939
940 # Now repeat with __hash__.
941 @dataclass(frozen=True, hash=True)
942 class C:
943 i: int
944 j: str
945 k: F = f
946 l: float=field(default=None)
947 z: complex=field(default=3+4j, init=False)
948
949 validate_class(C)
950
951 def test_dont_include_other_annotations(self):
952 @dataclass
953 class C:
954 i: int
955 def foo(self) -> int:
956 return 4
957 @property
958 def bar(self) -> int:
959 return 5
960 self.assertEqual(list(C.__annotations__), ['i'])
961 self.assertEqual(C(10).foo(), 4)
962 self.assertEqual(C(10).bar, 5)
963
964 def test_post_init(self):
965 # Just make sure it gets called
966 @dataclass
967 class C:
968 def __post_init__(self):
969 raise CustomError()
970 with self.assertRaises(CustomError):
971 C()
972
973 @dataclass
974 class C:
975 i: int = 10
976 def __post_init__(self):
977 if self.i == 10:
978 raise CustomError()
979 with self.assertRaises(CustomError):
980 C()
981 # post-init gets called, but doesn't raise. This is just
982 # checking that self is used correctly.
983 C(5)
984
985 # If there's not an __init__, then post-init won't get called.
986 @dataclass(init=False)
987 class C:
988 def __post_init__(self):
989 raise CustomError()
990 # Creating the class won't raise
991 C()
992
993 @dataclass
994 class C:
995 x: int = 0
996 def __post_init__(self):
997 self.x *= 2
998 self.assertEqual(C().x, 0)
999 self.assertEqual(C(2).x, 4)
1000
1001 # Make sure that if we'r frozen, post-init can't set
1002 # attributes.
1003 @dataclass(frozen=True)
1004 class C:
1005 x: int = 0
1006 def __post_init__(self):
1007 self.x *= 2
1008 with self.assertRaises(FrozenInstanceError):
1009 C()
1010
1011 def test_post_init_super(self):
1012 # Make sure super() post-init isn't called by default.
1013 class B:
1014 def __post_init__(self):
1015 raise CustomError()
1016
1017 @dataclass
1018 class C(B):
1019 def __post_init__(self):
1020 self.x = 5
1021
1022 self.assertEqual(C().x, 5)
1023
1024 # Now call super(), and it will raise
1025 @dataclass
1026 class C(B):
1027 def __post_init__(self):
1028 super().__post_init__()
1029
1030 with self.assertRaises(CustomError):
1031 C()
1032
1033 # Make sure post-init is called, even if not defined in our
1034 # class.
1035 @dataclass
1036 class C(B):
1037 pass
1038
1039 with self.assertRaises(CustomError):
1040 C()
1041
1042 def test_post_init_staticmethod(self):
1043 flag = False
1044 @dataclass
1045 class C:
1046 x: int
1047 y: int
1048 @staticmethod
1049 def __post_init__():
1050 nonlocal flag
1051 flag = True
1052
1053 self.assertFalse(flag)
1054 c = C(3, 4)
1055 self.assertEqual((c.x, c.y), (3, 4))
1056 self.assertTrue(flag)
1057
1058 def test_post_init_classmethod(self):
1059 @dataclass
1060 class C:
1061 flag = False
1062 x: int
1063 y: int
1064 @classmethod
1065 def __post_init__(cls):
1066 cls.flag = True
1067
1068 self.assertFalse(C.flag)
1069 c = C(3, 4)
1070 self.assertEqual((c.x, c.y), (3, 4))
1071 self.assertTrue(C.flag)
1072
1073 def test_class_var(self):
1074 # Make sure ClassVars are ignored in __init__, __repr__, etc.
1075 @dataclass
1076 class C:
1077 x: int
1078 y: int = 10
1079 z: ClassVar[int] = 1000
1080 w: ClassVar[int] = 2000
1081 t: ClassVar[int] = 3000
1082
1083 c = C(5)
1084 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
1085 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1086 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1087 self.assertEqual(c.z, 1000)
1088 self.assertEqual(c.w, 2000)
1089 self.assertEqual(c.t, 3000)
1090 C.z += 1
1091 self.assertEqual(c.z, 1001)
1092 c = C(20)
1093 self.assertEqual((c.x, c.y), (20, 10))
1094 self.assertEqual(c.z, 1001)
1095 self.assertEqual(c.w, 2000)
1096 self.assertEqual(c.t, 3000)
1097
1098 def test_class_var_no_default(self):
1099 # If a ClassVar has no default value, it should not be set on the class.
1100 @dataclass
1101 class C:
1102 x: ClassVar[int]
1103
1104 self.assertNotIn('x', C.__dict__)
1105
1106 def test_class_var_default_factory(self):
1107 # It makes no sense for a ClassVar to have a default factory. When
1108 # would it be called? Call it yourself, since it's class-wide.
1109 with self.assertRaisesRegex(TypeError,
1110 'cannot have a default factory'):
1111 @dataclass
1112 class C:
1113 x: ClassVar[int] = field(default_factory=int)
1114
1115 self.assertNotIn('x', C.__dict__)
1116
1117 def test_class_var_with_default(self):
1118 # If a ClassVar has a default value, it should be set on the class.
1119 @dataclass
1120 class C:
1121 x: ClassVar[int] = 10
1122 self.assertEqual(C.x, 10)
1123
1124 @dataclass
1125 class C:
1126 x: ClassVar[int] = field(default=10)
1127 self.assertEqual(C.x, 10)
1128
1129 def test_class_var_frozen(self):
1130 # Make sure ClassVars work even if we're frozen.
1131 @dataclass(frozen=True)
1132 class C:
1133 x: int
1134 y: int = 10
1135 z: ClassVar[int] = 1000
1136 w: ClassVar[int] = 2000
1137 t: ClassVar[int] = 3000
1138
1139 c = C(5)
1140 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1141 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1142 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1143 self.assertEqual(c.z, 1000)
1144 self.assertEqual(c.w, 2000)
1145 self.assertEqual(c.t, 3000)
1146 # We can still modify the ClassVar, it's only instances that are
1147 # frozen.
1148 C.z += 1
1149 self.assertEqual(c.z, 1001)
1150 c = C(20)
1151 self.assertEqual((c.x, c.y), (20, 10))
1152 self.assertEqual(c.z, 1001)
1153 self.assertEqual(c.w, 2000)
1154 self.assertEqual(c.t, 3000)
1155
1156 def test_init_var_no_default(self):
1157 # If an InitVar has no default value, it should not be set on the class.
1158 @dataclass
1159 class C:
1160 x: InitVar[int]
1161
1162 self.assertNotIn('x', C.__dict__)
1163
1164 def test_init_var_default_factory(self):
1165 # It makes no sense for an InitVar to have a default factory. When
1166 # would it be called? Call it yourself, since it's class-wide.
1167 with self.assertRaisesRegex(TypeError,
1168 'cannot have a default factory'):
1169 @dataclass
1170 class C:
1171 x: InitVar[int] = field(default_factory=int)
1172
1173 self.assertNotIn('x', C.__dict__)
1174
1175 def test_init_var_with_default(self):
1176 # If an InitVar has a default value, it should be set on the class.
1177 @dataclass
1178 class C:
1179 x: InitVar[int] = 10
1180 self.assertEqual(C.x, 10)
1181
1182 @dataclass
1183 class C:
1184 x: InitVar[int] = field(default=10)
1185 self.assertEqual(C.x, 10)
1186
1187 def test_init_var(self):
1188 @dataclass
1189 class C:
1190 x: int = None
1191 init_param: InitVar[int] = None
1192
1193 def __post_init__(self, init_param):
1194 if self.x is None:
1195 self.x = init_param*2
1196
1197 c = C(init_param=10)
1198 self.assertEqual(c.x, 20)
1199
1200 def test_init_var_inheritance(self):
1201 # Note that this deliberately tests that a dataclass need not
1202 # have a __post_init__ function if it has an InitVar field.
1203 # It could just be used in a derived class, as shown here.
1204 @dataclass
1205 class Base:
1206 x: int
1207 init_base: InitVar[int]
1208
1209 # We can instantiate by passing the InitVar, even though
1210 # it's not used.
1211 b = Base(0, 10)
1212 self.assertEqual(vars(b), {'x': 0})
1213
1214 @dataclass
1215 class C(Base):
1216 y: int
1217 init_derived: InitVar[int]
1218
1219 def __post_init__(self, init_base, init_derived):
1220 self.x = self.x + init_base
1221 self.y = self.y + init_derived
1222
1223 c = C(10, 11, 50, 51)
1224 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1225
1226 def test_default_factory(self):
1227 # Test a factory that returns a new list.
1228 @dataclass
1229 class C:
1230 x: int
1231 y: list = field(default_factory=list)
1232
1233 c0 = C(3)
1234 c1 = C(3)
1235 self.assertEqual(c0.x, 3)
1236 self.assertEqual(c0.y, [])
1237 self.assertEqual(c0, c1)
1238 self.assertIsNot(c0.y, c1.y)
1239 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1240
1241 # Test a factory that returns a shared list.
1242 l = []
1243 @dataclass
1244 class C:
1245 x: int
1246 y: list = field(default_factory=lambda: l)
1247
1248 c0 = C(3)
1249 c1 = C(3)
1250 self.assertEqual(c0.x, 3)
1251 self.assertEqual(c0.y, [])
1252 self.assertEqual(c0, c1)
1253 self.assertIs(c0.y, c1.y)
1254 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1255
1256 # Test various other field flags.
1257 # repr
1258 @dataclass
1259 class C:
1260 x: list = field(default_factory=list, repr=False)
1261 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1262 self.assertEqual(C().x, [])
1263
1264 # hash
1265 @dataclass(hash=True)
1266 class C:
1267 x: list = field(default_factory=list, hash=False)
1268 self.assertEqual(astuple(C()), ([],))
1269 self.assertEqual(hash(C()), hash(()))
1270
1271 # init (see also test_default_factory_with_no_init)
1272 @dataclass
1273 class C:
1274 x: list = field(default_factory=list, init=False)
1275 self.assertEqual(astuple(C()), ([],))
1276
1277 # compare
1278 @dataclass
1279 class C:
1280 x: list = field(default_factory=list, compare=False)
1281 self.assertEqual(C(), C([1]))
1282
1283 def test_default_factory_with_no_init(self):
1284 # We need a factory with a side effect.
1285 factory = Mock()
1286
1287 @dataclass
1288 class C:
1289 x: list = field(default_factory=factory, init=False)
1290
1291 # Make sure the default factory is called for each new instance.
1292 C().x
1293 self.assertEqual(factory.call_count, 1)
1294 C().x
1295 self.assertEqual(factory.call_count, 2)
1296
1297 def test_default_factory_not_called_if_value_given(self):
1298 # We need a factory that we can test if it's been called.
1299 factory = Mock()
1300
1301 @dataclass
1302 class C:
1303 x: int = field(default_factory=factory)
1304
1305 # Make sure that if a field has a default factory function,
1306 # it's not called if a value is specified.
1307 C().x
1308 self.assertEqual(factory.call_count, 1)
1309 self.assertEqual(C(10).x, 10)
1310 self.assertEqual(factory.call_count, 1)
1311 C().x
1312 self.assertEqual(factory.call_count, 2)
1313
1314 def x_test_classvar_default_factory(self):
1315 # XXX: it's an error for a ClassVar to have a factory function
1316 @dataclass
1317 class C:
1318 x: ClassVar[int] = field(default_factory=int)
1319
1320 self.assertIs(C().x, int)
1321
1322 def test_isdataclass(self):
1323 # There is no isdataclass() helper any more, but the PEP
1324 # describes how to write it, so make sure that works. Note
1325 # that this version returns True for both classes and
1326 # instances.
1327 def isdataclass(obj):
1328 try:
1329 fields(obj)
1330 return True
1331 except TypeError:
1332 return False
1333
1334 self.assertFalse(isdataclass(0))
1335 self.assertFalse(isdataclass(int))
1336
1337 @dataclass
1338 class C:
1339 x: int
1340
1341 self.assertTrue(isdataclass(C))
1342 self.assertTrue(isdataclass(C(0)))
1343
1344 def test_helper_fields_with_class_instance(self):
1345 # Check that we can call fields() on either a class or instance,
1346 # and get back the same thing.
1347 @dataclass
1348 class C:
1349 x: int
1350 y: float
1351
1352 self.assertEqual(fields(C), fields(C(0, 0.0)))
1353
1354 def test_helper_fields_exception(self):
1355 # Check that TypeError is raised if not passed a dataclass or
1356 # instance.
1357 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1358 fields(0)
1359
1360 class C: pass
1361 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1362 fields(C)
1363 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1364 fields(C())
1365
1366 def test_helper_asdict(self):
1367 # Basic tests for asdict(), it should return a new dictionary
1368 @dataclass
1369 class C:
1370 x: int
1371 y: int
1372 c = C(1, 2)
1373
1374 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1375 self.assertEqual(asdict(c), asdict(c))
1376 self.assertIsNot(asdict(c), asdict(c))
1377 c.x = 42
1378 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1379 self.assertIs(type(asdict(c)), dict)
1380
1381 def test_helper_asdict_raises_on_classes(self):
1382 # asdict() should raise on a class object
1383 @dataclass
1384 class C:
1385 x: int
1386 y: int
1387 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1388 asdict(C)
1389 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1390 asdict(int)
1391
1392 def test_helper_asdict_copy_values(self):
1393 @dataclass
1394 class C:
1395 x: int
1396 y: List[int] = field(default_factory=list)
1397 initial = []
1398 c = C(1, initial)
1399 d = asdict(c)
1400 self.assertEqual(d['y'], initial)
1401 self.assertIsNot(d['y'], initial)
1402 c = C(1)
1403 d = asdict(c)
1404 d['y'].append(1)
1405 self.assertEqual(c.y, [])
1406
1407 def test_helper_asdict_nested(self):
1408 @dataclass
1409 class UserId:
1410 token: int
1411 group: int
1412 @dataclass
1413 class User:
1414 name: str
1415 id: UserId
1416 u = User('Joe', UserId(123, 1))
1417 d = asdict(u)
1418 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1419 self.assertIsNot(asdict(u), asdict(u))
1420 u.id.group = 2
1421 self.assertEqual(asdict(u), {'name': 'Joe',
1422 'id': {'token': 123, 'group': 2}})
1423
1424 def test_helper_asdict_builtin_containers(self):
1425 @dataclass
1426 class User:
1427 name: str
1428 id: int
1429 @dataclass
1430 class GroupList:
1431 id: int
1432 users: List[User]
1433 @dataclass
1434 class GroupTuple:
1435 id: int
1436 users: Tuple[User, ...]
1437 @dataclass
1438 class GroupDict:
1439 id: int
1440 users: Dict[str, User]
1441 a = User('Alice', 1)
1442 b = User('Bob', 2)
1443 gl = GroupList(0, [a, b])
1444 gt = GroupTuple(0, (a, b))
1445 gd = GroupDict(0, {'first': a, 'second': b})
1446 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1447 {'name': 'Bob', 'id': 2}]})
1448 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1449 {'name': 'Bob', 'id': 2})})
1450 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1451 'second': {'name': 'Bob', 'id': 2}}})
1452
1453 def test_helper_asdict_builtin_containers(self):
1454 @dataclass
1455 class Child:
1456 d: object
1457
1458 @dataclass
1459 class Parent:
1460 child: Child
1461
1462 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1463 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1464
1465 def test_helper_asdict_factory(self):
1466 @dataclass
1467 class C:
1468 x: int
1469 y: int
1470 c = C(1, 2)
1471 d = asdict(c, dict_factory=OrderedDict)
1472 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1473 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1474 c.x = 42
1475 d = asdict(c, dict_factory=OrderedDict)
1476 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1477 self.assertIs(type(d), OrderedDict)
1478
1479 def test_helper_astuple(self):
1480 # Basic tests for astuple(), it should return a new tuple
1481 @dataclass
1482 class C:
1483 x: int
1484 y: int = 0
1485 c = C(1)
1486
1487 self.assertEqual(astuple(c), (1, 0))
1488 self.assertEqual(astuple(c), astuple(c))
1489 self.assertIsNot(astuple(c), astuple(c))
1490 c.y = 42
1491 self.assertEqual(astuple(c), (1, 42))
1492 self.assertIs(type(astuple(c)), tuple)
1493
1494 def test_helper_astuple_raises_on_classes(self):
1495 # astuple() should raise on a class object
1496 @dataclass
1497 class C:
1498 x: int
1499 y: int
1500 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1501 astuple(C)
1502 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1503 astuple(int)
1504
1505 def test_helper_astuple_copy_values(self):
1506 @dataclass
1507 class C:
1508 x: int
1509 y: List[int] = field(default_factory=list)
1510 initial = []
1511 c = C(1, initial)
1512 t = astuple(c)
1513 self.assertEqual(t[1], initial)
1514 self.assertIsNot(t[1], initial)
1515 c = C(1)
1516 t = astuple(c)
1517 t[1].append(1)
1518 self.assertEqual(c.y, [])
1519
1520 def test_helper_astuple_nested(self):
1521 @dataclass
1522 class UserId:
1523 token: int
1524 group: int
1525 @dataclass
1526 class User:
1527 name: str
1528 id: UserId
1529 u = User('Joe', UserId(123, 1))
1530 t = astuple(u)
1531 self.assertEqual(t, ('Joe', (123, 1)))
1532 self.assertIsNot(astuple(u), astuple(u))
1533 u.id.group = 2
1534 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1535
1536 def test_helper_astuple_builtin_containers(self):
1537 @dataclass
1538 class User:
1539 name: str
1540 id: int
1541 @dataclass
1542 class GroupList:
1543 id: int
1544 users: List[User]
1545 @dataclass
1546 class GroupTuple:
1547 id: int
1548 users: Tuple[User, ...]
1549 @dataclass
1550 class GroupDict:
1551 id: int
1552 users: Dict[str, User]
1553 a = User('Alice', 1)
1554 b = User('Bob', 2)
1555 gl = GroupList(0, [a, b])
1556 gt = GroupTuple(0, (a, b))
1557 gd = GroupDict(0, {'first': a, 'second': b})
1558 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1559 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1560 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1561
1562 def test_helper_astuple_builtin_containers(self):
1563 @dataclass
1564 class Child:
1565 d: object
1566
1567 @dataclass
1568 class Parent:
1569 child: Child
1570
1571 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1572 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1573
1574 def test_helper_astuple_factory(self):
1575 @dataclass
1576 class C:
1577 x: int
1578 y: int
1579 NT = namedtuple('NT', 'x y')
1580 def nt(lst):
1581 return NT(*lst)
1582 c = C(1, 2)
1583 t = astuple(c, tuple_factory=nt)
1584 self.assertEqual(t, NT(1, 2))
1585 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1586 c.x = 42
1587 t = astuple(c, tuple_factory=nt)
1588 self.assertEqual(t, NT(42, 2))
1589 self.assertIs(type(t), NT)
1590
1591 def test_dynamic_class_creation(self):
1592 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1593 }
1594
1595 # Create the class.
1596 cls = type('C', (), cls_dict)
1597
1598 # Make it a dataclass.
1599 cls1 = dataclass(cls)
1600
1601 self.assertEqual(cls1, cls)
1602 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1603
1604 def test_dynamic_class_creation_using_field(self):
1605 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1606 'y': field(default=5),
1607 }
1608
1609 # Create the class.
1610 cls = type('C', (), cls_dict)
1611
1612 # Make it a dataclass.
1613 cls1 = dataclass(cls)
1614
1615 self.assertEqual(cls1, cls)
1616 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1617
1618 def test_init_in_order(self):
1619 @dataclass
1620 class C:
1621 a: int
1622 b: int = field()
1623 c: list = field(default_factory=list, init=False)
1624 d: list = field(default_factory=list)
1625 e: int = field(default=4, init=False)
1626 f: int = 4
1627
1628 calls = []
1629 def setattr(self, name, value):
1630 calls.append((name, value))
1631
1632 C.__setattr__ = setattr
1633 c = C(0, 1)
1634 self.assertEqual(('a', 0), calls[0])
1635 self.assertEqual(('b', 1), calls[1])
1636 self.assertEqual(('c', []), calls[2])
1637 self.assertEqual(('d', []), calls[3])
1638 self.assertNotIn(('e', 4), calls)
1639 self.assertEqual(('f', 4), calls[4])
1640
1641 def test_items_in_dicts(self):
1642 @dataclass
1643 class C:
1644 a: int
1645 b: list = field(default_factory=list, init=False)
1646 c: list = field(default_factory=list)
1647 d: int = field(default=4, init=False)
1648 e: int = 0
1649
1650 c = C(0)
1651 # Class dict
1652 self.assertNotIn('a', C.__dict__)
1653 self.assertNotIn('b', C.__dict__)
1654 self.assertNotIn('c', C.__dict__)
1655 self.assertIn('d', C.__dict__)
1656 self.assertEqual(C.d, 4)
1657 self.assertIn('e', C.__dict__)
1658 self.assertEqual(C.e, 0)
1659 # Instance dict
1660 self.assertIn('a', c.__dict__)
1661 self.assertEqual(c.a, 0)
1662 self.assertIn('b', c.__dict__)
1663 self.assertEqual(c.b, [])
1664 self.assertIn('c', c.__dict__)
1665 self.assertEqual(c.c, [])
1666 self.assertNotIn('d', c.__dict__)
1667 self.assertIn('e', c.__dict__)
1668 self.assertEqual(c.e, 0)
1669
1670 def test_alternate_classmethod_constructor(self):
1671 # Since __post_init__ can't take params, use a classmethod
1672 # alternate constructor. This is mostly an example to show how
1673 # to use this technique.
1674 @dataclass
1675 class C:
1676 x: int
1677 @classmethod
1678 def from_file(cls, filename):
1679 # In a real example, create a new instance
1680 # and populate 'x' from contents of a file.
1681 value_in_file = 20
1682 return cls(value_in_file)
1683
1684 self.assertEqual(C.from_file('filename').x, 20)
1685
1686 def test_field_metadata_default(self):
1687 # Make sure the default metadata is read-only and of
1688 # zero length.
1689 @dataclass
1690 class C:
1691 i: int
1692
1693 self.assertFalse(fields(C)[0].metadata)
1694 self.assertEqual(len(fields(C)[0].metadata), 0)
1695 with self.assertRaisesRegex(TypeError,
1696 'does not support item assignment'):
1697 fields(C)[0].metadata['test'] = 3
1698
1699 def test_field_metadata_mapping(self):
1700 # Make sure only a mapping can be passed as metadata
1701 # zero length.
1702 with self.assertRaises(TypeError):
1703 @dataclass
1704 class C:
1705 i: int = field(metadata=0)
1706
1707 # Make sure an empty dict works
1708 @dataclass
1709 class C:
1710 i: int = field(metadata={})
1711 self.assertFalse(fields(C)[0].metadata)
1712 self.assertEqual(len(fields(C)[0].metadata), 0)
1713 with self.assertRaisesRegex(TypeError,
1714 'does not support item assignment'):
1715 fields(C)[0].metadata['test'] = 3
1716
1717 # Make sure a non-empty dict works.
1718 @dataclass
1719 class C:
1720 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1721 self.assertEqual(len(fields(C)[0].metadata), 3)
1722 self.assertEqual(fields(C)[0].metadata['test'], 10)
1723 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1724 self.assertEqual(fields(C)[0].metadata[3], 'three')
1725 with self.assertRaises(KeyError):
1726 # Non-existent key.
1727 fields(C)[0].metadata['baz']
1728 with self.assertRaisesRegex(TypeError,
1729 'does not support item assignment'):
1730 fields(C)[0].metadata['test'] = 3
1731
1732 def test_field_metadata_custom_mapping(self):
1733 # Try a custom mapping.
1734 class SimpleNameSpace:
1735 def __init__(self, **kw):
1736 self.__dict__.update(kw)
1737
1738 def __getitem__(self, item):
1739 if item == 'xyzzy':
1740 return 'plugh'
1741 return getattr(self, item)
1742
1743 def __len__(self):
1744 return self.__dict__.__len__()
1745
1746 @dataclass
1747 class C:
1748 i: int = field(metadata=SimpleNameSpace(a=10))
1749
1750 self.assertEqual(len(fields(C)[0].metadata), 1)
1751 self.assertEqual(fields(C)[0].metadata['a'], 10)
1752 with self.assertRaises(AttributeError):
1753 fields(C)[0].metadata['b']
1754 # Make sure we're still talking to our custom mapping.
1755 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1756
1757 def test_generic_dataclasses(self):
1758 T = TypeVar('T')
1759
1760 @dataclass
1761 class LabeledBox(Generic[T]):
1762 content: T
1763 label: str = '<unknown>'
1764
1765 box = LabeledBox(42)
1766 self.assertEqual(box.content, 42)
1767 self.assertEqual(box.label, '<unknown>')
1768
1769 # subscripting the resulting class should work, etc.
1770 Alias = List[LabeledBox[int]]
1771
1772 def test_generic_extending(self):
1773 S = TypeVar('S')
1774 T = TypeVar('T')
1775
1776 @dataclass
1777 class Base(Generic[T, S]):
1778 x: T
1779 y: S
1780
1781 @dataclass
1782 class DataDerived(Base[int, T]):
1783 new_field: str
1784 Alias = DataDerived[str]
1785 c = Alias(0, 'test1', 'test2')
1786 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1787
1788 class NonDataDerived(Base[int, T]):
1789 def new_method(self):
1790 return self.y
1791 Alias = NonDataDerived[float]
1792 c = Alias(10, 1.0)
1793 self.assertEqual(c.new_method(), 1.0)
1794
1795 def test_helper_replace(self):
1796 @dataclass(frozen=True)
1797 class C:
1798 x: int
1799 y: int
1800
1801 c = C(1, 2)
1802 c1 = replace(c, x=3)
1803 self.assertEqual(c1.x, 3)
1804 self.assertEqual(c1.y, 2)
1805
1806 def test_helper_replace_frozen(self):
1807 @dataclass(frozen=True)
1808 class C:
1809 x: int
1810 y: int
1811 z: int = field(init=False, default=10)
1812 t: int = field(init=False, default=100)
1813
1814 c = C(1, 2)
1815 c1 = replace(c, x=3)
1816 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1817 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1818
1819
1820 with self.assertRaisesRegex(ValueError, 'init=False'):
1821 replace(c, x=3, z=20, t=50)
1822 with self.assertRaisesRegex(ValueError, 'init=False'):
1823 replace(c, z=20)
1824 replace(c, x=3, z=20, t=50)
1825
1826 # Make sure the result is still frozen.
1827 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1828 c1.x = 3
1829
1830 # Make sure we can't replace an attribute that doesn't exist,
1831 # if we're also replacing one that does exist. Test this
1832 # here, because setting attributes on frozen instances is
1833 # handled slightly differently from non-frozen ones.
1834 with self.assertRaisesRegex(TypeError, "__init__\(\) got an unexpected "
1835 "keyword argument 'a'"):
1836 c1 = replace(c, x=20, a=5)
1837
1838 def test_helper_replace_invalid_field_name(self):
1839 @dataclass(frozen=True)
1840 class C:
1841 x: int
1842 y: int
1843
1844 c = C(1, 2)
1845 with self.assertRaisesRegex(TypeError, "__init__\(\) got an unexpected "
1846 "keyword argument 'z'"):
1847 c1 = replace(c, z=3)
1848
1849 def test_helper_replace_invalid_object(self):
1850 @dataclass(frozen=True)
1851 class C:
1852 x: int
1853 y: int
1854
1855 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1856 replace(C, x=3)
1857
1858 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1859 replace(0, x=3)
1860
1861 def test_helper_replace_no_init(self):
1862 @dataclass
1863 class C:
1864 x: int
1865 y: int = field(init=False, default=10)
1866
1867 c = C(1)
1868 c.y = 20
1869
1870 # Make sure y gets the default value.
1871 c1 = replace(c, x=5)
1872 self.assertEqual((c1.x, c1.y), (5, 10))
1873
1874 # Trying to replace y is an error.
1875 with self.assertRaisesRegex(ValueError, 'init=False'):
1876 replace(c, x=2, y=30)
1877 with self.assertRaisesRegex(ValueError, 'init=False'):
1878 replace(c, y=30)
1879
1880 def test_dataclassses_pickleable(self):
1881 global P, Q, R
1882 @dataclass
1883 class P:
1884 x: int
1885 y: int = 0
1886 @dataclass
1887 class Q:
1888 x: int
1889 y: int = field(default=0, init=False)
1890 @dataclass
1891 class R:
1892 x: int
1893 y: List[int] = field(default_factory=list)
1894 q = Q(1)
1895 q.y = 2
1896 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1897 for sample in samples:
1898 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1899 with self.subTest(sample=sample, proto=proto):
1900 new_sample = pickle.loads(pickle.dumps(sample, proto))
1901 self.assertEqual(sample.x, new_sample.x)
1902 self.assertEqual(sample.y, new_sample.y)
1903 self.assertIsNot(sample, new_sample)
1904 new_sample.x = 42
1905 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1906 self.assertEqual(new_sample.x, another_new_sample.x)
1907 self.assertEqual(sample.y, another_new_sample.y)
1908
1909 def test_helper_make_dataclass(self):
1910 C = make_dataclass('C',
1911 [('x', int),
1912 ('y', int, field(default=5))],
1913 namespace={'add_one': lambda self: self.x + 1})
1914 c = C(10)
1915 self.assertEqual((c.x, c.y), (10, 5))
1916 self.assertEqual(c.add_one(), 11)
1917
1918
1919 def test_helper_make_dataclass_no_mutate_namespace(self):
1920 # Make sure a provided namespace isn't mutated.
1921 ns = {}
1922 C = make_dataclass('C',
1923 [('x', int),
1924 ('y', int, field(default=5))],
1925 namespace=ns)
1926 self.assertEqual(ns, {})
1927
1928 def test_helper_make_dataclass_base(self):
1929 class Base1:
1930 pass
1931 class Base2:
1932 pass
1933 C = make_dataclass('C',
1934 [('x', int)],
1935 bases=(Base1, Base2))
1936 c = C(2)
1937 self.assertIsInstance(c, C)
1938 self.assertIsInstance(c, Base1)
1939 self.assertIsInstance(c, Base2)
1940
1941 def test_helper_make_dataclass_base_dataclass(self):
1942 @dataclass
1943 class Base1:
1944 x: int
1945 class Base2:
1946 pass
1947 C = make_dataclass('C',
1948 [('y', int)],
1949 bases=(Base1, Base2))
1950 with self.assertRaisesRegex(TypeError, 'required positional'):
1951 c = C(2)
1952 c = C(1, 2)
1953 self.assertIsInstance(c, C)
1954 self.assertIsInstance(c, Base1)
1955 self.assertIsInstance(c, Base2)
1956
1957 self.assertEqual((c.x, c.y), (1, 2))
1958
1959 def test_helper_make_dataclass_init_var(self):
1960 def post_init(self, y):
1961 self.x *= y
1962
1963 C = make_dataclass('C',
1964 [('x', int),
1965 ('y', InitVar[int]),
1966 ],
1967 namespace={'__post_init__': post_init},
1968 )
1969 c = C(2, 3)
1970 self.assertEqual(vars(c), {'x': 6})
1971 self.assertEqual(len(fields(c)), 1)
1972
1973 def test_helper_make_dataclass_class_var(self):
1974 C = make_dataclass('C',
1975 [('x', int),
1976 ('y', ClassVar[int], 10),
1977 ('z', ClassVar[int], field(default=20)),
1978 ])
1979 c = C(1)
1980 self.assertEqual(vars(c), {'x': 1})
1981 self.assertEqual(len(fields(c)), 1)
1982 self.assertEqual(C.y, 10)
1983 self.assertEqual(C.z, 20)
1984
1985
1986class TestDocString(unittest.TestCase):
1987 def assertDocStrEqual(self, a, b):
1988 # Because 3.6 and 3.7 differ in how inspect.signature work
1989 # (see bpo #32108), for the time being just compare them with
1990 # whitespace stripped.
1991 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1992
1993 def test_existing_docstring_not_overridden(self):
1994 @dataclass
1995 class C:
1996 """Lorem ipsum"""
1997 x: int
1998
1999 self.assertEqual(C.__doc__, "Lorem ipsum")
2000
2001 def test_docstring_no_fields(self):
2002 @dataclass
2003 class C:
2004 pass
2005
2006 self.assertDocStrEqual(C.__doc__, "C()")
2007
2008 def test_docstring_one_field(self):
2009 @dataclass
2010 class C:
2011 x: int
2012
2013 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2014
2015 def test_docstring_two_fields(self):
2016 @dataclass
2017 class C:
2018 x: int
2019 y: int
2020
2021 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2022
2023 def test_docstring_three_fields(self):
2024 @dataclass
2025 class C:
2026 x: int
2027 y: int
2028 z: str
2029
2030 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2031
2032 def test_docstring_one_field_with_default(self):
2033 @dataclass
2034 class C:
2035 x: int = 3
2036
2037 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2038
2039 def test_docstring_one_field_with_default_none(self):
2040 @dataclass
2041 class C:
2042 x: Union[int, type(None)] = None
2043
2044 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2045
2046 def test_docstring_list_field(self):
2047 @dataclass
2048 class C:
2049 x: List[int]
2050
2051 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2052
2053 def test_docstring_list_field_with_default_factory(self):
2054 @dataclass
2055 class C:
2056 x: List[int] = field(default_factory=list)
2057
2058 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2059
2060 def test_docstring_deque_field(self):
2061 @dataclass
2062 class C:
2063 x: deque
2064
2065 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2066
2067 def test_docstring_deque_field_with_default_factory(self):
2068 @dataclass
2069 class C:
2070 x: deque = field(default_factory=deque)
2071
2072 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2073
2074
2075if __name__ == '__main__':
2076 unittest.main()