blob: 9b5aad25745f7ff2d83f2854b9e21dc90fcfb60b [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
Eric V. Smithe7ba0132018-01-06 12:41:53 -05003 make_dataclass, replace, InitVar, Field, MISSING, is_dataclass,
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050012from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013
14# Just any custom exception we can catch.
15class CustomError(Exception): pass
16
17class TestCase(unittest.TestCase):
18 def test_no_fields(self):
19 @dataclass
20 class C:
21 pass
22
23 o = C()
24 self.assertEqual(len(fields(C)), 0)
25
26 def test_one_field_no_default(self):
27 @dataclass
28 class C:
29 x: int
30
31 o = C(42)
32 self.assertEqual(o.x, 42)
33
34 def test_named_init_params(self):
35 @dataclass
36 class C:
37 x: int
38
39 o = C(x=32)
40 self.assertEqual(o.x, 32)
41
42 def test_two_fields_one_default(self):
43 @dataclass
44 class C:
45 x: int
46 y: int = 0
47
48 o = C(3)
49 self.assertEqual((o.x, o.y), (3, 0))
50
51 # Non-defaults following defaults.
52 with self.assertRaisesRegex(TypeError,
53 "non-default argument 'y' follows "
54 "default argument"):
55 @dataclass
56 class C:
57 x: int = 0
58 y: int
59
60 # A derived class adds a non-default field after a default one.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class B:
66 x: int = 0
67
68 @dataclass
69 class C(B):
70 y: int
71
72 # Override a base class field and add a default to
73 # a field which didn't use to have a default.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int
80 y: int
81
82 @dataclass
83 class C(B):
84 x: int = 0
85
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080086 def test_overwrite_hash(self):
87 # Test that declaring this class isn't an error. It should
88 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050089 @dataclass(frozen=True)
90 class C:
91 x: int
92 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080093 return 301
94 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -050095
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080096 # Test that declaring this class isn't an error. It should
97 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800101 def __eq__(self, other):
102 return False
103 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # But this one should generate an exception, because with
106 # unsafe_hash=True, it's an error to have a __hash__ defined.
107 with self.assertRaisesRegex(TypeError,
108 'Cannot overwrite attribute __hash__'):
109 @dataclass(unsafe_hash=True)
110 class C:
111 def __hash__(self):
112 pass
113
114 # Creating this class should not generate an exception,
115 # because even though __hash__ exists before @dataclass is
116 # called, (due to __eq__ being defined), since it's None
117 # that's okay.
118 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500119 class C:
120 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800121 def __eq__(self):
122 pass
123 # The generated hash function works as we'd expect.
124 self.assertEqual(hash(C(10)), hash((10,)))
125
126 # Creating this class should generate an exception, because
127 # __hash__ exists and is not None, which it would be if it had
128 # been auto-generated do due __eq__ being defined.
129 with self.assertRaisesRegex(TypeError,
130 'Cannot overwrite attribute __hash__'):
131 @dataclass(unsafe_hash=True)
132 class C:
133 x: int
134 def __eq__(self):
135 pass
136 def __hash__(self):
137 pass
138
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500140 def test_overwrite_fields_in_derived_class(self):
141 # Note that x from C1 replaces x in Base, but the order remains
142 # the same as defined in Base.
143 @dataclass
144 class Base:
145 x: Any = 15.0
146 y: int = 0
147
148 @dataclass
149 class C1(Base):
150 z: int = 10
151 x: int = 15
152
153 o = Base()
154 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
155
156 o = C1()
157 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
158
159 o = C1(x=5)
160 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
161
162 def test_field_named_self(self):
163 @dataclass
164 class C:
165 self: str
166 c=C('foo')
167 self.assertEqual(c.self, 'foo')
168
169 # Make sure the first parameter is not named 'self'.
170 sig = inspect.signature(C.__init__)
171 first = next(iter(sig.parameters))
172 self.assertNotEqual('self', first)
173
174 # But we do use 'self' if no field named self.
175 @dataclass
176 class C:
177 selfx: str
178
179 # Make sure the first parameter is named 'self'.
180 sig = inspect.signature(C.__init__)
181 first = next(iter(sig.parameters))
182 self.assertEqual('self', first)
183
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500184 def test_0_field_compare(self):
185 # Ensure that order=False is the default.
186 @dataclass
187 class C0:
188 pass
189
190 @dataclass(order=False)
191 class C1:
192 pass
193
194 for cls in [C0, C1]:
195 with self.subTest(cls=cls):
196 self.assertEqual(cls(), cls())
197 for idx, fn in enumerate([lambda a, b: a < b,
198 lambda a, b: a <= b,
199 lambda a, b: a > b,
200 lambda a, b: a >= b]):
201 with self.subTest(idx=idx):
202 with self.assertRaisesRegex(TypeError,
203 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
204 fn(cls(), cls())
205
206 @dataclass(order=True)
207 class C:
208 pass
209 self.assertLessEqual(C(), C())
210 self.assertGreaterEqual(C(), C())
211
212 def test_1_field_compare(self):
213 # Ensure that order=False is the default.
214 @dataclass
215 class C0:
216 x: int
217
218 @dataclass(order=False)
219 class C1:
220 x: int
221
222 for cls in [C0, C1]:
223 with self.subTest(cls=cls):
224 self.assertEqual(cls(1), cls(1))
225 self.assertNotEqual(cls(0), cls(1))
226 for idx, fn in enumerate([lambda a, b: a < b,
227 lambda a, b: a <= b,
228 lambda a, b: a > b,
229 lambda a, b: a >= b]):
230 with self.subTest(idx=idx):
231 with self.assertRaisesRegex(TypeError,
232 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
233 fn(cls(0), cls(0))
234
235 @dataclass(order=True)
236 class C:
237 x: int
238 self.assertLess(C(0), C(1))
239 self.assertLessEqual(C(0), C(1))
240 self.assertLessEqual(C(1), C(1))
241 self.assertGreater(C(1), C(0))
242 self.assertGreaterEqual(C(1), C(0))
243 self.assertGreaterEqual(C(1), C(1))
244
245 def test_simple_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 x: int
250 y: int
251
252 @dataclass(order=False)
253 class C1:
254 x: int
255 y: int
256
257 for cls in [C0, C1]:
258 with self.subTest(cls=cls):
259 self.assertEqual(cls(0, 0), cls(0, 0))
260 self.assertEqual(cls(1, 2), cls(1, 2))
261 self.assertNotEqual(cls(1, 0), cls(0, 0))
262 self.assertNotEqual(cls(1, 0), cls(1, 1))
263 for idx, fn in enumerate([lambda a, b: a < b,
264 lambda a, b: a <= b,
265 lambda a, b: a > b,
266 lambda a, b: a >= b]):
267 with self.subTest(idx=idx):
268 with self.assertRaisesRegex(TypeError,
269 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
270 fn(cls(0, 0), cls(0, 0))
271
272 @dataclass(order=True)
273 class C:
274 x: int
275 y: int
276
277 for idx, fn in enumerate([lambda a, b: a == b,
278 lambda a, b: a <= b,
279 lambda a, b: a >= b]):
280 with self.subTest(idx=idx):
281 self.assertTrue(fn(C(0, 0), C(0, 0)))
282
283 for idx, fn in enumerate([lambda a, b: a < b,
284 lambda a, b: a <= b,
285 lambda a, b: a != b]):
286 with self.subTest(idx=idx):
287 self.assertTrue(fn(C(0, 0), C(0, 1)))
288 self.assertTrue(fn(C(0, 1), C(1, 0)))
289 self.assertTrue(fn(C(1, 0), C(1, 1)))
290
291 for idx, fn in enumerate([lambda a, b: a > b,
292 lambda a, b: a >= b,
293 lambda a, b: a != b]):
294 with self.subTest(idx=idx):
295 self.assertTrue(fn(C(0, 1), C(0, 0)))
296 self.assertTrue(fn(C(1, 0), C(0, 1)))
297 self.assertTrue(fn(C(1, 1), C(1, 0)))
298
299 def test_compare_subclasses(self):
300 # Comparisons fail for subclasses, even if no fields
301 # are added.
302 @dataclass
303 class B:
304 i: int
305
306 @dataclass
307 class C(B):
308 pass
309
310 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
311 (lambda a, b: a != b, True)]):
312 with self.subTest(idx=idx):
313 self.assertEqual(fn(B(0), C(0)), expected)
314
315 for idx, fn in enumerate([lambda a, b: a < b,
316 lambda a, b: a <= b,
317 lambda a, b: a > b,
318 lambda a, b: a >= b]):
319 with self.subTest(idx=idx):
320 with self.assertRaisesRegex(TypeError,
321 "not supported between instances of 'B' and 'C'"):
322 fn(B(0), C(0))
323
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500324 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500325 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500326 for (eq, order, result ) in [
327 (False, False, 'neither'),
328 (False, True, 'exception'),
329 (True, False, 'eq_only'),
330 (True, True, 'both'),
331 ]:
332 with self.subTest(eq=eq, order=order):
333 if result == 'exception':
334 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
335 @dataclass(eq=eq, order=order)
336 class C:
337 pass
338 else:
339 @dataclass(eq=eq, order=order)
340 class C:
341 pass
342
343 if result == 'neither':
344 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500345 self.assertNotIn('__lt__', C.__dict__)
346 self.assertNotIn('__le__', C.__dict__)
347 self.assertNotIn('__gt__', C.__dict__)
348 self.assertNotIn('__ge__', C.__dict__)
349 elif result == 'both':
350 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500351 self.assertIn('__lt__', C.__dict__)
352 self.assertIn('__le__', C.__dict__)
353 self.assertIn('__gt__', C.__dict__)
354 self.assertIn('__ge__', C.__dict__)
355 elif result == 'eq_only':
356 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500357 self.assertNotIn('__lt__', C.__dict__)
358 self.assertNotIn('__le__', C.__dict__)
359 self.assertNotIn('__gt__', C.__dict__)
360 self.assertNotIn('__ge__', C.__dict__)
361 else:
362 assert False, f'unknown result {result!r}'
363
364 def test_field_no_default(self):
365 @dataclass
366 class C:
367 x: int = field()
368
369 self.assertEqual(C(5).x, 5)
370
371 with self.assertRaisesRegex(TypeError,
372 r"__init__\(\) missing 1 required "
373 "positional argument: 'x'"):
374 C()
375
376 def test_field_default(self):
377 default = object()
378 @dataclass
379 class C:
380 x: object = field(default=default)
381
382 self.assertIs(C.x, default)
383 c = C(10)
384 self.assertEqual(c.x, 10)
385
386 # If we delete the instance attribute, we should then see the
387 # class attribute.
388 del c.x
389 self.assertIs(c.x, default)
390
391 self.assertIs(C().x, default)
392
393 def test_not_in_repr(self):
394 @dataclass
395 class C:
396 x: int = field(repr=False)
397 with self.assertRaises(TypeError):
398 C()
399 c = C(10)
400 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
401
402 @dataclass
403 class C:
404 x: int = field(repr=False)
405 y: int
406 c = C(10, 20)
407 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
408
409 def test_not_in_compare(self):
410 @dataclass
411 class C:
412 x: int = 0
413 y: int = field(compare=False, default=4)
414
415 self.assertEqual(C(), C(0, 20))
416 self.assertEqual(C(1, 10), C(1, 20))
417 self.assertNotEqual(C(3), C(4, 10))
418 self.assertNotEqual(C(3, 10), C(4, 10))
419
420 def test_hash_field_rules(self):
421 # Test all 6 cases of:
422 # hash=True/False/None
423 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800424 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500425 (True, False, 'field' ),
426 (True, True, 'field' ),
427 (False, False, 'absent'),
428 (False, True, 'absent'),
429 (None, False, 'absent'),
430 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800431 ]:
432 with self.subTest(hash=hash_, compare=compare):
433 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800435 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500436
437 if result == 'field':
438 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800439 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 elif result == 'absent':
441 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800442 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500443 else:
444 assert False, f'unknown result {result!r}'
445
446 def test_init_false_no_default(self):
447 # If init=False and no default value, then the field won't be
448 # present in the instance.
449 @dataclass
450 class C:
451 x: int = field(init=False)
452
453 self.assertNotIn('x', C().__dict__)
454
455 @dataclass
456 class C:
457 x: int
458 y: int = 0
459 z: int = field(init=False)
460 t: int = 10
461
462 self.assertNotIn('z', C(0).__dict__)
463 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
464
465 def test_class_marker(self):
466 @dataclass
467 class C:
468 x: int
469 y: str = field(init=False, default=None)
470 z: str = field(repr=False)
471
472 the_fields = fields(C)
473 # the_fields is a tuple of 3 items, each value
474 # is in __annotations__.
475 self.assertIsInstance(the_fields, tuple)
476 for f in the_fields:
477 self.assertIs(type(f), Field)
478 self.assertIn(f.name, C.__annotations__)
479
480 self.assertEqual(len(the_fields), 3)
481
482 self.assertEqual(the_fields[0].name, 'x')
483 self.assertEqual(the_fields[0].type, int)
484 self.assertFalse(hasattr(C, 'x'))
485 self.assertTrue (the_fields[0].init)
486 self.assertTrue (the_fields[0].repr)
487 self.assertEqual(the_fields[1].name, 'y')
488 self.assertEqual(the_fields[1].type, str)
489 self.assertIsNone(getattr(C, 'y'))
490 self.assertFalse(the_fields[1].init)
491 self.assertTrue (the_fields[1].repr)
492 self.assertEqual(the_fields[2].name, 'z')
493 self.assertEqual(the_fields[2].type, str)
494 self.assertFalse(hasattr(C, 'z'))
495 self.assertTrue (the_fields[2].init)
496 self.assertFalse(the_fields[2].repr)
497
498 def test_field_order(self):
499 @dataclass
500 class B:
501 a: str = 'B:a'
502 b: str = 'B:b'
503 c: str = 'B:c'
504
505 @dataclass
506 class C(B):
507 b: str = 'C:b'
508
509 self.assertEqual([(f.name, f.default) for f in fields(C)],
510 [('a', 'B:a'),
511 ('b', 'C:b'),
512 ('c', 'B:c')])
513
514 @dataclass
515 class D(B):
516 c: str = 'D:c'
517
518 self.assertEqual([(f.name, f.default) for f in fields(D)],
519 [('a', 'B:a'),
520 ('b', 'B:b'),
521 ('c', 'D:c')])
522
523 @dataclass
524 class E(D):
525 a: str = 'E:a'
526 d: str = 'E:d'
527
528 self.assertEqual([(f.name, f.default) for f in fields(E)],
529 [('a', 'E:a'),
530 ('b', 'B:b'),
531 ('c', 'D:c'),
532 ('d', 'E:d')])
533
534 def test_class_attrs(self):
535 # We only have a class attribute if a default value is
536 # specified, either directly or via a field with a default.
537 default = object()
538 @dataclass
539 class C:
540 x: int
541 y: int = field(repr=False)
542 z: object = default
543 t: int = field(default=100)
544
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertFalse(hasattr(C, 'y'))
547 self.assertIs (C.z, default)
548 self.assertEqual(C.t, 100)
549
550 def test_disallowed_mutable_defaults(self):
551 # For the known types, don't allow mutable default values.
552 for typ, empty, non_empty in [(list, [], [1]),
553 (dict, {}, {0:1}),
554 (set, set(), set([1])),
555 ]:
556 with self.subTest(typ=typ):
557 # Can't use a zero-length value.
558 with self.assertRaisesRegex(ValueError,
559 f'mutable default {typ} for field '
560 'x is not allowed'):
561 @dataclass
562 class Point:
563 x: typ = empty
564
565
566 # Nor a non-zero-length value
567 with self.assertRaisesRegex(ValueError,
568 f'mutable default {typ} for field '
569 'y is not allowed'):
570 @dataclass
571 class Point:
572 y: typ = non_empty
573
574 # Check subtypes also fail.
575 class Subclass(typ): pass
576
577 with self.assertRaisesRegex(ValueError,
578 f"mutable default .*Subclass'>"
579 ' for field z is not allowed'
580 ):
581 @dataclass
582 class Point:
583 z: typ = Subclass()
584
585 # Because this is a ClassVar, it can be mutable.
586 @dataclass
587 class C:
588 z: ClassVar[typ] = typ()
589
590 # Because this is a ClassVar, it can be mutable.
591 @dataclass
592 class C:
593 x: ClassVar[typ] = Subclass()
594
595
596 def test_deliberately_mutable_defaults(self):
597 # If a mutable default isn't in the known list of
598 # (list, dict, set), then it's okay.
599 class Mutable:
600 def __init__(self):
601 self.l = []
602
603 @dataclass
604 class C:
605 x: Mutable
606
607 # These 2 instances will share this value of x.
608 lst = Mutable()
609 o1 = C(lst)
610 o2 = C(lst)
611 self.assertEqual(o1, o2)
612 o1.x.l.extend([1, 2])
613 self.assertEqual(o1, o2)
614 self.assertEqual(o1.x.l, [1, 2])
615 self.assertIs(o1.x, o2.x)
616
617 def test_no_options(self):
618 # call with dataclass()
619 @dataclass()
620 class C:
621 x: int
622
623 self.assertEqual(C(42).x, 42)
624
625 def test_not_tuple(self):
626 # Make sure we can't be compared to a tuple.
627 @dataclass
628 class Point:
629 x: int
630 y: int
631 self.assertNotEqual(Point(1, 2), (1, 2))
632
633 # And that we can't compare to another unrelated dataclass
634 @dataclass
635 class C:
636 x: int
637 y: int
638 self.assertNotEqual(Point(1, 3), C(1, 3))
639
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500640 def test_not_tuple(self):
641 # Test that some of the problems with namedtuple don't happen
642 # here.
643 @dataclass
644 class Point3D:
645 x: int
646 y: int
647 z: int
648
649 @dataclass
650 class Date:
651 year: int
652 month: int
653 day: int
654
655 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
656 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
657
658 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200659 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500660 x, y, z = Point3D(4, 5, 6)
661
Eric V. Smith7c99e932018-01-28 19:18:55 -0500662 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500663 # equal.
664 @dataclass
665 class Point3Dv1:
666 x: int = 0
667 y: int = 0
668 z: int = 0
669 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
670
671 def test_function_annotations(self):
672 # Some dummy class and instance to use as a default.
673 class F:
674 pass
675 f = F()
676
677 def validate_class(cls):
678 # First, check __annotations__, even though they're not
679 # function annotations.
680 self.assertEqual(cls.__annotations__['i'], int)
681 self.assertEqual(cls.__annotations__['j'], str)
682 self.assertEqual(cls.__annotations__['k'], F)
683 self.assertEqual(cls.__annotations__['l'], float)
684 self.assertEqual(cls.__annotations__['z'], complex)
685
686 # Verify __init__.
687
688 signature = inspect.signature(cls.__init__)
689 # Check the return type, should be None
690 self.assertIs(signature.return_annotation, None)
691
692 # Check each parameter.
693 params = iter(signature.parameters.values())
694 param = next(params)
695 # This is testing an internal name, and probably shouldn't be tested.
696 self.assertEqual(param.name, 'self')
697 param = next(params)
698 self.assertEqual(param.name, 'i')
699 self.assertIs (param.annotation, int)
700 self.assertEqual(param.default, inspect.Parameter.empty)
701 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
702 param = next(params)
703 self.assertEqual(param.name, 'j')
704 self.assertIs (param.annotation, str)
705 self.assertEqual(param.default, inspect.Parameter.empty)
706 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
707 param = next(params)
708 self.assertEqual(param.name, 'k')
709 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500710 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500711 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
712 param = next(params)
713 self.assertEqual(param.name, 'l')
714 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500715 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
717 self.assertRaises(StopIteration, next, params)
718
719
720 @dataclass
721 class C:
722 i: int
723 j: str
724 k: F = f
725 l: float=field(default=None)
726 z: complex=field(default=3+4j, init=False)
727
728 validate_class(C)
729
730 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800731 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500732 class C:
733 i: int
734 j: str
735 k: F = f
736 l: float=field(default=None)
737 z: complex=field(default=3+4j, init=False)
738
739 validate_class(C)
740
Eric V. Smith03220fd2017-12-29 13:59:58 -0500741 def test_missing_default(self):
742 # Test that MISSING works the same as a default not being
743 # specified.
744 @dataclass
745 class C:
746 x: int=field(default=MISSING)
747 with self.assertRaisesRegex(TypeError,
748 r'__init__\(\) missing 1 required '
749 'positional argument'):
750 C()
751 self.assertNotIn('x', C.__dict__)
752
753 @dataclass
754 class D:
755 x: int
756 with self.assertRaisesRegex(TypeError,
757 r'__init__\(\) missing 1 required '
758 'positional argument'):
759 D()
760 self.assertNotIn('x', D.__dict__)
761
762 def test_missing_default_factory(self):
763 # Test that MISSING works the same as a default factory not
764 # being specified (which is really the same as a default not
765 # being specified, too).
766 @dataclass
767 class C:
768 x: int=field(default_factory=MISSING)
769 with self.assertRaisesRegex(TypeError,
770 r'__init__\(\) missing 1 required '
771 'positional argument'):
772 C()
773 self.assertNotIn('x', C.__dict__)
774
775 @dataclass
776 class D:
777 x: int=field(default=MISSING, default_factory=MISSING)
778 with self.assertRaisesRegex(TypeError,
779 r'__init__\(\) missing 1 required '
780 'positional argument'):
781 D()
782 self.assertNotIn('x', D.__dict__)
783
784 def test_missing_repr(self):
785 self.assertIn('MISSING_TYPE object', repr(MISSING))
786
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500787 def test_dont_include_other_annotations(self):
788 @dataclass
789 class C:
790 i: int
791 def foo(self) -> int:
792 return 4
793 @property
794 def bar(self) -> int:
795 return 5
796 self.assertEqual(list(C.__annotations__), ['i'])
797 self.assertEqual(C(10).foo(), 4)
798 self.assertEqual(C(10).bar, 5)
799
800 def test_post_init(self):
801 # Just make sure it gets called
802 @dataclass
803 class C:
804 def __post_init__(self):
805 raise CustomError()
806 with self.assertRaises(CustomError):
807 C()
808
809 @dataclass
810 class C:
811 i: int = 10
812 def __post_init__(self):
813 if self.i == 10:
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817 # post-init gets called, but doesn't raise. This is just
818 # checking that self is used correctly.
819 C(5)
820
821 # If there's not an __init__, then post-init won't get called.
822 @dataclass(init=False)
823 class C:
824 def __post_init__(self):
825 raise CustomError()
826 # Creating the class won't raise
827 C()
828
829 @dataclass
830 class C:
831 x: int = 0
832 def __post_init__(self):
833 self.x *= 2
834 self.assertEqual(C().x, 0)
835 self.assertEqual(C(2).x, 4)
836
Mike53f7a7c2017-12-14 14:04:53 +0300837 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500838 # attributes.
839 @dataclass(frozen=True)
840 class C:
841 x: int = 0
842 def __post_init__(self):
843 self.x *= 2
844 with self.assertRaises(FrozenInstanceError):
845 C()
846
847 def test_post_init_super(self):
848 # Make sure super() post-init isn't called by default.
849 class B:
850 def __post_init__(self):
851 raise CustomError()
852
853 @dataclass
854 class C(B):
855 def __post_init__(self):
856 self.x = 5
857
858 self.assertEqual(C().x, 5)
859
860 # Now call super(), and it will raise
861 @dataclass
862 class C(B):
863 def __post_init__(self):
864 super().__post_init__()
865
866 with self.assertRaises(CustomError):
867 C()
868
869 # Make sure post-init is called, even if not defined in our
870 # class.
871 @dataclass
872 class C(B):
873 pass
874
875 with self.assertRaises(CustomError):
876 C()
877
878 def test_post_init_staticmethod(self):
879 flag = False
880 @dataclass
881 class C:
882 x: int
883 y: int
884 @staticmethod
885 def __post_init__():
886 nonlocal flag
887 flag = True
888
889 self.assertFalse(flag)
890 c = C(3, 4)
891 self.assertEqual((c.x, c.y), (3, 4))
892 self.assertTrue(flag)
893
894 def test_post_init_classmethod(self):
895 @dataclass
896 class C:
897 flag = False
898 x: int
899 y: int
900 @classmethod
901 def __post_init__(cls):
902 cls.flag = True
903
904 self.assertFalse(C.flag)
905 c = C(3, 4)
906 self.assertEqual((c.x, c.y), (3, 4))
907 self.assertTrue(C.flag)
908
909 def test_class_var(self):
910 # Make sure ClassVars are ignored in __init__, __repr__, etc.
911 @dataclass
912 class C:
913 x: int
914 y: int = 10
915 z: ClassVar[int] = 1000
916 w: ClassVar[int] = 2000
917 t: ClassVar[int] = 3000
918
919 c = C(5)
920 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
921 self.assertEqual(len(fields(C)), 2) # We have 2 fields
922 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
923 self.assertEqual(c.z, 1000)
924 self.assertEqual(c.w, 2000)
925 self.assertEqual(c.t, 3000)
926 C.z += 1
927 self.assertEqual(c.z, 1001)
928 c = C(20)
929 self.assertEqual((c.x, c.y), (20, 10))
930 self.assertEqual(c.z, 1001)
931 self.assertEqual(c.w, 2000)
932 self.assertEqual(c.t, 3000)
933
934 def test_class_var_no_default(self):
935 # If a ClassVar has no default value, it should not be set on the class.
936 @dataclass
937 class C:
938 x: ClassVar[int]
939
940 self.assertNotIn('x', C.__dict__)
941
942 def test_class_var_default_factory(self):
943 # It makes no sense for a ClassVar to have a default factory. When
944 # would it be called? Call it yourself, since it's class-wide.
945 with self.assertRaisesRegex(TypeError,
946 'cannot have a default factory'):
947 @dataclass
948 class C:
949 x: ClassVar[int] = field(default_factory=int)
950
951 self.assertNotIn('x', C.__dict__)
952
953 def test_class_var_with_default(self):
954 # If a ClassVar has a default value, it should be set on the class.
955 @dataclass
956 class C:
957 x: ClassVar[int] = 10
958 self.assertEqual(C.x, 10)
959
960 @dataclass
961 class C:
962 x: ClassVar[int] = field(default=10)
963 self.assertEqual(C.x, 10)
964
965 def test_class_var_frozen(self):
966 # Make sure ClassVars work even if we're frozen.
967 @dataclass(frozen=True)
968 class C:
969 x: int
970 y: int = 10
971 z: ClassVar[int] = 1000
972 w: ClassVar[int] = 2000
973 t: ClassVar[int] = 3000
974
975 c = C(5)
976 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
977 self.assertEqual(len(fields(C)), 2) # We have 2 fields
978 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
979 self.assertEqual(c.z, 1000)
980 self.assertEqual(c.w, 2000)
981 self.assertEqual(c.t, 3000)
982 # We can still modify the ClassVar, it's only instances that are
983 # frozen.
984 C.z += 1
985 self.assertEqual(c.z, 1001)
986 c = C(20)
987 self.assertEqual((c.x, c.y), (20, 10))
988 self.assertEqual(c.z, 1001)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991
992 def test_init_var_no_default(self):
993 # If an InitVar has no default value, it should not be set on the class.
994 @dataclass
995 class C:
996 x: InitVar[int]
997
998 self.assertNotIn('x', C.__dict__)
999
1000 def test_init_var_default_factory(self):
1001 # It makes no sense for an InitVar to have a default factory. When
1002 # would it be called? Call it yourself, since it's class-wide.
1003 with self.assertRaisesRegex(TypeError,
1004 'cannot have a default factory'):
1005 @dataclass
1006 class C:
1007 x: InitVar[int] = field(default_factory=int)
1008
1009 self.assertNotIn('x', C.__dict__)
1010
1011 def test_init_var_with_default(self):
1012 # If an InitVar has a default value, it should be set on the class.
1013 @dataclass
1014 class C:
1015 x: InitVar[int] = 10
1016 self.assertEqual(C.x, 10)
1017
1018 @dataclass
1019 class C:
1020 x: InitVar[int] = field(default=10)
1021 self.assertEqual(C.x, 10)
1022
1023 def test_init_var(self):
1024 @dataclass
1025 class C:
1026 x: int = None
1027 init_param: InitVar[int] = None
1028
1029 def __post_init__(self, init_param):
1030 if self.x is None:
1031 self.x = init_param*2
1032
1033 c = C(init_param=10)
1034 self.assertEqual(c.x, 20)
1035
1036 def test_init_var_inheritance(self):
1037 # Note that this deliberately tests that a dataclass need not
1038 # have a __post_init__ function if it has an InitVar field.
1039 # It could just be used in a derived class, as shown here.
1040 @dataclass
1041 class Base:
1042 x: int
1043 init_base: InitVar[int]
1044
1045 # We can instantiate by passing the InitVar, even though
1046 # it's not used.
1047 b = Base(0, 10)
1048 self.assertEqual(vars(b), {'x': 0})
1049
1050 @dataclass
1051 class C(Base):
1052 y: int
1053 init_derived: InitVar[int]
1054
1055 def __post_init__(self, init_base, init_derived):
1056 self.x = self.x + init_base
1057 self.y = self.y + init_derived
1058
1059 c = C(10, 11, 50, 51)
1060 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1061
1062 def test_default_factory(self):
1063 # Test a factory that returns a new list.
1064 @dataclass
1065 class C:
1066 x: int
1067 y: list = field(default_factory=list)
1068
1069 c0 = C(3)
1070 c1 = C(3)
1071 self.assertEqual(c0.x, 3)
1072 self.assertEqual(c0.y, [])
1073 self.assertEqual(c0, c1)
1074 self.assertIsNot(c0.y, c1.y)
1075 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1076
1077 # Test a factory that returns a shared list.
1078 l = []
1079 @dataclass
1080 class C:
1081 x: int
1082 y: list = field(default_factory=lambda: l)
1083
1084 c0 = C(3)
1085 c1 = C(3)
1086 self.assertEqual(c0.x, 3)
1087 self.assertEqual(c0.y, [])
1088 self.assertEqual(c0, c1)
1089 self.assertIs(c0.y, c1.y)
1090 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1091
1092 # Test various other field flags.
1093 # repr
1094 @dataclass
1095 class C:
1096 x: list = field(default_factory=list, repr=False)
1097 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1098 self.assertEqual(C().x, [])
1099
1100 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001101 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001102 class C:
1103 x: list = field(default_factory=list, hash=False)
1104 self.assertEqual(astuple(C()), ([],))
1105 self.assertEqual(hash(C()), hash(()))
1106
1107 # init (see also test_default_factory_with_no_init)
1108 @dataclass
1109 class C:
1110 x: list = field(default_factory=list, init=False)
1111 self.assertEqual(astuple(C()), ([],))
1112
1113 # compare
1114 @dataclass
1115 class C:
1116 x: list = field(default_factory=list, compare=False)
1117 self.assertEqual(C(), C([1]))
1118
1119 def test_default_factory_with_no_init(self):
1120 # We need a factory with a side effect.
1121 factory = Mock()
1122
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=factory, init=False)
1126
1127 # Make sure the default factory is called for each new instance.
1128 C().x
1129 self.assertEqual(factory.call_count, 1)
1130 C().x
1131 self.assertEqual(factory.call_count, 2)
1132
1133 def test_default_factory_not_called_if_value_given(self):
1134 # We need a factory that we can test if it's been called.
1135 factory = Mock()
1136
1137 @dataclass
1138 class C:
1139 x: int = field(default_factory=factory)
1140
1141 # Make sure that if a field has a default factory function,
1142 # it's not called if a value is specified.
1143 C().x
1144 self.assertEqual(factory.call_count, 1)
1145 self.assertEqual(C(10).x, 10)
1146 self.assertEqual(factory.call_count, 1)
1147 C().x
1148 self.assertEqual(factory.call_count, 2)
1149
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001150 def test_default_factory_derived(self):
1151 # See bpo-32896.
1152 @dataclass
1153 class Foo:
1154 x: dict = field(default_factory=dict)
1155
1156 @dataclass
1157 class Bar(Foo):
1158 y: int = 1
1159
1160 self.assertEqual(Foo().x, {})
1161 self.assertEqual(Bar().x, {})
1162 self.assertEqual(Bar().y, 1)
1163
1164 @dataclass
1165 class Baz(Foo):
1166 pass
1167 self.assertEqual(Baz().x, {})
1168
1169 def test_intermediate_non_dataclass(self):
1170 # Test that an intermediate class that defines
1171 # annotations does not define fields.
1172
1173 @dataclass
1174 class A:
1175 x: int
1176
1177 class B(A):
1178 y: int
1179
1180 @dataclass
1181 class C(B):
1182 z: int
1183
1184 c = C(1, 3)
1185 self.assertEqual((c.x, c.z), (1, 3))
1186
1187 # .y was not initialized.
1188 with self.assertRaisesRegex(AttributeError,
1189 'object has no attribute'):
1190 c.y
1191
1192 # And if we again derive a non-dataclass, no fields are added.
1193 class D(C):
1194 t: int
1195 d = D(4, 5)
1196 self.assertEqual((d.x, d.z), (4, 5))
1197
1198
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001199 def x_test_classvar_default_factory(self):
1200 # XXX: it's an error for a ClassVar to have a factory function
1201 @dataclass
1202 class C:
1203 x: ClassVar[int] = field(default_factory=int)
1204
1205 self.assertIs(C().x, int)
1206
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001207 def test_is_dataclass(self):
1208 class NotDataClass:
1209 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001210
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001211 self.assertFalse(is_dataclass(0))
1212 self.assertFalse(is_dataclass(int))
1213 self.assertFalse(is_dataclass(NotDataClass))
1214 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001215
1216 @dataclass
1217 class C:
1218 x: int
1219
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001220 @dataclass
1221 class D:
1222 d: C
1223 e: int
1224
1225 c = C(10)
1226 d = D(c, 4)
1227
1228 self.assertTrue(is_dataclass(C))
1229 self.assertTrue(is_dataclass(c))
1230 self.assertFalse(is_dataclass(c.x))
1231 self.assertTrue(is_dataclass(d.d))
1232 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001233
1234 def test_helper_fields_with_class_instance(self):
1235 # Check that we can call fields() on either a class or instance,
1236 # and get back the same thing.
1237 @dataclass
1238 class C:
1239 x: int
1240 y: float
1241
1242 self.assertEqual(fields(C), fields(C(0, 0.0)))
1243
1244 def test_helper_fields_exception(self):
1245 # Check that TypeError is raised if not passed a dataclass or
1246 # instance.
1247 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1248 fields(0)
1249
1250 class C: pass
1251 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1252 fields(C)
1253 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1254 fields(C())
1255
1256 def test_helper_asdict(self):
1257 # Basic tests for asdict(), it should return a new dictionary
1258 @dataclass
1259 class C:
1260 x: int
1261 y: int
1262 c = C(1, 2)
1263
1264 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1265 self.assertEqual(asdict(c), asdict(c))
1266 self.assertIsNot(asdict(c), asdict(c))
1267 c.x = 42
1268 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1269 self.assertIs(type(asdict(c)), dict)
1270
1271 def test_helper_asdict_raises_on_classes(self):
1272 # asdict() should raise on a class object
1273 @dataclass
1274 class C:
1275 x: int
1276 y: int
1277 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1278 asdict(C)
1279 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1280 asdict(int)
1281
1282 def test_helper_asdict_copy_values(self):
1283 @dataclass
1284 class C:
1285 x: int
1286 y: List[int] = field(default_factory=list)
1287 initial = []
1288 c = C(1, initial)
1289 d = asdict(c)
1290 self.assertEqual(d['y'], initial)
1291 self.assertIsNot(d['y'], initial)
1292 c = C(1)
1293 d = asdict(c)
1294 d['y'].append(1)
1295 self.assertEqual(c.y, [])
1296
1297 def test_helper_asdict_nested(self):
1298 @dataclass
1299 class UserId:
1300 token: int
1301 group: int
1302 @dataclass
1303 class User:
1304 name: str
1305 id: UserId
1306 u = User('Joe', UserId(123, 1))
1307 d = asdict(u)
1308 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1309 self.assertIsNot(asdict(u), asdict(u))
1310 u.id.group = 2
1311 self.assertEqual(asdict(u), {'name': 'Joe',
1312 'id': {'token': 123, 'group': 2}})
1313
1314 def test_helper_asdict_builtin_containers(self):
1315 @dataclass
1316 class User:
1317 name: str
1318 id: int
1319 @dataclass
1320 class GroupList:
1321 id: int
1322 users: List[User]
1323 @dataclass
1324 class GroupTuple:
1325 id: int
1326 users: Tuple[User, ...]
1327 @dataclass
1328 class GroupDict:
1329 id: int
1330 users: Dict[str, User]
1331 a = User('Alice', 1)
1332 b = User('Bob', 2)
1333 gl = GroupList(0, [a, b])
1334 gt = GroupTuple(0, (a, b))
1335 gd = GroupDict(0, {'first': a, 'second': b})
1336 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1337 {'name': 'Bob', 'id': 2}]})
1338 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1339 {'name': 'Bob', 'id': 2})})
1340 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1341 'second': {'name': 'Bob', 'id': 2}}})
1342
1343 def test_helper_asdict_builtin_containers(self):
1344 @dataclass
1345 class Child:
1346 d: object
1347
1348 @dataclass
1349 class Parent:
1350 child: Child
1351
1352 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1353 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1354
1355 def test_helper_asdict_factory(self):
1356 @dataclass
1357 class C:
1358 x: int
1359 y: int
1360 c = C(1, 2)
1361 d = asdict(c, dict_factory=OrderedDict)
1362 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1363 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1364 c.x = 42
1365 d = asdict(c, dict_factory=OrderedDict)
1366 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1367 self.assertIs(type(d), OrderedDict)
1368
1369 def test_helper_astuple(self):
1370 # Basic tests for astuple(), it should return a new tuple
1371 @dataclass
1372 class C:
1373 x: int
1374 y: int = 0
1375 c = C(1)
1376
1377 self.assertEqual(astuple(c), (1, 0))
1378 self.assertEqual(astuple(c), astuple(c))
1379 self.assertIsNot(astuple(c), astuple(c))
1380 c.y = 42
1381 self.assertEqual(astuple(c), (1, 42))
1382 self.assertIs(type(astuple(c)), tuple)
1383
1384 def test_helper_astuple_raises_on_classes(self):
1385 # astuple() should raise on a class object
1386 @dataclass
1387 class C:
1388 x: int
1389 y: int
1390 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1391 astuple(C)
1392 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1393 astuple(int)
1394
1395 def test_helper_astuple_copy_values(self):
1396 @dataclass
1397 class C:
1398 x: int
1399 y: List[int] = field(default_factory=list)
1400 initial = []
1401 c = C(1, initial)
1402 t = astuple(c)
1403 self.assertEqual(t[1], initial)
1404 self.assertIsNot(t[1], initial)
1405 c = C(1)
1406 t = astuple(c)
1407 t[1].append(1)
1408 self.assertEqual(c.y, [])
1409
1410 def test_helper_astuple_nested(self):
1411 @dataclass
1412 class UserId:
1413 token: int
1414 group: int
1415 @dataclass
1416 class User:
1417 name: str
1418 id: UserId
1419 u = User('Joe', UserId(123, 1))
1420 t = astuple(u)
1421 self.assertEqual(t, ('Joe', (123, 1)))
1422 self.assertIsNot(astuple(u), astuple(u))
1423 u.id.group = 2
1424 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1425
1426 def test_helper_astuple_builtin_containers(self):
1427 @dataclass
1428 class User:
1429 name: str
1430 id: int
1431 @dataclass
1432 class GroupList:
1433 id: int
1434 users: List[User]
1435 @dataclass
1436 class GroupTuple:
1437 id: int
1438 users: Tuple[User, ...]
1439 @dataclass
1440 class GroupDict:
1441 id: int
1442 users: Dict[str, User]
1443 a = User('Alice', 1)
1444 b = User('Bob', 2)
1445 gl = GroupList(0, [a, b])
1446 gt = GroupTuple(0, (a, b))
1447 gd = GroupDict(0, {'first': a, 'second': b})
1448 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1449 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1450 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1451
1452 def test_helper_astuple_builtin_containers(self):
1453 @dataclass
1454 class Child:
1455 d: object
1456
1457 @dataclass
1458 class Parent:
1459 child: Child
1460
1461 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1462 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1463
1464 def test_helper_astuple_factory(self):
1465 @dataclass
1466 class C:
1467 x: int
1468 y: int
1469 NT = namedtuple('NT', 'x y')
1470 def nt(lst):
1471 return NT(*lst)
1472 c = C(1, 2)
1473 t = astuple(c, tuple_factory=nt)
1474 self.assertEqual(t, NT(1, 2))
1475 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1476 c.x = 42
1477 t = astuple(c, tuple_factory=nt)
1478 self.assertEqual(t, NT(42, 2))
1479 self.assertIs(type(t), NT)
1480
1481 def test_dynamic_class_creation(self):
1482 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1483 }
1484
1485 # Create the class.
1486 cls = type('C', (), cls_dict)
1487
1488 # Make it a dataclass.
1489 cls1 = dataclass(cls)
1490
1491 self.assertEqual(cls1, cls)
1492 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1493
1494 def test_dynamic_class_creation_using_field(self):
1495 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1496 'y': field(default=5),
1497 }
1498
1499 # Create the class.
1500 cls = type('C', (), cls_dict)
1501
1502 # Make it a dataclass.
1503 cls1 = dataclass(cls)
1504
1505 self.assertEqual(cls1, cls)
1506 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1507
1508 def test_init_in_order(self):
1509 @dataclass
1510 class C:
1511 a: int
1512 b: int = field()
1513 c: list = field(default_factory=list, init=False)
1514 d: list = field(default_factory=list)
1515 e: int = field(default=4, init=False)
1516 f: int = 4
1517
1518 calls = []
1519 def setattr(self, name, value):
1520 calls.append((name, value))
1521
1522 C.__setattr__ = setattr
1523 c = C(0, 1)
1524 self.assertEqual(('a', 0), calls[0])
1525 self.assertEqual(('b', 1), calls[1])
1526 self.assertEqual(('c', []), calls[2])
1527 self.assertEqual(('d', []), calls[3])
1528 self.assertNotIn(('e', 4), calls)
1529 self.assertEqual(('f', 4), calls[4])
1530
1531 def test_items_in_dicts(self):
1532 @dataclass
1533 class C:
1534 a: int
1535 b: list = field(default_factory=list, init=False)
1536 c: list = field(default_factory=list)
1537 d: int = field(default=4, init=False)
1538 e: int = 0
1539
1540 c = C(0)
1541 # Class dict
1542 self.assertNotIn('a', C.__dict__)
1543 self.assertNotIn('b', C.__dict__)
1544 self.assertNotIn('c', C.__dict__)
1545 self.assertIn('d', C.__dict__)
1546 self.assertEqual(C.d, 4)
1547 self.assertIn('e', C.__dict__)
1548 self.assertEqual(C.e, 0)
1549 # Instance dict
1550 self.assertIn('a', c.__dict__)
1551 self.assertEqual(c.a, 0)
1552 self.assertIn('b', c.__dict__)
1553 self.assertEqual(c.b, [])
1554 self.assertIn('c', c.__dict__)
1555 self.assertEqual(c.c, [])
1556 self.assertNotIn('d', c.__dict__)
1557 self.assertIn('e', c.__dict__)
1558 self.assertEqual(c.e, 0)
1559
1560 def test_alternate_classmethod_constructor(self):
1561 # Since __post_init__ can't take params, use a classmethod
1562 # alternate constructor. This is mostly an example to show how
1563 # to use this technique.
1564 @dataclass
1565 class C:
1566 x: int
1567 @classmethod
1568 def from_file(cls, filename):
1569 # In a real example, create a new instance
1570 # and populate 'x' from contents of a file.
1571 value_in_file = 20
1572 return cls(value_in_file)
1573
1574 self.assertEqual(C.from_file('filename').x, 20)
1575
1576 def test_field_metadata_default(self):
1577 # Make sure the default metadata is read-only and of
1578 # zero length.
1579 @dataclass
1580 class C:
1581 i: int
1582
1583 self.assertFalse(fields(C)[0].metadata)
1584 self.assertEqual(len(fields(C)[0].metadata), 0)
1585 with self.assertRaisesRegex(TypeError,
1586 'does not support item assignment'):
1587 fields(C)[0].metadata['test'] = 3
1588
1589 def test_field_metadata_mapping(self):
1590 # Make sure only a mapping can be passed as metadata
1591 # zero length.
1592 with self.assertRaises(TypeError):
1593 @dataclass
1594 class C:
1595 i: int = field(metadata=0)
1596
1597 # Make sure an empty dict works
1598 @dataclass
1599 class C:
1600 i: int = field(metadata={})
1601 self.assertFalse(fields(C)[0].metadata)
1602 self.assertEqual(len(fields(C)[0].metadata), 0)
1603 with self.assertRaisesRegex(TypeError,
1604 'does not support item assignment'):
1605 fields(C)[0].metadata['test'] = 3
1606
1607 # Make sure a non-empty dict works.
1608 @dataclass
1609 class C:
1610 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1611 self.assertEqual(len(fields(C)[0].metadata), 3)
1612 self.assertEqual(fields(C)[0].metadata['test'], 10)
1613 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1614 self.assertEqual(fields(C)[0].metadata[3], 'three')
1615 with self.assertRaises(KeyError):
1616 # Non-existent key.
1617 fields(C)[0].metadata['baz']
1618 with self.assertRaisesRegex(TypeError,
1619 'does not support item assignment'):
1620 fields(C)[0].metadata['test'] = 3
1621
1622 def test_field_metadata_custom_mapping(self):
1623 # Try a custom mapping.
1624 class SimpleNameSpace:
1625 def __init__(self, **kw):
1626 self.__dict__.update(kw)
1627
1628 def __getitem__(self, item):
1629 if item == 'xyzzy':
1630 return 'plugh'
1631 return getattr(self, item)
1632
1633 def __len__(self):
1634 return self.__dict__.__len__()
1635
1636 @dataclass
1637 class C:
1638 i: int = field(metadata=SimpleNameSpace(a=10))
1639
1640 self.assertEqual(len(fields(C)[0].metadata), 1)
1641 self.assertEqual(fields(C)[0].metadata['a'], 10)
1642 with self.assertRaises(AttributeError):
1643 fields(C)[0].metadata['b']
1644 # Make sure we're still talking to our custom mapping.
1645 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1646
1647 def test_generic_dataclasses(self):
1648 T = TypeVar('T')
1649
1650 @dataclass
1651 class LabeledBox(Generic[T]):
1652 content: T
1653 label: str = '<unknown>'
1654
1655 box = LabeledBox(42)
1656 self.assertEqual(box.content, 42)
1657 self.assertEqual(box.label, '<unknown>')
1658
1659 # subscripting the resulting class should work, etc.
1660 Alias = List[LabeledBox[int]]
1661
1662 def test_generic_extending(self):
1663 S = TypeVar('S')
1664 T = TypeVar('T')
1665
1666 @dataclass
1667 class Base(Generic[T, S]):
1668 x: T
1669 y: S
1670
1671 @dataclass
1672 class DataDerived(Base[int, T]):
1673 new_field: str
1674 Alias = DataDerived[str]
1675 c = Alias(0, 'test1', 'test2')
1676 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1677
1678 class NonDataDerived(Base[int, T]):
1679 def new_method(self):
1680 return self.y
1681 Alias = NonDataDerived[float]
1682 c = Alias(10, 1.0)
1683 self.assertEqual(c.new_method(), 1.0)
1684
1685 def test_helper_replace(self):
1686 @dataclass(frozen=True)
1687 class C:
1688 x: int
1689 y: int
1690
1691 c = C(1, 2)
1692 c1 = replace(c, x=3)
1693 self.assertEqual(c1.x, 3)
1694 self.assertEqual(c1.y, 2)
1695
1696 def test_helper_replace_frozen(self):
1697 @dataclass(frozen=True)
1698 class C:
1699 x: int
1700 y: int
1701 z: int = field(init=False, default=10)
1702 t: int = field(init=False, default=100)
1703
1704 c = C(1, 2)
1705 c1 = replace(c, x=3)
1706 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1707 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1708
1709
1710 with self.assertRaisesRegex(ValueError, 'init=False'):
1711 replace(c, x=3, z=20, t=50)
1712 with self.assertRaisesRegex(ValueError, 'init=False'):
1713 replace(c, z=20)
1714 replace(c, x=3, z=20, t=50)
1715
1716 # Make sure the result is still frozen.
1717 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1718 c1.x = 3
1719
1720 # Make sure we can't replace an attribute that doesn't exist,
1721 # if we're also replacing one that does exist. Test this
1722 # here, because setting attributes on frozen instances is
1723 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001724 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001725 "keyword argument 'a'"):
1726 c1 = replace(c, x=20, a=5)
1727
1728 def test_helper_replace_invalid_field_name(self):
1729 @dataclass(frozen=True)
1730 class C:
1731 x: int
1732 y: int
1733
1734 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001735 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001736 "keyword argument 'z'"):
1737 c1 = replace(c, z=3)
1738
1739 def test_helper_replace_invalid_object(self):
1740 @dataclass(frozen=True)
1741 class C:
1742 x: int
1743 y: int
1744
1745 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1746 replace(C, x=3)
1747
1748 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1749 replace(0, x=3)
1750
1751 def test_helper_replace_no_init(self):
1752 @dataclass
1753 class C:
1754 x: int
1755 y: int = field(init=False, default=10)
1756
1757 c = C(1)
1758 c.y = 20
1759
1760 # Make sure y gets the default value.
1761 c1 = replace(c, x=5)
1762 self.assertEqual((c1.x, c1.y), (5, 10))
1763
1764 # Trying to replace y is an error.
1765 with self.assertRaisesRegex(ValueError, 'init=False'):
1766 replace(c, x=2, y=30)
1767 with self.assertRaisesRegex(ValueError, 'init=False'):
1768 replace(c, y=30)
1769
1770 def test_dataclassses_pickleable(self):
1771 global P, Q, R
1772 @dataclass
1773 class P:
1774 x: int
1775 y: int = 0
1776 @dataclass
1777 class Q:
1778 x: int
1779 y: int = field(default=0, init=False)
1780 @dataclass
1781 class R:
1782 x: int
1783 y: List[int] = field(default_factory=list)
1784 q = Q(1)
1785 q.y = 2
1786 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1787 for sample in samples:
1788 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1789 with self.subTest(sample=sample, proto=proto):
1790 new_sample = pickle.loads(pickle.dumps(sample, proto))
1791 self.assertEqual(sample.x, new_sample.x)
1792 self.assertEqual(sample.y, new_sample.y)
1793 self.assertIsNot(sample, new_sample)
1794 new_sample.x = 42
1795 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1796 self.assertEqual(new_sample.x, another_new_sample.x)
1797 self.assertEqual(sample.y, another_new_sample.y)
1798
1799 def test_helper_make_dataclass(self):
1800 C = make_dataclass('C',
1801 [('x', int),
1802 ('y', int, field(default=5))],
1803 namespace={'add_one': lambda self: self.x + 1})
1804 c = C(10)
1805 self.assertEqual((c.x, c.y), (10, 5))
1806 self.assertEqual(c.add_one(), 11)
1807
1808
1809 def test_helper_make_dataclass_no_mutate_namespace(self):
1810 # Make sure a provided namespace isn't mutated.
1811 ns = {}
1812 C = make_dataclass('C',
1813 [('x', int),
1814 ('y', int, field(default=5))],
1815 namespace=ns)
1816 self.assertEqual(ns, {})
1817
1818 def test_helper_make_dataclass_base(self):
1819 class Base1:
1820 pass
1821 class Base2:
1822 pass
1823 C = make_dataclass('C',
1824 [('x', int)],
1825 bases=(Base1, Base2))
1826 c = C(2)
1827 self.assertIsInstance(c, C)
1828 self.assertIsInstance(c, Base1)
1829 self.assertIsInstance(c, Base2)
1830
1831 def test_helper_make_dataclass_base_dataclass(self):
1832 @dataclass
1833 class Base1:
1834 x: int
1835 class Base2:
1836 pass
1837 C = make_dataclass('C',
1838 [('y', int)],
1839 bases=(Base1, Base2))
1840 with self.assertRaisesRegex(TypeError, 'required positional'):
1841 c = C(2)
1842 c = C(1, 2)
1843 self.assertIsInstance(c, C)
1844 self.assertIsInstance(c, Base1)
1845 self.assertIsInstance(c, Base2)
1846
1847 self.assertEqual((c.x, c.y), (1, 2))
1848
1849 def test_helper_make_dataclass_init_var(self):
1850 def post_init(self, y):
1851 self.x *= y
1852
1853 C = make_dataclass('C',
1854 [('x', int),
1855 ('y', InitVar[int]),
1856 ],
1857 namespace={'__post_init__': post_init},
1858 )
1859 c = C(2, 3)
1860 self.assertEqual(vars(c), {'x': 6})
1861 self.assertEqual(len(fields(c)), 1)
1862
1863 def test_helper_make_dataclass_class_var(self):
1864 C = make_dataclass('C',
1865 [('x', int),
1866 ('y', ClassVar[int], 10),
1867 ('z', ClassVar[int], field(default=20)),
1868 ])
1869 c = C(1)
1870 self.assertEqual(vars(c), {'x': 1})
1871 self.assertEqual(len(fields(c)), 1)
1872 self.assertEqual(C.y, 10)
1873 self.assertEqual(C.z, 20)
1874
Eric V. Smithd80b4432018-01-06 17:09:58 -05001875 def test_helper_make_dataclass_other_params(self):
1876 C = make_dataclass('C',
1877 [('x', int),
1878 ('y', ClassVar[int], 10),
1879 ('z', ClassVar[int], field(default=20)),
1880 ],
1881 init=False)
1882 # Make sure we have a repr, but no init.
1883 self.assertNotIn('__init__', vars(C))
1884 self.assertIn('__repr__', vars(C))
1885
1886 # Make sure random other params don't work.
1887 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1888 C = make_dataclass('C',
1889 [],
1890 xxinit=False)
1891
Eric V. Smithed7d4292018-01-06 16:14:03 -05001892 def test_helper_make_dataclass_no_types(self):
1893 C = make_dataclass('Point', ['x', 'y', 'z'])
1894 c = C(1, 2, 3)
1895 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1896 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1897 'y': 'typing.Any',
1898 'z': 'typing.Any'})
1899
1900 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1901 c = C(1, 2, 3)
1902 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1903 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1904 'y': int,
1905 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001906
Eric V. Smithea8fc522018-01-27 19:07:40 -05001907
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001908class TestDocString(unittest.TestCase):
1909 def assertDocStrEqual(self, a, b):
1910 # Because 3.6 and 3.7 differ in how inspect.signature work
1911 # (see bpo #32108), for the time being just compare them with
1912 # whitespace stripped.
1913 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1914
1915 def test_existing_docstring_not_overridden(self):
1916 @dataclass
1917 class C:
1918 """Lorem ipsum"""
1919 x: int
1920
1921 self.assertEqual(C.__doc__, "Lorem ipsum")
1922
1923 def test_docstring_no_fields(self):
1924 @dataclass
1925 class C:
1926 pass
1927
1928 self.assertDocStrEqual(C.__doc__, "C()")
1929
1930 def test_docstring_one_field(self):
1931 @dataclass
1932 class C:
1933 x: int
1934
1935 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1936
1937 def test_docstring_two_fields(self):
1938 @dataclass
1939 class C:
1940 x: int
1941 y: int
1942
1943 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1944
1945 def test_docstring_three_fields(self):
1946 @dataclass
1947 class C:
1948 x: int
1949 y: int
1950 z: str
1951
1952 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1953
1954 def test_docstring_one_field_with_default(self):
1955 @dataclass
1956 class C:
1957 x: int = 3
1958
1959 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1960
1961 def test_docstring_one_field_with_default_none(self):
1962 @dataclass
1963 class C:
1964 x: Union[int, type(None)] = None
1965
1966 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1967
1968 def test_docstring_list_field(self):
1969 @dataclass
1970 class C:
1971 x: List[int]
1972
1973 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1974
1975 def test_docstring_list_field_with_default_factory(self):
1976 @dataclass
1977 class C:
1978 x: List[int] = field(default_factory=list)
1979
1980 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1981
1982 def test_docstring_deque_field(self):
1983 @dataclass
1984 class C:
1985 x: deque
1986
1987 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1988
1989 def test_docstring_deque_field_with_default_factory(self):
1990 @dataclass
1991 class C:
1992 x: deque = field(default_factory=deque)
1993
1994 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1995
1996
Eric V. Smithea8fc522018-01-27 19:07:40 -05001997class TestInit(unittest.TestCase):
1998 def test_base_has_init(self):
1999 class B:
2000 def __init__(self):
2001 self.z = 100
2002 pass
2003
2004 # Make sure that declaring this class doesn't raise an error.
2005 # The issue is that we can't override __init__ in our class,
2006 # but it should be okay to add __init__ to us if our base has
2007 # an __init__.
2008 @dataclass
2009 class C(B):
2010 x: int = 0
2011 c = C(10)
2012 self.assertEqual(c.x, 10)
2013 self.assertNotIn('z', vars(c))
2014
2015 # Make sure that if we don't add an init, the base __init__
2016 # gets called.
2017 @dataclass(init=False)
2018 class C(B):
2019 x: int = 10
2020 c = C()
2021 self.assertEqual(c.x, 10)
2022 self.assertEqual(c.z, 100)
2023
2024 def test_no_init(self):
2025 dataclass(init=False)
2026 class C:
2027 i: int = 0
2028 self.assertEqual(C().i, 0)
2029
2030 dataclass(init=False)
2031 class C:
2032 i: int = 2
2033 def __init__(self):
2034 self.i = 3
2035 self.assertEqual(C().i, 3)
2036
2037 def test_overwriting_init(self):
2038 # If the class has __init__, use it no matter the value of
2039 # init=.
2040
2041 @dataclass
2042 class C:
2043 x: int
2044 def __init__(self, x):
2045 self.x = 2 * x
2046 self.assertEqual(C(3).x, 6)
2047
2048 @dataclass(init=True)
2049 class C:
2050 x: int
2051 def __init__(self, x):
2052 self.x = 2 * x
2053 self.assertEqual(C(4).x, 8)
2054
2055 @dataclass(init=False)
2056 class C:
2057 x: int
2058 def __init__(self, x):
2059 self.x = 2 * x
2060 self.assertEqual(C(5).x, 10)
2061
2062
2063class TestRepr(unittest.TestCase):
2064 def test_repr(self):
2065 @dataclass
2066 class B:
2067 x: int
2068
2069 @dataclass
2070 class C(B):
2071 y: int = 10
2072
2073 o = C(4)
2074 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2075
2076 @dataclass
2077 class D(C):
2078 x: int = 20
2079 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2080
2081 @dataclass
2082 class C:
2083 @dataclass
2084 class D:
2085 i: int
2086 @dataclass
2087 class E:
2088 pass
2089 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2090 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2091
2092 def test_no_repr(self):
2093 # Test a class with no __repr__ and repr=False.
2094 @dataclass(repr=False)
2095 class C:
2096 x: int
2097 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2098 repr(C(3)))
2099
2100 # Test a class with a __repr__ and repr=False.
2101 @dataclass(repr=False)
2102 class C:
2103 x: int
2104 def __repr__(self):
2105 return 'C-class'
2106 self.assertEqual(repr(C(3)), 'C-class')
2107
2108 def test_overwriting_repr(self):
2109 # If the class has __repr__, use it no matter the value of
2110 # repr=.
2111
2112 @dataclass
2113 class C:
2114 x: int
2115 def __repr__(self):
2116 return 'x'
2117 self.assertEqual(repr(C(0)), 'x')
2118
2119 @dataclass(repr=True)
2120 class C:
2121 x: int
2122 def __repr__(self):
2123 return 'x'
2124 self.assertEqual(repr(C(0)), 'x')
2125
2126 @dataclass(repr=False)
2127 class C:
2128 x: int
2129 def __repr__(self):
2130 return 'x'
2131 self.assertEqual(repr(C(0)), 'x')
2132
2133
2134class TestFrozen(unittest.TestCase):
2135 def test_overwriting_frozen(self):
2136 # frozen uses __setattr__ and __delattr__
2137 with self.assertRaisesRegex(TypeError,
2138 'Cannot overwrite attribute __setattr__'):
2139 @dataclass(frozen=True)
2140 class C:
2141 x: int
2142 def __setattr__(self):
2143 pass
2144
2145 with self.assertRaisesRegex(TypeError,
2146 'Cannot overwrite attribute __delattr__'):
2147 @dataclass(frozen=True)
2148 class C:
2149 x: int
2150 def __delattr__(self):
2151 pass
2152
2153 @dataclass(frozen=False)
2154 class C:
2155 x: int
2156 def __setattr__(self, name, value):
2157 self.__dict__['x'] = value * 2
2158 self.assertEqual(C(10).x, 20)
2159
2160
2161class TestEq(unittest.TestCase):
2162 def test_no_eq(self):
2163 # Test a class with no __eq__ and eq=False.
2164 @dataclass(eq=False)
2165 class C:
2166 x: int
2167 self.assertNotEqual(C(0), C(0))
2168 c = C(3)
2169 self.assertEqual(c, c)
2170
2171 # Test a class with an __eq__ and eq=False.
2172 @dataclass(eq=False)
2173 class C:
2174 x: int
2175 def __eq__(self, other):
2176 return other == 10
2177 self.assertEqual(C(3), 10)
2178
2179 def test_overwriting_eq(self):
2180 # If the class has __eq__, use it no matter the value of
2181 # eq=.
2182
2183 @dataclass
2184 class C:
2185 x: int
2186 def __eq__(self, other):
2187 return other == 3
2188 self.assertEqual(C(1), 3)
2189 self.assertNotEqual(C(1), 1)
2190
2191 @dataclass(eq=True)
2192 class C:
2193 x: int
2194 def __eq__(self, other):
2195 return other == 4
2196 self.assertEqual(C(1), 4)
2197 self.assertNotEqual(C(1), 1)
2198
2199 @dataclass(eq=False)
2200 class C:
2201 x: int
2202 def __eq__(self, other):
2203 return other == 5
2204 self.assertEqual(C(1), 5)
2205 self.assertNotEqual(C(1), 1)
2206
2207
2208class TestOrdering(unittest.TestCase):
2209 def test_functools_total_ordering(self):
2210 # Test that functools.total_ordering works with this class.
2211 @total_ordering
2212 @dataclass
2213 class C:
2214 x: int
2215 def __lt__(self, other):
2216 # Perform the test "backward", just to make
2217 # sure this is being called.
2218 return self.x >= other
2219
2220 self.assertLess(C(0), -1)
2221 self.assertLessEqual(C(0), -1)
2222 self.assertGreater(C(0), 1)
2223 self.assertGreaterEqual(C(0), 1)
2224
2225 def test_no_order(self):
2226 # Test that no ordering functions are added by default.
2227 @dataclass(order=False)
2228 class C:
2229 x: int
2230 # Make sure no order methods are added.
2231 self.assertNotIn('__le__', C.__dict__)
2232 self.assertNotIn('__lt__', C.__dict__)
2233 self.assertNotIn('__ge__', C.__dict__)
2234 self.assertNotIn('__gt__', C.__dict__)
2235
2236 # Test that __lt__ is still called
2237 @dataclass(order=False)
2238 class C:
2239 x: int
2240 def __lt__(self, other):
2241 return False
2242 # Make sure other methods aren't added.
2243 self.assertNotIn('__le__', C.__dict__)
2244 self.assertNotIn('__ge__', C.__dict__)
2245 self.assertNotIn('__gt__', C.__dict__)
2246
2247 def test_overwriting_order(self):
2248 with self.assertRaisesRegex(TypeError,
2249 'Cannot overwrite attribute __lt__'
2250 '.*using functools.total_ordering'):
2251 @dataclass(order=True)
2252 class C:
2253 x: int
2254 def __lt__(self):
2255 pass
2256
2257 with self.assertRaisesRegex(TypeError,
2258 'Cannot overwrite attribute __le__'
2259 '.*using functools.total_ordering'):
2260 @dataclass(order=True)
2261 class C:
2262 x: int
2263 def __le__(self):
2264 pass
2265
2266 with self.assertRaisesRegex(TypeError,
2267 'Cannot overwrite attribute __gt__'
2268 '.*using functools.total_ordering'):
2269 @dataclass(order=True)
2270 class C:
2271 x: int
2272 def __gt__(self):
2273 pass
2274
2275 with self.assertRaisesRegex(TypeError,
2276 'Cannot overwrite attribute __ge__'
2277 '.*using functools.total_ordering'):
2278 @dataclass(order=True)
2279 class C:
2280 x: int
2281 def __ge__(self):
2282 pass
2283
2284class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002285 def test_unsafe_hash(self):
2286 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002287 class C:
2288 x: int
2289 y: str
2290 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2291
Eric V. Smithea8fc522018-01-27 19:07:40 -05002292 def test_hash_rules(self):
2293 def non_bool(value):
2294 # Map to something else that's True, but not a bool.
2295 if value is None:
2296 return None
2297 if value:
2298 return (3,)
2299 return 0
2300
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002301 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2302 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2303 frozen=frozen):
2304 if result != 'exception':
2305 if with_hash:
2306 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2307 class C:
2308 def __hash__(self):
2309 return 0
2310 else:
2311 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2312 class C:
2313 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002314
2315 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002316 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002317 # __hash__ contains the function we generated.
2318 self.assertIn('__hash__', C.__dict__)
2319 self.assertIsNotNone(C.__dict__['__hash__'])
2320
Eric V. Smithea8fc522018-01-27 19:07:40 -05002321 elif result == '':
2322 # __hash__ is not present in our class.
2323 if not with_hash:
2324 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002325
Eric V. Smithea8fc522018-01-27 19:07:40 -05002326 elif result == 'none':
2327 # __hash__ is set to None.
2328 self.assertIn('__hash__', C.__dict__)
2329 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002330
2331 elif result == 'exception':
2332 # Creating the class should cause an exception.
2333 # This only happens with with_hash==True.
2334 assert(with_hash)
2335 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2336 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2337 class C:
2338 def __hash__(self):
2339 return 0
2340
Eric V. Smithea8fc522018-01-27 19:07:40 -05002341 else:
2342 assert False, f'unknown result {result!r}'
2343
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002344 # There are 8 cases of:
2345 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002346 # eq=True/False
2347 # frozen=True/False
2348 # And for each of these, a different result if
2349 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002350 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2351 (False, False, False, '', ''),
2352 (False, False, True, '', ''),
2353 (False, True, False, 'none', ''),
2354 (False, True, True, 'fn', ''),
2355 (True, False, False, 'fn', 'exception'),
2356 (True, False, True, 'fn', 'exception'),
2357 (True, True, False, 'fn', 'exception'),
2358 (True, True, True, 'fn', 'exception'),
2359 ], 1):
2360 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2361 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002362
2363 # Test non-bool truth values, too. This is just to
2364 # make sure the data-driven table in the decorator
2365 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002366 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2367 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002368
2369
2370 def test_eq_only(self):
2371 # If a class defines __eq__, __hash__ is automatically added
2372 # and set to None. This is normal Python behavior, not
2373 # related to dataclasses. Make sure we don't interfere with
2374 # that (see bpo=32546).
2375
2376 @dataclass
2377 class C:
2378 i: int
2379 def __eq__(self, other):
2380 return self.i == other.i
2381 self.assertEqual(C(1), C(1))
2382 self.assertNotEqual(C(1), C(4))
2383
2384 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002385 # unsafe_hash=True.
2386 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002387 class C:
2388 i: int
2389 def __eq__(self, other):
2390 return self.i == other.i
2391 self.assertEqual(C(1), C(1.0))
2392 self.assertEqual(hash(C(1)), hash(C(1.0)))
2393
2394 # And check that the classes __eq__ is being used, despite
2395 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002396 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002397 class C:
2398 i: int
2399 def __eq__(self, other):
2400 return self.i == 3 and self.i == other.i
2401 self.assertEqual(C(3), C(3))
2402 self.assertNotEqual(C(1), C(1))
2403 self.assertEqual(hash(C(1)), hash(C(1.0)))
2404
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002405 def test_0_field_hash(self):
2406 @dataclass(frozen=True)
2407 class C:
2408 pass
2409 self.assertEqual(hash(C()), hash(()))
2410
2411 @dataclass(unsafe_hash=True)
2412 class C:
2413 pass
2414 self.assertEqual(hash(C()), hash(()))
2415
2416 def test_1_field_hash(self):
2417 @dataclass(frozen=True)
2418 class C:
2419 x: int
2420 self.assertEqual(hash(C(4)), hash((4,)))
2421 self.assertEqual(hash(C(42)), hash((42,)))
2422
2423 @dataclass(unsafe_hash=True)
2424 class C:
2425 x: int
2426 self.assertEqual(hash(C(4)), hash((4,)))
2427 self.assertEqual(hash(C(42)), hash((42,)))
2428
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002429 def test_hash_no_args(self):
2430 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002431 # make sure that if the @dataclass parameter name is changed
2432 # or the non-default hashing behavior changes, the default
2433 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002434
2435 class Base:
2436 def __hash__(self):
2437 return 301
2438
2439 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)1a579062018-02-25 19:09:05 -08002440 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002441 for frozen, eq, base, expected in [
2442 (None, None, object, 'unhashable'),
2443 (None, None, Base, 'unhashable'),
2444 (None, False, object, 'object'),
2445 (None, False, Base, 'base'),
2446 (None, True, object, 'unhashable'),
2447 (None, True, Base, 'unhashable'),
2448 (False, None, object, 'unhashable'),
2449 (False, None, Base, 'unhashable'),
2450 (False, False, object, 'object'),
2451 (False, False, Base, 'base'),
2452 (False, True, object, 'unhashable'),
2453 (False, True, Base, 'unhashable'),
2454 (True, None, object, 'tuple'),
2455 (True, None, Base, 'tuple'),
2456 (True, False, object, 'object'),
2457 (True, False, Base, 'base'),
2458 (True, True, object, 'tuple'),
2459 (True, True, Base, 'tuple'),
2460 ]:
2461
2462 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2463 # First, create the class.
2464 if frozen is None and eq is None:
2465 @dataclass
2466 class C(base):
2467 i: int
2468 elif frozen is None:
2469 @dataclass(eq=eq)
2470 class C(base):
2471 i: int
2472 elif eq is None:
2473 @dataclass(frozen=frozen)
2474 class C(base):
2475 i: int
2476 else:
2477 @dataclass(frozen=frozen, eq=eq)
2478 class C(base):
2479 i: int
2480
2481 # Now, make sure it hashes as expected.
2482 if expected == 'unhashable':
2483 c = C(10)
2484 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2485 hash(c)
2486
2487 elif expected == 'base':
2488 self.assertEqual(hash(C(10)), 301)
2489
2490 elif expected == 'object':
2491 # I'm not sure what test to use here. object's
2492 # hash isn't based on id(), so calling hash()
2493 # won't tell us much. So, just check the function
2494 # used is object's.
2495 self.assertIs(C.__hash__, object.__hash__)
2496
2497 elif expected == 'tuple':
2498 self.assertEqual(hash(C(42)), hash((42,)))
2499
2500 else:
2501 assert False, f'unknown value for expected={expected!r}'
2502
Eric V. Smithea8fc522018-01-27 19:07:40 -05002503
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002504class TestFrozen(unittest.TestCase):
2505 def test_frozen(self):
2506 @dataclass(frozen=True)
2507 class C:
2508 i: int
2509
2510 c = C(10)
2511 self.assertEqual(c.i, 10)
2512 with self.assertRaises(FrozenInstanceError):
2513 c.i = 5
2514 self.assertEqual(c.i, 10)
2515
2516 def test_inherit(self):
2517 @dataclass(frozen=True)
2518 class C:
2519 i: int
2520
2521 @dataclass(frozen=True)
2522 class D(C):
2523 j: int
2524
2525 d = D(0, 10)
2526 with self.assertRaises(FrozenInstanceError):
2527 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002528 with self.assertRaises(FrozenInstanceError):
2529 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002530 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002531 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002532
Miss Islington (bot)45648312018-03-18 18:03:36 -07002533 # Test both ways: with an intermediate normal (non-dataclass)
2534 # class and without an intermediate class.
2535 def test_inherit_nonfrozen_from_frozen(self):
2536 for intermediate_class in [True, False]:
2537 with self.subTest(intermediate_class=intermediate_class):
2538 @dataclass(frozen=True)
2539 class C:
2540 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002541
Miss Islington (bot)45648312018-03-18 18:03:36 -07002542 if intermediate_class:
2543 class I(C): pass
2544 else:
2545 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002546
Miss Islington (bot)45648312018-03-18 18:03:36 -07002547 with self.assertRaisesRegex(TypeError,
2548 'cannot inherit non-frozen dataclass from a frozen one'):
2549 @dataclass
2550 class D(I):
2551 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002552
Miss Islington (bot)45648312018-03-18 18:03:36 -07002553 def test_inherit_frozen_from_nonfrozen(self):
2554 for intermediate_class in [True, False]:
2555 with self.subTest(intermediate_class=intermediate_class):
2556 @dataclass
2557 class C:
2558 i: int
2559
2560 if intermediate_class:
2561 class I(C): pass
2562 else:
2563 I = C
2564
2565 with self.assertRaisesRegex(TypeError,
2566 'cannot inherit frozen dataclass from a non-frozen one'):
2567 @dataclass(frozen=True)
2568 class D(I):
2569 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002570
2571 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002572 for intermediate_class in [True, False]:
2573 with self.subTest(intermediate_class=intermediate_class):
2574 class C:
2575 pass
2576
2577 if intermediate_class:
2578 class I(C): pass
2579 else:
2580 I = C
2581
2582 @dataclass(frozen=True)
2583 class D(I):
2584 i: int
2585
2586 d = D(10)
2587 with self.assertRaises(FrozenInstanceError):
2588 d.i = 5
2589
2590 def test_non_frozen_normal_derived(self):
2591 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002592
2593 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002594 class D:
2595 x: int
2596 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002597
Miss Islington (bot)45648312018-03-18 18:03:36 -07002598 class S(D):
2599 pass
2600
2601 s = S(3)
2602 self.assertEqual(s.x, 3)
2603 self.assertEqual(s.y, 10)
2604 s.cached = True
2605
2606 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002607 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002608 s.x = 5
2609 with self.assertRaises(FrozenInstanceError):
2610 s.y = 5
2611 self.assertEqual(s.x, 3)
2612 self.assertEqual(s.y, 10)
2613 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002614
2615
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002616class TestSlots(unittest.TestCase):
2617 def test_simple(self):
2618 @dataclass
2619 class C:
2620 __slots__ = ('x',)
2621 x: Any
2622
2623 # There was a bug where a variable in a slot was assumed
2624 # to also have a default value (of type types.MemberDescriptorType).
2625 with self.assertRaisesRegex(TypeError,
2626 "__init__\(\) missing 1 required positional argument: 'x'"):
2627 C()
2628
2629 # We can create an instance, and assign to x.
2630 c = C(10)
2631 self.assertEqual(c.x, 10)
2632 c.x = 5
2633 self.assertEqual(c.x, 5)
2634
2635 # We can't assign to anything else.
2636 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2637 c.y = 5
2638
2639 def test_derived_added_field(self):
2640 # See bpo-33100.
2641 @dataclass
2642 class Base:
2643 __slots__ = ('x',)
2644 x: Any
2645
2646 @dataclass
2647 class Derived(Base):
2648 x: int
2649 y: int
2650
2651 d = Derived(1, 2)
2652 self.assertEqual((d.x, d.y), (1, 2))
2653
2654 # We can add a new field to the derived instance.
2655 d.z = 10
2656
2657
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002658if __name__ == '__main__':
2659 unittest.main()