blob: 9752f5502c7da093cd21513766194eb191bacfcd [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
Eric V. Smithe7ba0132018-01-06 12:41:53 -05003 make_dataclass, replace, InitVar, Field, MISSING, is_dataclass,
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050012from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013
14# Just any custom exception we can catch.
15class CustomError(Exception): pass
16
17class TestCase(unittest.TestCase):
18 def test_no_fields(self):
19 @dataclass
20 class C:
21 pass
22
23 o = C()
24 self.assertEqual(len(fields(C)), 0)
25
26 def test_one_field_no_default(self):
27 @dataclass
28 class C:
29 x: int
30
31 o = C(42)
32 self.assertEqual(o.x, 42)
33
34 def test_named_init_params(self):
35 @dataclass
36 class C:
37 x: int
38
39 o = C(x=32)
40 self.assertEqual(o.x, 32)
41
42 def test_two_fields_one_default(self):
43 @dataclass
44 class C:
45 x: int
46 y: int = 0
47
48 o = C(3)
49 self.assertEqual((o.x, o.y), (3, 0))
50
51 # Non-defaults following defaults.
52 with self.assertRaisesRegex(TypeError,
53 "non-default argument 'y' follows "
54 "default argument"):
55 @dataclass
56 class C:
57 x: int = 0
58 y: int
59
60 # A derived class adds a non-default field after a default one.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class B:
66 x: int = 0
67
68 @dataclass
69 class C(B):
70 y: int
71
72 # Override a base class field and add a default to
73 # a field which didn't use to have a default.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int
80 y: int
81
82 @dataclass
83 class C(B):
84 x: int = 0
85
Eric V. Smithf0db54a2017-12-04 16:58:55 -050086 def test_overwriting_hash(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -050087 @dataclass(frozen=True)
88 class C:
89 x: int
90 def __hash__(self):
91 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -050092
93 @dataclass(frozen=True,hash=False)
94 class C:
95 x: int
96 def __hash__(self):
97 return 600
98 self.assertEqual(hash(C(0)), 600)
99
Eric V. Smithea8fc522018-01-27 19:07:40 -0500100 @dataclass(frozen=True)
101 class C:
102 x: int
103 def __hash__(self):
104 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500105
106 @dataclass(frozen=True, hash=False)
107 class C:
108 x: int
109 def __hash__(self):
110 return 600
111 self.assertEqual(hash(C(0)), 600)
112
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500113 def test_overwrite_fields_in_derived_class(self):
114 # Note that x from C1 replaces x in Base, but the order remains
115 # the same as defined in Base.
116 @dataclass
117 class Base:
118 x: Any = 15.0
119 y: int = 0
120
121 @dataclass
122 class C1(Base):
123 z: int = 10
124 x: int = 15
125
126 o = Base()
127 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
128
129 o = C1()
130 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
131
132 o = C1(x=5)
133 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
134
135 def test_field_named_self(self):
136 @dataclass
137 class C:
138 self: str
139 c=C('foo')
140 self.assertEqual(c.self, 'foo')
141
142 # Make sure the first parameter is not named 'self'.
143 sig = inspect.signature(C.__init__)
144 first = next(iter(sig.parameters))
145 self.assertNotEqual('self', first)
146
147 # But we do use 'self' if no field named self.
148 @dataclass
149 class C:
150 selfx: str
151
152 # Make sure the first parameter is named 'self'.
153 sig = inspect.signature(C.__init__)
154 first = next(iter(sig.parameters))
155 self.assertEqual('self', first)
156
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500157 def test_0_field_compare(self):
158 # Ensure that order=False is the default.
159 @dataclass
160 class C0:
161 pass
162
163 @dataclass(order=False)
164 class C1:
165 pass
166
167 for cls in [C0, C1]:
168 with self.subTest(cls=cls):
169 self.assertEqual(cls(), cls())
170 for idx, fn in enumerate([lambda a, b: a < b,
171 lambda a, b: a <= b,
172 lambda a, b: a > b,
173 lambda a, b: a >= b]):
174 with self.subTest(idx=idx):
175 with self.assertRaisesRegex(TypeError,
176 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
177 fn(cls(), cls())
178
179 @dataclass(order=True)
180 class C:
181 pass
182 self.assertLessEqual(C(), C())
183 self.assertGreaterEqual(C(), C())
184
185 def test_1_field_compare(self):
186 # Ensure that order=False is the default.
187 @dataclass
188 class C0:
189 x: int
190
191 @dataclass(order=False)
192 class C1:
193 x: int
194
195 for cls in [C0, C1]:
196 with self.subTest(cls=cls):
197 self.assertEqual(cls(1), cls(1))
198 self.assertNotEqual(cls(0), cls(1))
199 for idx, fn in enumerate([lambda a, b: a < b,
200 lambda a, b: a <= b,
201 lambda a, b: a > b,
202 lambda a, b: a >= b]):
203 with self.subTest(idx=idx):
204 with self.assertRaisesRegex(TypeError,
205 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
206 fn(cls(0), cls(0))
207
208 @dataclass(order=True)
209 class C:
210 x: int
211 self.assertLess(C(0), C(1))
212 self.assertLessEqual(C(0), C(1))
213 self.assertLessEqual(C(1), C(1))
214 self.assertGreater(C(1), C(0))
215 self.assertGreaterEqual(C(1), C(0))
216 self.assertGreaterEqual(C(1), C(1))
217
218 def test_simple_compare(self):
219 # Ensure that order=False is the default.
220 @dataclass
221 class C0:
222 x: int
223 y: int
224
225 @dataclass(order=False)
226 class C1:
227 x: int
228 y: int
229
230 for cls in [C0, C1]:
231 with self.subTest(cls=cls):
232 self.assertEqual(cls(0, 0), cls(0, 0))
233 self.assertEqual(cls(1, 2), cls(1, 2))
234 self.assertNotEqual(cls(1, 0), cls(0, 0))
235 self.assertNotEqual(cls(1, 0), cls(1, 1))
236 for idx, fn in enumerate([lambda a, b: a < b,
237 lambda a, b: a <= b,
238 lambda a, b: a > b,
239 lambda a, b: a >= b]):
240 with self.subTest(idx=idx):
241 with self.assertRaisesRegex(TypeError,
242 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
243 fn(cls(0, 0), cls(0, 0))
244
245 @dataclass(order=True)
246 class C:
247 x: int
248 y: int
249
250 for idx, fn in enumerate([lambda a, b: a == b,
251 lambda a, b: a <= b,
252 lambda a, b: a >= b]):
253 with self.subTest(idx=idx):
254 self.assertTrue(fn(C(0, 0), C(0, 0)))
255
256 for idx, fn in enumerate([lambda a, b: a < b,
257 lambda a, b: a <= b,
258 lambda a, b: a != b]):
259 with self.subTest(idx=idx):
260 self.assertTrue(fn(C(0, 0), C(0, 1)))
261 self.assertTrue(fn(C(0, 1), C(1, 0)))
262 self.assertTrue(fn(C(1, 0), C(1, 1)))
263
264 for idx, fn in enumerate([lambda a, b: a > b,
265 lambda a, b: a >= b,
266 lambda a, b: a != b]):
267 with self.subTest(idx=idx):
268 self.assertTrue(fn(C(0, 1), C(0, 0)))
269 self.assertTrue(fn(C(1, 0), C(0, 1)))
270 self.assertTrue(fn(C(1, 1), C(1, 0)))
271
272 def test_compare_subclasses(self):
273 # Comparisons fail for subclasses, even if no fields
274 # are added.
275 @dataclass
276 class B:
277 i: int
278
279 @dataclass
280 class C(B):
281 pass
282
283 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
284 (lambda a, b: a != b, True)]):
285 with self.subTest(idx=idx):
286 self.assertEqual(fn(B(0), C(0)), expected)
287
288 for idx, fn in enumerate([lambda a, b: a < b,
289 lambda a, b: a <= b,
290 lambda a, b: a > b,
291 lambda a, b: a >= b]):
292 with self.subTest(idx=idx):
293 with self.assertRaisesRegex(TypeError,
294 "not supported between instances of 'B' and 'C'"):
295 fn(B(0), C(0))
296
297 def test_0_field_hash(self):
298 @dataclass(hash=True)
299 class C:
300 pass
301 self.assertEqual(hash(C()), hash(()))
302
303 def test_1_field_hash(self):
304 @dataclass(hash=True)
305 class C:
306 x: int
307 self.assertEqual(hash(C(4)), hash((4,)))
308 self.assertEqual(hash(C(42)), hash((42,)))
309
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500310 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500311 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500312 for (eq, order, result ) in [
313 (False, False, 'neither'),
314 (False, True, 'exception'),
315 (True, False, 'eq_only'),
316 (True, True, 'both'),
317 ]:
318 with self.subTest(eq=eq, order=order):
319 if result == 'exception':
320 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
321 @dataclass(eq=eq, order=order)
322 class C:
323 pass
324 else:
325 @dataclass(eq=eq, order=order)
326 class C:
327 pass
328
329 if result == 'neither':
330 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500331 self.assertNotIn('__lt__', C.__dict__)
332 self.assertNotIn('__le__', C.__dict__)
333 self.assertNotIn('__gt__', C.__dict__)
334 self.assertNotIn('__ge__', C.__dict__)
335 elif result == 'both':
336 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500337 self.assertIn('__lt__', C.__dict__)
338 self.assertIn('__le__', C.__dict__)
339 self.assertIn('__gt__', C.__dict__)
340 self.assertIn('__ge__', C.__dict__)
341 elif result == 'eq_only':
342 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500343 self.assertNotIn('__lt__', C.__dict__)
344 self.assertNotIn('__le__', C.__dict__)
345 self.assertNotIn('__gt__', C.__dict__)
346 self.assertNotIn('__ge__', C.__dict__)
347 else:
348 assert False, f'unknown result {result!r}'
349
350 def test_field_no_default(self):
351 @dataclass
352 class C:
353 x: int = field()
354
355 self.assertEqual(C(5).x, 5)
356
357 with self.assertRaisesRegex(TypeError,
358 r"__init__\(\) missing 1 required "
359 "positional argument: 'x'"):
360 C()
361
362 def test_field_default(self):
363 default = object()
364 @dataclass
365 class C:
366 x: object = field(default=default)
367
368 self.assertIs(C.x, default)
369 c = C(10)
370 self.assertEqual(c.x, 10)
371
372 # If we delete the instance attribute, we should then see the
373 # class attribute.
374 del c.x
375 self.assertIs(c.x, default)
376
377 self.assertIs(C().x, default)
378
379 def test_not_in_repr(self):
380 @dataclass
381 class C:
382 x: int = field(repr=False)
383 with self.assertRaises(TypeError):
384 C()
385 c = C(10)
386 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
387
388 @dataclass
389 class C:
390 x: int = field(repr=False)
391 y: int
392 c = C(10, 20)
393 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
394
395 def test_not_in_compare(self):
396 @dataclass
397 class C:
398 x: int = 0
399 y: int = field(compare=False, default=4)
400
401 self.assertEqual(C(), C(0, 20))
402 self.assertEqual(C(1, 10), C(1, 20))
403 self.assertNotEqual(C(3), C(4, 10))
404 self.assertNotEqual(C(3, 10), C(4, 10))
405
406 def test_hash_field_rules(self):
407 # Test all 6 cases of:
408 # hash=True/False/None
409 # compare=True/False
410 for (hash_val, compare, result ) in [
411 (True, False, 'field' ),
412 (True, True, 'field' ),
413 (False, False, 'absent'),
414 (False, True, 'absent'),
415 (None, False, 'absent'),
416 (None, True, 'field' ),
417 ]:
418 with self.subTest(hash_val=hash_val, compare=compare):
419 @dataclass(hash=True)
420 class C:
421 x: int = field(compare=compare, hash=hash_val, default=5)
422
423 if result == 'field':
424 # __hash__ contains the field.
425 self.assertEqual(C(5).__hash__(), hash((5,)))
426 elif result == 'absent':
427 # The field is not present in the hash.
428 self.assertEqual(C(5).__hash__(), hash(()))
429 else:
430 assert False, f'unknown result {result!r}'
431
432 def test_init_false_no_default(self):
433 # If init=False and no default value, then the field won't be
434 # present in the instance.
435 @dataclass
436 class C:
437 x: int = field(init=False)
438
439 self.assertNotIn('x', C().__dict__)
440
441 @dataclass
442 class C:
443 x: int
444 y: int = 0
445 z: int = field(init=False)
446 t: int = 10
447
448 self.assertNotIn('z', C(0).__dict__)
449 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
450
451 def test_class_marker(self):
452 @dataclass
453 class C:
454 x: int
455 y: str = field(init=False, default=None)
456 z: str = field(repr=False)
457
458 the_fields = fields(C)
459 # the_fields is a tuple of 3 items, each value
460 # is in __annotations__.
461 self.assertIsInstance(the_fields, tuple)
462 for f in the_fields:
463 self.assertIs(type(f), Field)
464 self.assertIn(f.name, C.__annotations__)
465
466 self.assertEqual(len(the_fields), 3)
467
468 self.assertEqual(the_fields[0].name, 'x')
469 self.assertEqual(the_fields[0].type, int)
470 self.assertFalse(hasattr(C, 'x'))
471 self.assertTrue (the_fields[0].init)
472 self.assertTrue (the_fields[0].repr)
473 self.assertEqual(the_fields[1].name, 'y')
474 self.assertEqual(the_fields[1].type, str)
475 self.assertIsNone(getattr(C, 'y'))
476 self.assertFalse(the_fields[1].init)
477 self.assertTrue (the_fields[1].repr)
478 self.assertEqual(the_fields[2].name, 'z')
479 self.assertEqual(the_fields[2].type, str)
480 self.assertFalse(hasattr(C, 'z'))
481 self.assertTrue (the_fields[2].init)
482 self.assertFalse(the_fields[2].repr)
483
484 def test_field_order(self):
485 @dataclass
486 class B:
487 a: str = 'B:a'
488 b: str = 'B:b'
489 c: str = 'B:c'
490
491 @dataclass
492 class C(B):
493 b: str = 'C:b'
494
495 self.assertEqual([(f.name, f.default) for f in fields(C)],
496 [('a', 'B:a'),
497 ('b', 'C:b'),
498 ('c', 'B:c')])
499
500 @dataclass
501 class D(B):
502 c: str = 'D:c'
503
504 self.assertEqual([(f.name, f.default) for f in fields(D)],
505 [('a', 'B:a'),
506 ('b', 'B:b'),
507 ('c', 'D:c')])
508
509 @dataclass
510 class E(D):
511 a: str = 'E:a'
512 d: str = 'E:d'
513
514 self.assertEqual([(f.name, f.default) for f in fields(E)],
515 [('a', 'E:a'),
516 ('b', 'B:b'),
517 ('c', 'D:c'),
518 ('d', 'E:d')])
519
520 def test_class_attrs(self):
521 # We only have a class attribute if a default value is
522 # specified, either directly or via a field with a default.
523 default = object()
524 @dataclass
525 class C:
526 x: int
527 y: int = field(repr=False)
528 z: object = default
529 t: int = field(default=100)
530
531 self.assertFalse(hasattr(C, 'x'))
532 self.assertFalse(hasattr(C, 'y'))
533 self.assertIs (C.z, default)
534 self.assertEqual(C.t, 100)
535
536 def test_disallowed_mutable_defaults(self):
537 # For the known types, don't allow mutable default values.
538 for typ, empty, non_empty in [(list, [], [1]),
539 (dict, {}, {0:1}),
540 (set, set(), set([1])),
541 ]:
542 with self.subTest(typ=typ):
543 # Can't use a zero-length value.
544 with self.assertRaisesRegex(ValueError,
545 f'mutable default {typ} for field '
546 'x is not allowed'):
547 @dataclass
548 class Point:
549 x: typ = empty
550
551
552 # Nor a non-zero-length value
553 with self.assertRaisesRegex(ValueError,
554 f'mutable default {typ} for field '
555 'y is not allowed'):
556 @dataclass
557 class Point:
558 y: typ = non_empty
559
560 # Check subtypes also fail.
561 class Subclass(typ): pass
562
563 with self.assertRaisesRegex(ValueError,
564 f"mutable default .*Subclass'>"
565 ' for field z is not allowed'
566 ):
567 @dataclass
568 class Point:
569 z: typ = Subclass()
570
571 # Because this is a ClassVar, it can be mutable.
572 @dataclass
573 class C:
574 z: ClassVar[typ] = typ()
575
576 # Because this is a ClassVar, it can be mutable.
577 @dataclass
578 class C:
579 x: ClassVar[typ] = Subclass()
580
581
582 def test_deliberately_mutable_defaults(self):
583 # If a mutable default isn't in the known list of
584 # (list, dict, set), then it's okay.
585 class Mutable:
586 def __init__(self):
587 self.l = []
588
589 @dataclass
590 class C:
591 x: Mutable
592
593 # These 2 instances will share this value of x.
594 lst = Mutable()
595 o1 = C(lst)
596 o2 = C(lst)
597 self.assertEqual(o1, o2)
598 o1.x.l.extend([1, 2])
599 self.assertEqual(o1, o2)
600 self.assertEqual(o1.x.l, [1, 2])
601 self.assertIs(o1.x, o2.x)
602
603 def test_no_options(self):
604 # call with dataclass()
605 @dataclass()
606 class C:
607 x: int
608
609 self.assertEqual(C(42).x, 42)
610
611 def test_not_tuple(self):
612 # Make sure we can't be compared to a tuple.
613 @dataclass
614 class Point:
615 x: int
616 y: int
617 self.assertNotEqual(Point(1, 2), (1, 2))
618
619 # And that we can't compare to another unrelated dataclass
620 @dataclass
621 class C:
622 x: int
623 y: int
624 self.assertNotEqual(Point(1, 3), C(1, 3))
625
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500626 def test_frozen(self):
627 @dataclass(frozen=True)
628 class C:
629 i: int
630
631 c = C(10)
632 self.assertEqual(c.i, 10)
633 with self.assertRaises(FrozenInstanceError):
634 c.i = 5
635 self.assertEqual(c.i, 10)
636
637 # Check that a derived class is still frozen, even if not
638 # marked so.
639 @dataclass
640 class D(C):
641 pass
642
643 d = D(20)
644 self.assertEqual(d.i, 20)
645 with self.assertRaises(FrozenInstanceError):
646 d.i = 5
647 self.assertEqual(d.i, 20)
648
649 def test_not_tuple(self):
650 # Test that some of the problems with namedtuple don't happen
651 # here.
652 @dataclass
653 class Point3D:
654 x: int
655 y: int
656 z: int
657
658 @dataclass
659 class Date:
660 year: int
661 month: int
662 day: int
663
664 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
665 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
666
667 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200668 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500669 x, y, z = Point3D(4, 5, 6)
670
Eric V. Smith7c99e932018-01-28 19:18:55 -0500671 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500672 # equal.
673 @dataclass
674 class Point3Dv1:
675 x: int = 0
676 y: int = 0
677 z: int = 0
678 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
679
680 def test_function_annotations(self):
681 # Some dummy class and instance to use as a default.
682 class F:
683 pass
684 f = F()
685
686 def validate_class(cls):
687 # First, check __annotations__, even though they're not
688 # function annotations.
689 self.assertEqual(cls.__annotations__['i'], int)
690 self.assertEqual(cls.__annotations__['j'], str)
691 self.assertEqual(cls.__annotations__['k'], F)
692 self.assertEqual(cls.__annotations__['l'], float)
693 self.assertEqual(cls.__annotations__['z'], complex)
694
695 # Verify __init__.
696
697 signature = inspect.signature(cls.__init__)
698 # Check the return type, should be None
699 self.assertIs(signature.return_annotation, None)
700
701 # Check each parameter.
702 params = iter(signature.parameters.values())
703 param = next(params)
704 # This is testing an internal name, and probably shouldn't be tested.
705 self.assertEqual(param.name, 'self')
706 param = next(params)
707 self.assertEqual(param.name, 'i')
708 self.assertIs (param.annotation, int)
709 self.assertEqual(param.default, inspect.Parameter.empty)
710 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
711 param = next(params)
712 self.assertEqual(param.name, 'j')
713 self.assertIs (param.annotation, str)
714 self.assertEqual(param.default, inspect.Parameter.empty)
715 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
716 param = next(params)
717 self.assertEqual(param.name, 'k')
718 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500719 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
721 param = next(params)
722 self.assertEqual(param.name, 'l')
723 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500724 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500725 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
726 self.assertRaises(StopIteration, next, params)
727
728
729 @dataclass
730 class C:
731 i: int
732 j: str
733 k: F = f
734 l: float=field(default=None)
735 z: complex=field(default=3+4j, init=False)
736
737 validate_class(C)
738
739 # Now repeat with __hash__.
740 @dataclass(frozen=True, hash=True)
741 class C:
742 i: int
743 j: str
744 k: F = f
745 l: float=field(default=None)
746 z: complex=field(default=3+4j, init=False)
747
748 validate_class(C)
749
Eric V. Smith03220fd2017-12-29 13:59:58 -0500750 def test_missing_default(self):
751 # Test that MISSING works the same as a default not being
752 # specified.
753 @dataclass
754 class C:
755 x: int=field(default=MISSING)
756 with self.assertRaisesRegex(TypeError,
757 r'__init__\(\) missing 1 required '
758 'positional argument'):
759 C()
760 self.assertNotIn('x', C.__dict__)
761
762 @dataclass
763 class D:
764 x: int
765 with self.assertRaisesRegex(TypeError,
766 r'__init__\(\) missing 1 required '
767 'positional argument'):
768 D()
769 self.assertNotIn('x', D.__dict__)
770
771 def test_missing_default_factory(self):
772 # Test that MISSING works the same as a default factory not
773 # being specified (which is really the same as a default not
774 # being specified, too).
775 @dataclass
776 class C:
777 x: int=field(default_factory=MISSING)
778 with self.assertRaisesRegex(TypeError,
779 r'__init__\(\) missing 1 required '
780 'positional argument'):
781 C()
782 self.assertNotIn('x', C.__dict__)
783
784 @dataclass
785 class D:
786 x: int=field(default=MISSING, default_factory=MISSING)
787 with self.assertRaisesRegex(TypeError,
788 r'__init__\(\) missing 1 required '
789 'positional argument'):
790 D()
791 self.assertNotIn('x', D.__dict__)
792
793 def test_missing_repr(self):
794 self.assertIn('MISSING_TYPE object', repr(MISSING))
795
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500796 def test_dont_include_other_annotations(self):
797 @dataclass
798 class C:
799 i: int
800 def foo(self) -> int:
801 return 4
802 @property
803 def bar(self) -> int:
804 return 5
805 self.assertEqual(list(C.__annotations__), ['i'])
806 self.assertEqual(C(10).foo(), 4)
807 self.assertEqual(C(10).bar, 5)
808
809 def test_post_init(self):
810 # Just make sure it gets called
811 @dataclass
812 class C:
813 def __post_init__(self):
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817
818 @dataclass
819 class C:
820 i: int = 10
821 def __post_init__(self):
822 if self.i == 10:
823 raise CustomError()
824 with self.assertRaises(CustomError):
825 C()
826 # post-init gets called, but doesn't raise. This is just
827 # checking that self is used correctly.
828 C(5)
829
830 # If there's not an __init__, then post-init won't get called.
831 @dataclass(init=False)
832 class C:
833 def __post_init__(self):
834 raise CustomError()
835 # Creating the class won't raise
836 C()
837
838 @dataclass
839 class C:
840 x: int = 0
841 def __post_init__(self):
842 self.x *= 2
843 self.assertEqual(C().x, 0)
844 self.assertEqual(C(2).x, 4)
845
Mike53f7a7c2017-12-14 14:04:53 +0300846 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 # attributes.
848 @dataclass(frozen=True)
849 class C:
850 x: int = 0
851 def __post_init__(self):
852 self.x *= 2
853 with self.assertRaises(FrozenInstanceError):
854 C()
855
856 def test_post_init_super(self):
857 # Make sure super() post-init isn't called by default.
858 class B:
859 def __post_init__(self):
860 raise CustomError()
861
862 @dataclass
863 class C(B):
864 def __post_init__(self):
865 self.x = 5
866
867 self.assertEqual(C().x, 5)
868
869 # Now call super(), and it will raise
870 @dataclass
871 class C(B):
872 def __post_init__(self):
873 super().__post_init__()
874
875 with self.assertRaises(CustomError):
876 C()
877
878 # Make sure post-init is called, even if not defined in our
879 # class.
880 @dataclass
881 class C(B):
882 pass
883
884 with self.assertRaises(CustomError):
885 C()
886
887 def test_post_init_staticmethod(self):
888 flag = False
889 @dataclass
890 class C:
891 x: int
892 y: int
893 @staticmethod
894 def __post_init__():
895 nonlocal flag
896 flag = True
897
898 self.assertFalse(flag)
899 c = C(3, 4)
900 self.assertEqual((c.x, c.y), (3, 4))
901 self.assertTrue(flag)
902
903 def test_post_init_classmethod(self):
904 @dataclass
905 class C:
906 flag = False
907 x: int
908 y: int
909 @classmethod
910 def __post_init__(cls):
911 cls.flag = True
912
913 self.assertFalse(C.flag)
914 c = C(3, 4)
915 self.assertEqual((c.x, c.y), (3, 4))
916 self.assertTrue(C.flag)
917
918 def test_class_var(self):
919 # Make sure ClassVars are ignored in __init__, __repr__, etc.
920 @dataclass
921 class C:
922 x: int
923 y: int = 10
924 z: ClassVar[int] = 1000
925 w: ClassVar[int] = 2000
926 t: ClassVar[int] = 3000
927
928 c = C(5)
929 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
930 self.assertEqual(len(fields(C)), 2) # We have 2 fields
931 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
932 self.assertEqual(c.z, 1000)
933 self.assertEqual(c.w, 2000)
934 self.assertEqual(c.t, 3000)
935 C.z += 1
936 self.assertEqual(c.z, 1001)
937 c = C(20)
938 self.assertEqual((c.x, c.y), (20, 10))
939 self.assertEqual(c.z, 1001)
940 self.assertEqual(c.w, 2000)
941 self.assertEqual(c.t, 3000)
942
943 def test_class_var_no_default(self):
944 # If a ClassVar has no default value, it should not be set on the class.
945 @dataclass
946 class C:
947 x: ClassVar[int]
948
949 self.assertNotIn('x', C.__dict__)
950
951 def test_class_var_default_factory(self):
952 # It makes no sense for a ClassVar to have a default factory. When
953 # would it be called? Call it yourself, since it's class-wide.
954 with self.assertRaisesRegex(TypeError,
955 'cannot have a default factory'):
956 @dataclass
957 class C:
958 x: ClassVar[int] = field(default_factory=int)
959
960 self.assertNotIn('x', C.__dict__)
961
962 def test_class_var_with_default(self):
963 # If a ClassVar has a default value, it should be set on the class.
964 @dataclass
965 class C:
966 x: ClassVar[int] = 10
967 self.assertEqual(C.x, 10)
968
969 @dataclass
970 class C:
971 x: ClassVar[int] = field(default=10)
972 self.assertEqual(C.x, 10)
973
974 def test_class_var_frozen(self):
975 # Make sure ClassVars work even if we're frozen.
976 @dataclass(frozen=True)
977 class C:
978 x: int
979 y: int = 10
980 z: ClassVar[int] = 1000
981 w: ClassVar[int] = 2000
982 t: ClassVar[int] = 3000
983
984 c = C(5)
985 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
986 self.assertEqual(len(fields(C)), 2) # We have 2 fields
987 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
988 self.assertEqual(c.z, 1000)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991 # We can still modify the ClassVar, it's only instances that are
992 # frozen.
993 C.z += 1
994 self.assertEqual(c.z, 1001)
995 c = C(20)
996 self.assertEqual((c.x, c.y), (20, 10))
997 self.assertEqual(c.z, 1001)
998 self.assertEqual(c.w, 2000)
999 self.assertEqual(c.t, 3000)
1000
1001 def test_init_var_no_default(self):
1002 # If an InitVar has no default value, it should not be set on the class.
1003 @dataclass
1004 class C:
1005 x: InitVar[int]
1006
1007 self.assertNotIn('x', C.__dict__)
1008
1009 def test_init_var_default_factory(self):
1010 # It makes no sense for an InitVar to have a default factory. When
1011 # would it be called? Call it yourself, since it's class-wide.
1012 with self.assertRaisesRegex(TypeError,
1013 'cannot have a default factory'):
1014 @dataclass
1015 class C:
1016 x: InitVar[int] = field(default_factory=int)
1017
1018 self.assertNotIn('x', C.__dict__)
1019
1020 def test_init_var_with_default(self):
1021 # If an InitVar has a default value, it should be set on the class.
1022 @dataclass
1023 class C:
1024 x: InitVar[int] = 10
1025 self.assertEqual(C.x, 10)
1026
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = field(default=10)
1030 self.assertEqual(C.x, 10)
1031
1032 def test_init_var(self):
1033 @dataclass
1034 class C:
1035 x: int = None
1036 init_param: InitVar[int] = None
1037
1038 def __post_init__(self, init_param):
1039 if self.x is None:
1040 self.x = init_param*2
1041
1042 c = C(init_param=10)
1043 self.assertEqual(c.x, 20)
1044
1045 def test_init_var_inheritance(self):
1046 # Note that this deliberately tests that a dataclass need not
1047 # have a __post_init__ function if it has an InitVar field.
1048 # It could just be used in a derived class, as shown here.
1049 @dataclass
1050 class Base:
1051 x: int
1052 init_base: InitVar[int]
1053
1054 # We can instantiate by passing the InitVar, even though
1055 # it's not used.
1056 b = Base(0, 10)
1057 self.assertEqual(vars(b), {'x': 0})
1058
1059 @dataclass
1060 class C(Base):
1061 y: int
1062 init_derived: InitVar[int]
1063
1064 def __post_init__(self, init_base, init_derived):
1065 self.x = self.x + init_base
1066 self.y = self.y + init_derived
1067
1068 c = C(10, 11, 50, 51)
1069 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1070
1071 def test_default_factory(self):
1072 # Test a factory that returns a new list.
1073 @dataclass
1074 class C:
1075 x: int
1076 y: list = field(default_factory=list)
1077
1078 c0 = C(3)
1079 c1 = C(3)
1080 self.assertEqual(c0.x, 3)
1081 self.assertEqual(c0.y, [])
1082 self.assertEqual(c0, c1)
1083 self.assertIsNot(c0.y, c1.y)
1084 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1085
1086 # Test a factory that returns a shared list.
1087 l = []
1088 @dataclass
1089 class C:
1090 x: int
1091 y: list = field(default_factory=lambda: l)
1092
1093 c0 = C(3)
1094 c1 = C(3)
1095 self.assertEqual(c0.x, 3)
1096 self.assertEqual(c0.y, [])
1097 self.assertEqual(c0, c1)
1098 self.assertIs(c0.y, c1.y)
1099 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1100
1101 # Test various other field flags.
1102 # repr
1103 @dataclass
1104 class C:
1105 x: list = field(default_factory=list, repr=False)
1106 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1107 self.assertEqual(C().x, [])
1108
1109 # hash
1110 @dataclass(hash=True)
1111 class C:
1112 x: list = field(default_factory=list, hash=False)
1113 self.assertEqual(astuple(C()), ([],))
1114 self.assertEqual(hash(C()), hash(()))
1115
1116 # init (see also test_default_factory_with_no_init)
1117 @dataclass
1118 class C:
1119 x: list = field(default_factory=list, init=False)
1120 self.assertEqual(astuple(C()), ([],))
1121
1122 # compare
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=list, compare=False)
1126 self.assertEqual(C(), C([1]))
1127
1128 def test_default_factory_with_no_init(self):
1129 # We need a factory with a side effect.
1130 factory = Mock()
1131
1132 @dataclass
1133 class C:
1134 x: list = field(default_factory=factory, init=False)
1135
1136 # Make sure the default factory is called for each new instance.
1137 C().x
1138 self.assertEqual(factory.call_count, 1)
1139 C().x
1140 self.assertEqual(factory.call_count, 2)
1141
1142 def test_default_factory_not_called_if_value_given(self):
1143 # We need a factory that we can test if it's been called.
1144 factory = Mock()
1145
1146 @dataclass
1147 class C:
1148 x: int = field(default_factory=factory)
1149
1150 # Make sure that if a field has a default factory function,
1151 # it's not called if a value is specified.
1152 C().x
1153 self.assertEqual(factory.call_count, 1)
1154 self.assertEqual(C(10).x, 10)
1155 self.assertEqual(factory.call_count, 1)
1156 C().x
1157 self.assertEqual(factory.call_count, 2)
1158
1159 def x_test_classvar_default_factory(self):
1160 # XXX: it's an error for a ClassVar to have a factory function
1161 @dataclass
1162 class C:
1163 x: ClassVar[int] = field(default_factory=int)
1164
1165 self.assertIs(C().x, int)
1166
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001167 def test_is_dataclass(self):
1168 class NotDataClass:
1169 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001170
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001171 self.assertFalse(is_dataclass(0))
1172 self.assertFalse(is_dataclass(int))
1173 self.assertFalse(is_dataclass(NotDataClass))
1174 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001175
1176 @dataclass
1177 class C:
1178 x: int
1179
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001180 @dataclass
1181 class D:
1182 d: C
1183 e: int
1184
1185 c = C(10)
1186 d = D(c, 4)
1187
1188 self.assertTrue(is_dataclass(C))
1189 self.assertTrue(is_dataclass(c))
1190 self.assertFalse(is_dataclass(c.x))
1191 self.assertTrue(is_dataclass(d.d))
1192 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001193
1194 def test_helper_fields_with_class_instance(self):
1195 # Check that we can call fields() on either a class or instance,
1196 # and get back the same thing.
1197 @dataclass
1198 class C:
1199 x: int
1200 y: float
1201
1202 self.assertEqual(fields(C), fields(C(0, 0.0)))
1203
1204 def test_helper_fields_exception(self):
1205 # Check that TypeError is raised if not passed a dataclass or
1206 # instance.
1207 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1208 fields(0)
1209
1210 class C: pass
1211 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1212 fields(C)
1213 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1214 fields(C())
1215
1216 def test_helper_asdict(self):
1217 # Basic tests for asdict(), it should return a new dictionary
1218 @dataclass
1219 class C:
1220 x: int
1221 y: int
1222 c = C(1, 2)
1223
1224 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1225 self.assertEqual(asdict(c), asdict(c))
1226 self.assertIsNot(asdict(c), asdict(c))
1227 c.x = 42
1228 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1229 self.assertIs(type(asdict(c)), dict)
1230
1231 def test_helper_asdict_raises_on_classes(self):
1232 # asdict() should raise on a class object
1233 @dataclass
1234 class C:
1235 x: int
1236 y: int
1237 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1238 asdict(C)
1239 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1240 asdict(int)
1241
1242 def test_helper_asdict_copy_values(self):
1243 @dataclass
1244 class C:
1245 x: int
1246 y: List[int] = field(default_factory=list)
1247 initial = []
1248 c = C(1, initial)
1249 d = asdict(c)
1250 self.assertEqual(d['y'], initial)
1251 self.assertIsNot(d['y'], initial)
1252 c = C(1)
1253 d = asdict(c)
1254 d['y'].append(1)
1255 self.assertEqual(c.y, [])
1256
1257 def test_helper_asdict_nested(self):
1258 @dataclass
1259 class UserId:
1260 token: int
1261 group: int
1262 @dataclass
1263 class User:
1264 name: str
1265 id: UserId
1266 u = User('Joe', UserId(123, 1))
1267 d = asdict(u)
1268 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1269 self.assertIsNot(asdict(u), asdict(u))
1270 u.id.group = 2
1271 self.assertEqual(asdict(u), {'name': 'Joe',
1272 'id': {'token': 123, 'group': 2}})
1273
1274 def test_helper_asdict_builtin_containers(self):
1275 @dataclass
1276 class User:
1277 name: str
1278 id: int
1279 @dataclass
1280 class GroupList:
1281 id: int
1282 users: List[User]
1283 @dataclass
1284 class GroupTuple:
1285 id: int
1286 users: Tuple[User, ...]
1287 @dataclass
1288 class GroupDict:
1289 id: int
1290 users: Dict[str, User]
1291 a = User('Alice', 1)
1292 b = User('Bob', 2)
1293 gl = GroupList(0, [a, b])
1294 gt = GroupTuple(0, (a, b))
1295 gd = GroupDict(0, {'first': a, 'second': b})
1296 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1297 {'name': 'Bob', 'id': 2}]})
1298 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1299 {'name': 'Bob', 'id': 2})})
1300 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1301 'second': {'name': 'Bob', 'id': 2}}})
1302
1303 def test_helper_asdict_builtin_containers(self):
1304 @dataclass
1305 class Child:
1306 d: object
1307
1308 @dataclass
1309 class Parent:
1310 child: Child
1311
1312 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1313 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1314
1315 def test_helper_asdict_factory(self):
1316 @dataclass
1317 class C:
1318 x: int
1319 y: int
1320 c = C(1, 2)
1321 d = asdict(c, dict_factory=OrderedDict)
1322 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1323 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1324 c.x = 42
1325 d = asdict(c, dict_factory=OrderedDict)
1326 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1327 self.assertIs(type(d), OrderedDict)
1328
1329 def test_helper_astuple(self):
1330 # Basic tests for astuple(), it should return a new tuple
1331 @dataclass
1332 class C:
1333 x: int
1334 y: int = 0
1335 c = C(1)
1336
1337 self.assertEqual(astuple(c), (1, 0))
1338 self.assertEqual(astuple(c), astuple(c))
1339 self.assertIsNot(astuple(c), astuple(c))
1340 c.y = 42
1341 self.assertEqual(astuple(c), (1, 42))
1342 self.assertIs(type(astuple(c)), tuple)
1343
1344 def test_helper_astuple_raises_on_classes(self):
1345 # astuple() should raise on a class object
1346 @dataclass
1347 class C:
1348 x: int
1349 y: int
1350 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1351 astuple(C)
1352 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1353 astuple(int)
1354
1355 def test_helper_astuple_copy_values(self):
1356 @dataclass
1357 class C:
1358 x: int
1359 y: List[int] = field(default_factory=list)
1360 initial = []
1361 c = C(1, initial)
1362 t = astuple(c)
1363 self.assertEqual(t[1], initial)
1364 self.assertIsNot(t[1], initial)
1365 c = C(1)
1366 t = astuple(c)
1367 t[1].append(1)
1368 self.assertEqual(c.y, [])
1369
1370 def test_helper_astuple_nested(self):
1371 @dataclass
1372 class UserId:
1373 token: int
1374 group: int
1375 @dataclass
1376 class User:
1377 name: str
1378 id: UserId
1379 u = User('Joe', UserId(123, 1))
1380 t = astuple(u)
1381 self.assertEqual(t, ('Joe', (123, 1)))
1382 self.assertIsNot(astuple(u), astuple(u))
1383 u.id.group = 2
1384 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1385
1386 def test_helper_astuple_builtin_containers(self):
1387 @dataclass
1388 class User:
1389 name: str
1390 id: int
1391 @dataclass
1392 class GroupList:
1393 id: int
1394 users: List[User]
1395 @dataclass
1396 class GroupTuple:
1397 id: int
1398 users: Tuple[User, ...]
1399 @dataclass
1400 class GroupDict:
1401 id: int
1402 users: Dict[str, User]
1403 a = User('Alice', 1)
1404 b = User('Bob', 2)
1405 gl = GroupList(0, [a, b])
1406 gt = GroupTuple(0, (a, b))
1407 gd = GroupDict(0, {'first': a, 'second': b})
1408 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1409 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1410 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1411
1412 def test_helper_astuple_builtin_containers(self):
1413 @dataclass
1414 class Child:
1415 d: object
1416
1417 @dataclass
1418 class Parent:
1419 child: Child
1420
1421 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1422 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1423
1424 def test_helper_astuple_factory(self):
1425 @dataclass
1426 class C:
1427 x: int
1428 y: int
1429 NT = namedtuple('NT', 'x y')
1430 def nt(lst):
1431 return NT(*lst)
1432 c = C(1, 2)
1433 t = astuple(c, tuple_factory=nt)
1434 self.assertEqual(t, NT(1, 2))
1435 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1436 c.x = 42
1437 t = astuple(c, tuple_factory=nt)
1438 self.assertEqual(t, NT(42, 2))
1439 self.assertIs(type(t), NT)
1440
1441 def test_dynamic_class_creation(self):
1442 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1443 }
1444
1445 # Create the class.
1446 cls = type('C', (), cls_dict)
1447
1448 # Make it a dataclass.
1449 cls1 = dataclass(cls)
1450
1451 self.assertEqual(cls1, cls)
1452 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1453
1454 def test_dynamic_class_creation_using_field(self):
1455 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1456 'y': field(default=5),
1457 }
1458
1459 # Create the class.
1460 cls = type('C', (), cls_dict)
1461
1462 # Make it a dataclass.
1463 cls1 = dataclass(cls)
1464
1465 self.assertEqual(cls1, cls)
1466 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1467
1468 def test_init_in_order(self):
1469 @dataclass
1470 class C:
1471 a: int
1472 b: int = field()
1473 c: list = field(default_factory=list, init=False)
1474 d: list = field(default_factory=list)
1475 e: int = field(default=4, init=False)
1476 f: int = 4
1477
1478 calls = []
1479 def setattr(self, name, value):
1480 calls.append((name, value))
1481
1482 C.__setattr__ = setattr
1483 c = C(0, 1)
1484 self.assertEqual(('a', 0), calls[0])
1485 self.assertEqual(('b', 1), calls[1])
1486 self.assertEqual(('c', []), calls[2])
1487 self.assertEqual(('d', []), calls[3])
1488 self.assertNotIn(('e', 4), calls)
1489 self.assertEqual(('f', 4), calls[4])
1490
1491 def test_items_in_dicts(self):
1492 @dataclass
1493 class C:
1494 a: int
1495 b: list = field(default_factory=list, init=False)
1496 c: list = field(default_factory=list)
1497 d: int = field(default=4, init=False)
1498 e: int = 0
1499
1500 c = C(0)
1501 # Class dict
1502 self.assertNotIn('a', C.__dict__)
1503 self.assertNotIn('b', C.__dict__)
1504 self.assertNotIn('c', C.__dict__)
1505 self.assertIn('d', C.__dict__)
1506 self.assertEqual(C.d, 4)
1507 self.assertIn('e', C.__dict__)
1508 self.assertEqual(C.e, 0)
1509 # Instance dict
1510 self.assertIn('a', c.__dict__)
1511 self.assertEqual(c.a, 0)
1512 self.assertIn('b', c.__dict__)
1513 self.assertEqual(c.b, [])
1514 self.assertIn('c', c.__dict__)
1515 self.assertEqual(c.c, [])
1516 self.assertNotIn('d', c.__dict__)
1517 self.assertIn('e', c.__dict__)
1518 self.assertEqual(c.e, 0)
1519
1520 def test_alternate_classmethod_constructor(self):
1521 # Since __post_init__ can't take params, use a classmethod
1522 # alternate constructor. This is mostly an example to show how
1523 # to use this technique.
1524 @dataclass
1525 class C:
1526 x: int
1527 @classmethod
1528 def from_file(cls, filename):
1529 # In a real example, create a new instance
1530 # and populate 'x' from contents of a file.
1531 value_in_file = 20
1532 return cls(value_in_file)
1533
1534 self.assertEqual(C.from_file('filename').x, 20)
1535
1536 def test_field_metadata_default(self):
1537 # Make sure the default metadata is read-only and of
1538 # zero length.
1539 @dataclass
1540 class C:
1541 i: int
1542
1543 self.assertFalse(fields(C)[0].metadata)
1544 self.assertEqual(len(fields(C)[0].metadata), 0)
1545 with self.assertRaisesRegex(TypeError,
1546 'does not support item assignment'):
1547 fields(C)[0].metadata['test'] = 3
1548
1549 def test_field_metadata_mapping(self):
1550 # Make sure only a mapping can be passed as metadata
1551 # zero length.
1552 with self.assertRaises(TypeError):
1553 @dataclass
1554 class C:
1555 i: int = field(metadata=0)
1556
1557 # Make sure an empty dict works
1558 @dataclass
1559 class C:
1560 i: int = field(metadata={})
1561 self.assertFalse(fields(C)[0].metadata)
1562 self.assertEqual(len(fields(C)[0].metadata), 0)
1563 with self.assertRaisesRegex(TypeError,
1564 'does not support item assignment'):
1565 fields(C)[0].metadata['test'] = 3
1566
1567 # Make sure a non-empty dict works.
1568 @dataclass
1569 class C:
1570 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1571 self.assertEqual(len(fields(C)[0].metadata), 3)
1572 self.assertEqual(fields(C)[0].metadata['test'], 10)
1573 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1574 self.assertEqual(fields(C)[0].metadata[3], 'three')
1575 with self.assertRaises(KeyError):
1576 # Non-existent key.
1577 fields(C)[0].metadata['baz']
1578 with self.assertRaisesRegex(TypeError,
1579 'does not support item assignment'):
1580 fields(C)[0].metadata['test'] = 3
1581
1582 def test_field_metadata_custom_mapping(self):
1583 # Try a custom mapping.
1584 class SimpleNameSpace:
1585 def __init__(self, **kw):
1586 self.__dict__.update(kw)
1587
1588 def __getitem__(self, item):
1589 if item == 'xyzzy':
1590 return 'plugh'
1591 return getattr(self, item)
1592
1593 def __len__(self):
1594 return self.__dict__.__len__()
1595
1596 @dataclass
1597 class C:
1598 i: int = field(metadata=SimpleNameSpace(a=10))
1599
1600 self.assertEqual(len(fields(C)[0].metadata), 1)
1601 self.assertEqual(fields(C)[0].metadata['a'], 10)
1602 with self.assertRaises(AttributeError):
1603 fields(C)[0].metadata['b']
1604 # Make sure we're still talking to our custom mapping.
1605 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1606
1607 def test_generic_dataclasses(self):
1608 T = TypeVar('T')
1609
1610 @dataclass
1611 class LabeledBox(Generic[T]):
1612 content: T
1613 label: str = '<unknown>'
1614
1615 box = LabeledBox(42)
1616 self.assertEqual(box.content, 42)
1617 self.assertEqual(box.label, '<unknown>')
1618
1619 # subscripting the resulting class should work, etc.
1620 Alias = List[LabeledBox[int]]
1621
1622 def test_generic_extending(self):
1623 S = TypeVar('S')
1624 T = TypeVar('T')
1625
1626 @dataclass
1627 class Base(Generic[T, S]):
1628 x: T
1629 y: S
1630
1631 @dataclass
1632 class DataDerived(Base[int, T]):
1633 new_field: str
1634 Alias = DataDerived[str]
1635 c = Alias(0, 'test1', 'test2')
1636 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1637
1638 class NonDataDerived(Base[int, T]):
1639 def new_method(self):
1640 return self.y
1641 Alias = NonDataDerived[float]
1642 c = Alias(10, 1.0)
1643 self.assertEqual(c.new_method(), 1.0)
1644
1645 def test_helper_replace(self):
1646 @dataclass(frozen=True)
1647 class C:
1648 x: int
1649 y: int
1650
1651 c = C(1, 2)
1652 c1 = replace(c, x=3)
1653 self.assertEqual(c1.x, 3)
1654 self.assertEqual(c1.y, 2)
1655
1656 def test_helper_replace_frozen(self):
1657 @dataclass(frozen=True)
1658 class C:
1659 x: int
1660 y: int
1661 z: int = field(init=False, default=10)
1662 t: int = field(init=False, default=100)
1663
1664 c = C(1, 2)
1665 c1 = replace(c, x=3)
1666 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1667 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1668
1669
1670 with self.assertRaisesRegex(ValueError, 'init=False'):
1671 replace(c, x=3, z=20, t=50)
1672 with self.assertRaisesRegex(ValueError, 'init=False'):
1673 replace(c, z=20)
1674 replace(c, x=3, z=20, t=50)
1675
1676 # Make sure the result is still frozen.
1677 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1678 c1.x = 3
1679
1680 # Make sure we can't replace an attribute that doesn't exist,
1681 # if we're also replacing one that does exist. Test this
1682 # here, because setting attributes on frozen instances is
1683 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001684 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001685 "keyword argument 'a'"):
1686 c1 = replace(c, x=20, a=5)
1687
1688 def test_helper_replace_invalid_field_name(self):
1689 @dataclass(frozen=True)
1690 class C:
1691 x: int
1692 y: int
1693
1694 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001695 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001696 "keyword argument 'z'"):
1697 c1 = replace(c, z=3)
1698
1699 def test_helper_replace_invalid_object(self):
1700 @dataclass(frozen=True)
1701 class C:
1702 x: int
1703 y: int
1704
1705 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1706 replace(C, x=3)
1707
1708 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1709 replace(0, x=3)
1710
1711 def test_helper_replace_no_init(self):
1712 @dataclass
1713 class C:
1714 x: int
1715 y: int = field(init=False, default=10)
1716
1717 c = C(1)
1718 c.y = 20
1719
1720 # Make sure y gets the default value.
1721 c1 = replace(c, x=5)
1722 self.assertEqual((c1.x, c1.y), (5, 10))
1723
1724 # Trying to replace y is an error.
1725 with self.assertRaisesRegex(ValueError, 'init=False'):
1726 replace(c, x=2, y=30)
1727 with self.assertRaisesRegex(ValueError, 'init=False'):
1728 replace(c, y=30)
1729
1730 def test_dataclassses_pickleable(self):
1731 global P, Q, R
1732 @dataclass
1733 class P:
1734 x: int
1735 y: int = 0
1736 @dataclass
1737 class Q:
1738 x: int
1739 y: int = field(default=0, init=False)
1740 @dataclass
1741 class R:
1742 x: int
1743 y: List[int] = field(default_factory=list)
1744 q = Q(1)
1745 q.y = 2
1746 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1747 for sample in samples:
1748 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1749 with self.subTest(sample=sample, proto=proto):
1750 new_sample = pickle.loads(pickle.dumps(sample, proto))
1751 self.assertEqual(sample.x, new_sample.x)
1752 self.assertEqual(sample.y, new_sample.y)
1753 self.assertIsNot(sample, new_sample)
1754 new_sample.x = 42
1755 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1756 self.assertEqual(new_sample.x, another_new_sample.x)
1757 self.assertEqual(sample.y, another_new_sample.y)
1758
1759 def test_helper_make_dataclass(self):
1760 C = make_dataclass('C',
1761 [('x', int),
1762 ('y', int, field(default=5))],
1763 namespace={'add_one': lambda self: self.x + 1})
1764 c = C(10)
1765 self.assertEqual((c.x, c.y), (10, 5))
1766 self.assertEqual(c.add_one(), 11)
1767
1768
1769 def test_helper_make_dataclass_no_mutate_namespace(self):
1770 # Make sure a provided namespace isn't mutated.
1771 ns = {}
1772 C = make_dataclass('C',
1773 [('x', int),
1774 ('y', int, field(default=5))],
1775 namespace=ns)
1776 self.assertEqual(ns, {})
1777
1778 def test_helper_make_dataclass_base(self):
1779 class Base1:
1780 pass
1781 class Base2:
1782 pass
1783 C = make_dataclass('C',
1784 [('x', int)],
1785 bases=(Base1, Base2))
1786 c = C(2)
1787 self.assertIsInstance(c, C)
1788 self.assertIsInstance(c, Base1)
1789 self.assertIsInstance(c, Base2)
1790
1791 def test_helper_make_dataclass_base_dataclass(self):
1792 @dataclass
1793 class Base1:
1794 x: int
1795 class Base2:
1796 pass
1797 C = make_dataclass('C',
1798 [('y', int)],
1799 bases=(Base1, Base2))
1800 with self.assertRaisesRegex(TypeError, 'required positional'):
1801 c = C(2)
1802 c = C(1, 2)
1803 self.assertIsInstance(c, C)
1804 self.assertIsInstance(c, Base1)
1805 self.assertIsInstance(c, Base2)
1806
1807 self.assertEqual((c.x, c.y), (1, 2))
1808
1809 def test_helper_make_dataclass_init_var(self):
1810 def post_init(self, y):
1811 self.x *= y
1812
1813 C = make_dataclass('C',
1814 [('x', int),
1815 ('y', InitVar[int]),
1816 ],
1817 namespace={'__post_init__': post_init},
1818 )
1819 c = C(2, 3)
1820 self.assertEqual(vars(c), {'x': 6})
1821 self.assertEqual(len(fields(c)), 1)
1822
1823 def test_helper_make_dataclass_class_var(self):
1824 C = make_dataclass('C',
1825 [('x', int),
1826 ('y', ClassVar[int], 10),
1827 ('z', ClassVar[int], field(default=20)),
1828 ])
1829 c = C(1)
1830 self.assertEqual(vars(c), {'x': 1})
1831 self.assertEqual(len(fields(c)), 1)
1832 self.assertEqual(C.y, 10)
1833 self.assertEqual(C.z, 20)
1834
Eric V. Smithd80b4432018-01-06 17:09:58 -05001835 def test_helper_make_dataclass_other_params(self):
1836 C = make_dataclass('C',
1837 [('x', int),
1838 ('y', ClassVar[int], 10),
1839 ('z', ClassVar[int], field(default=20)),
1840 ],
1841 init=False)
1842 # Make sure we have a repr, but no init.
1843 self.assertNotIn('__init__', vars(C))
1844 self.assertIn('__repr__', vars(C))
1845
1846 # Make sure random other params don't work.
1847 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1848 C = make_dataclass('C',
1849 [],
1850 xxinit=False)
1851
Eric V. Smithed7d4292018-01-06 16:14:03 -05001852 def test_helper_make_dataclass_no_types(self):
1853 C = make_dataclass('Point', ['x', 'y', 'z'])
1854 c = C(1, 2, 3)
1855 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1856 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1857 'y': 'typing.Any',
1858 'z': 'typing.Any'})
1859
1860 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1861 c = C(1, 2, 3)
1862 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1863 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1864 'y': int,
1865 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001866
Eric V. Smithea8fc522018-01-27 19:07:40 -05001867
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001868class TestDocString(unittest.TestCase):
1869 def assertDocStrEqual(self, a, b):
1870 # Because 3.6 and 3.7 differ in how inspect.signature work
1871 # (see bpo #32108), for the time being just compare them with
1872 # whitespace stripped.
1873 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1874
1875 def test_existing_docstring_not_overridden(self):
1876 @dataclass
1877 class C:
1878 """Lorem ipsum"""
1879 x: int
1880
1881 self.assertEqual(C.__doc__, "Lorem ipsum")
1882
1883 def test_docstring_no_fields(self):
1884 @dataclass
1885 class C:
1886 pass
1887
1888 self.assertDocStrEqual(C.__doc__, "C()")
1889
1890 def test_docstring_one_field(self):
1891 @dataclass
1892 class C:
1893 x: int
1894
1895 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1896
1897 def test_docstring_two_fields(self):
1898 @dataclass
1899 class C:
1900 x: int
1901 y: int
1902
1903 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1904
1905 def test_docstring_three_fields(self):
1906 @dataclass
1907 class C:
1908 x: int
1909 y: int
1910 z: str
1911
1912 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1913
1914 def test_docstring_one_field_with_default(self):
1915 @dataclass
1916 class C:
1917 x: int = 3
1918
1919 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1920
1921 def test_docstring_one_field_with_default_none(self):
1922 @dataclass
1923 class C:
1924 x: Union[int, type(None)] = None
1925
1926 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1927
1928 def test_docstring_list_field(self):
1929 @dataclass
1930 class C:
1931 x: List[int]
1932
1933 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1934
1935 def test_docstring_list_field_with_default_factory(self):
1936 @dataclass
1937 class C:
1938 x: List[int] = field(default_factory=list)
1939
1940 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1941
1942 def test_docstring_deque_field(self):
1943 @dataclass
1944 class C:
1945 x: deque
1946
1947 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1948
1949 def test_docstring_deque_field_with_default_factory(self):
1950 @dataclass
1951 class C:
1952 x: deque = field(default_factory=deque)
1953
1954 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1955
1956
Eric V. Smithea8fc522018-01-27 19:07:40 -05001957class TestInit(unittest.TestCase):
1958 def test_base_has_init(self):
1959 class B:
1960 def __init__(self):
1961 self.z = 100
1962 pass
1963
1964 # Make sure that declaring this class doesn't raise an error.
1965 # The issue is that we can't override __init__ in our class,
1966 # but it should be okay to add __init__ to us if our base has
1967 # an __init__.
1968 @dataclass
1969 class C(B):
1970 x: int = 0
1971 c = C(10)
1972 self.assertEqual(c.x, 10)
1973 self.assertNotIn('z', vars(c))
1974
1975 # Make sure that if we don't add an init, the base __init__
1976 # gets called.
1977 @dataclass(init=False)
1978 class C(B):
1979 x: int = 10
1980 c = C()
1981 self.assertEqual(c.x, 10)
1982 self.assertEqual(c.z, 100)
1983
1984 def test_no_init(self):
1985 dataclass(init=False)
1986 class C:
1987 i: int = 0
1988 self.assertEqual(C().i, 0)
1989
1990 dataclass(init=False)
1991 class C:
1992 i: int = 2
1993 def __init__(self):
1994 self.i = 3
1995 self.assertEqual(C().i, 3)
1996
1997 def test_overwriting_init(self):
1998 # If the class has __init__, use it no matter the value of
1999 # init=.
2000
2001 @dataclass
2002 class C:
2003 x: int
2004 def __init__(self, x):
2005 self.x = 2 * x
2006 self.assertEqual(C(3).x, 6)
2007
2008 @dataclass(init=True)
2009 class C:
2010 x: int
2011 def __init__(self, x):
2012 self.x = 2 * x
2013 self.assertEqual(C(4).x, 8)
2014
2015 @dataclass(init=False)
2016 class C:
2017 x: int
2018 def __init__(self, x):
2019 self.x = 2 * x
2020 self.assertEqual(C(5).x, 10)
2021
2022
2023class TestRepr(unittest.TestCase):
2024 def test_repr(self):
2025 @dataclass
2026 class B:
2027 x: int
2028
2029 @dataclass
2030 class C(B):
2031 y: int = 10
2032
2033 o = C(4)
2034 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2035
2036 @dataclass
2037 class D(C):
2038 x: int = 20
2039 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2040
2041 @dataclass
2042 class C:
2043 @dataclass
2044 class D:
2045 i: int
2046 @dataclass
2047 class E:
2048 pass
2049 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2050 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2051
2052 def test_no_repr(self):
2053 # Test a class with no __repr__ and repr=False.
2054 @dataclass(repr=False)
2055 class C:
2056 x: int
2057 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2058 repr(C(3)))
2059
2060 # Test a class with a __repr__ and repr=False.
2061 @dataclass(repr=False)
2062 class C:
2063 x: int
2064 def __repr__(self):
2065 return 'C-class'
2066 self.assertEqual(repr(C(3)), 'C-class')
2067
2068 def test_overwriting_repr(self):
2069 # If the class has __repr__, use it no matter the value of
2070 # repr=.
2071
2072 @dataclass
2073 class C:
2074 x: int
2075 def __repr__(self):
2076 return 'x'
2077 self.assertEqual(repr(C(0)), 'x')
2078
2079 @dataclass(repr=True)
2080 class C:
2081 x: int
2082 def __repr__(self):
2083 return 'x'
2084 self.assertEqual(repr(C(0)), 'x')
2085
2086 @dataclass(repr=False)
2087 class C:
2088 x: int
2089 def __repr__(self):
2090 return 'x'
2091 self.assertEqual(repr(C(0)), 'x')
2092
2093
2094class TestFrozen(unittest.TestCase):
2095 def test_overwriting_frozen(self):
2096 # frozen uses __setattr__ and __delattr__
2097 with self.assertRaisesRegex(TypeError,
2098 'Cannot overwrite attribute __setattr__'):
2099 @dataclass(frozen=True)
2100 class C:
2101 x: int
2102 def __setattr__(self):
2103 pass
2104
2105 with self.assertRaisesRegex(TypeError,
2106 'Cannot overwrite attribute __delattr__'):
2107 @dataclass(frozen=True)
2108 class C:
2109 x: int
2110 def __delattr__(self):
2111 pass
2112
2113 @dataclass(frozen=False)
2114 class C:
2115 x: int
2116 def __setattr__(self, name, value):
2117 self.__dict__['x'] = value * 2
2118 self.assertEqual(C(10).x, 20)
2119
2120
2121class TestEq(unittest.TestCase):
2122 def test_no_eq(self):
2123 # Test a class with no __eq__ and eq=False.
2124 @dataclass(eq=False)
2125 class C:
2126 x: int
2127 self.assertNotEqual(C(0), C(0))
2128 c = C(3)
2129 self.assertEqual(c, c)
2130
2131 # Test a class with an __eq__ and eq=False.
2132 @dataclass(eq=False)
2133 class C:
2134 x: int
2135 def __eq__(self, other):
2136 return other == 10
2137 self.assertEqual(C(3), 10)
2138
2139 def test_overwriting_eq(self):
2140 # If the class has __eq__, use it no matter the value of
2141 # eq=.
2142
2143 @dataclass
2144 class C:
2145 x: int
2146 def __eq__(self, other):
2147 return other == 3
2148 self.assertEqual(C(1), 3)
2149 self.assertNotEqual(C(1), 1)
2150
2151 @dataclass(eq=True)
2152 class C:
2153 x: int
2154 def __eq__(self, other):
2155 return other == 4
2156 self.assertEqual(C(1), 4)
2157 self.assertNotEqual(C(1), 1)
2158
2159 @dataclass(eq=False)
2160 class C:
2161 x: int
2162 def __eq__(self, other):
2163 return other == 5
2164 self.assertEqual(C(1), 5)
2165 self.assertNotEqual(C(1), 1)
2166
2167
2168class TestOrdering(unittest.TestCase):
2169 def test_functools_total_ordering(self):
2170 # Test that functools.total_ordering works with this class.
2171 @total_ordering
2172 @dataclass
2173 class C:
2174 x: int
2175 def __lt__(self, other):
2176 # Perform the test "backward", just to make
2177 # sure this is being called.
2178 return self.x >= other
2179
2180 self.assertLess(C(0), -1)
2181 self.assertLessEqual(C(0), -1)
2182 self.assertGreater(C(0), 1)
2183 self.assertGreaterEqual(C(0), 1)
2184
2185 def test_no_order(self):
2186 # Test that no ordering functions are added by default.
2187 @dataclass(order=False)
2188 class C:
2189 x: int
2190 # Make sure no order methods are added.
2191 self.assertNotIn('__le__', C.__dict__)
2192 self.assertNotIn('__lt__', C.__dict__)
2193 self.assertNotIn('__ge__', C.__dict__)
2194 self.assertNotIn('__gt__', C.__dict__)
2195
2196 # Test that __lt__ is still called
2197 @dataclass(order=False)
2198 class C:
2199 x: int
2200 def __lt__(self, other):
2201 return False
2202 # Make sure other methods aren't added.
2203 self.assertNotIn('__le__', C.__dict__)
2204 self.assertNotIn('__ge__', C.__dict__)
2205 self.assertNotIn('__gt__', C.__dict__)
2206
2207 def test_overwriting_order(self):
2208 with self.assertRaisesRegex(TypeError,
2209 'Cannot overwrite attribute __lt__'
2210 '.*using functools.total_ordering'):
2211 @dataclass(order=True)
2212 class C:
2213 x: int
2214 def __lt__(self):
2215 pass
2216
2217 with self.assertRaisesRegex(TypeError,
2218 'Cannot overwrite attribute __le__'
2219 '.*using functools.total_ordering'):
2220 @dataclass(order=True)
2221 class C:
2222 x: int
2223 def __le__(self):
2224 pass
2225
2226 with self.assertRaisesRegex(TypeError,
2227 'Cannot overwrite attribute __gt__'
2228 '.*using functools.total_ordering'):
2229 @dataclass(order=True)
2230 class C:
2231 x: int
2232 def __gt__(self):
2233 pass
2234
2235 with self.assertRaisesRegex(TypeError,
2236 'Cannot overwrite attribute __ge__'
2237 '.*using functools.total_ordering'):
2238 @dataclass(order=True)
2239 class C:
2240 x: int
2241 def __ge__(self):
2242 pass
2243
2244class TestHash(unittest.TestCase):
2245 def test_hash(self):
2246 @dataclass(hash=True)
2247 class C:
2248 x: int
2249 y: str
2250 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2251
2252 def test_hash_false(self):
2253 @dataclass(hash=False)
2254 class C:
2255 x: int
2256 y: str
2257 self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2258
2259 def test_hash_none(self):
2260 @dataclass(hash=None)
2261 class C:
2262 x: int
2263 with self.assertRaisesRegex(TypeError,
2264 "unhashable type: 'C'"):
2265 hash(C(1))
2266
2267 def test_hash_rules(self):
2268 def non_bool(value):
2269 # Map to something else that's True, but not a bool.
2270 if value is None:
2271 return None
2272 if value:
2273 return (3,)
2274 return 0
2275
2276 def test(case, hash, eq, frozen, with_hash, result):
2277 with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
2278 if with_hash:
2279 @dataclass(hash=hash, eq=eq, frozen=frozen)
2280 class C:
2281 def __hash__(self):
2282 return 0
2283 else:
2284 @dataclass(hash=hash, eq=eq, frozen=frozen)
2285 class C:
2286 pass
2287
2288 # See if the result matches what's expected.
2289 if result in ('fn', 'fn-x'):
2290 # __hash__ contains the function we generated.
2291 self.assertIn('__hash__', C.__dict__)
2292 self.assertIsNotNone(C.__dict__['__hash__'])
2293
2294 if result == 'fn-x':
2295 # This is the "auto-hash test" case. We
2296 # should overwrite __hash__ iff there's an
2297 # __eq__ and if __hash__=None.
2298
2299 # There are two ways of getting __hash__=None:
2300 # explicitely, and by defining __eq__. If
2301 # __eq__ is defined, python will add __hash__
2302 # when the class is created.
2303 @dataclass(hash=hash, eq=eq, frozen=frozen)
2304 class C:
2305 def __eq__(self, other): pass
2306 __hash__ = None
2307
2308 # Hash should be overwritten (non-None).
2309 self.assertIsNotNone(C.__dict__['__hash__'])
2310
2311 # Same test as above, but we don't provide
2312 # __hash__, it will implicitely set to None.
2313 @dataclass(hash=hash, eq=eq, frozen=frozen)
2314 class C:
2315 def __eq__(self, other): pass
2316
2317 # Hash should be overwritten (non-None).
2318 self.assertIsNotNone(C.__dict__['__hash__'])
2319
2320 elif result == '':
2321 # __hash__ is not present in our class.
2322 if not with_hash:
2323 self.assertNotIn('__hash__', C.__dict__)
2324 elif result == 'none':
2325 # __hash__ is set to None.
2326 self.assertIn('__hash__', C.__dict__)
2327 self.assertIsNone(C.__dict__['__hash__'])
2328 else:
2329 assert False, f'unknown result {result!r}'
2330
2331 # There are 12 cases of:
2332 # hash=True/False/None
2333 # eq=True/False
2334 # frozen=True/False
2335 # And for each of these, a different result if
2336 # __hash__ is defined or not.
2337 for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
2338 (None, False, False, '', ''),
2339 (None, False, True, '', ''),
2340 (None, True, False, 'none', ''),
2341 (None, True, True, 'fn', 'fn-x'),
2342 (False, False, False, '', ''),
2343 (False, False, True, '', ''),
2344 (False, True, False, '', ''),
2345 (False, True, True, '', ''),
2346 (True, False, False, 'fn', 'fn-x'),
2347 (True, False, True, 'fn', 'fn-x'),
2348 (True, True, False, 'fn', 'fn-x'),
2349 (True, True, True, 'fn', 'fn-x'),
2350 ], 1):
2351 test(case, hash, eq, frozen, False, result_no)
2352 test(case, hash, eq, frozen, True, result_yes)
2353
2354 # Test non-bool truth values, too. This is just to
2355 # make sure the data-driven table in the decorator
2356 # handles non-bool values.
2357 test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
2358 test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
2359
2360
2361 def test_eq_only(self):
2362 # If a class defines __eq__, __hash__ is automatically added
2363 # and set to None. This is normal Python behavior, not
2364 # related to dataclasses. Make sure we don't interfere with
2365 # that (see bpo=32546).
2366
2367 @dataclass
2368 class C:
2369 i: int
2370 def __eq__(self, other):
2371 return self.i == other.i
2372 self.assertEqual(C(1), C(1))
2373 self.assertNotEqual(C(1), C(4))
2374
2375 # And make sure things work in this case if we specify
2376 # hash=True.
2377 @dataclass(hash=True)
2378 class C:
2379 i: int
2380 def __eq__(self, other):
2381 return self.i == other.i
2382 self.assertEqual(C(1), C(1.0))
2383 self.assertEqual(hash(C(1)), hash(C(1.0)))
2384
2385 # And check that the classes __eq__ is being used, despite
2386 # specifying eq=True.
2387 @dataclass(hash=True, eq=True)
2388 class C:
2389 i: int
2390 def __eq__(self, other):
2391 return self.i == 3 and self.i == other.i
2392 self.assertEqual(C(3), C(3))
2393 self.assertNotEqual(C(1), C(1))
2394 self.assertEqual(hash(C(1)), hash(C(1.0)))
2395
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002396 def test_hash_no_args(self):
2397 # Test dataclasses with no hash= argument. This exists to
2398 # make sure that when hash is changed, the default hashability
2399 # keeps working.
2400
2401 class Base:
2402 def __hash__(self):
2403 return 301
2404
2405 # If frozen or eq is None, then use the default value (do not
2406 # specify any value in the deecorator).
2407 for frozen, eq, base, expected in [
2408 (None, None, object, 'unhashable'),
2409 (None, None, Base, 'unhashable'),
2410 (None, False, object, 'object'),
2411 (None, False, Base, 'base'),
2412 (None, True, object, 'unhashable'),
2413 (None, True, Base, 'unhashable'),
2414 (False, None, object, 'unhashable'),
2415 (False, None, Base, 'unhashable'),
2416 (False, False, object, 'object'),
2417 (False, False, Base, 'base'),
2418 (False, True, object, 'unhashable'),
2419 (False, True, Base, 'unhashable'),
2420 (True, None, object, 'tuple'),
2421 (True, None, Base, 'tuple'),
2422 (True, False, object, 'object'),
2423 (True, False, Base, 'base'),
2424 (True, True, object, 'tuple'),
2425 (True, True, Base, 'tuple'),
2426 ]:
2427
2428 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2429 # First, create the class.
2430 if frozen is None and eq is None:
2431 @dataclass
2432 class C(base):
2433 i: int
2434 elif frozen is None:
2435 @dataclass(eq=eq)
2436 class C(base):
2437 i: int
2438 elif eq is None:
2439 @dataclass(frozen=frozen)
2440 class C(base):
2441 i: int
2442 else:
2443 @dataclass(frozen=frozen, eq=eq)
2444 class C(base):
2445 i: int
2446
2447 # Now, make sure it hashes as expected.
2448 if expected == 'unhashable':
2449 c = C(10)
2450 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2451 hash(c)
2452
2453 elif expected == 'base':
2454 self.assertEqual(hash(C(10)), 301)
2455
2456 elif expected == 'object':
2457 # I'm not sure what test to use here. object's
2458 # hash isn't based on id(), so calling hash()
2459 # won't tell us much. So, just check the function
2460 # used is object's.
2461 self.assertIs(C.__hash__, object.__hash__)
2462
2463 elif expected == 'tuple':
2464 self.assertEqual(hash(C(42)), hash((42,)))
2465
2466 else:
2467 assert False, f'unknown value for expected={expected!r}'
2468
Eric V. Smithea8fc522018-01-27 19:07:40 -05002469
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002470if __name__ == '__main__':
2471 unittest.main()