blob: 867210688f5737f8d8034d26492be73e557e5a1c [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +03009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050014from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040016import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
17import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
18
Eric V. Smithf0db54a2017-12-04 16:58:55 -050019# Just any custom exception we can catch.
20class CustomError(Exception): pass
21
22class TestCase(unittest.TestCase):
23 def test_no_fields(self):
24 @dataclass
25 class C:
26 pass
27
28 o = C()
29 self.assertEqual(len(fields(C)), 0)
30
Eric V. Smith56970b82018-03-22 16:28:48 -040031 def test_no_fields_but_member_variable(self):
32 @dataclass
33 class C:
34 i = 0
35
36 o = C()
37 self.assertEqual(len(fields(C)), 0)
38
Eric V. Smithf0db54a2017-12-04 16:58:55 -050039 def test_one_field_no_default(self):
40 @dataclass
41 class C:
42 x: int
43
44 o = C(42)
45 self.assertEqual(o.x, 42)
46
47 def test_named_init_params(self):
48 @dataclass
49 class C:
50 x: int
51
52 o = C(x=32)
53 self.assertEqual(o.x, 32)
54
55 def test_two_fields_one_default(self):
56 @dataclass
57 class C:
58 x: int
59 y: int = 0
60
61 o = C(3)
62 self.assertEqual((o.x, o.y), (3, 0))
63
64 # Non-defaults following defaults.
65 with self.assertRaisesRegex(TypeError,
66 "non-default argument 'y' follows "
67 "default argument"):
68 @dataclass
69 class C:
70 x: int = 0
71 y: int
72
73 # A derived class adds a non-default field after a default one.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int = 0
80
81 @dataclass
82 class C(B):
83 y: int
84
85 # Override a base class field and add a default to
86 # a field which didn't use to have a default.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class B:
92 x: int
93 y: int
94
95 @dataclass
96 class C(B):
97 x: int = 0
98
Eric V. Smithdbf9cff2018-02-25 21:30:17 -050099 def test_overwrite_hash(self):
100 # Test that declaring this class isn't an error. It should
101 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500102 @dataclass(frozen=True)
103 class C:
104 x: int
105 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500106 return 301
107 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500108
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500109 # Test that declaring this class isn't an error. It should
110 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500111 @dataclass(frozen=True)
112 class C:
113 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500114 def __eq__(self, other):
115 return False
116 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500117
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500118 # But this one should generate an exception, because with
119 # unsafe_hash=True, it's an error to have a __hash__ defined.
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __hash__'):
122 @dataclass(unsafe_hash=True)
123 class C:
124 def __hash__(self):
125 pass
126
127 # Creating this class should not generate an exception,
128 # because even though __hash__ exists before @dataclass is
129 # called, (due to __eq__ being defined), since it's None
130 # that's okay.
131 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500132 class C:
133 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500134 def __eq__(self):
135 pass
136 # The generated hash function works as we'd expect.
137 self.assertEqual(hash(C(10)), hash((10,)))
138
139 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400140 # __hash__ exists and is not None, which it would be if it
141 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 x: int
147 def __eq__(self):
148 pass
149 def __hash__(self):
150 pass
151
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500152 def test_overwrite_fields_in_derived_class(self):
153 # Note that x from C1 replaces x in Base, but the order remains
154 # the same as defined in Base.
155 @dataclass
156 class Base:
157 x: Any = 15.0
158 y: int = 0
159
160 @dataclass
161 class C1(Base):
162 z: int = 10
163 x: int = 15
164
165 o = Base()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
167
168 o = C1()
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
170
171 o = C1(x=5)
172 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
173
174 def test_field_named_self(self):
175 @dataclass
176 class C:
177 self: str
178 c=C('foo')
179 self.assertEqual(c.self, 'foo')
180
181 # Make sure the first parameter is not named 'self'.
182 sig = inspect.signature(C.__init__)
183 first = next(iter(sig.parameters))
184 self.assertNotEqual('self', first)
185
186 # But we do use 'self' if no field named self.
187 @dataclass
188 class C:
189 selfx: str
190
191 # Make sure the first parameter is named 'self'.
192 sig = inspect.signature(C.__init__)
193 first = next(iter(sig.parameters))
194 self.assertEqual('self', first)
195
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300196 def test_field_named_object(self):
197 @dataclass
198 class C:
199 object: str
200 c = C('foo')
201 self.assertEqual(c.object, 'foo')
202
203 def test_field_named_object_frozen(self):
204 @dataclass(frozen=True)
205 class C:
206 object: str
207 c = C('foo')
208 self.assertEqual(c.object, 'foo')
209
210 def test_field_named_like_builtin(self):
211 # Attribute names can shadow built-in names
212 # since code generation is used.
213 # Ensure that this is not happening.
214 exclusions = {'None', 'True', 'False'}
215 builtins_names = sorted(
216 b for b in builtins.__dict__.keys()
217 if not b.startswith('__') and b not in exclusions
218 )
219 attributes = [(name, str) for name in builtins_names]
220 C = make_dataclass('C', attributes)
221
222 c = C(*[name for name in builtins_names])
223
224 for name in builtins_names:
225 self.assertEqual(getattr(c, name), name)
226
227 def test_field_named_like_builtin_frozen(self):
228 # Attribute names can shadow built-in names
229 # since code generation is used.
230 # Ensure that this is not happening
231 # for frozen data classes.
232 exclusions = {'None', 'True', 'False'}
233 builtins_names = sorted(
234 b for b in builtins.__dict__.keys()
235 if not b.startswith('__') and b not in exclusions
236 )
237 attributes = [(name, str) for name in builtins_names]
238 C = make_dataclass('C', attributes, frozen=True)
239
240 c = C(*[name for name in builtins_names])
241
242 for name in builtins_names:
243 self.assertEqual(getattr(c, name), name)
244
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500245 def test_0_field_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 pass
250
251 @dataclass(order=False)
252 class C1:
253 pass
254
255 for cls in [C0, C1]:
256 with self.subTest(cls=cls):
257 self.assertEqual(cls(), cls())
258 for idx, fn in enumerate([lambda a, b: a < b,
259 lambda a, b: a <= b,
260 lambda a, b: a > b,
261 lambda a, b: a >= b]):
262 with self.subTest(idx=idx):
263 with self.assertRaisesRegex(TypeError,
264 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
265 fn(cls(), cls())
266
267 @dataclass(order=True)
268 class C:
269 pass
270 self.assertLessEqual(C(), C())
271 self.assertGreaterEqual(C(), C())
272
273 def test_1_field_compare(self):
274 # Ensure that order=False is the default.
275 @dataclass
276 class C0:
277 x: int
278
279 @dataclass(order=False)
280 class C1:
281 x: int
282
283 for cls in [C0, C1]:
284 with self.subTest(cls=cls):
285 self.assertEqual(cls(1), cls(1))
286 self.assertNotEqual(cls(0), cls(1))
287 for idx, fn in enumerate([lambda a, b: a < b,
288 lambda a, b: a <= b,
289 lambda a, b: a > b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 with self.assertRaisesRegex(TypeError,
293 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
294 fn(cls(0), cls(0))
295
296 @dataclass(order=True)
297 class C:
298 x: int
299 self.assertLess(C(0), C(1))
300 self.assertLessEqual(C(0), C(1))
301 self.assertLessEqual(C(1), C(1))
302 self.assertGreater(C(1), C(0))
303 self.assertGreaterEqual(C(1), C(0))
304 self.assertGreaterEqual(C(1), C(1))
305
306 def test_simple_compare(self):
307 # Ensure that order=False is the default.
308 @dataclass
309 class C0:
310 x: int
311 y: int
312
313 @dataclass(order=False)
314 class C1:
315 x: int
316 y: int
317
318 for cls in [C0, C1]:
319 with self.subTest(cls=cls):
320 self.assertEqual(cls(0, 0), cls(0, 0))
321 self.assertEqual(cls(1, 2), cls(1, 2))
322 self.assertNotEqual(cls(1, 0), cls(0, 0))
323 self.assertNotEqual(cls(1, 0), cls(1, 1))
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
331 fn(cls(0, 0), cls(0, 0))
332
333 @dataclass(order=True)
334 class C:
335 x: int
336 y: int
337
338 for idx, fn in enumerate([lambda a, b: a == b,
339 lambda a, b: a <= b,
340 lambda a, b: a >= b]):
341 with self.subTest(idx=idx):
342 self.assertTrue(fn(C(0, 0), C(0, 0)))
343
344 for idx, fn in enumerate([lambda a, b: a < b,
345 lambda a, b: a <= b,
346 lambda a, b: a != b]):
347 with self.subTest(idx=idx):
348 self.assertTrue(fn(C(0, 0), C(0, 1)))
349 self.assertTrue(fn(C(0, 1), C(1, 0)))
350 self.assertTrue(fn(C(1, 0), C(1, 1)))
351
352 for idx, fn in enumerate([lambda a, b: a > b,
353 lambda a, b: a >= b,
354 lambda a, b: a != b]):
355 with self.subTest(idx=idx):
356 self.assertTrue(fn(C(0, 1), C(0, 0)))
357 self.assertTrue(fn(C(1, 0), C(0, 1)))
358 self.assertTrue(fn(C(1, 1), C(1, 0)))
359
360 def test_compare_subclasses(self):
361 # Comparisons fail for subclasses, even if no fields
362 # are added.
363 @dataclass
364 class B:
365 i: int
366
367 @dataclass
368 class C(B):
369 pass
370
371 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
372 (lambda a, b: a != b, True)]):
373 with self.subTest(idx=idx):
374 self.assertEqual(fn(B(0), C(0)), expected)
375
376 for idx, fn in enumerate([lambda a, b: a < b,
377 lambda a, b: a <= b,
378 lambda a, b: a > b,
379 lambda a, b: a >= b]):
380 with self.subTest(idx=idx):
381 with self.assertRaisesRegex(TypeError,
382 "not supported between instances of 'B' and 'C'"):
383 fn(B(0), C(0))
384
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500385 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500386 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500387 for (eq, order, result ) in [
388 (False, False, 'neither'),
389 (False, True, 'exception'),
390 (True, False, 'eq_only'),
391 (True, True, 'both'),
392 ]:
393 with self.subTest(eq=eq, order=order):
394 if result == 'exception':
395 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
396 @dataclass(eq=eq, order=order)
397 class C:
398 pass
399 else:
400 @dataclass(eq=eq, order=order)
401 class C:
402 pass
403
404 if result == 'neither':
405 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 self.assertNotIn('__lt__', C.__dict__)
407 self.assertNotIn('__le__', C.__dict__)
408 self.assertNotIn('__gt__', C.__dict__)
409 self.assertNotIn('__ge__', C.__dict__)
410 elif result == 'both':
411 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500412 self.assertIn('__lt__', C.__dict__)
413 self.assertIn('__le__', C.__dict__)
414 self.assertIn('__gt__', C.__dict__)
415 self.assertIn('__ge__', C.__dict__)
416 elif result == 'eq_only':
417 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500418 self.assertNotIn('__lt__', C.__dict__)
419 self.assertNotIn('__le__', C.__dict__)
420 self.assertNotIn('__gt__', C.__dict__)
421 self.assertNotIn('__ge__', C.__dict__)
422 else:
423 assert False, f'unknown result {result!r}'
424
425 def test_field_no_default(self):
426 @dataclass
427 class C:
428 x: int = field()
429
430 self.assertEqual(C(5).x, 5)
431
432 with self.assertRaisesRegex(TypeError,
433 r"__init__\(\) missing 1 required "
434 "positional argument: 'x'"):
435 C()
436
437 def test_field_default(self):
438 default = object()
439 @dataclass
440 class C:
441 x: object = field(default=default)
442
443 self.assertIs(C.x, default)
444 c = C(10)
445 self.assertEqual(c.x, 10)
446
447 # If we delete the instance attribute, we should then see the
448 # class attribute.
449 del c.x
450 self.assertIs(c.x, default)
451
452 self.assertIs(C().x, default)
453
454 def test_not_in_repr(self):
455 @dataclass
456 class C:
457 x: int = field(repr=False)
458 with self.assertRaises(TypeError):
459 C()
460 c = C(10)
461 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
462
463 @dataclass
464 class C:
465 x: int = field(repr=False)
466 y: int
467 c = C(10, 20)
468 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
469
470 def test_not_in_compare(self):
471 @dataclass
472 class C:
473 x: int = 0
474 y: int = field(compare=False, default=4)
475
476 self.assertEqual(C(), C(0, 20))
477 self.assertEqual(C(1, 10), C(1, 20))
478 self.assertNotEqual(C(3), C(4, 10))
479 self.assertNotEqual(C(3, 10), C(4, 10))
480
481 def test_hash_field_rules(self):
482 # Test all 6 cases of:
483 # hash=True/False/None
484 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500485 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500486 (True, False, 'field' ),
487 (True, True, 'field' ),
488 (False, False, 'absent'),
489 (False, True, 'absent'),
490 (None, False, 'absent'),
491 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500492 ]:
493 with self.subTest(hash=hash_, compare=compare):
494 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500495 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500496 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500497
498 if result == 'field':
499 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500500 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500501 elif result == 'absent':
502 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500503 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500504 else:
505 assert False, f'unknown result {result!r}'
506
507 def test_init_false_no_default(self):
508 # If init=False and no default value, then the field won't be
509 # present in the instance.
510 @dataclass
511 class C:
512 x: int = field(init=False)
513
514 self.assertNotIn('x', C().__dict__)
515
516 @dataclass
517 class C:
518 x: int
519 y: int = 0
520 z: int = field(init=False)
521 t: int = 10
522
523 self.assertNotIn('z', C(0).__dict__)
524 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
525
526 def test_class_marker(self):
527 @dataclass
528 class C:
529 x: int
530 y: str = field(init=False, default=None)
531 z: str = field(repr=False)
532
533 the_fields = fields(C)
534 # the_fields is a tuple of 3 items, each value
535 # is in __annotations__.
536 self.assertIsInstance(the_fields, tuple)
537 for f in the_fields:
538 self.assertIs(type(f), Field)
539 self.assertIn(f.name, C.__annotations__)
540
541 self.assertEqual(len(the_fields), 3)
542
543 self.assertEqual(the_fields[0].name, 'x')
544 self.assertEqual(the_fields[0].type, int)
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertTrue (the_fields[0].init)
547 self.assertTrue (the_fields[0].repr)
548 self.assertEqual(the_fields[1].name, 'y')
549 self.assertEqual(the_fields[1].type, str)
550 self.assertIsNone(getattr(C, 'y'))
551 self.assertFalse(the_fields[1].init)
552 self.assertTrue (the_fields[1].repr)
553 self.assertEqual(the_fields[2].name, 'z')
554 self.assertEqual(the_fields[2].type, str)
555 self.assertFalse(hasattr(C, 'z'))
556 self.assertTrue (the_fields[2].init)
557 self.assertFalse(the_fields[2].repr)
558
559 def test_field_order(self):
560 @dataclass
561 class B:
562 a: str = 'B:a'
563 b: str = 'B:b'
564 c: str = 'B:c'
565
566 @dataclass
567 class C(B):
568 b: str = 'C:b'
569
570 self.assertEqual([(f.name, f.default) for f in fields(C)],
571 [('a', 'B:a'),
572 ('b', 'C:b'),
573 ('c', 'B:c')])
574
575 @dataclass
576 class D(B):
577 c: str = 'D:c'
578
579 self.assertEqual([(f.name, f.default) for f in fields(D)],
580 [('a', 'B:a'),
581 ('b', 'B:b'),
582 ('c', 'D:c')])
583
584 @dataclass
585 class E(D):
586 a: str = 'E:a'
587 d: str = 'E:d'
588
589 self.assertEqual([(f.name, f.default) for f in fields(E)],
590 [('a', 'E:a'),
591 ('b', 'B:b'),
592 ('c', 'D:c'),
593 ('d', 'E:d')])
594
595 def test_class_attrs(self):
596 # We only have a class attribute if a default value is
597 # specified, either directly or via a field with a default.
598 default = object()
599 @dataclass
600 class C:
601 x: int
602 y: int = field(repr=False)
603 z: object = default
604 t: int = field(default=100)
605
606 self.assertFalse(hasattr(C, 'x'))
607 self.assertFalse(hasattr(C, 'y'))
608 self.assertIs (C.z, default)
609 self.assertEqual(C.t, 100)
610
611 def test_disallowed_mutable_defaults(self):
612 # For the known types, don't allow mutable default values.
613 for typ, empty, non_empty in [(list, [], [1]),
614 (dict, {}, {0:1}),
615 (set, set(), set([1])),
616 ]:
617 with self.subTest(typ=typ):
618 # Can't use a zero-length value.
619 with self.assertRaisesRegex(ValueError,
620 f'mutable default {typ} for field '
621 'x is not allowed'):
622 @dataclass
623 class Point:
624 x: typ = empty
625
626
627 # Nor a non-zero-length value
628 with self.assertRaisesRegex(ValueError,
629 f'mutable default {typ} for field '
630 'y is not allowed'):
631 @dataclass
632 class Point:
633 y: typ = non_empty
634
635 # Check subtypes also fail.
636 class Subclass(typ): pass
637
638 with self.assertRaisesRegex(ValueError,
639 f"mutable default .*Subclass'>"
640 ' for field z is not allowed'
641 ):
642 @dataclass
643 class Point:
644 z: typ = Subclass()
645
646 # Because this is a ClassVar, it can be mutable.
647 @dataclass
648 class C:
649 z: ClassVar[typ] = typ()
650
651 # Because this is a ClassVar, it can be mutable.
652 @dataclass
653 class C:
654 x: ClassVar[typ] = Subclass()
655
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500656 def test_deliberately_mutable_defaults(self):
657 # If a mutable default isn't in the known list of
658 # (list, dict, set), then it's okay.
659 class Mutable:
660 def __init__(self):
661 self.l = []
662
663 @dataclass
664 class C:
665 x: Mutable
666
667 # These 2 instances will share this value of x.
668 lst = Mutable()
669 o1 = C(lst)
670 o2 = C(lst)
671 self.assertEqual(o1, o2)
672 o1.x.l.extend([1, 2])
673 self.assertEqual(o1, o2)
674 self.assertEqual(o1.x.l, [1, 2])
675 self.assertIs(o1.x, o2.x)
676
677 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400678 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500679 @dataclass()
680 class C:
681 x: int
682
683 self.assertEqual(C(42).x, 42)
684
685 def test_not_tuple(self):
686 # Make sure we can't be compared to a tuple.
687 @dataclass
688 class Point:
689 x: int
690 y: int
691 self.assertNotEqual(Point(1, 2), (1, 2))
692
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400693 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 @dataclass
695 class C:
696 x: int
697 y: int
698 self.assertNotEqual(Point(1, 3), C(1, 3))
699
Windson yangbe372d72019-04-23 02:45:34 +0800700 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 # Test that some of the problems with namedtuple don't happen
702 # here.
703 @dataclass
704 class Point3D:
705 x: int
706 y: int
707 z: int
708
709 @dataclass
710 class Date:
711 year: int
712 month: int
713 day: int
714
715 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
716 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
717
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400718 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200719 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 x, y, z = Point3D(4, 5, 6)
721
Eric V. Smith7c99e932018-01-28 19:18:55 -0500722 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # equal.
724 @dataclass
725 class Point3Dv1:
726 x: int = 0
727 y: int = 0
728 z: int = 0
729 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
730
731 def test_function_annotations(self):
732 # Some dummy class and instance to use as a default.
733 class F:
734 pass
735 f = F()
736
737 def validate_class(cls):
738 # First, check __annotations__, even though they're not
739 # function annotations.
740 self.assertEqual(cls.__annotations__['i'], int)
741 self.assertEqual(cls.__annotations__['j'], str)
742 self.assertEqual(cls.__annotations__['k'], F)
743 self.assertEqual(cls.__annotations__['l'], float)
744 self.assertEqual(cls.__annotations__['z'], complex)
745
746 # Verify __init__.
747
748 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400749 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500750 self.assertIs(signature.return_annotation, None)
751
752 # Check each parameter.
753 params = iter(signature.parameters.values())
754 param = next(params)
755 # This is testing an internal name, and probably shouldn't be tested.
756 self.assertEqual(param.name, 'self')
757 param = next(params)
758 self.assertEqual(param.name, 'i')
759 self.assertIs (param.annotation, int)
760 self.assertEqual(param.default, inspect.Parameter.empty)
761 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
762 param = next(params)
763 self.assertEqual(param.name, 'j')
764 self.assertIs (param.annotation, str)
765 self.assertEqual(param.default, inspect.Parameter.empty)
766 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
767 param = next(params)
768 self.assertEqual(param.name, 'k')
769 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400770 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
772 param = next(params)
773 self.assertEqual(param.name, 'l')
774 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400775 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500776 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
777 self.assertRaises(StopIteration, next, params)
778
779
780 @dataclass
781 class C:
782 i: int
783 j: str
784 k: F = f
785 l: float=field(default=None)
786 z: complex=field(default=3+4j, init=False)
787
788 validate_class(C)
789
790 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500791 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 class C:
793 i: int
794 j: str
795 k: F = f
796 l: float=field(default=None)
797 z: complex=field(default=3+4j, init=False)
798
799 validate_class(C)
800
Eric V. Smith03220fd2017-12-29 13:59:58 -0500801 def test_missing_default(self):
802 # Test that MISSING works the same as a default not being
803 # specified.
804 @dataclass
805 class C:
806 x: int=field(default=MISSING)
807 with self.assertRaisesRegex(TypeError,
808 r'__init__\(\) missing 1 required '
809 'positional argument'):
810 C()
811 self.assertNotIn('x', C.__dict__)
812
813 @dataclass
814 class D:
815 x: int
816 with self.assertRaisesRegex(TypeError,
817 r'__init__\(\) missing 1 required '
818 'positional argument'):
819 D()
820 self.assertNotIn('x', D.__dict__)
821
822 def test_missing_default_factory(self):
823 # Test that MISSING works the same as a default factory not
824 # being specified (which is really the same as a default not
825 # being specified, too).
826 @dataclass
827 class C:
828 x: int=field(default_factory=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int=field(default=MISSING, default_factory=MISSING)
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_repr(self):
845 self.assertIn('MISSING_TYPE object', repr(MISSING))
846
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 def test_dont_include_other_annotations(self):
848 @dataclass
849 class C:
850 i: int
851 def foo(self) -> int:
852 return 4
853 @property
854 def bar(self) -> int:
855 return 5
856 self.assertEqual(list(C.__annotations__), ['i'])
857 self.assertEqual(C(10).foo(), 4)
858 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400859 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500860
861 def test_post_init(self):
862 # Just make sure it gets called
863 @dataclass
864 class C:
865 def __post_init__(self):
866 raise CustomError()
867 with self.assertRaises(CustomError):
868 C()
869
870 @dataclass
871 class C:
872 i: int = 10
873 def __post_init__(self):
874 if self.i == 10:
875 raise CustomError()
876 with self.assertRaises(CustomError):
877 C()
878 # post-init gets called, but doesn't raise. This is just
879 # checking that self is used correctly.
880 C(5)
881
882 # If there's not an __init__, then post-init won't get called.
883 @dataclass(init=False)
884 class C:
885 def __post_init__(self):
886 raise CustomError()
887 # Creating the class won't raise
888 C()
889
890 @dataclass
891 class C:
892 x: int = 0
893 def __post_init__(self):
894 self.x *= 2
895 self.assertEqual(C().x, 0)
896 self.assertEqual(C(2).x, 4)
897
Mike53f7a7c2017-12-14 14:04:53 +0300898 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500899 # attributes.
900 @dataclass(frozen=True)
901 class C:
902 x: int = 0
903 def __post_init__(self):
904 self.x *= 2
905 with self.assertRaises(FrozenInstanceError):
906 C()
907
908 def test_post_init_super(self):
909 # Make sure super() post-init isn't called by default.
910 class B:
911 def __post_init__(self):
912 raise CustomError()
913
914 @dataclass
915 class C(B):
916 def __post_init__(self):
917 self.x = 5
918
919 self.assertEqual(C().x, 5)
920
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400921 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500922 @dataclass
923 class C(B):
924 def __post_init__(self):
925 super().__post_init__()
926
927 with self.assertRaises(CustomError):
928 C()
929
930 # Make sure post-init is called, even if not defined in our
931 # class.
932 @dataclass
933 class C(B):
934 pass
935
936 with self.assertRaises(CustomError):
937 C()
938
939 def test_post_init_staticmethod(self):
940 flag = False
941 @dataclass
942 class C:
943 x: int
944 y: int
945 @staticmethod
946 def __post_init__():
947 nonlocal flag
948 flag = True
949
950 self.assertFalse(flag)
951 c = C(3, 4)
952 self.assertEqual((c.x, c.y), (3, 4))
953 self.assertTrue(flag)
954
955 def test_post_init_classmethod(self):
956 @dataclass
957 class C:
958 flag = False
959 x: int
960 y: int
961 @classmethod
962 def __post_init__(cls):
963 cls.flag = True
964
965 self.assertFalse(C.flag)
966 c = C(3, 4)
967 self.assertEqual((c.x, c.y), (3, 4))
968 self.assertTrue(C.flag)
969
970 def test_class_var(self):
971 # Make sure ClassVars are ignored in __init__, __repr__, etc.
972 @dataclass
973 class C:
974 x: int
975 y: int = 10
976 z: ClassVar[int] = 1000
977 w: ClassVar[int] = 2000
978 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400979 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500980
981 c = C(5)
982 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400983 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400984 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500985 self.assertEqual(c.z, 1000)
986 self.assertEqual(c.w, 2000)
987 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400988 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500989 C.z += 1
990 self.assertEqual(c.z, 1001)
991 c = C(20)
992 self.assertEqual((c.x, c.y), (20, 10))
993 self.assertEqual(c.z, 1001)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400996 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500997
998 def test_class_var_no_default(self):
999 # If a ClassVar has no default value, it should not be set on the class.
1000 @dataclass
1001 class C:
1002 x: ClassVar[int]
1003
1004 self.assertNotIn('x', C.__dict__)
1005
1006 def test_class_var_default_factory(self):
1007 # It makes no sense for a ClassVar to have a default factory. When
1008 # would it be called? Call it yourself, since it's class-wide.
1009 with self.assertRaisesRegex(TypeError,
1010 'cannot have a default factory'):
1011 @dataclass
1012 class C:
1013 x: ClassVar[int] = field(default_factory=int)
1014
1015 self.assertNotIn('x', C.__dict__)
1016
1017 def test_class_var_with_default(self):
1018 # If a ClassVar has a default value, it should be set on the class.
1019 @dataclass
1020 class C:
1021 x: ClassVar[int] = 10
1022 self.assertEqual(C.x, 10)
1023
1024 @dataclass
1025 class C:
1026 x: ClassVar[int] = field(default=10)
1027 self.assertEqual(C.x, 10)
1028
1029 def test_class_var_frozen(self):
1030 # Make sure ClassVars work even if we're frozen.
1031 @dataclass(frozen=True)
1032 class C:
1033 x: int
1034 y: int = 10
1035 z: ClassVar[int] = 1000
1036 w: ClassVar[int] = 2000
1037 t: ClassVar[int] = 3000
1038
1039 c = C(5)
1040 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1041 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1042 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1043 self.assertEqual(c.z, 1000)
1044 self.assertEqual(c.w, 2000)
1045 self.assertEqual(c.t, 3000)
1046 # We can still modify the ClassVar, it's only instances that are
1047 # frozen.
1048 C.z += 1
1049 self.assertEqual(c.z, 1001)
1050 c = C(20)
1051 self.assertEqual((c.x, c.y), (20, 10))
1052 self.assertEqual(c.z, 1001)
1053 self.assertEqual(c.w, 2000)
1054 self.assertEqual(c.t, 3000)
1055
1056 def test_init_var_no_default(self):
1057 # If an InitVar has no default value, it should not be set on the class.
1058 @dataclass
1059 class C:
1060 x: InitVar[int]
1061
1062 self.assertNotIn('x', C.__dict__)
1063
1064 def test_init_var_default_factory(self):
1065 # It makes no sense for an InitVar to have a default factory. When
1066 # would it be called? Call it yourself, since it's class-wide.
1067 with self.assertRaisesRegex(TypeError,
1068 'cannot have a default factory'):
1069 @dataclass
1070 class C:
1071 x: InitVar[int] = field(default_factory=int)
1072
1073 self.assertNotIn('x', C.__dict__)
1074
1075 def test_init_var_with_default(self):
1076 # If an InitVar has a default value, it should be set on the class.
1077 @dataclass
1078 class C:
1079 x: InitVar[int] = 10
1080 self.assertEqual(C.x, 10)
1081
1082 @dataclass
1083 class C:
1084 x: InitVar[int] = field(default=10)
1085 self.assertEqual(C.x, 10)
1086
1087 def test_init_var(self):
1088 @dataclass
1089 class C:
1090 x: int = None
1091 init_param: InitVar[int] = None
1092
1093 def __post_init__(self, init_param):
1094 if self.x is None:
1095 self.x = init_param*2
1096
1097 c = C(init_param=10)
1098 self.assertEqual(c.x, 20)
1099
1100 def test_init_var_inheritance(self):
1101 # Note that this deliberately tests that a dataclass need not
1102 # have a __post_init__ function if it has an InitVar field.
1103 # It could just be used in a derived class, as shown here.
1104 @dataclass
1105 class Base:
1106 x: int
1107 init_base: InitVar[int]
1108
1109 # We can instantiate by passing the InitVar, even though
1110 # it's not used.
1111 b = Base(0, 10)
1112 self.assertEqual(vars(b), {'x': 0})
1113
1114 @dataclass
1115 class C(Base):
1116 y: int
1117 init_derived: InitVar[int]
1118
1119 def __post_init__(self, init_base, init_derived):
1120 self.x = self.x + init_base
1121 self.y = self.y + init_derived
1122
1123 c = C(10, 11, 50, 51)
1124 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1125
1126 def test_default_factory(self):
1127 # Test a factory that returns a new list.
1128 @dataclass
1129 class C:
1130 x: int
1131 y: list = field(default_factory=list)
1132
1133 c0 = C(3)
1134 c1 = C(3)
1135 self.assertEqual(c0.x, 3)
1136 self.assertEqual(c0.y, [])
1137 self.assertEqual(c0, c1)
1138 self.assertIsNot(c0.y, c1.y)
1139 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1140
1141 # Test a factory that returns a shared list.
1142 l = []
1143 @dataclass
1144 class C:
1145 x: int
1146 y: list = field(default_factory=lambda: l)
1147
1148 c0 = C(3)
1149 c1 = C(3)
1150 self.assertEqual(c0.x, 3)
1151 self.assertEqual(c0.y, [])
1152 self.assertEqual(c0, c1)
1153 self.assertIs(c0.y, c1.y)
1154 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1155
1156 # Test various other field flags.
1157 # repr
1158 @dataclass
1159 class C:
1160 x: list = field(default_factory=list, repr=False)
1161 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1162 self.assertEqual(C().x, [])
1163
1164 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001165 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001166 class C:
1167 x: list = field(default_factory=list, hash=False)
1168 self.assertEqual(astuple(C()), ([],))
1169 self.assertEqual(hash(C()), hash(()))
1170
1171 # init (see also test_default_factory_with_no_init)
1172 @dataclass
1173 class C:
1174 x: list = field(default_factory=list, init=False)
1175 self.assertEqual(astuple(C()), ([],))
1176
1177 # compare
1178 @dataclass
1179 class C:
1180 x: list = field(default_factory=list, compare=False)
1181 self.assertEqual(C(), C([1]))
1182
1183 def test_default_factory_with_no_init(self):
1184 # We need a factory with a side effect.
1185 factory = Mock()
1186
1187 @dataclass
1188 class C:
1189 x: list = field(default_factory=factory, init=False)
1190
1191 # Make sure the default factory is called for each new instance.
1192 C().x
1193 self.assertEqual(factory.call_count, 1)
1194 C().x
1195 self.assertEqual(factory.call_count, 2)
1196
1197 def test_default_factory_not_called_if_value_given(self):
1198 # We need a factory that we can test if it's been called.
1199 factory = Mock()
1200
1201 @dataclass
1202 class C:
1203 x: int = field(default_factory=factory)
1204
1205 # Make sure that if a field has a default factory function,
1206 # it's not called if a value is specified.
1207 C().x
1208 self.assertEqual(factory.call_count, 1)
1209 self.assertEqual(C(10).x, 10)
1210 self.assertEqual(factory.call_count, 1)
1211 C().x
1212 self.assertEqual(factory.call_count, 2)
1213
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001214 def test_default_factory_derived(self):
1215 # See bpo-32896.
1216 @dataclass
1217 class Foo:
1218 x: dict = field(default_factory=dict)
1219
1220 @dataclass
1221 class Bar(Foo):
1222 y: int = 1
1223
1224 self.assertEqual(Foo().x, {})
1225 self.assertEqual(Bar().x, {})
1226 self.assertEqual(Bar().y, 1)
1227
1228 @dataclass
1229 class Baz(Foo):
1230 pass
1231 self.assertEqual(Baz().x, {})
1232
1233 def test_intermediate_non_dataclass(self):
1234 # Test that an intermediate class that defines
1235 # annotations does not define fields.
1236
1237 @dataclass
1238 class A:
1239 x: int
1240
1241 class B(A):
1242 y: int
1243
1244 @dataclass
1245 class C(B):
1246 z: int
1247
1248 c = C(1, 3)
1249 self.assertEqual((c.x, c.z), (1, 3))
1250
1251 # .y was not initialized.
1252 with self.assertRaisesRegex(AttributeError,
1253 'object has no attribute'):
1254 c.y
1255
1256 # And if we again derive a non-dataclass, no fields are added.
1257 class D(C):
1258 t: int
1259 d = D(4, 5)
1260 self.assertEqual((d.x, d.z), (4, 5))
1261
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001262 def test_classvar_default_factory(self):
1263 # It's an error for a ClassVar to have a factory function.
1264 with self.assertRaisesRegex(TypeError,
1265 'cannot have a default factory'):
1266 @dataclass
1267 class C:
1268 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001269
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001270 def test_is_dataclass(self):
1271 class NotDataClass:
1272 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001273
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001274 self.assertFalse(is_dataclass(0))
1275 self.assertFalse(is_dataclass(int))
1276 self.assertFalse(is_dataclass(NotDataClass))
1277 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001278
1279 @dataclass
1280 class C:
1281 x: int
1282
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001283 @dataclass
1284 class D:
1285 d: C
1286 e: int
1287
1288 c = C(10)
1289 d = D(c, 4)
1290
1291 self.assertTrue(is_dataclass(C))
1292 self.assertTrue(is_dataclass(c))
1293 self.assertFalse(is_dataclass(c.x))
1294 self.assertTrue(is_dataclass(d.d))
1295 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001296
1297 def test_helper_fields_with_class_instance(self):
1298 # Check that we can call fields() on either a class or instance,
1299 # and get back the same thing.
1300 @dataclass
1301 class C:
1302 x: int
1303 y: float
1304
1305 self.assertEqual(fields(C), fields(C(0, 0.0)))
1306
1307 def test_helper_fields_exception(self):
1308 # Check that TypeError is raised if not passed a dataclass or
1309 # instance.
1310 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1311 fields(0)
1312
1313 class C: pass
1314 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1315 fields(C)
1316 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1317 fields(C())
1318
1319 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001320 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001321 @dataclass
1322 class C:
1323 x: int
1324 y: int
1325 c = C(1, 2)
1326
1327 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1328 self.assertEqual(asdict(c), asdict(c))
1329 self.assertIsNot(asdict(c), asdict(c))
1330 c.x = 42
1331 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1332 self.assertIs(type(asdict(c)), dict)
1333
1334 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001335 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001336 @dataclass
1337 class C:
1338 x: int
1339 y: int
1340 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1341 asdict(C)
1342 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1343 asdict(int)
1344
1345 def test_helper_asdict_copy_values(self):
1346 @dataclass
1347 class C:
1348 x: int
1349 y: List[int] = field(default_factory=list)
1350 initial = []
1351 c = C(1, initial)
1352 d = asdict(c)
1353 self.assertEqual(d['y'], initial)
1354 self.assertIsNot(d['y'], initial)
1355 c = C(1)
1356 d = asdict(c)
1357 d['y'].append(1)
1358 self.assertEqual(c.y, [])
1359
1360 def test_helper_asdict_nested(self):
1361 @dataclass
1362 class UserId:
1363 token: int
1364 group: int
1365 @dataclass
1366 class User:
1367 name: str
1368 id: UserId
1369 u = User('Joe', UserId(123, 1))
1370 d = asdict(u)
1371 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1372 self.assertIsNot(asdict(u), asdict(u))
1373 u.id.group = 2
1374 self.assertEqual(asdict(u), {'name': 'Joe',
1375 'id': {'token': 123, 'group': 2}})
1376
1377 def test_helper_asdict_builtin_containers(self):
1378 @dataclass
1379 class User:
1380 name: str
1381 id: int
1382 @dataclass
1383 class GroupList:
1384 id: int
1385 users: List[User]
1386 @dataclass
1387 class GroupTuple:
1388 id: int
1389 users: Tuple[User, ...]
1390 @dataclass
1391 class GroupDict:
1392 id: int
1393 users: Dict[str, User]
1394 a = User('Alice', 1)
1395 b = User('Bob', 2)
1396 gl = GroupList(0, [a, b])
1397 gt = GroupTuple(0, (a, b))
1398 gd = GroupDict(0, {'first': a, 'second': b})
1399 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1400 {'name': 'Bob', 'id': 2}]})
1401 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1402 {'name': 'Bob', 'id': 2})})
1403 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1404 'second': {'name': 'Bob', 'id': 2}}})
1405
Windson yangbe372d72019-04-23 02:45:34 +08001406 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001407 @dataclass
1408 class Child:
1409 d: object
1410
1411 @dataclass
1412 class Parent:
1413 child: Child
1414
1415 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1416 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1417
1418 def test_helper_asdict_factory(self):
1419 @dataclass
1420 class C:
1421 x: int
1422 y: int
1423 c = C(1, 2)
1424 d = asdict(c, dict_factory=OrderedDict)
1425 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1426 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1427 c.x = 42
1428 d = asdict(c, dict_factory=OrderedDict)
1429 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1430 self.assertIs(type(d), OrderedDict)
1431
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001432 def test_helper_asdict_namedtuple(self):
1433 T = namedtuple('T', 'a b c')
1434 @dataclass
1435 class C:
1436 x: str
1437 y: T
1438 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1439
1440 d = asdict(c)
1441 self.assertEqual(d, {'x': 'outer',
1442 'y': T(1,
1443 {'x': 'inner',
1444 'y': T(11, 12, 13)},
1445 2),
1446 }
1447 )
1448
1449 # Now with a dict_factory. OrderedDict is convenient, but
1450 # since it compares to dicts, we also need to have separate
1451 # assertIs tests.
1452 d = asdict(c, dict_factory=OrderedDict)
1453 self.assertEqual(d, {'x': 'outer',
1454 'y': T(1,
1455 {'x': 'inner',
1456 'y': T(11, 12, 13)},
1457 2),
1458 }
1459 )
1460
penguindustin96466302019-05-06 14:57:17 -04001461 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001462 self.assertIs(type(d), OrderedDict)
1463 self.assertIs(type(d['y'][1]), OrderedDict)
1464
1465 def test_helper_asdict_namedtuple_key(self):
1466 # Ensure that a field that contains a dict which has a
1467 # namedtuple as a key works with asdict().
1468
1469 @dataclass
1470 class C:
1471 f: dict
1472 T = namedtuple('T', 'a')
1473
1474 c = C({T('an a'): 0})
1475
1476 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1477
1478 def test_helper_asdict_namedtuple_derived(self):
1479 class T(namedtuple('Tbase', 'a')):
1480 def my_a(self):
1481 return self.a
1482
1483 @dataclass
1484 class C:
1485 f: T
1486
1487 t = T(6)
1488 c = C(t)
1489
1490 d = asdict(c)
1491 self.assertEqual(d, {'f': T(a=6)})
1492 # Make sure that t has been copied, not used directly.
1493 self.assertIsNot(d['f'], t)
1494 self.assertEqual(d['f'].my_a(), 6)
1495
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001496 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001497 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001498 @dataclass
1499 class C:
1500 x: int
1501 y: int = 0
1502 c = C(1)
1503
1504 self.assertEqual(astuple(c), (1, 0))
1505 self.assertEqual(astuple(c), astuple(c))
1506 self.assertIsNot(astuple(c), astuple(c))
1507 c.y = 42
1508 self.assertEqual(astuple(c), (1, 42))
1509 self.assertIs(type(astuple(c)), tuple)
1510
1511 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001512 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001513 @dataclass
1514 class C:
1515 x: int
1516 y: int
1517 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1518 astuple(C)
1519 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1520 astuple(int)
1521
1522 def test_helper_astuple_copy_values(self):
1523 @dataclass
1524 class C:
1525 x: int
1526 y: List[int] = field(default_factory=list)
1527 initial = []
1528 c = C(1, initial)
1529 t = astuple(c)
1530 self.assertEqual(t[1], initial)
1531 self.assertIsNot(t[1], initial)
1532 c = C(1)
1533 t = astuple(c)
1534 t[1].append(1)
1535 self.assertEqual(c.y, [])
1536
1537 def test_helper_astuple_nested(self):
1538 @dataclass
1539 class UserId:
1540 token: int
1541 group: int
1542 @dataclass
1543 class User:
1544 name: str
1545 id: UserId
1546 u = User('Joe', UserId(123, 1))
1547 t = astuple(u)
1548 self.assertEqual(t, ('Joe', (123, 1)))
1549 self.assertIsNot(astuple(u), astuple(u))
1550 u.id.group = 2
1551 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1552
1553 def test_helper_astuple_builtin_containers(self):
1554 @dataclass
1555 class User:
1556 name: str
1557 id: int
1558 @dataclass
1559 class GroupList:
1560 id: int
1561 users: List[User]
1562 @dataclass
1563 class GroupTuple:
1564 id: int
1565 users: Tuple[User, ...]
1566 @dataclass
1567 class GroupDict:
1568 id: int
1569 users: Dict[str, User]
1570 a = User('Alice', 1)
1571 b = User('Bob', 2)
1572 gl = GroupList(0, [a, b])
1573 gt = GroupTuple(0, (a, b))
1574 gd = GroupDict(0, {'first': a, 'second': b})
1575 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1576 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1577 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1578
Windson yangbe372d72019-04-23 02:45:34 +08001579 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001580 @dataclass
1581 class Child:
1582 d: object
1583
1584 @dataclass
1585 class Parent:
1586 child: Child
1587
1588 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1589 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1590
1591 def test_helper_astuple_factory(self):
1592 @dataclass
1593 class C:
1594 x: int
1595 y: int
1596 NT = namedtuple('NT', 'x y')
1597 def nt(lst):
1598 return NT(*lst)
1599 c = C(1, 2)
1600 t = astuple(c, tuple_factory=nt)
1601 self.assertEqual(t, NT(1, 2))
1602 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1603 c.x = 42
1604 t = astuple(c, tuple_factory=nt)
1605 self.assertEqual(t, NT(42, 2))
1606 self.assertIs(type(t), NT)
1607
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001608 def test_helper_astuple_namedtuple(self):
1609 T = namedtuple('T', 'a b c')
1610 @dataclass
1611 class C:
1612 x: str
1613 y: T
1614 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1615
1616 t = astuple(c)
1617 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1618
1619 # Now, using a tuple_factory. list is convenient here.
1620 t = astuple(c, tuple_factory=list)
1621 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1622
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001623 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001624 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001625 }
1626
1627 # Create the class.
1628 cls = type('C', (), cls_dict)
1629
1630 # Make it a dataclass.
1631 cls1 = dataclass(cls)
1632
1633 self.assertEqual(cls1, cls)
1634 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1635
1636 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001637 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001638 'y': field(default=5),
1639 }
1640
1641 # Create the class.
1642 cls = type('C', (), cls_dict)
1643
1644 # Make it a dataclass.
1645 cls1 = dataclass(cls)
1646
1647 self.assertEqual(cls1, cls)
1648 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1649
1650 def test_init_in_order(self):
1651 @dataclass
1652 class C:
1653 a: int
1654 b: int = field()
1655 c: list = field(default_factory=list, init=False)
1656 d: list = field(default_factory=list)
1657 e: int = field(default=4, init=False)
1658 f: int = 4
1659
1660 calls = []
1661 def setattr(self, name, value):
1662 calls.append((name, value))
1663
1664 C.__setattr__ = setattr
1665 c = C(0, 1)
1666 self.assertEqual(('a', 0), calls[0])
1667 self.assertEqual(('b', 1), calls[1])
1668 self.assertEqual(('c', []), calls[2])
1669 self.assertEqual(('d', []), calls[3])
1670 self.assertNotIn(('e', 4), calls)
1671 self.assertEqual(('f', 4), calls[4])
1672
1673 def test_items_in_dicts(self):
1674 @dataclass
1675 class C:
1676 a: int
1677 b: list = field(default_factory=list, init=False)
1678 c: list = field(default_factory=list)
1679 d: int = field(default=4, init=False)
1680 e: int = 0
1681
1682 c = C(0)
1683 # Class dict
1684 self.assertNotIn('a', C.__dict__)
1685 self.assertNotIn('b', C.__dict__)
1686 self.assertNotIn('c', C.__dict__)
1687 self.assertIn('d', C.__dict__)
1688 self.assertEqual(C.d, 4)
1689 self.assertIn('e', C.__dict__)
1690 self.assertEqual(C.e, 0)
1691 # Instance dict
1692 self.assertIn('a', c.__dict__)
1693 self.assertEqual(c.a, 0)
1694 self.assertIn('b', c.__dict__)
1695 self.assertEqual(c.b, [])
1696 self.assertIn('c', c.__dict__)
1697 self.assertEqual(c.c, [])
1698 self.assertNotIn('d', c.__dict__)
1699 self.assertIn('e', c.__dict__)
1700 self.assertEqual(c.e, 0)
1701
1702 def test_alternate_classmethod_constructor(self):
1703 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001704 # alternate constructor. This is mostly an example to show
1705 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001706 @dataclass
1707 class C:
1708 x: int
1709 @classmethod
1710 def from_file(cls, filename):
1711 # In a real example, create a new instance
1712 # and populate 'x' from contents of a file.
1713 value_in_file = 20
1714 return cls(value_in_file)
1715
1716 self.assertEqual(C.from_file('filename').x, 20)
1717
1718 def test_field_metadata_default(self):
1719 # Make sure the default metadata is read-only and of
1720 # zero length.
1721 @dataclass
1722 class C:
1723 i: int
1724
1725 self.assertFalse(fields(C)[0].metadata)
1726 self.assertEqual(len(fields(C)[0].metadata), 0)
1727 with self.assertRaisesRegex(TypeError,
1728 'does not support item assignment'):
1729 fields(C)[0].metadata['test'] = 3
1730
1731 def test_field_metadata_mapping(self):
1732 # Make sure only a mapping can be passed as metadata
1733 # zero length.
1734 with self.assertRaises(TypeError):
1735 @dataclass
1736 class C:
1737 i: int = field(metadata=0)
1738
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001739 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001740 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001741 @dataclass
1742 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001743 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001744 self.assertFalse(fields(C)[0].metadata)
1745 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001746 # Update should work (see bpo-35960).
1747 d['foo'] = 1
1748 self.assertEqual(len(fields(C)[0].metadata), 1)
1749 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001750 with self.assertRaisesRegex(TypeError,
1751 'does not support item assignment'):
1752 fields(C)[0].metadata['test'] = 3
1753
1754 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001755 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001756 @dataclass
1757 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001758 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001759 self.assertEqual(len(fields(C)[0].metadata), 3)
1760 self.assertEqual(fields(C)[0].metadata['test'], 10)
1761 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1762 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001763 # Update should work.
1764 d['foo'] = 1
1765 self.assertEqual(len(fields(C)[0].metadata), 4)
1766 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001767 with self.assertRaises(KeyError):
1768 # Non-existent key.
1769 fields(C)[0].metadata['baz']
1770 with self.assertRaisesRegex(TypeError,
1771 'does not support item assignment'):
1772 fields(C)[0].metadata['test'] = 3
1773
1774 def test_field_metadata_custom_mapping(self):
1775 # Try a custom mapping.
1776 class SimpleNameSpace:
1777 def __init__(self, **kw):
1778 self.__dict__.update(kw)
1779
1780 def __getitem__(self, item):
1781 if item == 'xyzzy':
1782 return 'plugh'
1783 return getattr(self, item)
1784
1785 def __len__(self):
1786 return self.__dict__.__len__()
1787
1788 @dataclass
1789 class C:
1790 i: int = field(metadata=SimpleNameSpace(a=10))
1791
1792 self.assertEqual(len(fields(C)[0].metadata), 1)
1793 self.assertEqual(fields(C)[0].metadata['a'], 10)
1794 with self.assertRaises(AttributeError):
1795 fields(C)[0].metadata['b']
1796 # Make sure we're still talking to our custom mapping.
1797 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1798
1799 def test_generic_dataclasses(self):
1800 T = TypeVar('T')
1801
1802 @dataclass
1803 class LabeledBox(Generic[T]):
1804 content: T
1805 label: str = '<unknown>'
1806
1807 box = LabeledBox(42)
1808 self.assertEqual(box.content, 42)
1809 self.assertEqual(box.label, '<unknown>')
1810
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001811 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001812 Alias = List[LabeledBox[int]]
1813
1814 def test_generic_extending(self):
1815 S = TypeVar('S')
1816 T = TypeVar('T')
1817
1818 @dataclass
1819 class Base(Generic[T, S]):
1820 x: T
1821 y: S
1822
1823 @dataclass
1824 class DataDerived(Base[int, T]):
1825 new_field: str
1826 Alias = DataDerived[str]
1827 c = Alias(0, 'test1', 'test2')
1828 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1829
1830 class NonDataDerived(Base[int, T]):
1831 def new_method(self):
1832 return self.y
1833 Alias = NonDataDerived[float]
1834 c = Alias(10, 1.0)
1835 self.assertEqual(c.new_method(), 1.0)
1836
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001837 def test_generic_dynamic(self):
1838 T = TypeVar('T')
1839
1840 @dataclass
1841 class Parent(Generic[T]):
1842 x: T
1843 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1844 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1845 self.assertIs(Child[int](1, 2).z, None)
1846 self.assertEqual(Child[int](1, 2, 3).z, 3)
1847 self.assertEqual(Child[int](1, 2, 3).other, 42)
1848 # Check that type aliases work correctly.
1849 Alias = Child[T]
1850 self.assertEqual(Alias[int](1, 2).x, 1)
1851 # Check MRO resolution.
1852 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1853
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001854 def test_dataclassses_pickleable(self):
1855 global P, Q, R
1856 @dataclass
1857 class P:
1858 x: int
1859 y: int = 0
1860 @dataclass
1861 class Q:
1862 x: int
1863 y: int = field(default=0, init=False)
1864 @dataclass
1865 class R:
1866 x: int
1867 y: List[int] = field(default_factory=list)
1868 q = Q(1)
1869 q.y = 2
1870 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1871 for sample in samples:
1872 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1873 with self.subTest(sample=sample, proto=proto):
1874 new_sample = pickle.loads(pickle.dumps(sample, proto))
1875 self.assertEqual(sample.x, new_sample.x)
1876 self.assertEqual(sample.y, new_sample.y)
1877 self.assertIsNot(sample, new_sample)
1878 new_sample.x = 42
1879 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1880 self.assertEqual(new_sample.x, another_new_sample.x)
1881 self.assertEqual(sample.y, another_new_sample.y)
1882
Eric V. Smithea8fc522018-01-27 19:07:40 -05001883
Eric V. Smith56970b82018-03-22 16:28:48 -04001884class TestFieldNoAnnotation(unittest.TestCase):
1885 def test_field_without_annotation(self):
1886 with self.assertRaisesRegex(TypeError,
1887 "'f' is a field but has no type annotation"):
1888 @dataclass
1889 class C:
1890 f = field()
1891
1892 def test_field_without_annotation_but_annotation_in_base(self):
1893 @dataclass
1894 class B:
1895 f: int
1896
1897 with self.assertRaisesRegex(TypeError,
1898 "'f' is a field but has no type annotation"):
1899 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001900 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001901 @dataclass
1902 class C(B):
1903 f = field()
1904
1905 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1906 # Same test, but with the base class not a dataclass.
1907 class B:
1908 f: int
1909
1910 with self.assertRaisesRegex(TypeError,
1911 "'f' is a field but has no type annotation"):
1912 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001913 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001914 @dataclass
1915 class C(B):
1916 f = field()
1917
1918
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001919class TestDocString(unittest.TestCase):
1920 def assertDocStrEqual(self, a, b):
1921 # Because 3.6 and 3.7 differ in how inspect.signature work
1922 # (see bpo #32108), for the time being just compare them with
1923 # whitespace stripped.
1924 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1925
1926 def test_existing_docstring_not_overridden(self):
1927 @dataclass
1928 class C:
1929 """Lorem ipsum"""
1930 x: int
1931
1932 self.assertEqual(C.__doc__, "Lorem ipsum")
1933
1934 def test_docstring_no_fields(self):
1935 @dataclass
1936 class C:
1937 pass
1938
1939 self.assertDocStrEqual(C.__doc__, "C()")
1940
1941 def test_docstring_one_field(self):
1942 @dataclass
1943 class C:
1944 x: int
1945
1946 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1947
1948 def test_docstring_two_fields(self):
1949 @dataclass
1950 class C:
1951 x: int
1952 y: int
1953
1954 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1955
1956 def test_docstring_three_fields(self):
1957 @dataclass
1958 class C:
1959 x: int
1960 y: int
1961 z: str
1962
1963 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1964
1965 def test_docstring_one_field_with_default(self):
1966 @dataclass
1967 class C:
1968 x: int = 3
1969
1970 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1971
1972 def test_docstring_one_field_with_default_none(self):
1973 @dataclass
1974 class C:
1975 x: Union[int, type(None)] = None
1976
1977 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1978
1979 def test_docstring_list_field(self):
1980 @dataclass
1981 class C:
1982 x: List[int]
1983
1984 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1985
1986 def test_docstring_list_field_with_default_factory(self):
1987 @dataclass
1988 class C:
1989 x: List[int] = field(default_factory=list)
1990
1991 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1992
1993 def test_docstring_deque_field(self):
1994 @dataclass
1995 class C:
1996 x: deque
1997
1998 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1999
2000 def test_docstring_deque_field_with_default_factory(self):
2001 @dataclass
2002 class C:
2003 x: deque = field(default_factory=deque)
2004
2005 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2006
2007
Eric V. Smithea8fc522018-01-27 19:07:40 -05002008class TestInit(unittest.TestCase):
2009 def test_base_has_init(self):
2010 class B:
2011 def __init__(self):
2012 self.z = 100
2013 pass
2014
2015 # Make sure that declaring this class doesn't raise an error.
2016 # The issue is that we can't override __init__ in our class,
2017 # but it should be okay to add __init__ to us if our base has
2018 # an __init__.
2019 @dataclass
2020 class C(B):
2021 x: int = 0
2022 c = C(10)
2023 self.assertEqual(c.x, 10)
2024 self.assertNotIn('z', vars(c))
2025
2026 # Make sure that if we don't add an init, the base __init__
2027 # gets called.
2028 @dataclass(init=False)
2029 class C(B):
2030 x: int = 10
2031 c = C()
2032 self.assertEqual(c.x, 10)
2033 self.assertEqual(c.z, 100)
2034
2035 def test_no_init(self):
2036 dataclass(init=False)
2037 class C:
2038 i: int = 0
2039 self.assertEqual(C().i, 0)
2040
2041 dataclass(init=False)
2042 class C:
2043 i: int = 2
2044 def __init__(self):
2045 self.i = 3
2046 self.assertEqual(C().i, 3)
2047
2048 def test_overwriting_init(self):
2049 # If the class has __init__, use it no matter the value of
2050 # init=.
2051
2052 @dataclass
2053 class C:
2054 x: int
2055 def __init__(self, x):
2056 self.x = 2 * x
2057 self.assertEqual(C(3).x, 6)
2058
2059 @dataclass(init=True)
2060 class C:
2061 x: int
2062 def __init__(self, x):
2063 self.x = 2 * x
2064 self.assertEqual(C(4).x, 8)
2065
2066 @dataclass(init=False)
2067 class C:
2068 x: int
2069 def __init__(self, x):
2070 self.x = 2 * x
2071 self.assertEqual(C(5).x, 10)
2072
2073
2074class TestRepr(unittest.TestCase):
2075 def test_repr(self):
2076 @dataclass
2077 class B:
2078 x: int
2079
2080 @dataclass
2081 class C(B):
2082 y: int = 10
2083
2084 o = C(4)
2085 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2086
2087 @dataclass
2088 class D(C):
2089 x: int = 20
2090 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2091
2092 @dataclass
2093 class C:
2094 @dataclass
2095 class D:
2096 i: int
2097 @dataclass
2098 class E:
2099 pass
2100 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2101 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2102
2103 def test_no_repr(self):
2104 # Test a class with no __repr__ and repr=False.
2105 @dataclass(repr=False)
2106 class C:
2107 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002108 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002109 repr(C(3)))
2110
2111 # Test a class with a __repr__ and repr=False.
2112 @dataclass(repr=False)
2113 class C:
2114 x: int
2115 def __repr__(self):
2116 return 'C-class'
2117 self.assertEqual(repr(C(3)), 'C-class')
2118
2119 def test_overwriting_repr(self):
2120 # If the class has __repr__, use it no matter the value of
2121 # repr=.
2122
2123 @dataclass
2124 class C:
2125 x: int
2126 def __repr__(self):
2127 return 'x'
2128 self.assertEqual(repr(C(0)), 'x')
2129
2130 @dataclass(repr=True)
2131 class C:
2132 x: int
2133 def __repr__(self):
2134 return 'x'
2135 self.assertEqual(repr(C(0)), 'x')
2136
2137 @dataclass(repr=False)
2138 class C:
2139 x: int
2140 def __repr__(self):
2141 return 'x'
2142 self.assertEqual(repr(C(0)), 'x')
2143
2144
Eric V. Smithea8fc522018-01-27 19:07:40 -05002145class TestEq(unittest.TestCase):
2146 def test_no_eq(self):
2147 # Test a class with no __eq__ and eq=False.
2148 @dataclass(eq=False)
2149 class C:
2150 x: int
2151 self.assertNotEqual(C(0), C(0))
2152 c = C(3)
2153 self.assertEqual(c, c)
2154
2155 # Test a class with an __eq__ and eq=False.
2156 @dataclass(eq=False)
2157 class C:
2158 x: int
2159 def __eq__(self, other):
2160 return other == 10
2161 self.assertEqual(C(3), 10)
2162
2163 def test_overwriting_eq(self):
2164 # If the class has __eq__, use it no matter the value of
2165 # eq=.
2166
2167 @dataclass
2168 class C:
2169 x: int
2170 def __eq__(self, other):
2171 return other == 3
2172 self.assertEqual(C(1), 3)
2173 self.assertNotEqual(C(1), 1)
2174
2175 @dataclass(eq=True)
2176 class C:
2177 x: int
2178 def __eq__(self, other):
2179 return other == 4
2180 self.assertEqual(C(1), 4)
2181 self.assertNotEqual(C(1), 1)
2182
2183 @dataclass(eq=False)
2184 class C:
2185 x: int
2186 def __eq__(self, other):
2187 return other == 5
2188 self.assertEqual(C(1), 5)
2189 self.assertNotEqual(C(1), 1)
2190
2191
2192class TestOrdering(unittest.TestCase):
2193 def test_functools_total_ordering(self):
2194 # Test that functools.total_ordering works with this class.
2195 @total_ordering
2196 @dataclass
2197 class C:
2198 x: int
2199 def __lt__(self, other):
2200 # Perform the test "backward", just to make
2201 # sure this is being called.
2202 return self.x >= other
2203
2204 self.assertLess(C(0), -1)
2205 self.assertLessEqual(C(0), -1)
2206 self.assertGreater(C(0), 1)
2207 self.assertGreaterEqual(C(0), 1)
2208
2209 def test_no_order(self):
2210 # Test that no ordering functions are added by default.
2211 @dataclass(order=False)
2212 class C:
2213 x: int
2214 # Make sure no order methods are added.
2215 self.assertNotIn('__le__', C.__dict__)
2216 self.assertNotIn('__lt__', C.__dict__)
2217 self.assertNotIn('__ge__', C.__dict__)
2218 self.assertNotIn('__gt__', C.__dict__)
2219
2220 # Test that __lt__ is still called
2221 @dataclass(order=False)
2222 class C:
2223 x: int
2224 def __lt__(self, other):
2225 return False
2226 # Make sure other methods aren't added.
2227 self.assertNotIn('__le__', C.__dict__)
2228 self.assertNotIn('__ge__', C.__dict__)
2229 self.assertNotIn('__gt__', C.__dict__)
2230
2231 def test_overwriting_order(self):
2232 with self.assertRaisesRegex(TypeError,
2233 'Cannot overwrite attribute __lt__'
2234 '.*using functools.total_ordering'):
2235 @dataclass(order=True)
2236 class C:
2237 x: int
2238 def __lt__(self):
2239 pass
2240
2241 with self.assertRaisesRegex(TypeError,
2242 'Cannot overwrite attribute __le__'
2243 '.*using functools.total_ordering'):
2244 @dataclass(order=True)
2245 class C:
2246 x: int
2247 def __le__(self):
2248 pass
2249
2250 with self.assertRaisesRegex(TypeError,
2251 'Cannot overwrite attribute __gt__'
2252 '.*using functools.total_ordering'):
2253 @dataclass(order=True)
2254 class C:
2255 x: int
2256 def __gt__(self):
2257 pass
2258
2259 with self.assertRaisesRegex(TypeError,
2260 'Cannot overwrite attribute __ge__'
2261 '.*using functools.total_ordering'):
2262 @dataclass(order=True)
2263 class C:
2264 x: int
2265 def __ge__(self):
2266 pass
2267
2268class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002269 def test_unsafe_hash(self):
2270 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002271 class C:
2272 x: int
2273 y: str
2274 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2275
Eric V. Smithea8fc522018-01-27 19:07:40 -05002276 def test_hash_rules(self):
2277 def non_bool(value):
2278 # Map to something else that's True, but not a bool.
2279 if value is None:
2280 return None
2281 if value:
2282 return (3,)
2283 return 0
2284
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002285 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2286 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2287 frozen=frozen):
2288 if result != 'exception':
2289 if with_hash:
2290 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2291 class C:
2292 def __hash__(self):
2293 return 0
2294 else:
2295 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2296 class C:
2297 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002298
2299 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002300 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002301 # __hash__ contains the function we generated.
2302 self.assertIn('__hash__', C.__dict__)
2303 self.assertIsNotNone(C.__dict__['__hash__'])
2304
Eric V. Smithea8fc522018-01-27 19:07:40 -05002305 elif result == '':
2306 # __hash__ is not present in our class.
2307 if not with_hash:
2308 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002309
Eric V. Smithea8fc522018-01-27 19:07:40 -05002310 elif result == 'none':
2311 # __hash__ is set to None.
2312 self.assertIn('__hash__', C.__dict__)
2313 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002314
2315 elif result == 'exception':
2316 # Creating the class should cause an exception.
2317 # This only happens with with_hash==True.
2318 assert(with_hash)
2319 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2320 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2321 class C:
2322 def __hash__(self):
2323 return 0
2324
Eric V. Smithea8fc522018-01-27 19:07:40 -05002325 else:
2326 assert False, f'unknown result {result!r}'
2327
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002328 # There are 8 cases of:
2329 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002330 # eq=True/False
2331 # frozen=True/False
2332 # And for each of these, a different result if
2333 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002334 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2335 (False, False, False, '', ''),
2336 (False, False, True, '', ''),
2337 (False, True, False, 'none', ''),
2338 (False, True, True, 'fn', ''),
2339 (True, False, False, 'fn', 'exception'),
2340 (True, False, True, 'fn', 'exception'),
2341 (True, True, False, 'fn', 'exception'),
2342 (True, True, True, 'fn', 'exception'),
2343 ], 1):
2344 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2345 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002346
2347 # Test non-bool truth values, too. This is just to
2348 # make sure the data-driven table in the decorator
2349 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002350 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2351 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002352
2353
2354 def test_eq_only(self):
2355 # If a class defines __eq__, __hash__ is automatically added
2356 # and set to None. This is normal Python behavior, not
2357 # related to dataclasses. Make sure we don't interfere with
2358 # that (see bpo=32546).
2359
2360 @dataclass
2361 class C:
2362 i: int
2363 def __eq__(self, other):
2364 return self.i == other.i
2365 self.assertEqual(C(1), C(1))
2366 self.assertNotEqual(C(1), C(4))
2367
2368 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002369 # unsafe_hash=True.
2370 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002371 class C:
2372 i: int
2373 def __eq__(self, other):
2374 return self.i == other.i
2375 self.assertEqual(C(1), C(1.0))
2376 self.assertEqual(hash(C(1)), hash(C(1.0)))
2377
2378 # And check that the classes __eq__ is being used, despite
2379 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002380 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002381 class C:
2382 i: int
2383 def __eq__(self, other):
2384 return self.i == 3 and self.i == other.i
2385 self.assertEqual(C(3), C(3))
2386 self.assertNotEqual(C(1), C(1))
2387 self.assertEqual(hash(C(1)), hash(C(1.0)))
2388
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002389 def test_0_field_hash(self):
2390 @dataclass(frozen=True)
2391 class C:
2392 pass
2393 self.assertEqual(hash(C()), hash(()))
2394
2395 @dataclass(unsafe_hash=True)
2396 class C:
2397 pass
2398 self.assertEqual(hash(C()), hash(()))
2399
2400 def test_1_field_hash(self):
2401 @dataclass(frozen=True)
2402 class C:
2403 x: int
2404 self.assertEqual(hash(C(4)), hash((4,)))
2405 self.assertEqual(hash(C(42)), hash((42,)))
2406
2407 @dataclass(unsafe_hash=True)
2408 class C:
2409 x: int
2410 self.assertEqual(hash(C(4)), hash((4,)))
2411 self.assertEqual(hash(C(42)), hash((42,)))
2412
Eric V. Smith718070d2018-02-23 13:01:31 -05002413 def test_hash_no_args(self):
2414 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002415 # make sure that if the @dataclass parameter name is changed
2416 # or the non-default hashing behavior changes, the default
2417 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002418
2419 class Base:
2420 def __hash__(self):
2421 return 301
2422
2423 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002424 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002425 for frozen, eq, base, expected in [
2426 (None, None, object, 'unhashable'),
2427 (None, None, Base, 'unhashable'),
2428 (None, False, object, 'object'),
2429 (None, False, Base, 'base'),
2430 (None, True, object, 'unhashable'),
2431 (None, True, Base, 'unhashable'),
2432 (False, None, object, 'unhashable'),
2433 (False, None, Base, 'unhashable'),
2434 (False, False, object, 'object'),
2435 (False, False, Base, 'base'),
2436 (False, True, object, 'unhashable'),
2437 (False, True, Base, 'unhashable'),
2438 (True, None, object, 'tuple'),
2439 (True, None, Base, 'tuple'),
2440 (True, False, object, 'object'),
2441 (True, False, Base, 'base'),
2442 (True, True, object, 'tuple'),
2443 (True, True, Base, 'tuple'),
2444 ]:
2445
2446 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2447 # First, create the class.
2448 if frozen is None and eq is None:
2449 @dataclass
2450 class C(base):
2451 i: int
2452 elif frozen is None:
2453 @dataclass(eq=eq)
2454 class C(base):
2455 i: int
2456 elif eq is None:
2457 @dataclass(frozen=frozen)
2458 class C(base):
2459 i: int
2460 else:
2461 @dataclass(frozen=frozen, eq=eq)
2462 class C(base):
2463 i: int
2464
2465 # Now, make sure it hashes as expected.
2466 if expected == 'unhashable':
2467 c = C(10)
2468 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2469 hash(c)
2470
2471 elif expected == 'base':
2472 self.assertEqual(hash(C(10)), 301)
2473
2474 elif expected == 'object':
2475 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002476 # hash isn't based on id(), so calling hash()
2477 # won't tell us much. So, just check the
2478 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002479 self.assertIs(C.__hash__, object.__hash__)
2480
2481 elif expected == 'tuple':
2482 self.assertEqual(hash(C(42)), hash((42,)))
2483
2484 else:
2485 assert False, f'unknown value for expected={expected!r}'
2486
Eric V. Smithea8fc522018-01-27 19:07:40 -05002487
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002488class TestFrozen(unittest.TestCase):
2489 def test_frozen(self):
2490 @dataclass(frozen=True)
2491 class C:
2492 i: int
2493
2494 c = C(10)
2495 self.assertEqual(c.i, 10)
2496 with self.assertRaises(FrozenInstanceError):
2497 c.i = 5
2498 self.assertEqual(c.i, 10)
2499
2500 def test_inherit(self):
2501 @dataclass(frozen=True)
2502 class C:
2503 i: int
2504
2505 @dataclass(frozen=True)
2506 class D(C):
2507 j: int
2508
2509 d = D(0, 10)
2510 with self.assertRaises(FrozenInstanceError):
2511 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002512 with self.assertRaises(FrozenInstanceError):
2513 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002514 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002515 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002516
Eric V. Smithf199bc62018-03-18 20:40:34 -04002517 # Test both ways: with an intermediate normal (non-dataclass)
2518 # class and without an intermediate class.
2519 def test_inherit_nonfrozen_from_frozen(self):
2520 for intermediate_class in [True, False]:
2521 with self.subTest(intermediate_class=intermediate_class):
2522 @dataclass(frozen=True)
2523 class C:
2524 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002525
Eric V. Smithf199bc62018-03-18 20:40:34 -04002526 if intermediate_class:
2527 class I(C): pass
2528 else:
2529 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002530
Eric V. Smithf199bc62018-03-18 20:40:34 -04002531 with self.assertRaisesRegex(TypeError,
2532 'cannot inherit non-frozen dataclass from a frozen one'):
2533 @dataclass
2534 class D(I):
2535 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002536
Eric V. Smithf199bc62018-03-18 20:40:34 -04002537 def test_inherit_frozen_from_nonfrozen(self):
2538 for intermediate_class in [True, False]:
2539 with self.subTest(intermediate_class=intermediate_class):
2540 @dataclass
2541 class C:
2542 i: int
2543
2544 if intermediate_class:
2545 class I(C): pass
2546 else:
2547 I = C
2548
2549 with self.assertRaisesRegex(TypeError,
2550 'cannot inherit frozen dataclass from a non-frozen one'):
2551 @dataclass(frozen=True)
2552 class D(I):
2553 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002554
2555 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002556 for intermediate_class in [True, False]:
2557 with self.subTest(intermediate_class=intermediate_class):
2558 class C:
2559 pass
2560
2561 if intermediate_class:
2562 class I(C): pass
2563 else:
2564 I = C
2565
2566 @dataclass(frozen=True)
2567 class D(I):
2568 i: int
2569
2570 d = D(10)
2571 with self.assertRaises(FrozenInstanceError):
2572 d.i = 5
2573
2574 def test_non_frozen_normal_derived(self):
2575 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002576
2577 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002578 class D:
2579 x: int
2580 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002581
Eric V. Smithf199bc62018-03-18 20:40:34 -04002582 class S(D):
2583 pass
2584
2585 s = S(3)
2586 self.assertEqual(s.x, 3)
2587 self.assertEqual(s.y, 10)
2588 s.cached = True
2589
2590 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002591 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002592 s.x = 5
2593 with self.assertRaises(FrozenInstanceError):
2594 s.y = 5
2595 self.assertEqual(s.x, 3)
2596 self.assertEqual(s.y, 10)
2597 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002598
Eric V. Smith74940912018-04-05 06:50:18 -04002599 def test_overwriting_frozen(self):
2600 # frozen uses __setattr__ and __delattr__.
2601 with self.assertRaisesRegex(TypeError,
2602 'Cannot overwrite attribute __setattr__'):
2603 @dataclass(frozen=True)
2604 class C:
2605 x: int
2606 def __setattr__(self):
2607 pass
2608
2609 with self.assertRaisesRegex(TypeError,
2610 'Cannot overwrite attribute __delattr__'):
2611 @dataclass(frozen=True)
2612 class C:
2613 x: int
2614 def __delattr__(self):
2615 pass
2616
2617 @dataclass(frozen=False)
2618 class C:
2619 x: int
2620 def __setattr__(self, name, value):
2621 self.__dict__['x'] = value * 2
2622 self.assertEqual(C(10).x, 20)
2623
2624 def test_frozen_hash(self):
2625 @dataclass(frozen=True)
2626 class C:
2627 x: Any
2628
2629 # If x is immutable, we can compute the hash. No exception is
2630 # raised.
2631 hash(C(3))
2632
2633 # If x is mutable, computing the hash is an error.
2634 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2635 hash(C({}))
2636
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002637
Eric V. Smith7389fd92018-03-19 21:07:51 -04002638class TestSlots(unittest.TestCase):
2639 def test_simple(self):
2640 @dataclass
2641 class C:
2642 __slots__ = ('x',)
2643 x: Any
2644
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002645 # There was a bug where a variable in a slot was assumed to
2646 # also have a default value (of type
2647 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002648 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002649 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002650 C()
2651
2652 # We can create an instance, and assign to x.
2653 c = C(10)
2654 self.assertEqual(c.x, 10)
2655 c.x = 5
2656 self.assertEqual(c.x, 5)
2657
2658 # We can't assign to anything else.
2659 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2660 c.y = 5
2661
2662 def test_derived_added_field(self):
2663 # See bpo-33100.
2664 @dataclass
2665 class Base:
2666 __slots__ = ('x',)
2667 x: Any
2668
2669 @dataclass
2670 class Derived(Base):
2671 x: int
2672 y: int
2673
2674 d = Derived(1, 2)
2675 self.assertEqual((d.x, d.y), (1, 2))
2676
2677 # We can add a new field to the derived instance.
2678 d.z = 10
2679
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002680class TestDescriptors(unittest.TestCase):
2681 def test_set_name(self):
2682 # See bpo-33141.
2683
2684 # Create a descriptor.
2685 class D:
2686 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002687 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002688 def __get__(self, instance, owner):
2689 if instance is not None:
2690 return 1
2691 return self
2692
2693 # This is the case of just normal descriptor behavior, no
2694 # dataclass code is involved in initializing the descriptor.
2695 @dataclass
2696 class C:
2697 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002698 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002699
2700 # Now test with a default value and init=False, which is the
2701 # only time this is really meaningful. If not using
2702 # init=False, then the descriptor will be overwritten, anyway.
2703 @dataclass
2704 class C:
2705 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002706 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002707 self.assertEqual(C().c, 1)
2708
2709 def test_non_descriptor(self):
2710 # PEP 487 says __set_name__ should work on non-descriptors.
2711 # Create a descriptor.
2712
2713 class D:
2714 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002715 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002716
2717 @dataclass
2718 class C:
2719 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002720 self.assertEqual(C.c.name, 'cx')
2721
2722 def test_lookup_on_instance(self):
2723 # See bpo-33175.
2724 class D:
2725 pass
2726
2727 d = D()
2728 # Create an attribute on the instance, not type.
2729 d.__set_name__ = Mock()
2730
2731 # Make sure d.__set_name__ is not called.
2732 @dataclass
2733 class C:
2734 i: int=field(default=d, init=False)
2735
2736 self.assertEqual(d.__set_name__.call_count, 0)
2737
2738 def test_lookup_on_class(self):
2739 # See bpo-33175.
2740 class D:
2741 pass
2742 D.__set_name__ = Mock()
2743
2744 # Make sure D.__set_name__ is called.
2745 @dataclass
2746 class C:
2747 i: int=field(default=D(), init=False)
2748
2749 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002750
Eric V. Smith7389fd92018-03-19 21:07:51 -04002751
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002752class TestStringAnnotations(unittest.TestCase):
2753 def test_classvar(self):
2754 # Some expressions recognized as ClassVar really aren't. But
2755 # if you're using string annotations, it's not an exact
2756 # science.
2757 # These tests assume that both "import typing" and "from
2758 # typing import *" have been run in this file.
2759 for typestr in ('ClassVar[int]',
2760 'ClassVar [int]'
2761 ' ClassVar [int]',
2762 'ClassVar',
2763 ' ClassVar ',
2764 'typing.ClassVar[int]',
2765 'typing.ClassVar[str]',
2766 ' typing.ClassVar[str]',
2767 'typing .ClassVar[str]',
2768 'typing. ClassVar[str]',
2769 'typing.ClassVar [str]',
2770 'typing.ClassVar [ str]',
2771
2772 # Not syntactically valid, but these will
2773 # be treated as ClassVars.
2774 'typing.ClassVar.[int]',
2775 'typing.ClassVar+',
2776 ):
2777 with self.subTest(typestr=typestr):
2778 @dataclass
2779 class C:
2780 x: typestr
2781
2782 # x is a ClassVar, so C() takes no args.
2783 C()
2784
2785 # And it won't appear in the class's dict because it doesn't
2786 # have a default.
2787 self.assertNotIn('x', C.__dict__)
2788
2789 def test_isnt_classvar(self):
2790 for typestr in ('CV',
2791 't.ClassVar',
2792 't.ClassVar[int]',
2793 'typing..ClassVar[int]',
2794 'Classvar',
2795 'Classvar[int]',
2796 'typing.ClassVarx[int]',
2797 'typong.ClassVar[int]',
2798 'dataclasses.ClassVar[int]',
2799 'typingxClassVar[str]',
2800 ):
2801 with self.subTest(typestr=typestr):
2802 @dataclass
2803 class C:
2804 x: typestr
2805
2806 # x is not a ClassVar, so C() takes one arg.
2807 self.assertEqual(C(10).x, 10)
2808
2809 def test_initvar(self):
2810 # These tests assume that both "import dataclasses" and "from
2811 # dataclasses import *" have been run in this file.
2812 for typestr in ('InitVar[int]',
2813 'InitVar [int]'
2814 ' InitVar [int]',
2815 'InitVar',
2816 ' InitVar ',
2817 'dataclasses.InitVar[int]',
2818 'dataclasses.InitVar[str]',
2819 ' dataclasses.InitVar[str]',
2820 'dataclasses .InitVar[str]',
2821 'dataclasses. InitVar[str]',
2822 'dataclasses.InitVar [str]',
2823 'dataclasses.InitVar [ str]',
2824
2825 # Not syntactically valid, but these will
2826 # be treated as InitVars.
2827 'dataclasses.InitVar.[int]',
2828 'dataclasses.InitVar+',
2829 ):
2830 with self.subTest(typestr=typestr):
2831 @dataclass
2832 class C:
2833 x: typestr
2834
2835 # x is an InitVar, so doesn't create a member.
2836 with self.assertRaisesRegex(AttributeError,
2837 "object has no attribute 'x'"):
2838 C(1).x
2839
2840 def test_isnt_initvar(self):
2841 for typestr in ('IV',
2842 'dc.InitVar',
2843 'xdataclasses.xInitVar',
2844 'typing.xInitVar[int]',
2845 ):
2846 with self.subTest(typestr=typestr):
2847 @dataclass
2848 class C:
2849 x: typestr
2850
2851 # x is not an InitVar, so there will be a member x.
2852 self.assertEqual(C(10).x, 10)
2853
2854 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002855 from test import dataclass_module_1
2856 from test import dataclass_module_1_str
2857 from test import dataclass_module_2
2858 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002859
2860 for m in (dataclass_module_1, dataclass_module_1_str,
2861 dataclass_module_2, dataclass_module_2_str,
2862 ):
2863 with self.subTest(m=m):
2864 # There's a difference in how the ClassVars are
2865 # interpreted when using string annotations or
2866 # not. See the imported modules for details.
2867 if m.USING_STRINGS:
2868 c = m.CV(10)
2869 else:
2870 c = m.CV()
2871 self.assertEqual(c.cv0, 20)
2872
2873
2874 # There's a difference in how the InitVars are
2875 # interpreted when using string annotations or
2876 # not. See the imported modules for details.
2877 c = m.IV(0, 1, 2, 3, 4)
2878
2879 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2880 with self.subTest(field_name=field_name):
2881 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2882 # Since field_name is an InitVar, it's
2883 # not an instance field.
2884 getattr(c, field_name)
2885
2886 if m.USING_STRINGS:
2887 # iv4 is interpreted as a normal field.
2888 self.assertIn('not_iv4', c.__dict__)
2889 self.assertEqual(c.not_iv4, 4)
2890 else:
2891 # iv4 is interpreted as an InitVar, so it
2892 # won't exist on the instance.
2893 self.assertNotIn('not_iv4', c.__dict__)
2894
2895
Eric V. Smith4e812962018-05-16 11:31:29 -04002896class TestMakeDataclass(unittest.TestCase):
2897 def test_simple(self):
2898 C = make_dataclass('C',
2899 [('x', int),
2900 ('y', int, field(default=5))],
2901 namespace={'add_one': lambda self: self.x + 1})
2902 c = C(10)
2903 self.assertEqual((c.x, c.y), (10, 5))
2904 self.assertEqual(c.add_one(), 11)
2905
2906
2907 def test_no_mutate_namespace(self):
2908 # Make sure a provided namespace isn't mutated.
2909 ns = {}
2910 C = make_dataclass('C',
2911 [('x', int),
2912 ('y', int, field(default=5))],
2913 namespace=ns)
2914 self.assertEqual(ns, {})
2915
2916 def test_base(self):
2917 class Base1:
2918 pass
2919 class Base2:
2920 pass
2921 C = make_dataclass('C',
2922 [('x', int)],
2923 bases=(Base1, Base2))
2924 c = C(2)
2925 self.assertIsInstance(c, C)
2926 self.assertIsInstance(c, Base1)
2927 self.assertIsInstance(c, Base2)
2928
2929 def test_base_dataclass(self):
2930 @dataclass
2931 class Base1:
2932 x: int
2933 class Base2:
2934 pass
2935 C = make_dataclass('C',
2936 [('y', int)],
2937 bases=(Base1, Base2))
2938 with self.assertRaisesRegex(TypeError, 'required positional'):
2939 c = C(2)
2940 c = C(1, 2)
2941 self.assertIsInstance(c, C)
2942 self.assertIsInstance(c, Base1)
2943 self.assertIsInstance(c, Base2)
2944
2945 self.assertEqual((c.x, c.y), (1, 2))
2946
2947 def test_init_var(self):
2948 def post_init(self, y):
2949 self.x *= y
2950
2951 C = make_dataclass('C',
2952 [('x', int),
2953 ('y', InitVar[int]),
2954 ],
2955 namespace={'__post_init__': post_init},
2956 )
2957 c = C(2, 3)
2958 self.assertEqual(vars(c), {'x': 6})
2959 self.assertEqual(len(fields(c)), 1)
2960
2961 def test_class_var(self):
2962 C = make_dataclass('C',
2963 [('x', int),
2964 ('y', ClassVar[int], 10),
2965 ('z', ClassVar[int], field(default=20)),
2966 ])
2967 c = C(1)
2968 self.assertEqual(vars(c), {'x': 1})
2969 self.assertEqual(len(fields(c)), 1)
2970 self.assertEqual(C.y, 10)
2971 self.assertEqual(C.z, 20)
2972
2973 def test_other_params(self):
2974 C = make_dataclass('C',
2975 [('x', int),
2976 ('y', ClassVar[int], 10),
2977 ('z', ClassVar[int], field(default=20)),
2978 ],
2979 init=False)
2980 # Make sure we have a repr, but no init.
2981 self.assertNotIn('__init__', vars(C))
2982 self.assertIn('__repr__', vars(C))
2983
2984 # Make sure random other params don't work.
2985 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2986 C = make_dataclass('C',
2987 [],
2988 xxinit=False)
2989
2990 def test_no_types(self):
2991 C = make_dataclass('Point', ['x', 'y', 'z'])
2992 c = C(1, 2, 3)
2993 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2994 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2995 'y': 'typing.Any',
2996 'z': 'typing.Any'})
2997
2998 C = make_dataclass('Point', ['x', ('y', int), 'z'])
2999 c = C(1, 2, 3)
3000 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3001 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3002 'y': int,
3003 'z': 'typing.Any'})
3004
3005 def test_invalid_type_specification(self):
3006 for bad_field in [(),
3007 (1, 2, 3, 4),
3008 ]:
3009 with self.subTest(bad_field=bad_field):
3010 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3011 make_dataclass('C', ['a', bad_field])
3012
3013 # And test for things with no len().
3014 for bad_field in [float,
3015 lambda x:x,
3016 ]:
3017 with self.subTest(bad_field=bad_field):
3018 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3019 make_dataclass('C', ['a', bad_field])
3020
3021 def test_duplicate_field_names(self):
3022 for field in ['a', 'ab']:
3023 with self.subTest(field=field):
3024 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3025 make_dataclass('C', [field, 'a', field])
3026
3027 def test_keyword_field_names(self):
3028 for field in ['for', 'async', 'await', 'as']:
3029 with self.subTest(field=field):
3030 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3031 make_dataclass('C', ['a', field])
3032 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3033 make_dataclass('C', [field])
3034 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3035 make_dataclass('C', [field, 'a'])
3036
3037 def test_non_identifier_field_names(self):
3038 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3039 with self.subTest(field=field):
3040 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3041 make_dataclass('C', ['a', field])
3042 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3043 make_dataclass('C', [field])
3044 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3045 make_dataclass('C', [field, 'a'])
3046
3047 def test_underscore_field_names(self):
3048 # Unlike namedtuple, it's okay if dataclass field names have
3049 # an underscore.
3050 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3051
3052 def test_funny_class_names_names(self):
3053 # No reason to prevent weird class names, since
3054 # types.new_class allows them.
3055 for classname in ['()', 'x,y', '*', '2@3', '']:
3056 with self.subTest(classname=classname):
3057 C = make_dataclass(classname, ['a', 'b'])
3058 self.assertEqual(C.__name__, classname)
3059
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003060class TestReplace(unittest.TestCase):
3061 def test(self):
3062 @dataclass(frozen=True)
3063 class C:
3064 x: int
3065 y: int
3066
3067 c = C(1, 2)
3068 c1 = replace(c, x=3)
3069 self.assertEqual(c1.x, 3)
3070 self.assertEqual(c1.y, 2)
3071
3072 def test_frozen(self):
3073 @dataclass(frozen=True)
3074 class C:
3075 x: int
3076 y: int
3077 z: int = field(init=False, default=10)
3078 t: int = field(init=False, default=100)
3079
3080 c = C(1, 2)
3081 c1 = replace(c, x=3)
3082 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3083 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3084
3085
3086 with self.assertRaisesRegex(ValueError, 'init=False'):
3087 replace(c, x=3, z=20, t=50)
3088 with self.assertRaisesRegex(ValueError, 'init=False'):
3089 replace(c, z=20)
3090 replace(c, x=3, z=20, t=50)
3091
3092 # Make sure the result is still frozen.
3093 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3094 c1.x = 3
3095
3096 # Make sure we can't replace an attribute that doesn't exist,
3097 # if we're also replacing one that does exist. Test this
3098 # here, because setting attributes on frozen instances is
3099 # handled slightly differently from non-frozen ones.
3100 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3101 "keyword argument 'a'"):
3102 c1 = replace(c, x=20, a=5)
3103
3104 def test_invalid_field_name(self):
3105 @dataclass(frozen=True)
3106 class C:
3107 x: int
3108 y: int
3109
3110 c = C(1, 2)
3111 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3112 "keyword argument 'z'"):
3113 c1 = replace(c, z=3)
3114
3115 def test_invalid_object(self):
3116 @dataclass(frozen=True)
3117 class C:
3118 x: int
3119 y: int
3120
3121 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3122 replace(C, x=3)
3123
3124 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3125 replace(0, x=3)
3126
3127 def test_no_init(self):
3128 @dataclass
3129 class C:
3130 x: int
3131 y: int = field(init=False, default=10)
3132
3133 c = C(1)
3134 c.y = 20
3135
3136 # Make sure y gets the default value.
3137 c1 = replace(c, x=5)
3138 self.assertEqual((c1.x, c1.y), (5, 10))
3139
3140 # Trying to replace y is an error.
3141 with self.assertRaisesRegex(ValueError, 'init=False'):
3142 replace(c, x=2, y=30)
3143
3144 with self.assertRaisesRegex(ValueError, 'init=False'):
3145 replace(c, y=30)
3146
3147 def test_classvar(self):
3148 @dataclass
3149 class C:
3150 x: int
3151 y: ClassVar[int] = 1000
3152
3153 c = C(1)
3154 d = C(2)
3155
3156 self.assertIs(c.y, d.y)
3157 self.assertEqual(c.y, 1000)
3158
3159 # Trying to replace y is an error: can't replace ClassVars.
3160 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3161 "unexpected keyword argument 'y'"):
3162 replace(c, y=30)
3163
3164 replace(c, x=5)
3165
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003166 def test_initvar_is_specified(self):
3167 @dataclass
3168 class C:
3169 x: int
3170 y: InitVar[int]
3171
3172 def __post_init__(self, y):
3173 self.x *= y
3174
3175 c = C(1, 10)
3176 self.assertEqual(c.x, 10)
3177 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3178 "specified with replace()"):
3179 replace(c, x=3)
3180 c = replace(c, x=3, y=5)
3181 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303182
3183 def test_recursive_repr(self):
3184 @dataclass
3185 class C:
3186 f: "C"
3187
3188 c = C(None)
3189 c.f = c
3190 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3191
3192 def test_recursive_repr_two_attrs(self):
3193 @dataclass
3194 class C:
3195 f: "C"
3196 g: "C"
3197
3198 c = C(None, None)
3199 c.f = c
3200 c.g = c
3201 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3202 ".<locals>.C(f=..., g=...)")
3203
3204 def test_recursive_repr_indirection(self):
3205 @dataclass
3206 class C:
3207 f: "D"
3208
3209 @dataclass
3210 class D:
3211 f: "C"
3212
3213 c = C(None)
3214 d = D(None)
3215 c.f = d
3216 d.f = c
3217 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3218 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3219 ".<locals>.D(f=...))")
3220
3221 def test_recursive_repr_indirection_two(self):
3222 @dataclass
3223 class C:
3224 f: "D"
3225
3226 @dataclass
3227 class D:
3228 f: "E"
3229
3230 @dataclass
3231 class E:
3232 f: "C"
3233
3234 c = C(None)
3235 d = D(None)
3236 e = E(None)
3237 c.f = d
3238 d.f = e
3239 e.f = c
3240 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3241 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3242 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3243 ".<locals>.E(f=...)))")
3244
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303245 def test_recursive_repr_misc_attrs(self):
3246 @dataclass
3247 class C:
3248 f: "C"
3249 g: int
3250
3251 c = C(None, 1)
3252 c.f = c
3253 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3254 ".<locals>.C(f=..., g=1)")
3255
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003256 ## def test_initvar(self):
3257 ## @dataclass
3258 ## class C:
3259 ## x: int
3260 ## y: InitVar[int]
3261
3262 ## c = C(1, 10)
3263 ## d = C(2, 20)
3264
3265 ## # In our case, replacing an InitVar is a no-op
3266 ## self.assertEqual(c, replace(c, y=5))
3267
3268 ## replace(c, x=5)
3269
Eric V. Smith4e812962018-05-16 11:31:29 -04003270
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003271if __name__ == '__main__':
3272 unittest.main()