blob: 736bc490867b0c01a9cd4e56429c8d44258d8f5d [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
Eric V. Smithe7ba0132018-01-06 12:41:53 -05003 make_dataclass, replace, InitVar, Field, MISSING, is_dataclass,
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050012from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013
14# Just any custom exception we can catch.
15class CustomError(Exception): pass
16
17class TestCase(unittest.TestCase):
18 def test_no_fields(self):
19 @dataclass
20 class C:
21 pass
22
23 o = C()
24 self.assertEqual(len(fields(C)), 0)
25
26 def test_one_field_no_default(self):
27 @dataclass
28 class C:
29 x: int
30
31 o = C(42)
32 self.assertEqual(o.x, 42)
33
34 def test_named_init_params(self):
35 @dataclass
36 class C:
37 x: int
38
39 o = C(x=32)
40 self.assertEqual(o.x, 32)
41
42 def test_two_fields_one_default(self):
43 @dataclass
44 class C:
45 x: int
46 y: int = 0
47
48 o = C(3)
49 self.assertEqual((o.x, o.y), (3, 0))
50
51 # Non-defaults following defaults.
52 with self.assertRaisesRegex(TypeError,
53 "non-default argument 'y' follows "
54 "default argument"):
55 @dataclass
56 class C:
57 x: int = 0
58 y: int
59
60 # A derived class adds a non-default field after a default one.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class B:
66 x: int = 0
67
68 @dataclass
69 class C(B):
70 y: int
71
72 # Override a base class field and add a default to
73 # a field which didn't use to have a default.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int
80 y: int
81
82 @dataclass
83 class C(B):
84 x: int = 0
85
Eric V. Smithf0db54a2017-12-04 16:58:55 -050086 def test_overwriting_hash(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -050087 @dataclass(frozen=True)
88 class C:
89 x: int
90 def __hash__(self):
91 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -050092
93 @dataclass(frozen=True,hash=False)
94 class C:
95 x: int
96 def __hash__(self):
97 return 600
98 self.assertEqual(hash(C(0)), 600)
99
Eric V. Smithea8fc522018-01-27 19:07:40 -0500100 @dataclass(frozen=True)
101 class C:
102 x: int
103 def __hash__(self):
104 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500105
106 @dataclass(frozen=True, hash=False)
107 class C:
108 x: int
109 def __hash__(self):
110 return 600
111 self.assertEqual(hash(C(0)), 600)
112
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500113 def test_overwrite_fields_in_derived_class(self):
114 # Note that x from C1 replaces x in Base, but the order remains
115 # the same as defined in Base.
116 @dataclass
117 class Base:
118 x: Any = 15.0
119 y: int = 0
120
121 @dataclass
122 class C1(Base):
123 z: int = 10
124 x: int = 15
125
126 o = Base()
127 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
128
129 o = C1()
130 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
131
132 o = C1(x=5)
133 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
134
135 def test_field_named_self(self):
136 @dataclass
137 class C:
138 self: str
139 c=C('foo')
140 self.assertEqual(c.self, 'foo')
141
142 # Make sure the first parameter is not named 'self'.
143 sig = inspect.signature(C.__init__)
144 first = next(iter(sig.parameters))
145 self.assertNotEqual('self', first)
146
147 # But we do use 'self' if no field named self.
148 @dataclass
149 class C:
150 selfx: str
151
152 # Make sure the first parameter is named 'self'.
153 sig = inspect.signature(C.__init__)
154 first = next(iter(sig.parameters))
155 self.assertEqual('self', first)
156
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500157 def test_0_field_compare(self):
158 # Ensure that order=False is the default.
159 @dataclass
160 class C0:
161 pass
162
163 @dataclass(order=False)
164 class C1:
165 pass
166
167 for cls in [C0, C1]:
168 with self.subTest(cls=cls):
169 self.assertEqual(cls(), cls())
170 for idx, fn in enumerate([lambda a, b: a < b,
171 lambda a, b: a <= b,
172 lambda a, b: a > b,
173 lambda a, b: a >= b]):
174 with self.subTest(idx=idx):
175 with self.assertRaisesRegex(TypeError,
176 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
177 fn(cls(), cls())
178
179 @dataclass(order=True)
180 class C:
181 pass
182 self.assertLessEqual(C(), C())
183 self.assertGreaterEqual(C(), C())
184
185 def test_1_field_compare(self):
186 # Ensure that order=False is the default.
187 @dataclass
188 class C0:
189 x: int
190
191 @dataclass(order=False)
192 class C1:
193 x: int
194
195 for cls in [C0, C1]:
196 with self.subTest(cls=cls):
197 self.assertEqual(cls(1), cls(1))
198 self.assertNotEqual(cls(0), cls(1))
199 for idx, fn in enumerate([lambda a, b: a < b,
200 lambda a, b: a <= b,
201 lambda a, b: a > b,
202 lambda a, b: a >= b]):
203 with self.subTest(idx=idx):
204 with self.assertRaisesRegex(TypeError,
205 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
206 fn(cls(0), cls(0))
207
208 @dataclass(order=True)
209 class C:
210 x: int
211 self.assertLess(C(0), C(1))
212 self.assertLessEqual(C(0), C(1))
213 self.assertLessEqual(C(1), C(1))
214 self.assertGreater(C(1), C(0))
215 self.assertGreaterEqual(C(1), C(0))
216 self.assertGreaterEqual(C(1), C(1))
217
218 def test_simple_compare(self):
219 # Ensure that order=False is the default.
220 @dataclass
221 class C0:
222 x: int
223 y: int
224
225 @dataclass(order=False)
226 class C1:
227 x: int
228 y: int
229
230 for cls in [C0, C1]:
231 with self.subTest(cls=cls):
232 self.assertEqual(cls(0, 0), cls(0, 0))
233 self.assertEqual(cls(1, 2), cls(1, 2))
234 self.assertNotEqual(cls(1, 0), cls(0, 0))
235 self.assertNotEqual(cls(1, 0), cls(1, 1))
236 for idx, fn in enumerate([lambda a, b: a < b,
237 lambda a, b: a <= b,
238 lambda a, b: a > b,
239 lambda a, b: a >= b]):
240 with self.subTest(idx=idx):
241 with self.assertRaisesRegex(TypeError,
242 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
243 fn(cls(0, 0), cls(0, 0))
244
245 @dataclass(order=True)
246 class C:
247 x: int
248 y: int
249
250 for idx, fn in enumerate([lambda a, b: a == b,
251 lambda a, b: a <= b,
252 lambda a, b: a >= b]):
253 with self.subTest(idx=idx):
254 self.assertTrue(fn(C(0, 0), C(0, 0)))
255
256 for idx, fn in enumerate([lambda a, b: a < b,
257 lambda a, b: a <= b,
258 lambda a, b: a != b]):
259 with self.subTest(idx=idx):
260 self.assertTrue(fn(C(0, 0), C(0, 1)))
261 self.assertTrue(fn(C(0, 1), C(1, 0)))
262 self.assertTrue(fn(C(1, 0), C(1, 1)))
263
264 for idx, fn in enumerate([lambda a, b: a > b,
265 lambda a, b: a >= b,
266 lambda a, b: a != b]):
267 with self.subTest(idx=idx):
268 self.assertTrue(fn(C(0, 1), C(0, 0)))
269 self.assertTrue(fn(C(1, 0), C(0, 1)))
270 self.assertTrue(fn(C(1, 1), C(1, 0)))
271
272 def test_compare_subclasses(self):
273 # Comparisons fail for subclasses, even if no fields
274 # are added.
275 @dataclass
276 class B:
277 i: int
278
279 @dataclass
280 class C(B):
281 pass
282
283 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
284 (lambda a, b: a != b, True)]):
285 with self.subTest(idx=idx):
286 self.assertEqual(fn(B(0), C(0)), expected)
287
288 for idx, fn in enumerate([lambda a, b: a < b,
289 lambda a, b: a <= b,
290 lambda a, b: a > b,
291 lambda a, b: a >= b]):
292 with self.subTest(idx=idx):
293 with self.assertRaisesRegex(TypeError,
294 "not supported between instances of 'B' and 'C'"):
295 fn(B(0), C(0))
296
297 def test_0_field_hash(self):
298 @dataclass(hash=True)
299 class C:
300 pass
301 self.assertEqual(hash(C()), hash(()))
302
303 def test_1_field_hash(self):
304 @dataclass(hash=True)
305 class C:
306 x: int
307 self.assertEqual(hash(C(4)), hash((4,)))
308 self.assertEqual(hash(C(42)), hash((42,)))
309
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500310 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500311 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500312 for (eq, order, result ) in [
313 (False, False, 'neither'),
314 (False, True, 'exception'),
315 (True, False, 'eq_only'),
316 (True, True, 'both'),
317 ]:
318 with self.subTest(eq=eq, order=order):
319 if result == 'exception':
320 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
321 @dataclass(eq=eq, order=order)
322 class C:
323 pass
324 else:
325 @dataclass(eq=eq, order=order)
326 class C:
327 pass
328
329 if result == 'neither':
330 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500331 self.assertNotIn('__lt__', C.__dict__)
332 self.assertNotIn('__le__', C.__dict__)
333 self.assertNotIn('__gt__', C.__dict__)
334 self.assertNotIn('__ge__', C.__dict__)
335 elif result == 'both':
336 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500337 self.assertIn('__lt__', C.__dict__)
338 self.assertIn('__le__', C.__dict__)
339 self.assertIn('__gt__', C.__dict__)
340 self.assertIn('__ge__', C.__dict__)
341 elif result == 'eq_only':
342 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500343 self.assertNotIn('__lt__', C.__dict__)
344 self.assertNotIn('__le__', C.__dict__)
345 self.assertNotIn('__gt__', C.__dict__)
346 self.assertNotIn('__ge__', C.__dict__)
347 else:
348 assert False, f'unknown result {result!r}'
349
350 def test_field_no_default(self):
351 @dataclass
352 class C:
353 x: int = field()
354
355 self.assertEqual(C(5).x, 5)
356
357 with self.assertRaisesRegex(TypeError,
358 r"__init__\(\) missing 1 required "
359 "positional argument: 'x'"):
360 C()
361
362 def test_field_default(self):
363 default = object()
364 @dataclass
365 class C:
366 x: object = field(default=default)
367
368 self.assertIs(C.x, default)
369 c = C(10)
370 self.assertEqual(c.x, 10)
371
372 # If we delete the instance attribute, we should then see the
373 # class attribute.
374 del c.x
375 self.assertIs(c.x, default)
376
377 self.assertIs(C().x, default)
378
379 def test_not_in_repr(self):
380 @dataclass
381 class C:
382 x: int = field(repr=False)
383 with self.assertRaises(TypeError):
384 C()
385 c = C(10)
386 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
387
388 @dataclass
389 class C:
390 x: int = field(repr=False)
391 y: int
392 c = C(10, 20)
393 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
394
395 def test_not_in_compare(self):
396 @dataclass
397 class C:
398 x: int = 0
399 y: int = field(compare=False, default=4)
400
401 self.assertEqual(C(), C(0, 20))
402 self.assertEqual(C(1, 10), C(1, 20))
403 self.assertNotEqual(C(3), C(4, 10))
404 self.assertNotEqual(C(3, 10), C(4, 10))
405
406 def test_hash_field_rules(self):
407 # Test all 6 cases of:
408 # hash=True/False/None
409 # compare=True/False
410 for (hash_val, compare, result ) in [
411 (True, False, 'field' ),
412 (True, True, 'field' ),
413 (False, False, 'absent'),
414 (False, True, 'absent'),
415 (None, False, 'absent'),
416 (None, True, 'field' ),
417 ]:
418 with self.subTest(hash_val=hash_val, compare=compare):
419 @dataclass(hash=True)
420 class C:
421 x: int = field(compare=compare, hash=hash_val, default=5)
422
423 if result == 'field':
424 # __hash__ contains the field.
425 self.assertEqual(C(5).__hash__(), hash((5,)))
426 elif result == 'absent':
427 # The field is not present in the hash.
428 self.assertEqual(C(5).__hash__(), hash(()))
429 else:
430 assert False, f'unknown result {result!r}'
431
432 def test_init_false_no_default(self):
433 # If init=False and no default value, then the field won't be
434 # present in the instance.
435 @dataclass
436 class C:
437 x: int = field(init=False)
438
439 self.assertNotIn('x', C().__dict__)
440
441 @dataclass
442 class C:
443 x: int
444 y: int = 0
445 z: int = field(init=False)
446 t: int = 10
447
448 self.assertNotIn('z', C(0).__dict__)
449 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
450
451 def test_class_marker(self):
452 @dataclass
453 class C:
454 x: int
455 y: str = field(init=False, default=None)
456 z: str = field(repr=False)
457
458 the_fields = fields(C)
459 # the_fields is a tuple of 3 items, each value
460 # is in __annotations__.
461 self.assertIsInstance(the_fields, tuple)
462 for f in the_fields:
463 self.assertIs(type(f), Field)
464 self.assertIn(f.name, C.__annotations__)
465
466 self.assertEqual(len(the_fields), 3)
467
468 self.assertEqual(the_fields[0].name, 'x')
469 self.assertEqual(the_fields[0].type, int)
470 self.assertFalse(hasattr(C, 'x'))
471 self.assertTrue (the_fields[0].init)
472 self.assertTrue (the_fields[0].repr)
473 self.assertEqual(the_fields[1].name, 'y')
474 self.assertEqual(the_fields[1].type, str)
475 self.assertIsNone(getattr(C, 'y'))
476 self.assertFalse(the_fields[1].init)
477 self.assertTrue (the_fields[1].repr)
478 self.assertEqual(the_fields[2].name, 'z')
479 self.assertEqual(the_fields[2].type, str)
480 self.assertFalse(hasattr(C, 'z'))
481 self.assertTrue (the_fields[2].init)
482 self.assertFalse(the_fields[2].repr)
483
484 def test_field_order(self):
485 @dataclass
486 class B:
487 a: str = 'B:a'
488 b: str = 'B:b'
489 c: str = 'B:c'
490
491 @dataclass
492 class C(B):
493 b: str = 'C:b'
494
495 self.assertEqual([(f.name, f.default) for f in fields(C)],
496 [('a', 'B:a'),
497 ('b', 'C:b'),
498 ('c', 'B:c')])
499
500 @dataclass
501 class D(B):
502 c: str = 'D:c'
503
504 self.assertEqual([(f.name, f.default) for f in fields(D)],
505 [('a', 'B:a'),
506 ('b', 'B:b'),
507 ('c', 'D:c')])
508
509 @dataclass
510 class E(D):
511 a: str = 'E:a'
512 d: str = 'E:d'
513
514 self.assertEqual([(f.name, f.default) for f in fields(E)],
515 [('a', 'E:a'),
516 ('b', 'B:b'),
517 ('c', 'D:c'),
518 ('d', 'E:d')])
519
520 def test_class_attrs(self):
521 # We only have a class attribute if a default value is
522 # specified, either directly or via a field with a default.
523 default = object()
524 @dataclass
525 class C:
526 x: int
527 y: int = field(repr=False)
528 z: object = default
529 t: int = field(default=100)
530
531 self.assertFalse(hasattr(C, 'x'))
532 self.assertFalse(hasattr(C, 'y'))
533 self.assertIs (C.z, default)
534 self.assertEqual(C.t, 100)
535
536 def test_disallowed_mutable_defaults(self):
537 # For the known types, don't allow mutable default values.
538 for typ, empty, non_empty in [(list, [], [1]),
539 (dict, {}, {0:1}),
540 (set, set(), set([1])),
541 ]:
542 with self.subTest(typ=typ):
543 # Can't use a zero-length value.
544 with self.assertRaisesRegex(ValueError,
545 f'mutable default {typ} for field '
546 'x is not allowed'):
547 @dataclass
548 class Point:
549 x: typ = empty
550
551
552 # Nor a non-zero-length value
553 with self.assertRaisesRegex(ValueError,
554 f'mutable default {typ} for field '
555 'y is not allowed'):
556 @dataclass
557 class Point:
558 y: typ = non_empty
559
560 # Check subtypes also fail.
561 class Subclass(typ): pass
562
563 with self.assertRaisesRegex(ValueError,
564 f"mutable default .*Subclass'>"
565 ' for field z is not allowed'
566 ):
567 @dataclass
568 class Point:
569 z: typ = Subclass()
570
571 # Because this is a ClassVar, it can be mutable.
572 @dataclass
573 class C:
574 z: ClassVar[typ] = typ()
575
576 # Because this is a ClassVar, it can be mutable.
577 @dataclass
578 class C:
579 x: ClassVar[typ] = Subclass()
580
581
582 def test_deliberately_mutable_defaults(self):
583 # If a mutable default isn't in the known list of
584 # (list, dict, set), then it's okay.
585 class Mutable:
586 def __init__(self):
587 self.l = []
588
589 @dataclass
590 class C:
591 x: Mutable
592
593 # These 2 instances will share this value of x.
594 lst = Mutable()
595 o1 = C(lst)
596 o2 = C(lst)
597 self.assertEqual(o1, o2)
598 o1.x.l.extend([1, 2])
599 self.assertEqual(o1, o2)
600 self.assertEqual(o1.x.l, [1, 2])
601 self.assertIs(o1.x, o2.x)
602
603 def test_no_options(self):
604 # call with dataclass()
605 @dataclass()
606 class C:
607 x: int
608
609 self.assertEqual(C(42).x, 42)
610
611 def test_not_tuple(self):
612 # Make sure we can't be compared to a tuple.
613 @dataclass
614 class Point:
615 x: int
616 y: int
617 self.assertNotEqual(Point(1, 2), (1, 2))
618
619 # And that we can't compare to another unrelated dataclass
620 @dataclass
621 class C:
622 x: int
623 y: int
624 self.assertNotEqual(Point(1, 3), C(1, 3))
625
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500626 def test_frozen(self):
627 @dataclass(frozen=True)
628 class C:
629 i: int
630
631 c = C(10)
632 self.assertEqual(c.i, 10)
633 with self.assertRaises(FrozenInstanceError):
634 c.i = 5
635 self.assertEqual(c.i, 10)
636
637 # Check that a derived class is still frozen, even if not
638 # marked so.
639 @dataclass
640 class D(C):
641 pass
642
643 d = D(20)
644 self.assertEqual(d.i, 20)
645 with self.assertRaises(FrozenInstanceError):
646 d.i = 5
647 self.assertEqual(d.i, 20)
648
649 def test_not_tuple(self):
650 # Test that some of the problems with namedtuple don't happen
651 # here.
652 @dataclass
653 class Point3D:
654 x: int
655 y: int
656 z: int
657
658 @dataclass
659 class Date:
660 year: int
661 month: int
662 day: int
663
664 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
665 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
666
667 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200668 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500669 x, y, z = Point3D(4, 5, 6)
670
Eric V. Smith7c99e932018-01-28 19:18:55 -0500671 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500672 # equal.
673 @dataclass
674 class Point3Dv1:
675 x: int = 0
676 y: int = 0
677 z: int = 0
678 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
679
680 def test_function_annotations(self):
681 # Some dummy class and instance to use as a default.
682 class F:
683 pass
684 f = F()
685
686 def validate_class(cls):
687 # First, check __annotations__, even though they're not
688 # function annotations.
689 self.assertEqual(cls.__annotations__['i'], int)
690 self.assertEqual(cls.__annotations__['j'], str)
691 self.assertEqual(cls.__annotations__['k'], F)
692 self.assertEqual(cls.__annotations__['l'], float)
693 self.assertEqual(cls.__annotations__['z'], complex)
694
695 # Verify __init__.
696
697 signature = inspect.signature(cls.__init__)
698 # Check the return type, should be None
699 self.assertIs(signature.return_annotation, None)
700
701 # Check each parameter.
702 params = iter(signature.parameters.values())
703 param = next(params)
704 # This is testing an internal name, and probably shouldn't be tested.
705 self.assertEqual(param.name, 'self')
706 param = next(params)
707 self.assertEqual(param.name, 'i')
708 self.assertIs (param.annotation, int)
709 self.assertEqual(param.default, inspect.Parameter.empty)
710 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
711 param = next(params)
712 self.assertEqual(param.name, 'j')
713 self.assertIs (param.annotation, str)
714 self.assertEqual(param.default, inspect.Parameter.empty)
715 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
716 param = next(params)
717 self.assertEqual(param.name, 'k')
718 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500719 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
721 param = next(params)
722 self.assertEqual(param.name, 'l')
723 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500724 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500725 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
726 self.assertRaises(StopIteration, next, params)
727
728
729 @dataclass
730 class C:
731 i: int
732 j: str
733 k: F = f
734 l: float=field(default=None)
735 z: complex=field(default=3+4j, init=False)
736
737 validate_class(C)
738
739 # Now repeat with __hash__.
740 @dataclass(frozen=True, hash=True)
741 class C:
742 i: int
743 j: str
744 k: F = f
745 l: float=field(default=None)
746 z: complex=field(default=3+4j, init=False)
747
748 validate_class(C)
749
Eric V. Smith03220fd2017-12-29 13:59:58 -0500750 def test_missing_default(self):
751 # Test that MISSING works the same as a default not being
752 # specified.
753 @dataclass
754 class C:
755 x: int=field(default=MISSING)
756 with self.assertRaisesRegex(TypeError,
757 r'__init__\(\) missing 1 required '
758 'positional argument'):
759 C()
760 self.assertNotIn('x', C.__dict__)
761
762 @dataclass
763 class D:
764 x: int
765 with self.assertRaisesRegex(TypeError,
766 r'__init__\(\) missing 1 required '
767 'positional argument'):
768 D()
769 self.assertNotIn('x', D.__dict__)
770
771 def test_missing_default_factory(self):
772 # Test that MISSING works the same as a default factory not
773 # being specified (which is really the same as a default not
774 # being specified, too).
775 @dataclass
776 class C:
777 x: int=field(default_factory=MISSING)
778 with self.assertRaisesRegex(TypeError,
779 r'__init__\(\) missing 1 required '
780 'positional argument'):
781 C()
782 self.assertNotIn('x', C.__dict__)
783
784 @dataclass
785 class D:
786 x: int=field(default=MISSING, default_factory=MISSING)
787 with self.assertRaisesRegex(TypeError,
788 r'__init__\(\) missing 1 required '
789 'positional argument'):
790 D()
791 self.assertNotIn('x', D.__dict__)
792
793 def test_missing_repr(self):
794 self.assertIn('MISSING_TYPE object', repr(MISSING))
795
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500796 def test_dont_include_other_annotations(self):
797 @dataclass
798 class C:
799 i: int
800 def foo(self) -> int:
801 return 4
802 @property
803 def bar(self) -> int:
804 return 5
805 self.assertEqual(list(C.__annotations__), ['i'])
806 self.assertEqual(C(10).foo(), 4)
807 self.assertEqual(C(10).bar, 5)
808
809 def test_post_init(self):
810 # Just make sure it gets called
811 @dataclass
812 class C:
813 def __post_init__(self):
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817
818 @dataclass
819 class C:
820 i: int = 10
821 def __post_init__(self):
822 if self.i == 10:
823 raise CustomError()
824 with self.assertRaises(CustomError):
825 C()
826 # post-init gets called, but doesn't raise. This is just
827 # checking that self is used correctly.
828 C(5)
829
830 # If there's not an __init__, then post-init won't get called.
831 @dataclass(init=False)
832 class C:
833 def __post_init__(self):
834 raise CustomError()
835 # Creating the class won't raise
836 C()
837
838 @dataclass
839 class C:
840 x: int = 0
841 def __post_init__(self):
842 self.x *= 2
843 self.assertEqual(C().x, 0)
844 self.assertEqual(C(2).x, 4)
845
Mike53f7a7c2017-12-14 14:04:53 +0300846 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 # attributes.
848 @dataclass(frozen=True)
849 class C:
850 x: int = 0
851 def __post_init__(self):
852 self.x *= 2
853 with self.assertRaises(FrozenInstanceError):
854 C()
855
856 def test_post_init_super(self):
857 # Make sure super() post-init isn't called by default.
858 class B:
859 def __post_init__(self):
860 raise CustomError()
861
862 @dataclass
863 class C(B):
864 def __post_init__(self):
865 self.x = 5
866
867 self.assertEqual(C().x, 5)
868
869 # Now call super(), and it will raise
870 @dataclass
871 class C(B):
872 def __post_init__(self):
873 super().__post_init__()
874
875 with self.assertRaises(CustomError):
876 C()
877
878 # Make sure post-init is called, even if not defined in our
879 # class.
880 @dataclass
881 class C(B):
882 pass
883
884 with self.assertRaises(CustomError):
885 C()
886
887 def test_post_init_staticmethod(self):
888 flag = False
889 @dataclass
890 class C:
891 x: int
892 y: int
893 @staticmethod
894 def __post_init__():
895 nonlocal flag
896 flag = True
897
898 self.assertFalse(flag)
899 c = C(3, 4)
900 self.assertEqual((c.x, c.y), (3, 4))
901 self.assertTrue(flag)
902
903 def test_post_init_classmethod(self):
904 @dataclass
905 class C:
906 flag = False
907 x: int
908 y: int
909 @classmethod
910 def __post_init__(cls):
911 cls.flag = True
912
913 self.assertFalse(C.flag)
914 c = C(3, 4)
915 self.assertEqual((c.x, c.y), (3, 4))
916 self.assertTrue(C.flag)
917
918 def test_class_var(self):
919 # Make sure ClassVars are ignored in __init__, __repr__, etc.
920 @dataclass
921 class C:
922 x: int
923 y: int = 10
924 z: ClassVar[int] = 1000
925 w: ClassVar[int] = 2000
926 t: ClassVar[int] = 3000
927
928 c = C(5)
929 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
930 self.assertEqual(len(fields(C)), 2) # We have 2 fields
931 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
932 self.assertEqual(c.z, 1000)
933 self.assertEqual(c.w, 2000)
934 self.assertEqual(c.t, 3000)
935 C.z += 1
936 self.assertEqual(c.z, 1001)
937 c = C(20)
938 self.assertEqual((c.x, c.y), (20, 10))
939 self.assertEqual(c.z, 1001)
940 self.assertEqual(c.w, 2000)
941 self.assertEqual(c.t, 3000)
942
943 def test_class_var_no_default(self):
944 # If a ClassVar has no default value, it should not be set on the class.
945 @dataclass
946 class C:
947 x: ClassVar[int]
948
949 self.assertNotIn('x', C.__dict__)
950
951 def test_class_var_default_factory(self):
952 # It makes no sense for a ClassVar to have a default factory. When
953 # would it be called? Call it yourself, since it's class-wide.
954 with self.assertRaisesRegex(TypeError,
955 'cannot have a default factory'):
956 @dataclass
957 class C:
958 x: ClassVar[int] = field(default_factory=int)
959
960 self.assertNotIn('x', C.__dict__)
961
962 def test_class_var_with_default(self):
963 # If a ClassVar has a default value, it should be set on the class.
964 @dataclass
965 class C:
966 x: ClassVar[int] = 10
967 self.assertEqual(C.x, 10)
968
969 @dataclass
970 class C:
971 x: ClassVar[int] = field(default=10)
972 self.assertEqual(C.x, 10)
973
974 def test_class_var_frozen(self):
975 # Make sure ClassVars work even if we're frozen.
976 @dataclass(frozen=True)
977 class C:
978 x: int
979 y: int = 10
980 z: ClassVar[int] = 1000
981 w: ClassVar[int] = 2000
982 t: ClassVar[int] = 3000
983
984 c = C(5)
985 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
986 self.assertEqual(len(fields(C)), 2) # We have 2 fields
987 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
988 self.assertEqual(c.z, 1000)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991 # We can still modify the ClassVar, it's only instances that are
992 # frozen.
993 C.z += 1
994 self.assertEqual(c.z, 1001)
995 c = C(20)
996 self.assertEqual((c.x, c.y), (20, 10))
997 self.assertEqual(c.z, 1001)
998 self.assertEqual(c.w, 2000)
999 self.assertEqual(c.t, 3000)
1000
1001 def test_init_var_no_default(self):
1002 # If an InitVar has no default value, it should not be set on the class.
1003 @dataclass
1004 class C:
1005 x: InitVar[int]
1006
1007 self.assertNotIn('x', C.__dict__)
1008
1009 def test_init_var_default_factory(self):
1010 # It makes no sense for an InitVar to have a default factory. When
1011 # would it be called? Call it yourself, since it's class-wide.
1012 with self.assertRaisesRegex(TypeError,
1013 'cannot have a default factory'):
1014 @dataclass
1015 class C:
1016 x: InitVar[int] = field(default_factory=int)
1017
1018 self.assertNotIn('x', C.__dict__)
1019
1020 def test_init_var_with_default(self):
1021 # If an InitVar has a default value, it should be set on the class.
1022 @dataclass
1023 class C:
1024 x: InitVar[int] = 10
1025 self.assertEqual(C.x, 10)
1026
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = field(default=10)
1030 self.assertEqual(C.x, 10)
1031
1032 def test_init_var(self):
1033 @dataclass
1034 class C:
1035 x: int = None
1036 init_param: InitVar[int] = None
1037
1038 def __post_init__(self, init_param):
1039 if self.x is None:
1040 self.x = init_param*2
1041
1042 c = C(init_param=10)
1043 self.assertEqual(c.x, 20)
1044
1045 def test_init_var_inheritance(self):
1046 # Note that this deliberately tests that a dataclass need not
1047 # have a __post_init__ function if it has an InitVar field.
1048 # It could just be used in a derived class, as shown here.
1049 @dataclass
1050 class Base:
1051 x: int
1052 init_base: InitVar[int]
1053
1054 # We can instantiate by passing the InitVar, even though
1055 # it's not used.
1056 b = Base(0, 10)
1057 self.assertEqual(vars(b), {'x': 0})
1058
1059 @dataclass
1060 class C(Base):
1061 y: int
1062 init_derived: InitVar[int]
1063
1064 def __post_init__(self, init_base, init_derived):
1065 self.x = self.x + init_base
1066 self.y = self.y + init_derived
1067
1068 c = C(10, 11, 50, 51)
1069 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1070
1071 def test_default_factory(self):
1072 # Test a factory that returns a new list.
1073 @dataclass
1074 class C:
1075 x: int
1076 y: list = field(default_factory=list)
1077
1078 c0 = C(3)
1079 c1 = C(3)
1080 self.assertEqual(c0.x, 3)
1081 self.assertEqual(c0.y, [])
1082 self.assertEqual(c0, c1)
1083 self.assertIsNot(c0.y, c1.y)
1084 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1085
1086 # Test a factory that returns a shared list.
1087 l = []
1088 @dataclass
1089 class C:
1090 x: int
1091 y: list = field(default_factory=lambda: l)
1092
1093 c0 = C(3)
1094 c1 = C(3)
1095 self.assertEqual(c0.x, 3)
1096 self.assertEqual(c0.y, [])
1097 self.assertEqual(c0, c1)
1098 self.assertIs(c0.y, c1.y)
1099 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1100
1101 # Test various other field flags.
1102 # repr
1103 @dataclass
1104 class C:
1105 x: list = field(default_factory=list, repr=False)
1106 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1107 self.assertEqual(C().x, [])
1108
1109 # hash
1110 @dataclass(hash=True)
1111 class C:
1112 x: list = field(default_factory=list, hash=False)
1113 self.assertEqual(astuple(C()), ([],))
1114 self.assertEqual(hash(C()), hash(()))
1115
1116 # init (see also test_default_factory_with_no_init)
1117 @dataclass
1118 class C:
1119 x: list = field(default_factory=list, init=False)
1120 self.assertEqual(astuple(C()), ([],))
1121
1122 # compare
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=list, compare=False)
1126 self.assertEqual(C(), C([1]))
1127
1128 def test_default_factory_with_no_init(self):
1129 # We need a factory with a side effect.
1130 factory = Mock()
1131
1132 @dataclass
1133 class C:
1134 x: list = field(default_factory=factory, init=False)
1135
1136 # Make sure the default factory is called for each new instance.
1137 C().x
1138 self.assertEqual(factory.call_count, 1)
1139 C().x
1140 self.assertEqual(factory.call_count, 2)
1141
1142 def test_default_factory_not_called_if_value_given(self):
1143 # We need a factory that we can test if it's been called.
1144 factory = Mock()
1145
1146 @dataclass
1147 class C:
1148 x: int = field(default_factory=factory)
1149
1150 # Make sure that if a field has a default factory function,
1151 # it's not called if a value is specified.
1152 C().x
1153 self.assertEqual(factory.call_count, 1)
1154 self.assertEqual(C(10).x, 10)
1155 self.assertEqual(factory.call_count, 1)
1156 C().x
1157 self.assertEqual(factory.call_count, 2)
1158
1159 def x_test_classvar_default_factory(self):
1160 # XXX: it's an error for a ClassVar to have a factory function
1161 @dataclass
1162 class C:
1163 x: ClassVar[int] = field(default_factory=int)
1164
1165 self.assertIs(C().x, int)
1166
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001167 def test_is_dataclass(self):
1168 class NotDataClass:
1169 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001170
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001171 self.assertFalse(is_dataclass(0))
1172 self.assertFalse(is_dataclass(int))
1173 self.assertFalse(is_dataclass(NotDataClass))
1174 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001175
1176 @dataclass
1177 class C:
1178 x: int
1179
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001180 @dataclass
1181 class D:
1182 d: C
1183 e: int
1184
1185 c = C(10)
1186 d = D(c, 4)
1187
1188 self.assertTrue(is_dataclass(C))
1189 self.assertTrue(is_dataclass(c))
1190 self.assertFalse(is_dataclass(c.x))
1191 self.assertTrue(is_dataclass(d.d))
1192 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001193
1194 def test_helper_fields_with_class_instance(self):
1195 # Check that we can call fields() on either a class or instance,
1196 # and get back the same thing.
1197 @dataclass
1198 class C:
1199 x: int
1200 y: float
1201
1202 self.assertEqual(fields(C), fields(C(0, 0.0)))
1203
1204 def test_helper_fields_exception(self):
1205 # Check that TypeError is raised if not passed a dataclass or
1206 # instance.
1207 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1208 fields(0)
1209
1210 class C: pass
1211 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1212 fields(C)
1213 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1214 fields(C())
1215
1216 def test_helper_asdict(self):
1217 # Basic tests for asdict(), it should return a new dictionary
1218 @dataclass
1219 class C:
1220 x: int
1221 y: int
1222 c = C(1, 2)
1223
1224 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1225 self.assertEqual(asdict(c), asdict(c))
1226 self.assertIsNot(asdict(c), asdict(c))
1227 c.x = 42
1228 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1229 self.assertIs(type(asdict(c)), dict)
1230
1231 def test_helper_asdict_raises_on_classes(self):
1232 # asdict() should raise on a class object
1233 @dataclass
1234 class C:
1235 x: int
1236 y: int
1237 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1238 asdict(C)
1239 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1240 asdict(int)
1241
1242 def test_helper_asdict_copy_values(self):
1243 @dataclass
1244 class C:
1245 x: int
1246 y: List[int] = field(default_factory=list)
1247 initial = []
1248 c = C(1, initial)
1249 d = asdict(c)
1250 self.assertEqual(d['y'], initial)
1251 self.assertIsNot(d['y'], initial)
1252 c = C(1)
1253 d = asdict(c)
1254 d['y'].append(1)
1255 self.assertEqual(c.y, [])
1256
1257 def test_helper_asdict_nested(self):
1258 @dataclass
1259 class UserId:
1260 token: int
1261 group: int
1262 @dataclass
1263 class User:
1264 name: str
1265 id: UserId
1266 u = User('Joe', UserId(123, 1))
1267 d = asdict(u)
1268 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1269 self.assertIsNot(asdict(u), asdict(u))
1270 u.id.group = 2
1271 self.assertEqual(asdict(u), {'name': 'Joe',
1272 'id': {'token': 123, 'group': 2}})
1273
1274 def test_helper_asdict_builtin_containers(self):
1275 @dataclass
1276 class User:
1277 name: str
1278 id: int
1279 @dataclass
1280 class GroupList:
1281 id: int
1282 users: List[User]
1283 @dataclass
1284 class GroupTuple:
1285 id: int
1286 users: Tuple[User, ...]
1287 @dataclass
1288 class GroupDict:
1289 id: int
1290 users: Dict[str, User]
1291 a = User('Alice', 1)
1292 b = User('Bob', 2)
1293 gl = GroupList(0, [a, b])
1294 gt = GroupTuple(0, (a, b))
1295 gd = GroupDict(0, {'first': a, 'second': b})
1296 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1297 {'name': 'Bob', 'id': 2}]})
1298 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1299 {'name': 'Bob', 'id': 2})})
1300 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1301 'second': {'name': 'Bob', 'id': 2}}})
1302
1303 def test_helper_asdict_builtin_containers(self):
1304 @dataclass
1305 class Child:
1306 d: object
1307
1308 @dataclass
1309 class Parent:
1310 child: Child
1311
1312 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1313 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1314
1315 def test_helper_asdict_factory(self):
1316 @dataclass
1317 class C:
1318 x: int
1319 y: int
1320 c = C(1, 2)
1321 d = asdict(c, dict_factory=OrderedDict)
1322 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1323 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1324 c.x = 42
1325 d = asdict(c, dict_factory=OrderedDict)
1326 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1327 self.assertIs(type(d), OrderedDict)
1328
1329 def test_helper_astuple(self):
1330 # Basic tests for astuple(), it should return a new tuple
1331 @dataclass
1332 class C:
1333 x: int
1334 y: int = 0
1335 c = C(1)
1336
1337 self.assertEqual(astuple(c), (1, 0))
1338 self.assertEqual(astuple(c), astuple(c))
1339 self.assertIsNot(astuple(c), astuple(c))
1340 c.y = 42
1341 self.assertEqual(astuple(c), (1, 42))
1342 self.assertIs(type(astuple(c)), tuple)
1343
1344 def test_helper_astuple_raises_on_classes(self):
1345 # astuple() should raise on a class object
1346 @dataclass
1347 class C:
1348 x: int
1349 y: int
1350 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1351 astuple(C)
1352 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1353 astuple(int)
1354
1355 def test_helper_astuple_copy_values(self):
1356 @dataclass
1357 class C:
1358 x: int
1359 y: List[int] = field(default_factory=list)
1360 initial = []
1361 c = C(1, initial)
1362 t = astuple(c)
1363 self.assertEqual(t[1], initial)
1364 self.assertIsNot(t[1], initial)
1365 c = C(1)
1366 t = astuple(c)
1367 t[1].append(1)
1368 self.assertEqual(c.y, [])
1369
1370 def test_helper_astuple_nested(self):
1371 @dataclass
1372 class UserId:
1373 token: int
1374 group: int
1375 @dataclass
1376 class User:
1377 name: str
1378 id: UserId
1379 u = User('Joe', UserId(123, 1))
1380 t = astuple(u)
1381 self.assertEqual(t, ('Joe', (123, 1)))
1382 self.assertIsNot(astuple(u), astuple(u))
1383 u.id.group = 2
1384 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1385
1386 def test_helper_astuple_builtin_containers(self):
1387 @dataclass
1388 class User:
1389 name: str
1390 id: int
1391 @dataclass
1392 class GroupList:
1393 id: int
1394 users: List[User]
1395 @dataclass
1396 class GroupTuple:
1397 id: int
1398 users: Tuple[User, ...]
1399 @dataclass
1400 class GroupDict:
1401 id: int
1402 users: Dict[str, User]
1403 a = User('Alice', 1)
1404 b = User('Bob', 2)
1405 gl = GroupList(0, [a, b])
1406 gt = GroupTuple(0, (a, b))
1407 gd = GroupDict(0, {'first': a, 'second': b})
1408 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1409 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1410 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1411
1412 def test_helper_astuple_builtin_containers(self):
1413 @dataclass
1414 class Child:
1415 d: object
1416
1417 @dataclass
1418 class Parent:
1419 child: Child
1420
1421 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1422 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1423
1424 def test_helper_astuple_factory(self):
1425 @dataclass
1426 class C:
1427 x: int
1428 y: int
1429 NT = namedtuple('NT', 'x y')
1430 def nt(lst):
1431 return NT(*lst)
1432 c = C(1, 2)
1433 t = astuple(c, tuple_factory=nt)
1434 self.assertEqual(t, NT(1, 2))
1435 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1436 c.x = 42
1437 t = astuple(c, tuple_factory=nt)
1438 self.assertEqual(t, NT(42, 2))
1439 self.assertIs(type(t), NT)
1440
1441 def test_dynamic_class_creation(self):
1442 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1443 }
1444
1445 # Create the class.
1446 cls = type('C', (), cls_dict)
1447
1448 # Make it a dataclass.
1449 cls1 = dataclass(cls)
1450
1451 self.assertEqual(cls1, cls)
1452 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1453
1454 def test_dynamic_class_creation_using_field(self):
1455 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1456 'y': field(default=5),
1457 }
1458
1459 # Create the class.
1460 cls = type('C', (), cls_dict)
1461
1462 # Make it a dataclass.
1463 cls1 = dataclass(cls)
1464
1465 self.assertEqual(cls1, cls)
1466 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1467
1468 def test_init_in_order(self):
1469 @dataclass
1470 class C:
1471 a: int
1472 b: int = field()
1473 c: list = field(default_factory=list, init=False)
1474 d: list = field(default_factory=list)
1475 e: int = field(default=4, init=False)
1476 f: int = 4
1477
1478 calls = []
1479 def setattr(self, name, value):
1480 calls.append((name, value))
1481
1482 C.__setattr__ = setattr
1483 c = C(0, 1)
1484 self.assertEqual(('a', 0), calls[0])
1485 self.assertEqual(('b', 1), calls[1])
1486 self.assertEqual(('c', []), calls[2])
1487 self.assertEqual(('d', []), calls[3])
1488 self.assertNotIn(('e', 4), calls)
1489 self.assertEqual(('f', 4), calls[4])
1490
1491 def test_items_in_dicts(self):
1492 @dataclass
1493 class C:
1494 a: int
1495 b: list = field(default_factory=list, init=False)
1496 c: list = field(default_factory=list)
1497 d: int = field(default=4, init=False)
1498 e: int = 0
1499
1500 c = C(0)
1501 # Class dict
1502 self.assertNotIn('a', C.__dict__)
1503 self.assertNotIn('b', C.__dict__)
1504 self.assertNotIn('c', C.__dict__)
1505 self.assertIn('d', C.__dict__)
1506 self.assertEqual(C.d, 4)
1507 self.assertIn('e', C.__dict__)
1508 self.assertEqual(C.e, 0)
1509 # Instance dict
1510 self.assertIn('a', c.__dict__)
1511 self.assertEqual(c.a, 0)
1512 self.assertIn('b', c.__dict__)
1513 self.assertEqual(c.b, [])
1514 self.assertIn('c', c.__dict__)
1515 self.assertEqual(c.c, [])
1516 self.assertNotIn('d', c.__dict__)
1517 self.assertIn('e', c.__dict__)
1518 self.assertEqual(c.e, 0)
1519
1520 def test_alternate_classmethod_constructor(self):
1521 # Since __post_init__ can't take params, use a classmethod
1522 # alternate constructor. This is mostly an example to show how
1523 # to use this technique.
1524 @dataclass
1525 class C:
1526 x: int
1527 @classmethod
1528 def from_file(cls, filename):
1529 # In a real example, create a new instance
1530 # and populate 'x' from contents of a file.
1531 value_in_file = 20
1532 return cls(value_in_file)
1533
1534 self.assertEqual(C.from_file('filename').x, 20)
1535
1536 def test_field_metadata_default(self):
1537 # Make sure the default metadata is read-only and of
1538 # zero length.
1539 @dataclass
1540 class C:
1541 i: int
1542
1543 self.assertFalse(fields(C)[0].metadata)
1544 self.assertEqual(len(fields(C)[0].metadata), 0)
1545 with self.assertRaisesRegex(TypeError,
1546 'does not support item assignment'):
1547 fields(C)[0].metadata['test'] = 3
1548
1549 def test_field_metadata_mapping(self):
1550 # Make sure only a mapping can be passed as metadata
1551 # zero length.
1552 with self.assertRaises(TypeError):
1553 @dataclass
1554 class C:
1555 i: int = field(metadata=0)
1556
1557 # Make sure an empty dict works
1558 @dataclass
1559 class C:
1560 i: int = field(metadata={})
1561 self.assertFalse(fields(C)[0].metadata)
1562 self.assertEqual(len(fields(C)[0].metadata), 0)
1563 with self.assertRaisesRegex(TypeError,
1564 'does not support item assignment'):
1565 fields(C)[0].metadata['test'] = 3
1566
1567 # Make sure a non-empty dict works.
1568 @dataclass
1569 class C:
1570 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1571 self.assertEqual(len(fields(C)[0].metadata), 3)
1572 self.assertEqual(fields(C)[0].metadata['test'], 10)
1573 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1574 self.assertEqual(fields(C)[0].metadata[3], 'three')
1575 with self.assertRaises(KeyError):
1576 # Non-existent key.
1577 fields(C)[0].metadata['baz']
1578 with self.assertRaisesRegex(TypeError,
1579 'does not support item assignment'):
1580 fields(C)[0].metadata['test'] = 3
1581
1582 def test_field_metadata_custom_mapping(self):
1583 # Try a custom mapping.
1584 class SimpleNameSpace:
1585 def __init__(self, **kw):
1586 self.__dict__.update(kw)
1587
1588 def __getitem__(self, item):
1589 if item == 'xyzzy':
1590 return 'plugh'
1591 return getattr(self, item)
1592
1593 def __len__(self):
1594 return self.__dict__.__len__()
1595
1596 @dataclass
1597 class C:
1598 i: int = field(metadata=SimpleNameSpace(a=10))
1599
1600 self.assertEqual(len(fields(C)[0].metadata), 1)
1601 self.assertEqual(fields(C)[0].metadata['a'], 10)
1602 with self.assertRaises(AttributeError):
1603 fields(C)[0].metadata['b']
1604 # Make sure we're still talking to our custom mapping.
1605 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1606
1607 def test_generic_dataclasses(self):
1608 T = TypeVar('T')
1609
1610 @dataclass
1611 class LabeledBox(Generic[T]):
1612 content: T
1613 label: str = '<unknown>'
1614
1615 box = LabeledBox(42)
1616 self.assertEqual(box.content, 42)
1617 self.assertEqual(box.label, '<unknown>')
1618
1619 # subscripting the resulting class should work, etc.
1620 Alias = List[LabeledBox[int]]
1621
1622 def test_generic_extending(self):
1623 S = TypeVar('S')
1624 T = TypeVar('T')
1625
1626 @dataclass
1627 class Base(Generic[T, S]):
1628 x: T
1629 y: S
1630
1631 @dataclass
1632 class DataDerived(Base[int, T]):
1633 new_field: str
1634 Alias = DataDerived[str]
1635 c = Alias(0, 'test1', 'test2')
1636 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1637
1638 class NonDataDerived(Base[int, T]):
1639 def new_method(self):
1640 return self.y
1641 Alias = NonDataDerived[float]
1642 c = Alias(10, 1.0)
1643 self.assertEqual(c.new_method(), 1.0)
1644
1645 def test_helper_replace(self):
1646 @dataclass(frozen=True)
1647 class C:
1648 x: int
1649 y: int
1650
1651 c = C(1, 2)
1652 c1 = replace(c, x=3)
1653 self.assertEqual(c1.x, 3)
1654 self.assertEqual(c1.y, 2)
1655
1656 def test_helper_replace_frozen(self):
1657 @dataclass(frozen=True)
1658 class C:
1659 x: int
1660 y: int
1661 z: int = field(init=False, default=10)
1662 t: int = field(init=False, default=100)
1663
1664 c = C(1, 2)
1665 c1 = replace(c, x=3)
1666 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1667 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1668
1669
1670 with self.assertRaisesRegex(ValueError, 'init=False'):
1671 replace(c, x=3, z=20, t=50)
1672 with self.assertRaisesRegex(ValueError, 'init=False'):
1673 replace(c, z=20)
1674 replace(c, x=3, z=20, t=50)
1675
1676 # Make sure the result is still frozen.
1677 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1678 c1.x = 3
1679
1680 # Make sure we can't replace an attribute that doesn't exist,
1681 # if we're also replacing one that does exist. Test this
1682 # here, because setting attributes on frozen instances is
1683 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001684 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001685 "keyword argument 'a'"):
1686 c1 = replace(c, x=20, a=5)
1687
1688 def test_helper_replace_invalid_field_name(self):
1689 @dataclass(frozen=True)
1690 class C:
1691 x: int
1692 y: int
1693
1694 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001695 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001696 "keyword argument 'z'"):
1697 c1 = replace(c, z=3)
1698
1699 def test_helper_replace_invalid_object(self):
1700 @dataclass(frozen=True)
1701 class C:
1702 x: int
1703 y: int
1704
1705 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1706 replace(C, x=3)
1707
1708 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1709 replace(0, x=3)
1710
1711 def test_helper_replace_no_init(self):
1712 @dataclass
1713 class C:
1714 x: int
1715 y: int = field(init=False, default=10)
1716
1717 c = C(1)
1718 c.y = 20
1719
1720 # Make sure y gets the default value.
1721 c1 = replace(c, x=5)
1722 self.assertEqual((c1.x, c1.y), (5, 10))
1723
1724 # Trying to replace y is an error.
1725 with self.assertRaisesRegex(ValueError, 'init=False'):
1726 replace(c, x=2, y=30)
1727 with self.assertRaisesRegex(ValueError, 'init=False'):
1728 replace(c, y=30)
1729
1730 def test_dataclassses_pickleable(self):
1731 global P, Q, R
1732 @dataclass
1733 class P:
1734 x: int
1735 y: int = 0
1736 @dataclass
1737 class Q:
1738 x: int
1739 y: int = field(default=0, init=False)
1740 @dataclass
1741 class R:
1742 x: int
1743 y: List[int] = field(default_factory=list)
1744 q = Q(1)
1745 q.y = 2
1746 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1747 for sample in samples:
1748 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1749 with self.subTest(sample=sample, proto=proto):
1750 new_sample = pickle.loads(pickle.dumps(sample, proto))
1751 self.assertEqual(sample.x, new_sample.x)
1752 self.assertEqual(sample.y, new_sample.y)
1753 self.assertIsNot(sample, new_sample)
1754 new_sample.x = 42
1755 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1756 self.assertEqual(new_sample.x, another_new_sample.x)
1757 self.assertEqual(sample.y, another_new_sample.y)
1758
1759 def test_helper_make_dataclass(self):
1760 C = make_dataclass('C',
1761 [('x', int),
1762 ('y', int, field(default=5))],
1763 namespace={'add_one': lambda self: self.x + 1})
1764 c = C(10)
1765 self.assertEqual((c.x, c.y), (10, 5))
1766 self.assertEqual(c.add_one(), 11)
1767
1768
1769 def test_helper_make_dataclass_no_mutate_namespace(self):
1770 # Make sure a provided namespace isn't mutated.
1771 ns = {}
1772 C = make_dataclass('C',
1773 [('x', int),
1774 ('y', int, field(default=5))],
1775 namespace=ns)
1776 self.assertEqual(ns, {})
1777
1778 def test_helper_make_dataclass_base(self):
1779 class Base1:
1780 pass
1781 class Base2:
1782 pass
1783 C = make_dataclass('C',
1784 [('x', int)],
1785 bases=(Base1, Base2))
1786 c = C(2)
1787 self.assertIsInstance(c, C)
1788 self.assertIsInstance(c, Base1)
1789 self.assertIsInstance(c, Base2)
1790
1791 def test_helper_make_dataclass_base_dataclass(self):
1792 @dataclass
1793 class Base1:
1794 x: int
1795 class Base2:
1796 pass
1797 C = make_dataclass('C',
1798 [('y', int)],
1799 bases=(Base1, Base2))
1800 with self.assertRaisesRegex(TypeError, 'required positional'):
1801 c = C(2)
1802 c = C(1, 2)
1803 self.assertIsInstance(c, C)
1804 self.assertIsInstance(c, Base1)
1805 self.assertIsInstance(c, Base2)
1806
1807 self.assertEqual((c.x, c.y), (1, 2))
1808
1809 def test_helper_make_dataclass_init_var(self):
1810 def post_init(self, y):
1811 self.x *= y
1812
1813 C = make_dataclass('C',
1814 [('x', int),
1815 ('y', InitVar[int]),
1816 ],
1817 namespace={'__post_init__': post_init},
1818 )
1819 c = C(2, 3)
1820 self.assertEqual(vars(c), {'x': 6})
1821 self.assertEqual(len(fields(c)), 1)
1822
1823 def test_helper_make_dataclass_class_var(self):
1824 C = make_dataclass('C',
1825 [('x', int),
1826 ('y', ClassVar[int], 10),
1827 ('z', ClassVar[int], field(default=20)),
1828 ])
1829 c = C(1)
1830 self.assertEqual(vars(c), {'x': 1})
1831 self.assertEqual(len(fields(c)), 1)
1832 self.assertEqual(C.y, 10)
1833 self.assertEqual(C.z, 20)
1834
Eric V. Smithd80b4432018-01-06 17:09:58 -05001835 def test_helper_make_dataclass_other_params(self):
1836 C = make_dataclass('C',
1837 [('x', int),
1838 ('y', ClassVar[int], 10),
1839 ('z', ClassVar[int], field(default=20)),
1840 ],
1841 init=False)
1842 # Make sure we have a repr, but no init.
1843 self.assertNotIn('__init__', vars(C))
1844 self.assertIn('__repr__', vars(C))
1845
1846 # Make sure random other params don't work.
1847 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1848 C = make_dataclass('C',
1849 [],
1850 xxinit=False)
1851
Eric V. Smithed7d4292018-01-06 16:14:03 -05001852 def test_helper_make_dataclass_no_types(self):
1853 C = make_dataclass('Point', ['x', 'y', 'z'])
1854 c = C(1, 2, 3)
1855 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1856 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1857 'y': 'typing.Any',
1858 'z': 'typing.Any'})
1859
1860 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1861 c = C(1, 2, 3)
1862 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1863 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1864 'y': int,
1865 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001866
Eric V. Smithea8fc522018-01-27 19:07:40 -05001867
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001868class TestDocString(unittest.TestCase):
1869 def assertDocStrEqual(self, a, b):
1870 # Because 3.6 and 3.7 differ in how inspect.signature work
1871 # (see bpo #32108), for the time being just compare them with
1872 # whitespace stripped.
1873 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1874
1875 def test_existing_docstring_not_overridden(self):
1876 @dataclass
1877 class C:
1878 """Lorem ipsum"""
1879 x: int
1880
1881 self.assertEqual(C.__doc__, "Lorem ipsum")
1882
1883 def test_docstring_no_fields(self):
1884 @dataclass
1885 class C:
1886 pass
1887
1888 self.assertDocStrEqual(C.__doc__, "C()")
1889
1890 def test_docstring_one_field(self):
1891 @dataclass
1892 class C:
1893 x: int
1894
1895 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1896
1897 def test_docstring_two_fields(self):
1898 @dataclass
1899 class C:
1900 x: int
1901 y: int
1902
1903 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1904
1905 def test_docstring_three_fields(self):
1906 @dataclass
1907 class C:
1908 x: int
1909 y: int
1910 z: str
1911
1912 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1913
1914 def test_docstring_one_field_with_default(self):
1915 @dataclass
1916 class C:
1917 x: int = 3
1918
1919 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1920
1921 def test_docstring_one_field_with_default_none(self):
1922 @dataclass
1923 class C:
1924 x: Union[int, type(None)] = None
1925
1926 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1927
1928 def test_docstring_list_field(self):
1929 @dataclass
1930 class C:
1931 x: List[int]
1932
1933 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1934
1935 def test_docstring_list_field_with_default_factory(self):
1936 @dataclass
1937 class C:
1938 x: List[int] = field(default_factory=list)
1939
1940 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1941
1942 def test_docstring_deque_field(self):
1943 @dataclass
1944 class C:
1945 x: deque
1946
1947 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1948
1949 def test_docstring_deque_field_with_default_factory(self):
1950 @dataclass
1951 class C:
1952 x: deque = field(default_factory=deque)
1953
1954 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1955
1956
Eric V. Smithea8fc522018-01-27 19:07:40 -05001957class TestInit(unittest.TestCase):
1958 def test_base_has_init(self):
1959 class B:
1960 def __init__(self):
1961 self.z = 100
1962 pass
1963
1964 # Make sure that declaring this class doesn't raise an error.
1965 # The issue is that we can't override __init__ in our class,
1966 # but it should be okay to add __init__ to us if our base has
1967 # an __init__.
1968 @dataclass
1969 class C(B):
1970 x: int = 0
1971 c = C(10)
1972 self.assertEqual(c.x, 10)
1973 self.assertNotIn('z', vars(c))
1974
1975 # Make sure that if we don't add an init, the base __init__
1976 # gets called.
1977 @dataclass(init=False)
1978 class C(B):
1979 x: int = 10
1980 c = C()
1981 self.assertEqual(c.x, 10)
1982 self.assertEqual(c.z, 100)
1983
1984 def test_no_init(self):
1985 dataclass(init=False)
1986 class C:
1987 i: int = 0
1988 self.assertEqual(C().i, 0)
1989
1990 dataclass(init=False)
1991 class C:
1992 i: int = 2
1993 def __init__(self):
1994 self.i = 3
1995 self.assertEqual(C().i, 3)
1996
1997 def test_overwriting_init(self):
1998 # If the class has __init__, use it no matter the value of
1999 # init=.
2000
2001 @dataclass
2002 class C:
2003 x: int
2004 def __init__(self, x):
2005 self.x = 2 * x
2006 self.assertEqual(C(3).x, 6)
2007
2008 @dataclass(init=True)
2009 class C:
2010 x: int
2011 def __init__(self, x):
2012 self.x = 2 * x
2013 self.assertEqual(C(4).x, 8)
2014
2015 @dataclass(init=False)
2016 class C:
2017 x: int
2018 def __init__(self, x):
2019 self.x = 2 * x
2020 self.assertEqual(C(5).x, 10)
2021
2022
2023class TestRepr(unittest.TestCase):
2024 def test_repr(self):
2025 @dataclass
2026 class B:
2027 x: int
2028
2029 @dataclass
2030 class C(B):
2031 y: int = 10
2032
2033 o = C(4)
2034 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2035
2036 @dataclass
2037 class D(C):
2038 x: int = 20
2039 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2040
2041 @dataclass
2042 class C:
2043 @dataclass
2044 class D:
2045 i: int
2046 @dataclass
2047 class E:
2048 pass
2049 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2050 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2051
2052 def test_no_repr(self):
2053 # Test a class with no __repr__ and repr=False.
2054 @dataclass(repr=False)
2055 class C:
2056 x: int
2057 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2058 repr(C(3)))
2059
2060 # Test a class with a __repr__ and repr=False.
2061 @dataclass(repr=False)
2062 class C:
2063 x: int
2064 def __repr__(self):
2065 return 'C-class'
2066 self.assertEqual(repr(C(3)), 'C-class')
2067
2068 def test_overwriting_repr(self):
2069 # If the class has __repr__, use it no matter the value of
2070 # repr=.
2071
2072 @dataclass
2073 class C:
2074 x: int
2075 def __repr__(self):
2076 return 'x'
2077 self.assertEqual(repr(C(0)), 'x')
2078
2079 @dataclass(repr=True)
2080 class C:
2081 x: int
2082 def __repr__(self):
2083 return 'x'
2084 self.assertEqual(repr(C(0)), 'x')
2085
2086 @dataclass(repr=False)
2087 class C:
2088 x: int
2089 def __repr__(self):
2090 return 'x'
2091 self.assertEqual(repr(C(0)), 'x')
2092
2093
2094class TestFrozen(unittest.TestCase):
2095 def test_overwriting_frozen(self):
2096 # frozen uses __setattr__ and __delattr__
2097 with self.assertRaisesRegex(TypeError,
2098 'Cannot overwrite attribute __setattr__'):
2099 @dataclass(frozen=True)
2100 class C:
2101 x: int
2102 def __setattr__(self):
2103 pass
2104
2105 with self.assertRaisesRegex(TypeError,
2106 'Cannot overwrite attribute __delattr__'):
2107 @dataclass(frozen=True)
2108 class C:
2109 x: int
2110 def __delattr__(self):
2111 pass
2112
2113 @dataclass(frozen=False)
2114 class C:
2115 x: int
2116 def __setattr__(self, name, value):
2117 self.__dict__['x'] = value * 2
2118 self.assertEqual(C(10).x, 20)
2119
2120
2121class TestEq(unittest.TestCase):
2122 def test_no_eq(self):
2123 # Test a class with no __eq__ and eq=False.
2124 @dataclass(eq=False)
2125 class C:
2126 x: int
2127 self.assertNotEqual(C(0), C(0))
2128 c = C(3)
2129 self.assertEqual(c, c)
2130
2131 # Test a class with an __eq__ and eq=False.
2132 @dataclass(eq=False)
2133 class C:
2134 x: int
2135 def __eq__(self, other):
2136 return other == 10
2137 self.assertEqual(C(3), 10)
2138
2139 def test_overwriting_eq(self):
2140 # If the class has __eq__, use it no matter the value of
2141 # eq=.
2142
2143 @dataclass
2144 class C:
2145 x: int
2146 def __eq__(self, other):
2147 return other == 3
2148 self.assertEqual(C(1), 3)
2149 self.assertNotEqual(C(1), 1)
2150
2151 @dataclass(eq=True)
2152 class C:
2153 x: int
2154 def __eq__(self, other):
2155 return other == 4
2156 self.assertEqual(C(1), 4)
2157 self.assertNotEqual(C(1), 1)
2158
2159 @dataclass(eq=False)
2160 class C:
2161 x: int
2162 def __eq__(self, other):
2163 return other == 5
2164 self.assertEqual(C(1), 5)
2165 self.assertNotEqual(C(1), 1)
2166
2167
2168class TestOrdering(unittest.TestCase):
2169 def test_functools_total_ordering(self):
2170 # Test that functools.total_ordering works with this class.
2171 @total_ordering
2172 @dataclass
2173 class C:
2174 x: int
2175 def __lt__(self, other):
2176 # Perform the test "backward", just to make
2177 # sure this is being called.
2178 return self.x >= other
2179
2180 self.assertLess(C(0), -1)
2181 self.assertLessEqual(C(0), -1)
2182 self.assertGreater(C(0), 1)
2183 self.assertGreaterEqual(C(0), 1)
2184
2185 def test_no_order(self):
2186 # Test that no ordering functions are added by default.
2187 @dataclass(order=False)
2188 class C:
2189 x: int
2190 # Make sure no order methods are added.
2191 self.assertNotIn('__le__', C.__dict__)
2192 self.assertNotIn('__lt__', C.__dict__)
2193 self.assertNotIn('__ge__', C.__dict__)
2194 self.assertNotIn('__gt__', C.__dict__)
2195
2196 # Test that __lt__ is still called
2197 @dataclass(order=False)
2198 class C:
2199 x: int
2200 def __lt__(self, other):
2201 return False
2202 # Make sure other methods aren't added.
2203 self.assertNotIn('__le__', C.__dict__)
2204 self.assertNotIn('__ge__', C.__dict__)
2205 self.assertNotIn('__gt__', C.__dict__)
2206
2207 def test_overwriting_order(self):
2208 with self.assertRaisesRegex(TypeError,
2209 'Cannot overwrite attribute __lt__'
2210 '.*using functools.total_ordering'):
2211 @dataclass(order=True)
2212 class C:
2213 x: int
2214 def __lt__(self):
2215 pass
2216
2217 with self.assertRaisesRegex(TypeError,
2218 'Cannot overwrite attribute __le__'
2219 '.*using functools.total_ordering'):
2220 @dataclass(order=True)
2221 class C:
2222 x: int
2223 def __le__(self):
2224 pass
2225
2226 with self.assertRaisesRegex(TypeError,
2227 'Cannot overwrite attribute __gt__'
2228 '.*using functools.total_ordering'):
2229 @dataclass(order=True)
2230 class C:
2231 x: int
2232 def __gt__(self):
2233 pass
2234
2235 with self.assertRaisesRegex(TypeError,
2236 'Cannot overwrite attribute __ge__'
2237 '.*using functools.total_ordering'):
2238 @dataclass(order=True)
2239 class C:
2240 x: int
2241 def __ge__(self):
2242 pass
2243
2244class TestHash(unittest.TestCase):
2245 def test_hash(self):
2246 @dataclass(hash=True)
2247 class C:
2248 x: int
2249 y: str
2250 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2251
2252 def test_hash_false(self):
2253 @dataclass(hash=False)
2254 class C:
2255 x: int
2256 y: str
2257 self.assertNotEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2258
2259 def test_hash_none(self):
2260 @dataclass(hash=None)
2261 class C:
2262 x: int
2263 with self.assertRaisesRegex(TypeError,
2264 "unhashable type: 'C'"):
2265 hash(C(1))
2266
2267 def test_hash_rules(self):
2268 def non_bool(value):
2269 # Map to something else that's True, but not a bool.
2270 if value is None:
2271 return None
2272 if value:
2273 return (3,)
2274 return 0
2275
2276 def test(case, hash, eq, frozen, with_hash, result):
2277 with self.subTest(case=case, hash=hash, eq=eq, frozen=frozen):
2278 if with_hash:
2279 @dataclass(hash=hash, eq=eq, frozen=frozen)
2280 class C:
2281 def __hash__(self):
2282 return 0
2283 else:
2284 @dataclass(hash=hash, eq=eq, frozen=frozen)
2285 class C:
2286 pass
2287
2288 # See if the result matches what's expected.
2289 if result in ('fn', 'fn-x'):
2290 # __hash__ contains the function we generated.
2291 self.assertIn('__hash__', C.__dict__)
2292 self.assertIsNotNone(C.__dict__['__hash__'])
2293
2294 if result == 'fn-x':
2295 # This is the "auto-hash test" case. We
2296 # should overwrite __hash__ iff there's an
2297 # __eq__ and if __hash__=None.
2298
2299 # There are two ways of getting __hash__=None:
2300 # explicitely, and by defining __eq__. If
2301 # __eq__ is defined, python will add __hash__
2302 # when the class is created.
2303 @dataclass(hash=hash, eq=eq, frozen=frozen)
2304 class C:
2305 def __eq__(self, other): pass
2306 __hash__ = None
2307
2308 # Hash should be overwritten (non-None).
2309 self.assertIsNotNone(C.__dict__['__hash__'])
2310
2311 # Same test as above, but we don't provide
2312 # __hash__, it will implicitely set to None.
2313 @dataclass(hash=hash, eq=eq, frozen=frozen)
2314 class C:
2315 def __eq__(self, other): pass
2316
2317 # Hash should be overwritten (non-None).
2318 self.assertIsNotNone(C.__dict__['__hash__'])
2319
2320 elif result == '':
2321 # __hash__ is not present in our class.
2322 if not with_hash:
2323 self.assertNotIn('__hash__', C.__dict__)
2324 elif result == 'none':
2325 # __hash__ is set to None.
2326 self.assertIn('__hash__', C.__dict__)
2327 self.assertIsNone(C.__dict__['__hash__'])
2328 else:
2329 assert False, f'unknown result {result!r}'
2330
2331 # There are 12 cases of:
2332 # hash=True/False/None
2333 # eq=True/False
2334 # frozen=True/False
2335 # And for each of these, a different result if
2336 # __hash__ is defined or not.
2337 for case, (hash, eq, frozen, result_no, result_yes) in enumerate([
2338 (None, False, False, '', ''),
2339 (None, False, True, '', ''),
2340 (None, True, False, 'none', ''),
2341 (None, True, True, 'fn', 'fn-x'),
2342 (False, False, False, '', ''),
2343 (False, False, True, '', ''),
2344 (False, True, False, '', ''),
2345 (False, True, True, '', ''),
2346 (True, False, False, 'fn', 'fn-x'),
2347 (True, False, True, 'fn', 'fn-x'),
2348 (True, True, False, 'fn', 'fn-x'),
2349 (True, True, True, 'fn', 'fn-x'),
2350 ], 1):
2351 test(case, hash, eq, frozen, False, result_no)
2352 test(case, hash, eq, frozen, True, result_yes)
2353
2354 # Test non-bool truth values, too. This is just to
2355 # make sure the data-driven table in the decorator
2356 # handles non-bool values.
2357 test(case, non_bool(hash), non_bool(eq), non_bool(frozen), False, result_no)
2358 test(case, non_bool(hash), non_bool(eq), non_bool(frozen), True, result_yes)
2359
2360
2361 def test_eq_only(self):
2362 # If a class defines __eq__, __hash__ is automatically added
2363 # and set to None. This is normal Python behavior, not
2364 # related to dataclasses. Make sure we don't interfere with
2365 # that (see bpo=32546).
2366
2367 @dataclass
2368 class C:
2369 i: int
2370 def __eq__(self, other):
2371 return self.i == other.i
2372 self.assertEqual(C(1), C(1))
2373 self.assertNotEqual(C(1), C(4))
2374
2375 # And make sure things work in this case if we specify
2376 # hash=True.
2377 @dataclass(hash=True)
2378 class C:
2379 i: int
2380 def __eq__(self, other):
2381 return self.i == other.i
2382 self.assertEqual(C(1), C(1.0))
2383 self.assertEqual(hash(C(1)), hash(C(1.0)))
2384
2385 # And check that the classes __eq__ is being used, despite
2386 # specifying eq=True.
2387 @dataclass(hash=True, eq=True)
2388 class C:
2389 i: int
2390 def __eq__(self, other):
2391 return self.i == 3 and self.i == other.i
2392 self.assertEqual(C(3), C(3))
2393 self.assertNotEqual(C(1), C(1))
2394 self.assertEqual(hash(C(1)), hash(C(1.0)))
2395
2396
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002397if __name__ == '__main__':
2398 unittest.main()