blob: 53e8443c2adf17f1b461704a5e3a25759cf62e16 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +03009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050014from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040016import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
17import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
18
Eric V. Smithf0db54a2017-12-04 16:58:55 -050019# Just any custom exception we can catch.
20class CustomError(Exception): pass
21
22class TestCase(unittest.TestCase):
23 def test_no_fields(self):
24 @dataclass
25 class C:
26 pass
27
28 o = C()
29 self.assertEqual(len(fields(C)), 0)
30
Eric V. Smith56970b82018-03-22 16:28:48 -040031 def test_no_fields_but_member_variable(self):
32 @dataclass
33 class C:
34 i = 0
35
36 o = C()
37 self.assertEqual(len(fields(C)), 0)
38
Eric V. Smithf0db54a2017-12-04 16:58:55 -050039 def test_one_field_no_default(self):
40 @dataclass
41 class C:
42 x: int
43
44 o = C(42)
45 self.assertEqual(o.x, 42)
46
47 def test_named_init_params(self):
48 @dataclass
49 class C:
50 x: int
51
52 o = C(x=32)
53 self.assertEqual(o.x, 32)
54
55 def test_two_fields_one_default(self):
56 @dataclass
57 class C:
58 x: int
59 y: int = 0
60
61 o = C(3)
62 self.assertEqual((o.x, o.y), (3, 0))
63
64 # Non-defaults following defaults.
65 with self.assertRaisesRegex(TypeError,
66 "non-default argument 'y' follows "
67 "default argument"):
68 @dataclass
69 class C:
70 x: int = 0
71 y: int
72
73 # A derived class adds a non-default field after a default one.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int = 0
80
81 @dataclass
82 class C(B):
83 y: int
84
85 # Override a base class field and add a default to
86 # a field which didn't use to have a default.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class B:
92 x: int
93 y: int
94
95 @dataclass
96 class C(B):
97 x: int = 0
98
Eric V. Smithdbf9cff2018-02-25 21:30:17 -050099 def test_overwrite_hash(self):
100 # Test that declaring this class isn't an error. It should
101 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500102 @dataclass(frozen=True)
103 class C:
104 x: int
105 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500106 return 301
107 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500108
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500109 # Test that declaring this class isn't an error. It should
110 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500111 @dataclass(frozen=True)
112 class C:
113 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500114 def __eq__(self, other):
115 return False
116 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500117
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500118 # But this one should generate an exception, because with
119 # unsafe_hash=True, it's an error to have a __hash__ defined.
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __hash__'):
122 @dataclass(unsafe_hash=True)
123 class C:
124 def __hash__(self):
125 pass
126
127 # Creating this class should not generate an exception,
128 # because even though __hash__ exists before @dataclass is
129 # called, (due to __eq__ being defined), since it's None
130 # that's okay.
131 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500132 class C:
133 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500134 def __eq__(self):
135 pass
136 # The generated hash function works as we'd expect.
137 self.assertEqual(hash(C(10)), hash((10,)))
138
139 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400140 # __hash__ exists and is not None, which it would be if it
141 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 x: int
147 def __eq__(self):
148 pass
149 def __hash__(self):
150 pass
151
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500152 def test_overwrite_fields_in_derived_class(self):
153 # Note that x from C1 replaces x in Base, but the order remains
154 # the same as defined in Base.
155 @dataclass
156 class Base:
157 x: Any = 15.0
158 y: int = 0
159
160 @dataclass
161 class C1(Base):
162 z: int = 10
163 x: int = 15
164
165 o = Base()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
167
168 o = C1()
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
170
171 o = C1(x=5)
172 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
173
174 def test_field_named_self(self):
175 @dataclass
176 class C:
177 self: str
178 c=C('foo')
179 self.assertEqual(c.self, 'foo')
180
181 # Make sure the first parameter is not named 'self'.
182 sig = inspect.signature(C.__init__)
183 first = next(iter(sig.parameters))
184 self.assertNotEqual('self', first)
185
186 # But we do use 'self' if no field named self.
187 @dataclass
188 class C:
189 selfx: str
190
191 # Make sure the first parameter is named 'self'.
192 sig = inspect.signature(C.__init__)
193 first = next(iter(sig.parameters))
194 self.assertEqual('self', first)
195
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300196 def test_field_named_object(self):
197 @dataclass
198 class C:
199 object: str
200 c = C('foo')
201 self.assertEqual(c.object, 'foo')
202
203 def test_field_named_object_frozen(self):
204 @dataclass(frozen=True)
205 class C:
206 object: str
207 c = C('foo')
208 self.assertEqual(c.object, 'foo')
209
210 def test_field_named_like_builtin(self):
211 # Attribute names can shadow built-in names
212 # since code generation is used.
213 # Ensure that this is not happening.
214 exclusions = {'None', 'True', 'False'}
215 builtins_names = sorted(
216 b for b in builtins.__dict__.keys()
217 if not b.startswith('__') and b not in exclusions
218 )
219 attributes = [(name, str) for name in builtins_names]
220 C = make_dataclass('C', attributes)
221
222 c = C(*[name for name in builtins_names])
223
224 for name in builtins_names:
225 self.assertEqual(getattr(c, name), name)
226
227 def test_field_named_like_builtin_frozen(self):
228 # Attribute names can shadow built-in names
229 # since code generation is used.
230 # Ensure that this is not happening
231 # for frozen data classes.
232 exclusions = {'None', 'True', 'False'}
233 builtins_names = sorted(
234 b for b in builtins.__dict__.keys()
235 if not b.startswith('__') and b not in exclusions
236 )
237 attributes = [(name, str) for name in builtins_names]
238 C = make_dataclass('C', attributes, frozen=True)
239
240 c = C(*[name for name in builtins_names])
241
242 for name in builtins_names:
243 self.assertEqual(getattr(c, name), name)
244
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500245 def test_0_field_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 pass
250
251 @dataclass(order=False)
252 class C1:
253 pass
254
255 for cls in [C0, C1]:
256 with self.subTest(cls=cls):
257 self.assertEqual(cls(), cls())
258 for idx, fn in enumerate([lambda a, b: a < b,
259 lambda a, b: a <= b,
260 lambda a, b: a > b,
261 lambda a, b: a >= b]):
262 with self.subTest(idx=idx):
263 with self.assertRaisesRegex(TypeError,
264 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
265 fn(cls(), cls())
266
267 @dataclass(order=True)
268 class C:
269 pass
270 self.assertLessEqual(C(), C())
271 self.assertGreaterEqual(C(), C())
272
273 def test_1_field_compare(self):
274 # Ensure that order=False is the default.
275 @dataclass
276 class C0:
277 x: int
278
279 @dataclass(order=False)
280 class C1:
281 x: int
282
283 for cls in [C0, C1]:
284 with self.subTest(cls=cls):
285 self.assertEqual(cls(1), cls(1))
286 self.assertNotEqual(cls(0), cls(1))
287 for idx, fn in enumerate([lambda a, b: a < b,
288 lambda a, b: a <= b,
289 lambda a, b: a > b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 with self.assertRaisesRegex(TypeError,
293 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
294 fn(cls(0), cls(0))
295
296 @dataclass(order=True)
297 class C:
298 x: int
299 self.assertLess(C(0), C(1))
300 self.assertLessEqual(C(0), C(1))
301 self.assertLessEqual(C(1), C(1))
302 self.assertGreater(C(1), C(0))
303 self.assertGreaterEqual(C(1), C(0))
304 self.assertGreaterEqual(C(1), C(1))
305
306 def test_simple_compare(self):
307 # Ensure that order=False is the default.
308 @dataclass
309 class C0:
310 x: int
311 y: int
312
313 @dataclass(order=False)
314 class C1:
315 x: int
316 y: int
317
318 for cls in [C0, C1]:
319 with self.subTest(cls=cls):
320 self.assertEqual(cls(0, 0), cls(0, 0))
321 self.assertEqual(cls(1, 2), cls(1, 2))
322 self.assertNotEqual(cls(1, 0), cls(0, 0))
323 self.assertNotEqual(cls(1, 0), cls(1, 1))
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
331 fn(cls(0, 0), cls(0, 0))
332
333 @dataclass(order=True)
334 class C:
335 x: int
336 y: int
337
338 for idx, fn in enumerate([lambda a, b: a == b,
339 lambda a, b: a <= b,
340 lambda a, b: a >= b]):
341 with self.subTest(idx=idx):
342 self.assertTrue(fn(C(0, 0), C(0, 0)))
343
344 for idx, fn in enumerate([lambda a, b: a < b,
345 lambda a, b: a <= b,
346 lambda a, b: a != b]):
347 with self.subTest(idx=idx):
348 self.assertTrue(fn(C(0, 0), C(0, 1)))
349 self.assertTrue(fn(C(0, 1), C(1, 0)))
350 self.assertTrue(fn(C(1, 0), C(1, 1)))
351
352 for idx, fn in enumerate([lambda a, b: a > b,
353 lambda a, b: a >= b,
354 lambda a, b: a != b]):
355 with self.subTest(idx=idx):
356 self.assertTrue(fn(C(0, 1), C(0, 0)))
357 self.assertTrue(fn(C(1, 0), C(0, 1)))
358 self.assertTrue(fn(C(1, 1), C(1, 0)))
359
360 def test_compare_subclasses(self):
361 # Comparisons fail for subclasses, even if no fields
362 # are added.
363 @dataclass
364 class B:
365 i: int
366
367 @dataclass
368 class C(B):
369 pass
370
371 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
372 (lambda a, b: a != b, True)]):
373 with self.subTest(idx=idx):
374 self.assertEqual(fn(B(0), C(0)), expected)
375
376 for idx, fn in enumerate([lambda a, b: a < b,
377 lambda a, b: a <= b,
378 lambda a, b: a > b,
379 lambda a, b: a >= b]):
380 with self.subTest(idx=idx):
381 with self.assertRaisesRegex(TypeError,
382 "not supported between instances of 'B' and 'C'"):
383 fn(B(0), C(0))
384
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500385 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500386 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500387 for (eq, order, result ) in [
388 (False, False, 'neither'),
389 (False, True, 'exception'),
390 (True, False, 'eq_only'),
391 (True, True, 'both'),
392 ]:
393 with self.subTest(eq=eq, order=order):
394 if result == 'exception':
395 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
396 @dataclass(eq=eq, order=order)
397 class C:
398 pass
399 else:
400 @dataclass(eq=eq, order=order)
401 class C:
402 pass
403
404 if result == 'neither':
405 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 self.assertNotIn('__lt__', C.__dict__)
407 self.assertNotIn('__le__', C.__dict__)
408 self.assertNotIn('__gt__', C.__dict__)
409 self.assertNotIn('__ge__', C.__dict__)
410 elif result == 'both':
411 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500412 self.assertIn('__lt__', C.__dict__)
413 self.assertIn('__le__', C.__dict__)
414 self.assertIn('__gt__', C.__dict__)
415 self.assertIn('__ge__', C.__dict__)
416 elif result == 'eq_only':
417 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500418 self.assertNotIn('__lt__', C.__dict__)
419 self.assertNotIn('__le__', C.__dict__)
420 self.assertNotIn('__gt__', C.__dict__)
421 self.assertNotIn('__ge__', C.__dict__)
422 else:
423 assert False, f'unknown result {result!r}'
424
425 def test_field_no_default(self):
426 @dataclass
427 class C:
428 x: int = field()
429
430 self.assertEqual(C(5).x, 5)
431
432 with self.assertRaisesRegex(TypeError,
433 r"__init__\(\) missing 1 required "
434 "positional argument: 'x'"):
435 C()
436
437 def test_field_default(self):
438 default = object()
439 @dataclass
440 class C:
441 x: object = field(default=default)
442
443 self.assertIs(C.x, default)
444 c = C(10)
445 self.assertEqual(c.x, 10)
446
447 # If we delete the instance attribute, we should then see the
448 # class attribute.
449 del c.x
450 self.assertIs(c.x, default)
451
452 self.assertIs(C().x, default)
453
454 def test_not_in_repr(self):
455 @dataclass
456 class C:
457 x: int = field(repr=False)
458 with self.assertRaises(TypeError):
459 C()
460 c = C(10)
461 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
462
463 @dataclass
464 class C:
465 x: int = field(repr=False)
466 y: int
467 c = C(10, 20)
468 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
469
470 def test_not_in_compare(self):
471 @dataclass
472 class C:
473 x: int = 0
474 y: int = field(compare=False, default=4)
475
476 self.assertEqual(C(), C(0, 20))
477 self.assertEqual(C(1, 10), C(1, 20))
478 self.assertNotEqual(C(3), C(4, 10))
479 self.assertNotEqual(C(3, 10), C(4, 10))
480
481 def test_hash_field_rules(self):
482 # Test all 6 cases of:
483 # hash=True/False/None
484 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500485 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500486 (True, False, 'field' ),
487 (True, True, 'field' ),
488 (False, False, 'absent'),
489 (False, True, 'absent'),
490 (None, False, 'absent'),
491 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500492 ]:
493 with self.subTest(hash=hash_, compare=compare):
494 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500495 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500496 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500497
498 if result == 'field':
499 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500500 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500501 elif result == 'absent':
502 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500503 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500504 else:
505 assert False, f'unknown result {result!r}'
506
507 def test_init_false_no_default(self):
508 # If init=False and no default value, then the field won't be
509 # present in the instance.
510 @dataclass
511 class C:
512 x: int = field(init=False)
513
514 self.assertNotIn('x', C().__dict__)
515
516 @dataclass
517 class C:
518 x: int
519 y: int = 0
520 z: int = field(init=False)
521 t: int = 10
522
523 self.assertNotIn('z', C(0).__dict__)
524 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
525
526 def test_class_marker(self):
527 @dataclass
528 class C:
529 x: int
530 y: str = field(init=False, default=None)
531 z: str = field(repr=False)
532
533 the_fields = fields(C)
534 # the_fields is a tuple of 3 items, each value
535 # is in __annotations__.
536 self.assertIsInstance(the_fields, tuple)
537 for f in the_fields:
538 self.assertIs(type(f), Field)
539 self.assertIn(f.name, C.__annotations__)
540
541 self.assertEqual(len(the_fields), 3)
542
543 self.assertEqual(the_fields[0].name, 'x')
544 self.assertEqual(the_fields[0].type, int)
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertTrue (the_fields[0].init)
547 self.assertTrue (the_fields[0].repr)
548 self.assertEqual(the_fields[1].name, 'y')
549 self.assertEqual(the_fields[1].type, str)
550 self.assertIsNone(getattr(C, 'y'))
551 self.assertFalse(the_fields[1].init)
552 self.assertTrue (the_fields[1].repr)
553 self.assertEqual(the_fields[2].name, 'z')
554 self.assertEqual(the_fields[2].type, str)
555 self.assertFalse(hasattr(C, 'z'))
556 self.assertTrue (the_fields[2].init)
557 self.assertFalse(the_fields[2].repr)
558
559 def test_field_order(self):
560 @dataclass
561 class B:
562 a: str = 'B:a'
563 b: str = 'B:b'
564 c: str = 'B:c'
565
566 @dataclass
567 class C(B):
568 b: str = 'C:b'
569
570 self.assertEqual([(f.name, f.default) for f in fields(C)],
571 [('a', 'B:a'),
572 ('b', 'C:b'),
573 ('c', 'B:c')])
574
575 @dataclass
576 class D(B):
577 c: str = 'D:c'
578
579 self.assertEqual([(f.name, f.default) for f in fields(D)],
580 [('a', 'B:a'),
581 ('b', 'B:b'),
582 ('c', 'D:c')])
583
584 @dataclass
585 class E(D):
586 a: str = 'E:a'
587 d: str = 'E:d'
588
589 self.assertEqual([(f.name, f.default) for f in fields(E)],
590 [('a', 'E:a'),
591 ('b', 'B:b'),
592 ('c', 'D:c'),
593 ('d', 'E:d')])
594
595 def test_class_attrs(self):
596 # We only have a class attribute if a default value is
597 # specified, either directly or via a field with a default.
598 default = object()
599 @dataclass
600 class C:
601 x: int
602 y: int = field(repr=False)
603 z: object = default
604 t: int = field(default=100)
605
606 self.assertFalse(hasattr(C, 'x'))
607 self.assertFalse(hasattr(C, 'y'))
608 self.assertIs (C.z, default)
609 self.assertEqual(C.t, 100)
610
611 def test_disallowed_mutable_defaults(self):
612 # For the known types, don't allow mutable default values.
613 for typ, empty, non_empty in [(list, [], [1]),
614 (dict, {}, {0:1}),
615 (set, set(), set([1])),
616 ]:
617 with self.subTest(typ=typ):
618 # Can't use a zero-length value.
619 with self.assertRaisesRegex(ValueError,
620 f'mutable default {typ} for field '
621 'x is not allowed'):
622 @dataclass
623 class Point:
624 x: typ = empty
625
626
627 # Nor a non-zero-length value
628 with self.assertRaisesRegex(ValueError,
629 f'mutable default {typ} for field '
630 'y is not allowed'):
631 @dataclass
632 class Point:
633 y: typ = non_empty
634
635 # Check subtypes also fail.
636 class Subclass(typ): pass
637
638 with self.assertRaisesRegex(ValueError,
639 f"mutable default .*Subclass'>"
640 ' for field z is not allowed'
641 ):
642 @dataclass
643 class Point:
644 z: typ = Subclass()
645
646 # Because this is a ClassVar, it can be mutable.
647 @dataclass
648 class C:
649 z: ClassVar[typ] = typ()
650
651 # Because this is a ClassVar, it can be mutable.
652 @dataclass
653 class C:
654 x: ClassVar[typ] = Subclass()
655
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500656 def test_deliberately_mutable_defaults(self):
657 # If a mutable default isn't in the known list of
658 # (list, dict, set), then it's okay.
659 class Mutable:
660 def __init__(self):
661 self.l = []
662
663 @dataclass
664 class C:
665 x: Mutable
666
667 # These 2 instances will share this value of x.
668 lst = Mutable()
669 o1 = C(lst)
670 o2 = C(lst)
671 self.assertEqual(o1, o2)
672 o1.x.l.extend([1, 2])
673 self.assertEqual(o1, o2)
674 self.assertEqual(o1.x.l, [1, 2])
675 self.assertIs(o1.x, o2.x)
676
677 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400678 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500679 @dataclass()
680 class C:
681 x: int
682
683 self.assertEqual(C(42).x, 42)
684
685 def test_not_tuple(self):
686 # Make sure we can't be compared to a tuple.
687 @dataclass
688 class Point:
689 x: int
690 y: int
691 self.assertNotEqual(Point(1, 2), (1, 2))
692
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400693 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 @dataclass
695 class C:
696 x: int
697 y: int
698 self.assertNotEqual(Point(1, 3), C(1, 3))
699
Windson yangbe372d72019-04-23 02:45:34 +0800700 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 # Test that some of the problems with namedtuple don't happen
702 # here.
703 @dataclass
704 class Point3D:
705 x: int
706 y: int
707 z: int
708
709 @dataclass
710 class Date:
711 year: int
712 month: int
713 day: int
714
715 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
716 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
717
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400718 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200719 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 x, y, z = Point3D(4, 5, 6)
721
Eric V. Smith7c99e932018-01-28 19:18:55 -0500722 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # equal.
724 @dataclass
725 class Point3Dv1:
726 x: int = 0
727 y: int = 0
728 z: int = 0
729 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
730
731 def test_function_annotations(self):
732 # Some dummy class and instance to use as a default.
733 class F:
734 pass
735 f = F()
736
737 def validate_class(cls):
738 # First, check __annotations__, even though they're not
739 # function annotations.
740 self.assertEqual(cls.__annotations__['i'], int)
741 self.assertEqual(cls.__annotations__['j'], str)
742 self.assertEqual(cls.__annotations__['k'], F)
743 self.assertEqual(cls.__annotations__['l'], float)
744 self.assertEqual(cls.__annotations__['z'], complex)
745
746 # Verify __init__.
747
748 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400749 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500750 self.assertIs(signature.return_annotation, None)
751
752 # Check each parameter.
753 params = iter(signature.parameters.values())
754 param = next(params)
755 # This is testing an internal name, and probably shouldn't be tested.
756 self.assertEqual(param.name, 'self')
757 param = next(params)
758 self.assertEqual(param.name, 'i')
759 self.assertIs (param.annotation, int)
760 self.assertEqual(param.default, inspect.Parameter.empty)
761 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
762 param = next(params)
763 self.assertEqual(param.name, 'j')
764 self.assertIs (param.annotation, str)
765 self.assertEqual(param.default, inspect.Parameter.empty)
766 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
767 param = next(params)
768 self.assertEqual(param.name, 'k')
769 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400770 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
772 param = next(params)
773 self.assertEqual(param.name, 'l')
774 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400775 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500776 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
777 self.assertRaises(StopIteration, next, params)
778
779
780 @dataclass
781 class C:
782 i: int
783 j: str
784 k: F = f
785 l: float=field(default=None)
786 z: complex=field(default=3+4j, init=False)
787
788 validate_class(C)
789
790 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500791 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 class C:
793 i: int
794 j: str
795 k: F = f
796 l: float=field(default=None)
797 z: complex=field(default=3+4j, init=False)
798
799 validate_class(C)
800
Eric V. Smith03220fd2017-12-29 13:59:58 -0500801 def test_missing_default(self):
802 # Test that MISSING works the same as a default not being
803 # specified.
804 @dataclass
805 class C:
806 x: int=field(default=MISSING)
807 with self.assertRaisesRegex(TypeError,
808 r'__init__\(\) missing 1 required '
809 'positional argument'):
810 C()
811 self.assertNotIn('x', C.__dict__)
812
813 @dataclass
814 class D:
815 x: int
816 with self.assertRaisesRegex(TypeError,
817 r'__init__\(\) missing 1 required '
818 'positional argument'):
819 D()
820 self.assertNotIn('x', D.__dict__)
821
822 def test_missing_default_factory(self):
823 # Test that MISSING works the same as a default factory not
824 # being specified (which is really the same as a default not
825 # being specified, too).
826 @dataclass
827 class C:
828 x: int=field(default_factory=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int=field(default=MISSING, default_factory=MISSING)
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_repr(self):
845 self.assertIn('MISSING_TYPE object', repr(MISSING))
846
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 def test_dont_include_other_annotations(self):
848 @dataclass
849 class C:
850 i: int
851 def foo(self) -> int:
852 return 4
853 @property
854 def bar(self) -> int:
855 return 5
856 self.assertEqual(list(C.__annotations__), ['i'])
857 self.assertEqual(C(10).foo(), 4)
858 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400859 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500860
861 def test_post_init(self):
862 # Just make sure it gets called
863 @dataclass
864 class C:
865 def __post_init__(self):
866 raise CustomError()
867 with self.assertRaises(CustomError):
868 C()
869
870 @dataclass
871 class C:
872 i: int = 10
873 def __post_init__(self):
874 if self.i == 10:
875 raise CustomError()
876 with self.assertRaises(CustomError):
877 C()
878 # post-init gets called, but doesn't raise. This is just
879 # checking that self is used correctly.
880 C(5)
881
882 # If there's not an __init__, then post-init won't get called.
883 @dataclass(init=False)
884 class C:
885 def __post_init__(self):
886 raise CustomError()
887 # Creating the class won't raise
888 C()
889
890 @dataclass
891 class C:
892 x: int = 0
893 def __post_init__(self):
894 self.x *= 2
895 self.assertEqual(C().x, 0)
896 self.assertEqual(C(2).x, 4)
897
Mike53f7a7c2017-12-14 14:04:53 +0300898 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500899 # attributes.
900 @dataclass(frozen=True)
901 class C:
902 x: int = 0
903 def __post_init__(self):
904 self.x *= 2
905 with self.assertRaises(FrozenInstanceError):
906 C()
907
908 def test_post_init_super(self):
909 # Make sure super() post-init isn't called by default.
910 class B:
911 def __post_init__(self):
912 raise CustomError()
913
914 @dataclass
915 class C(B):
916 def __post_init__(self):
917 self.x = 5
918
919 self.assertEqual(C().x, 5)
920
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400921 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500922 @dataclass
923 class C(B):
924 def __post_init__(self):
925 super().__post_init__()
926
927 with self.assertRaises(CustomError):
928 C()
929
930 # Make sure post-init is called, even if not defined in our
931 # class.
932 @dataclass
933 class C(B):
934 pass
935
936 with self.assertRaises(CustomError):
937 C()
938
939 def test_post_init_staticmethod(self):
940 flag = False
941 @dataclass
942 class C:
943 x: int
944 y: int
945 @staticmethod
946 def __post_init__():
947 nonlocal flag
948 flag = True
949
950 self.assertFalse(flag)
951 c = C(3, 4)
952 self.assertEqual((c.x, c.y), (3, 4))
953 self.assertTrue(flag)
954
955 def test_post_init_classmethod(self):
956 @dataclass
957 class C:
958 flag = False
959 x: int
960 y: int
961 @classmethod
962 def __post_init__(cls):
963 cls.flag = True
964
965 self.assertFalse(C.flag)
966 c = C(3, 4)
967 self.assertEqual((c.x, c.y), (3, 4))
968 self.assertTrue(C.flag)
969
970 def test_class_var(self):
971 # Make sure ClassVars are ignored in __init__, __repr__, etc.
972 @dataclass
973 class C:
974 x: int
975 y: int = 10
976 z: ClassVar[int] = 1000
977 w: ClassVar[int] = 2000
978 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400979 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500980
981 c = C(5)
982 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400983 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400984 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500985 self.assertEqual(c.z, 1000)
986 self.assertEqual(c.w, 2000)
987 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400988 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500989 C.z += 1
990 self.assertEqual(c.z, 1001)
991 c = C(20)
992 self.assertEqual((c.x, c.y), (20, 10))
993 self.assertEqual(c.z, 1001)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400996 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500997
998 def test_class_var_no_default(self):
999 # If a ClassVar has no default value, it should not be set on the class.
1000 @dataclass
1001 class C:
1002 x: ClassVar[int]
1003
1004 self.assertNotIn('x', C.__dict__)
1005
1006 def test_class_var_default_factory(self):
1007 # It makes no sense for a ClassVar to have a default factory. When
1008 # would it be called? Call it yourself, since it's class-wide.
1009 with self.assertRaisesRegex(TypeError,
1010 'cannot have a default factory'):
1011 @dataclass
1012 class C:
1013 x: ClassVar[int] = field(default_factory=int)
1014
1015 self.assertNotIn('x', C.__dict__)
1016
1017 def test_class_var_with_default(self):
1018 # If a ClassVar has a default value, it should be set on the class.
1019 @dataclass
1020 class C:
1021 x: ClassVar[int] = 10
1022 self.assertEqual(C.x, 10)
1023
1024 @dataclass
1025 class C:
1026 x: ClassVar[int] = field(default=10)
1027 self.assertEqual(C.x, 10)
1028
1029 def test_class_var_frozen(self):
1030 # Make sure ClassVars work even if we're frozen.
1031 @dataclass(frozen=True)
1032 class C:
1033 x: int
1034 y: int = 10
1035 z: ClassVar[int] = 1000
1036 w: ClassVar[int] = 2000
1037 t: ClassVar[int] = 3000
1038
1039 c = C(5)
1040 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1041 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1042 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1043 self.assertEqual(c.z, 1000)
1044 self.assertEqual(c.w, 2000)
1045 self.assertEqual(c.t, 3000)
1046 # We can still modify the ClassVar, it's only instances that are
1047 # frozen.
1048 C.z += 1
1049 self.assertEqual(c.z, 1001)
1050 c = C(20)
1051 self.assertEqual((c.x, c.y), (20, 10))
1052 self.assertEqual(c.z, 1001)
1053 self.assertEqual(c.w, 2000)
1054 self.assertEqual(c.t, 3000)
1055
1056 def test_init_var_no_default(self):
1057 # If an InitVar has no default value, it should not be set on the class.
1058 @dataclass
1059 class C:
1060 x: InitVar[int]
1061
1062 self.assertNotIn('x', C.__dict__)
1063
1064 def test_init_var_default_factory(self):
1065 # It makes no sense for an InitVar to have a default factory. When
1066 # would it be called? Call it yourself, since it's class-wide.
1067 with self.assertRaisesRegex(TypeError,
1068 'cannot have a default factory'):
1069 @dataclass
1070 class C:
1071 x: InitVar[int] = field(default_factory=int)
1072
1073 self.assertNotIn('x', C.__dict__)
1074
1075 def test_init_var_with_default(self):
1076 # If an InitVar has a default value, it should be set on the class.
1077 @dataclass
1078 class C:
1079 x: InitVar[int] = 10
1080 self.assertEqual(C.x, 10)
1081
1082 @dataclass
1083 class C:
1084 x: InitVar[int] = field(default=10)
1085 self.assertEqual(C.x, 10)
1086
1087 def test_init_var(self):
1088 @dataclass
1089 class C:
1090 x: int = None
1091 init_param: InitVar[int] = None
1092
1093 def __post_init__(self, init_param):
1094 if self.x is None:
1095 self.x = init_param*2
1096
1097 c = C(init_param=10)
1098 self.assertEqual(c.x, 20)
1099
Augusto Hack01ee12b2019-06-02 23:14:48 -03001100 def test_init_var_preserve_type(self):
1101 self.assertEqual(InitVar[int].type, int)
1102
1103 # Make sure the repr is correct.
1104 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
1105
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001106 def test_init_var_inheritance(self):
1107 # Note that this deliberately tests that a dataclass need not
1108 # have a __post_init__ function if it has an InitVar field.
1109 # It could just be used in a derived class, as shown here.
1110 @dataclass
1111 class Base:
1112 x: int
1113 init_base: InitVar[int]
1114
1115 # We can instantiate by passing the InitVar, even though
1116 # it's not used.
1117 b = Base(0, 10)
1118 self.assertEqual(vars(b), {'x': 0})
1119
1120 @dataclass
1121 class C(Base):
1122 y: int
1123 init_derived: InitVar[int]
1124
1125 def __post_init__(self, init_base, init_derived):
1126 self.x = self.x + init_base
1127 self.y = self.y + init_derived
1128
1129 c = C(10, 11, 50, 51)
1130 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1131
1132 def test_default_factory(self):
1133 # Test a factory that returns a new list.
1134 @dataclass
1135 class C:
1136 x: int
1137 y: list = field(default_factory=list)
1138
1139 c0 = C(3)
1140 c1 = C(3)
1141 self.assertEqual(c0.x, 3)
1142 self.assertEqual(c0.y, [])
1143 self.assertEqual(c0, c1)
1144 self.assertIsNot(c0.y, c1.y)
1145 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1146
1147 # Test a factory that returns a shared list.
1148 l = []
1149 @dataclass
1150 class C:
1151 x: int
1152 y: list = field(default_factory=lambda: l)
1153
1154 c0 = C(3)
1155 c1 = C(3)
1156 self.assertEqual(c0.x, 3)
1157 self.assertEqual(c0.y, [])
1158 self.assertEqual(c0, c1)
1159 self.assertIs(c0.y, c1.y)
1160 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1161
1162 # Test various other field flags.
1163 # repr
1164 @dataclass
1165 class C:
1166 x: list = field(default_factory=list, repr=False)
1167 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1168 self.assertEqual(C().x, [])
1169
1170 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001171 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001172 class C:
1173 x: list = field(default_factory=list, hash=False)
1174 self.assertEqual(astuple(C()), ([],))
1175 self.assertEqual(hash(C()), hash(()))
1176
1177 # init (see also test_default_factory_with_no_init)
1178 @dataclass
1179 class C:
1180 x: list = field(default_factory=list, init=False)
1181 self.assertEqual(astuple(C()), ([],))
1182
1183 # compare
1184 @dataclass
1185 class C:
1186 x: list = field(default_factory=list, compare=False)
1187 self.assertEqual(C(), C([1]))
1188
1189 def test_default_factory_with_no_init(self):
1190 # We need a factory with a side effect.
1191 factory = Mock()
1192
1193 @dataclass
1194 class C:
1195 x: list = field(default_factory=factory, init=False)
1196
1197 # Make sure the default factory is called for each new instance.
1198 C().x
1199 self.assertEqual(factory.call_count, 1)
1200 C().x
1201 self.assertEqual(factory.call_count, 2)
1202
1203 def test_default_factory_not_called_if_value_given(self):
1204 # We need a factory that we can test if it's been called.
1205 factory = Mock()
1206
1207 @dataclass
1208 class C:
1209 x: int = field(default_factory=factory)
1210
1211 # Make sure that if a field has a default factory function,
1212 # it's not called if a value is specified.
1213 C().x
1214 self.assertEqual(factory.call_count, 1)
1215 self.assertEqual(C(10).x, 10)
1216 self.assertEqual(factory.call_count, 1)
1217 C().x
1218 self.assertEqual(factory.call_count, 2)
1219
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001220 def test_default_factory_derived(self):
1221 # See bpo-32896.
1222 @dataclass
1223 class Foo:
1224 x: dict = field(default_factory=dict)
1225
1226 @dataclass
1227 class Bar(Foo):
1228 y: int = 1
1229
1230 self.assertEqual(Foo().x, {})
1231 self.assertEqual(Bar().x, {})
1232 self.assertEqual(Bar().y, 1)
1233
1234 @dataclass
1235 class Baz(Foo):
1236 pass
1237 self.assertEqual(Baz().x, {})
1238
1239 def test_intermediate_non_dataclass(self):
1240 # Test that an intermediate class that defines
1241 # annotations does not define fields.
1242
1243 @dataclass
1244 class A:
1245 x: int
1246
1247 class B(A):
1248 y: int
1249
1250 @dataclass
1251 class C(B):
1252 z: int
1253
1254 c = C(1, 3)
1255 self.assertEqual((c.x, c.z), (1, 3))
1256
1257 # .y was not initialized.
1258 with self.assertRaisesRegex(AttributeError,
1259 'object has no attribute'):
1260 c.y
1261
1262 # And if we again derive a non-dataclass, no fields are added.
1263 class D(C):
1264 t: int
1265 d = D(4, 5)
1266 self.assertEqual((d.x, d.z), (4, 5))
1267
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001268 def test_classvar_default_factory(self):
1269 # It's an error for a ClassVar to have a factory function.
1270 with self.assertRaisesRegex(TypeError,
1271 'cannot have a default factory'):
1272 @dataclass
1273 class C:
1274 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001275
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001276 def test_is_dataclass(self):
1277 class NotDataClass:
1278 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001279
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001280 self.assertFalse(is_dataclass(0))
1281 self.assertFalse(is_dataclass(int))
1282 self.assertFalse(is_dataclass(NotDataClass))
1283 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001284
1285 @dataclass
1286 class C:
1287 x: int
1288
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001289 @dataclass
1290 class D:
1291 d: C
1292 e: int
1293
1294 c = C(10)
1295 d = D(c, 4)
1296
1297 self.assertTrue(is_dataclass(C))
1298 self.assertTrue(is_dataclass(c))
1299 self.assertFalse(is_dataclass(c.x))
1300 self.assertTrue(is_dataclass(d.d))
1301 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001302
1303 def test_helper_fields_with_class_instance(self):
1304 # Check that we can call fields() on either a class or instance,
1305 # and get back the same thing.
1306 @dataclass
1307 class C:
1308 x: int
1309 y: float
1310
1311 self.assertEqual(fields(C), fields(C(0, 0.0)))
1312
1313 def test_helper_fields_exception(self):
1314 # Check that TypeError is raised if not passed a dataclass or
1315 # instance.
1316 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1317 fields(0)
1318
1319 class C: pass
1320 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1321 fields(C)
1322 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1323 fields(C())
1324
1325 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001326 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001327 @dataclass
1328 class C:
1329 x: int
1330 y: int
1331 c = C(1, 2)
1332
1333 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1334 self.assertEqual(asdict(c), asdict(c))
1335 self.assertIsNot(asdict(c), asdict(c))
1336 c.x = 42
1337 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1338 self.assertIs(type(asdict(c)), dict)
1339
1340 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001341 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001342 @dataclass
1343 class C:
1344 x: int
1345 y: int
1346 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1347 asdict(C)
1348 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1349 asdict(int)
1350
1351 def test_helper_asdict_copy_values(self):
1352 @dataclass
1353 class C:
1354 x: int
1355 y: List[int] = field(default_factory=list)
1356 initial = []
1357 c = C(1, initial)
1358 d = asdict(c)
1359 self.assertEqual(d['y'], initial)
1360 self.assertIsNot(d['y'], initial)
1361 c = C(1)
1362 d = asdict(c)
1363 d['y'].append(1)
1364 self.assertEqual(c.y, [])
1365
1366 def test_helper_asdict_nested(self):
1367 @dataclass
1368 class UserId:
1369 token: int
1370 group: int
1371 @dataclass
1372 class User:
1373 name: str
1374 id: UserId
1375 u = User('Joe', UserId(123, 1))
1376 d = asdict(u)
1377 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1378 self.assertIsNot(asdict(u), asdict(u))
1379 u.id.group = 2
1380 self.assertEqual(asdict(u), {'name': 'Joe',
1381 'id': {'token': 123, 'group': 2}})
1382
1383 def test_helper_asdict_builtin_containers(self):
1384 @dataclass
1385 class User:
1386 name: str
1387 id: int
1388 @dataclass
1389 class GroupList:
1390 id: int
1391 users: List[User]
1392 @dataclass
1393 class GroupTuple:
1394 id: int
1395 users: Tuple[User, ...]
1396 @dataclass
1397 class GroupDict:
1398 id: int
1399 users: Dict[str, User]
1400 a = User('Alice', 1)
1401 b = User('Bob', 2)
1402 gl = GroupList(0, [a, b])
1403 gt = GroupTuple(0, (a, b))
1404 gd = GroupDict(0, {'first': a, 'second': b})
1405 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1406 {'name': 'Bob', 'id': 2}]})
1407 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1408 {'name': 'Bob', 'id': 2})})
1409 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1410 'second': {'name': 'Bob', 'id': 2}}})
1411
Windson yangbe372d72019-04-23 02:45:34 +08001412 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001413 @dataclass
1414 class Child:
1415 d: object
1416
1417 @dataclass
1418 class Parent:
1419 child: Child
1420
1421 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1422 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1423
1424 def test_helper_asdict_factory(self):
1425 @dataclass
1426 class C:
1427 x: int
1428 y: int
1429 c = C(1, 2)
1430 d = asdict(c, dict_factory=OrderedDict)
1431 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1432 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1433 c.x = 42
1434 d = asdict(c, dict_factory=OrderedDict)
1435 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1436 self.assertIs(type(d), OrderedDict)
1437
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001438 def test_helper_asdict_namedtuple(self):
1439 T = namedtuple('T', 'a b c')
1440 @dataclass
1441 class C:
1442 x: str
1443 y: T
1444 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1445
1446 d = asdict(c)
1447 self.assertEqual(d, {'x': 'outer',
1448 'y': T(1,
1449 {'x': 'inner',
1450 'y': T(11, 12, 13)},
1451 2),
1452 }
1453 )
1454
1455 # Now with a dict_factory. OrderedDict is convenient, but
1456 # since it compares to dicts, we also need to have separate
1457 # assertIs tests.
1458 d = asdict(c, dict_factory=OrderedDict)
1459 self.assertEqual(d, {'x': 'outer',
1460 'y': T(1,
1461 {'x': 'inner',
1462 'y': T(11, 12, 13)},
1463 2),
1464 }
1465 )
1466
penguindustin96466302019-05-06 14:57:17 -04001467 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001468 self.assertIs(type(d), OrderedDict)
1469 self.assertIs(type(d['y'][1]), OrderedDict)
1470
1471 def test_helper_asdict_namedtuple_key(self):
1472 # Ensure that a field that contains a dict which has a
1473 # namedtuple as a key works with asdict().
1474
1475 @dataclass
1476 class C:
1477 f: dict
1478 T = namedtuple('T', 'a')
1479
1480 c = C({T('an a'): 0})
1481
1482 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1483
1484 def test_helper_asdict_namedtuple_derived(self):
1485 class T(namedtuple('Tbase', 'a')):
1486 def my_a(self):
1487 return self.a
1488
1489 @dataclass
1490 class C:
1491 f: T
1492
1493 t = T(6)
1494 c = C(t)
1495
1496 d = asdict(c)
1497 self.assertEqual(d, {'f': T(a=6)})
1498 # Make sure that t has been copied, not used directly.
1499 self.assertIsNot(d['f'], t)
1500 self.assertEqual(d['f'].my_a(), 6)
1501
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001502 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001503 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001504 @dataclass
1505 class C:
1506 x: int
1507 y: int = 0
1508 c = C(1)
1509
1510 self.assertEqual(astuple(c), (1, 0))
1511 self.assertEqual(astuple(c), astuple(c))
1512 self.assertIsNot(astuple(c), astuple(c))
1513 c.y = 42
1514 self.assertEqual(astuple(c), (1, 42))
1515 self.assertIs(type(astuple(c)), tuple)
1516
1517 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001518 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001519 @dataclass
1520 class C:
1521 x: int
1522 y: int
1523 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1524 astuple(C)
1525 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1526 astuple(int)
1527
1528 def test_helper_astuple_copy_values(self):
1529 @dataclass
1530 class C:
1531 x: int
1532 y: List[int] = field(default_factory=list)
1533 initial = []
1534 c = C(1, initial)
1535 t = astuple(c)
1536 self.assertEqual(t[1], initial)
1537 self.assertIsNot(t[1], initial)
1538 c = C(1)
1539 t = astuple(c)
1540 t[1].append(1)
1541 self.assertEqual(c.y, [])
1542
1543 def test_helper_astuple_nested(self):
1544 @dataclass
1545 class UserId:
1546 token: int
1547 group: int
1548 @dataclass
1549 class User:
1550 name: str
1551 id: UserId
1552 u = User('Joe', UserId(123, 1))
1553 t = astuple(u)
1554 self.assertEqual(t, ('Joe', (123, 1)))
1555 self.assertIsNot(astuple(u), astuple(u))
1556 u.id.group = 2
1557 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1558
1559 def test_helper_astuple_builtin_containers(self):
1560 @dataclass
1561 class User:
1562 name: str
1563 id: int
1564 @dataclass
1565 class GroupList:
1566 id: int
1567 users: List[User]
1568 @dataclass
1569 class GroupTuple:
1570 id: int
1571 users: Tuple[User, ...]
1572 @dataclass
1573 class GroupDict:
1574 id: int
1575 users: Dict[str, User]
1576 a = User('Alice', 1)
1577 b = User('Bob', 2)
1578 gl = GroupList(0, [a, b])
1579 gt = GroupTuple(0, (a, b))
1580 gd = GroupDict(0, {'first': a, 'second': b})
1581 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1582 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1583 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1584
Windson yangbe372d72019-04-23 02:45:34 +08001585 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001586 @dataclass
1587 class Child:
1588 d: object
1589
1590 @dataclass
1591 class Parent:
1592 child: Child
1593
1594 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1595 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1596
1597 def test_helper_astuple_factory(self):
1598 @dataclass
1599 class C:
1600 x: int
1601 y: int
1602 NT = namedtuple('NT', 'x y')
1603 def nt(lst):
1604 return NT(*lst)
1605 c = C(1, 2)
1606 t = astuple(c, tuple_factory=nt)
1607 self.assertEqual(t, NT(1, 2))
1608 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1609 c.x = 42
1610 t = astuple(c, tuple_factory=nt)
1611 self.assertEqual(t, NT(42, 2))
1612 self.assertIs(type(t), NT)
1613
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001614 def test_helper_astuple_namedtuple(self):
1615 T = namedtuple('T', 'a b c')
1616 @dataclass
1617 class C:
1618 x: str
1619 y: T
1620 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1621
1622 t = astuple(c)
1623 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1624
1625 # Now, using a tuple_factory. list is convenient here.
1626 t = astuple(c, tuple_factory=list)
1627 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1628
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001629 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001630 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001631 }
1632
1633 # Create the class.
1634 cls = type('C', (), cls_dict)
1635
1636 # Make it a dataclass.
1637 cls1 = dataclass(cls)
1638
1639 self.assertEqual(cls1, cls)
1640 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1641
1642 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001643 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001644 'y': field(default=5),
1645 }
1646
1647 # Create the class.
1648 cls = type('C', (), cls_dict)
1649
1650 # Make it a dataclass.
1651 cls1 = dataclass(cls)
1652
1653 self.assertEqual(cls1, cls)
1654 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1655
1656 def test_init_in_order(self):
1657 @dataclass
1658 class C:
1659 a: int
1660 b: int = field()
1661 c: list = field(default_factory=list, init=False)
1662 d: list = field(default_factory=list)
1663 e: int = field(default=4, init=False)
1664 f: int = 4
1665
1666 calls = []
1667 def setattr(self, name, value):
1668 calls.append((name, value))
1669
1670 C.__setattr__ = setattr
1671 c = C(0, 1)
1672 self.assertEqual(('a', 0), calls[0])
1673 self.assertEqual(('b', 1), calls[1])
1674 self.assertEqual(('c', []), calls[2])
1675 self.assertEqual(('d', []), calls[3])
1676 self.assertNotIn(('e', 4), calls)
1677 self.assertEqual(('f', 4), calls[4])
1678
1679 def test_items_in_dicts(self):
1680 @dataclass
1681 class C:
1682 a: int
1683 b: list = field(default_factory=list, init=False)
1684 c: list = field(default_factory=list)
1685 d: int = field(default=4, init=False)
1686 e: int = 0
1687
1688 c = C(0)
1689 # Class dict
1690 self.assertNotIn('a', C.__dict__)
1691 self.assertNotIn('b', C.__dict__)
1692 self.assertNotIn('c', C.__dict__)
1693 self.assertIn('d', C.__dict__)
1694 self.assertEqual(C.d, 4)
1695 self.assertIn('e', C.__dict__)
1696 self.assertEqual(C.e, 0)
1697 # Instance dict
1698 self.assertIn('a', c.__dict__)
1699 self.assertEqual(c.a, 0)
1700 self.assertIn('b', c.__dict__)
1701 self.assertEqual(c.b, [])
1702 self.assertIn('c', c.__dict__)
1703 self.assertEqual(c.c, [])
1704 self.assertNotIn('d', c.__dict__)
1705 self.assertIn('e', c.__dict__)
1706 self.assertEqual(c.e, 0)
1707
1708 def test_alternate_classmethod_constructor(self):
1709 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001710 # alternate constructor. This is mostly an example to show
1711 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001712 @dataclass
1713 class C:
1714 x: int
1715 @classmethod
1716 def from_file(cls, filename):
1717 # In a real example, create a new instance
1718 # and populate 'x' from contents of a file.
1719 value_in_file = 20
1720 return cls(value_in_file)
1721
1722 self.assertEqual(C.from_file('filename').x, 20)
1723
1724 def test_field_metadata_default(self):
1725 # Make sure the default metadata is read-only and of
1726 # zero length.
1727 @dataclass
1728 class C:
1729 i: int
1730
1731 self.assertFalse(fields(C)[0].metadata)
1732 self.assertEqual(len(fields(C)[0].metadata), 0)
1733 with self.assertRaisesRegex(TypeError,
1734 'does not support item assignment'):
1735 fields(C)[0].metadata['test'] = 3
1736
1737 def test_field_metadata_mapping(self):
1738 # Make sure only a mapping can be passed as metadata
1739 # zero length.
1740 with self.assertRaises(TypeError):
1741 @dataclass
1742 class C:
1743 i: int = field(metadata=0)
1744
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001745 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001746 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001747 @dataclass
1748 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001749 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001750 self.assertFalse(fields(C)[0].metadata)
1751 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001752 # Update should work (see bpo-35960).
1753 d['foo'] = 1
1754 self.assertEqual(len(fields(C)[0].metadata), 1)
1755 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001756 with self.assertRaisesRegex(TypeError,
1757 'does not support item assignment'):
1758 fields(C)[0].metadata['test'] = 3
1759
1760 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001761 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001762 @dataclass
1763 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001764 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001765 self.assertEqual(len(fields(C)[0].metadata), 3)
1766 self.assertEqual(fields(C)[0].metadata['test'], 10)
1767 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1768 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001769 # Update should work.
1770 d['foo'] = 1
1771 self.assertEqual(len(fields(C)[0].metadata), 4)
1772 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001773 with self.assertRaises(KeyError):
1774 # Non-existent key.
1775 fields(C)[0].metadata['baz']
1776 with self.assertRaisesRegex(TypeError,
1777 'does not support item assignment'):
1778 fields(C)[0].metadata['test'] = 3
1779
1780 def test_field_metadata_custom_mapping(self):
1781 # Try a custom mapping.
1782 class SimpleNameSpace:
1783 def __init__(self, **kw):
1784 self.__dict__.update(kw)
1785
1786 def __getitem__(self, item):
1787 if item == 'xyzzy':
1788 return 'plugh'
1789 return getattr(self, item)
1790
1791 def __len__(self):
1792 return self.__dict__.__len__()
1793
1794 @dataclass
1795 class C:
1796 i: int = field(metadata=SimpleNameSpace(a=10))
1797
1798 self.assertEqual(len(fields(C)[0].metadata), 1)
1799 self.assertEqual(fields(C)[0].metadata['a'], 10)
1800 with self.assertRaises(AttributeError):
1801 fields(C)[0].metadata['b']
1802 # Make sure we're still talking to our custom mapping.
1803 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1804
1805 def test_generic_dataclasses(self):
1806 T = TypeVar('T')
1807
1808 @dataclass
1809 class LabeledBox(Generic[T]):
1810 content: T
1811 label: str = '<unknown>'
1812
1813 box = LabeledBox(42)
1814 self.assertEqual(box.content, 42)
1815 self.assertEqual(box.label, '<unknown>')
1816
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001817 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001818 Alias = List[LabeledBox[int]]
1819
1820 def test_generic_extending(self):
1821 S = TypeVar('S')
1822 T = TypeVar('T')
1823
1824 @dataclass
1825 class Base(Generic[T, S]):
1826 x: T
1827 y: S
1828
1829 @dataclass
1830 class DataDerived(Base[int, T]):
1831 new_field: str
1832 Alias = DataDerived[str]
1833 c = Alias(0, 'test1', 'test2')
1834 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1835
1836 class NonDataDerived(Base[int, T]):
1837 def new_method(self):
1838 return self.y
1839 Alias = NonDataDerived[float]
1840 c = Alias(10, 1.0)
1841 self.assertEqual(c.new_method(), 1.0)
1842
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001843 def test_generic_dynamic(self):
1844 T = TypeVar('T')
1845
1846 @dataclass
1847 class Parent(Generic[T]):
1848 x: T
1849 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1850 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1851 self.assertIs(Child[int](1, 2).z, None)
1852 self.assertEqual(Child[int](1, 2, 3).z, 3)
1853 self.assertEqual(Child[int](1, 2, 3).other, 42)
1854 # Check that type aliases work correctly.
1855 Alias = Child[T]
1856 self.assertEqual(Alias[int](1, 2).x, 1)
1857 # Check MRO resolution.
1858 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1859
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001860 def test_dataclassses_pickleable(self):
1861 global P, Q, R
1862 @dataclass
1863 class P:
1864 x: int
1865 y: int = 0
1866 @dataclass
1867 class Q:
1868 x: int
1869 y: int = field(default=0, init=False)
1870 @dataclass
1871 class R:
1872 x: int
1873 y: List[int] = field(default_factory=list)
1874 q = Q(1)
1875 q.y = 2
1876 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1877 for sample in samples:
1878 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1879 with self.subTest(sample=sample, proto=proto):
1880 new_sample = pickle.loads(pickle.dumps(sample, proto))
1881 self.assertEqual(sample.x, new_sample.x)
1882 self.assertEqual(sample.y, new_sample.y)
1883 self.assertIsNot(sample, new_sample)
1884 new_sample.x = 42
1885 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1886 self.assertEqual(new_sample.x, another_new_sample.x)
1887 self.assertEqual(sample.y, another_new_sample.y)
1888
Eric V. Smithea8fc522018-01-27 19:07:40 -05001889
Eric V. Smith56970b82018-03-22 16:28:48 -04001890class TestFieldNoAnnotation(unittest.TestCase):
1891 def test_field_without_annotation(self):
1892 with self.assertRaisesRegex(TypeError,
1893 "'f' is a field but has no type annotation"):
1894 @dataclass
1895 class C:
1896 f = field()
1897
1898 def test_field_without_annotation_but_annotation_in_base(self):
1899 @dataclass
1900 class B:
1901 f: int
1902
1903 with self.assertRaisesRegex(TypeError,
1904 "'f' is a field but has no type annotation"):
1905 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001906 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001907 @dataclass
1908 class C(B):
1909 f = field()
1910
1911 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1912 # Same test, but with the base class not a dataclass.
1913 class B:
1914 f: int
1915
1916 with self.assertRaisesRegex(TypeError,
1917 "'f' is a field but has no type annotation"):
1918 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001919 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001920 @dataclass
1921 class C(B):
1922 f = field()
1923
1924
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001925class TestDocString(unittest.TestCase):
1926 def assertDocStrEqual(self, a, b):
1927 # Because 3.6 and 3.7 differ in how inspect.signature work
1928 # (see bpo #32108), for the time being just compare them with
1929 # whitespace stripped.
1930 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1931
1932 def test_existing_docstring_not_overridden(self):
1933 @dataclass
1934 class C:
1935 """Lorem ipsum"""
1936 x: int
1937
1938 self.assertEqual(C.__doc__, "Lorem ipsum")
1939
1940 def test_docstring_no_fields(self):
1941 @dataclass
1942 class C:
1943 pass
1944
1945 self.assertDocStrEqual(C.__doc__, "C()")
1946
1947 def test_docstring_one_field(self):
1948 @dataclass
1949 class C:
1950 x: int
1951
1952 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1953
1954 def test_docstring_two_fields(self):
1955 @dataclass
1956 class C:
1957 x: int
1958 y: int
1959
1960 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1961
1962 def test_docstring_three_fields(self):
1963 @dataclass
1964 class C:
1965 x: int
1966 y: int
1967 z: str
1968
1969 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1970
1971 def test_docstring_one_field_with_default(self):
1972 @dataclass
1973 class C:
1974 x: int = 3
1975
1976 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1977
1978 def test_docstring_one_field_with_default_none(self):
1979 @dataclass
1980 class C:
1981 x: Union[int, type(None)] = None
1982
1983 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1984
1985 def test_docstring_list_field(self):
1986 @dataclass
1987 class C:
1988 x: List[int]
1989
1990 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1991
1992 def test_docstring_list_field_with_default_factory(self):
1993 @dataclass
1994 class C:
1995 x: List[int] = field(default_factory=list)
1996
1997 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1998
1999 def test_docstring_deque_field(self):
2000 @dataclass
2001 class C:
2002 x: deque
2003
2004 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2005
2006 def test_docstring_deque_field_with_default_factory(self):
2007 @dataclass
2008 class C:
2009 x: deque = field(default_factory=deque)
2010
2011 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2012
2013
Eric V. Smithea8fc522018-01-27 19:07:40 -05002014class TestInit(unittest.TestCase):
2015 def test_base_has_init(self):
2016 class B:
2017 def __init__(self):
2018 self.z = 100
2019 pass
2020
2021 # Make sure that declaring this class doesn't raise an error.
2022 # The issue is that we can't override __init__ in our class,
2023 # but it should be okay to add __init__ to us if our base has
2024 # an __init__.
2025 @dataclass
2026 class C(B):
2027 x: int = 0
2028 c = C(10)
2029 self.assertEqual(c.x, 10)
2030 self.assertNotIn('z', vars(c))
2031
2032 # Make sure that if we don't add an init, the base __init__
2033 # gets called.
2034 @dataclass(init=False)
2035 class C(B):
2036 x: int = 10
2037 c = C()
2038 self.assertEqual(c.x, 10)
2039 self.assertEqual(c.z, 100)
2040
2041 def test_no_init(self):
2042 dataclass(init=False)
2043 class C:
2044 i: int = 0
2045 self.assertEqual(C().i, 0)
2046
2047 dataclass(init=False)
2048 class C:
2049 i: int = 2
2050 def __init__(self):
2051 self.i = 3
2052 self.assertEqual(C().i, 3)
2053
2054 def test_overwriting_init(self):
2055 # If the class has __init__, use it no matter the value of
2056 # init=.
2057
2058 @dataclass
2059 class C:
2060 x: int
2061 def __init__(self, x):
2062 self.x = 2 * x
2063 self.assertEqual(C(3).x, 6)
2064
2065 @dataclass(init=True)
2066 class C:
2067 x: int
2068 def __init__(self, x):
2069 self.x = 2 * x
2070 self.assertEqual(C(4).x, 8)
2071
2072 @dataclass(init=False)
2073 class C:
2074 x: int
2075 def __init__(self, x):
2076 self.x = 2 * x
2077 self.assertEqual(C(5).x, 10)
2078
2079
2080class TestRepr(unittest.TestCase):
2081 def test_repr(self):
2082 @dataclass
2083 class B:
2084 x: int
2085
2086 @dataclass
2087 class C(B):
2088 y: int = 10
2089
2090 o = C(4)
2091 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2092
2093 @dataclass
2094 class D(C):
2095 x: int = 20
2096 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2097
2098 @dataclass
2099 class C:
2100 @dataclass
2101 class D:
2102 i: int
2103 @dataclass
2104 class E:
2105 pass
2106 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2107 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2108
2109 def test_no_repr(self):
2110 # Test a class with no __repr__ and repr=False.
2111 @dataclass(repr=False)
2112 class C:
2113 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002114 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002115 repr(C(3)))
2116
2117 # Test a class with a __repr__ and repr=False.
2118 @dataclass(repr=False)
2119 class C:
2120 x: int
2121 def __repr__(self):
2122 return 'C-class'
2123 self.assertEqual(repr(C(3)), 'C-class')
2124
2125 def test_overwriting_repr(self):
2126 # If the class has __repr__, use it no matter the value of
2127 # repr=.
2128
2129 @dataclass
2130 class C:
2131 x: int
2132 def __repr__(self):
2133 return 'x'
2134 self.assertEqual(repr(C(0)), 'x')
2135
2136 @dataclass(repr=True)
2137 class C:
2138 x: int
2139 def __repr__(self):
2140 return 'x'
2141 self.assertEqual(repr(C(0)), 'x')
2142
2143 @dataclass(repr=False)
2144 class C:
2145 x: int
2146 def __repr__(self):
2147 return 'x'
2148 self.assertEqual(repr(C(0)), 'x')
2149
2150
Eric V. Smithea8fc522018-01-27 19:07:40 -05002151class TestEq(unittest.TestCase):
2152 def test_no_eq(self):
2153 # Test a class with no __eq__ and eq=False.
2154 @dataclass(eq=False)
2155 class C:
2156 x: int
2157 self.assertNotEqual(C(0), C(0))
2158 c = C(3)
2159 self.assertEqual(c, c)
2160
2161 # Test a class with an __eq__ and eq=False.
2162 @dataclass(eq=False)
2163 class C:
2164 x: int
2165 def __eq__(self, other):
2166 return other == 10
2167 self.assertEqual(C(3), 10)
2168
2169 def test_overwriting_eq(self):
2170 # If the class has __eq__, use it no matter the value of
2171 # eq=.
2172
2173 @dataclass
2174 class C:
2175 x: int
2176 def __eq__(self, other):
2177 return other == 3
2178 self.assertEqual(C(1), 3)
2179 self.assertNotEqual(C(1), 1)
2180
2181 @dataclass(eq=True)
2182 class C:
2183 x: int
2184 def __eq__(self, other):
2185 return other == 4
2186 self.assertEqual(C(1), 4)
2187 self.assertNotEqual(C(1), 1)
2188
2189 @dataclass(eq=False)
2190 class C:
2191 x: int
2192 def __eq__(self, other):
2193 return other == 5
2194 self.assertEqual(C(1), 5)
2195 self.assertNotEqual(C(1), 1)
2196
2197
2198class TestOrdering(unittest.TestCase):
2199 def test_functools_total_ordering(self):
2200 # Test that functools.total_ordering works with this class.
2201 @total_ordering
2202 @dataclass
2203 class C:
2204 x: int
2205 def __lt__(self, other):
2206 # Perform the test "backward", just to make
2207 # sure this is being called.
2208 return self.x >= other
2209
2210 self.assertLess(C(0), -1)
2211 self.assertLessEqual(C(0), -1)
2212 self.assertGreater(C(0), 1)
2213 self.assertGreaterEqual(C(0), 1)
2214
2215 def test_no_order(self):
2216 # Test that no ordering functions are added by default.
2217 @dataclass(order=False)
2218 class C:
2219 x: int
2220 # Make sure no order methods are added.
2221 self.assertNotIn('__le__', C.__dict__)
2222 self.assertNotIn('__lt__', C.__dict__)
2223 self.assertNotIn('__ge__', C.__dict__)
2224 self.assertNotIn('__gt__', C.__dict__)
2225
2226 # Test that __lt__ is still called
2227 @dataclass(order=False)
2228 class C:
2229 x: int
2230 def __lt__(self, other):
2231 return False
2232 # Make sure other methods aren't added.
2233 self.assertNotIn('__le__', C.__dict__)
2234 self.assertNotIn('__ge__', C.__dict__)
2235 self.assertNotIn('__gt__', C.__dict__)
2236
2237 def test_overwriting_order(self):
2238 with self.assertRaisesRegex(TypeError,
2239 'Cannot overwrite attribute __lt__'
2240 '.*using functools.total_ordering'):
2241 @dataclass(order=True)
2242 class C:
2243 x: int
2244 def __lt__(self):
2245 pass
2246
2247 with self.assertRaisesRegex(TypeError,
2248 'Cannot overwrite attribute __le__'
2249 '.*using functools.total_ordering'):
2250 @dataclass(order=True)
2251 class C:
2252 x: int
2253 def __le__(self):
2254 pass
2255
2256 with self.assertRaisesRegex(TypeError,
2257 'Cannot overwrite attribute __gt__'
2258 '.*using functools.total_ordering'):
2259 @dataclass(order=True)
2260 class C:
2261 x: int
2262 def __gt__(self):
2263 pass
2264
2265 with self.assertRaisesRegex(TypeError,
2266 'Cannot overwrite attribute __ge__'
2267 '.*using functools.total_ordering'):
2268 @dataclass(order=True)
2269 class C:
2270 x: int
2271 def __ge__(self):
2272 pass
2273
2274class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002275 def test_unsafe_hash(self):
2276 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002277 class C:
2278 x: int
2279 y: str
2280 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2281
Eric V. Smithea8fc522018-01-27 19:07:40 -05002282 def test_hash_rules(self):
2283 def non_bool(value):
2284 # Map to something else that's True, but not a bool.
2285 if value is None:
2286 return None
2287 if value:
2288 return (3,)
2289 return 0
2290
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002291 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2292 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2293 frozen=frozen):
2294 if result != 'exception':
2295 if with_hash:
2296 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2297 class C:
2298 def __hash__(self):
2299 return 0
2300 else:
2301 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2302 class C:
2303 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002304
2305 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002306 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002307 # __hash__ contains the function we generated.
2308 self.assertIn('__hash__', C.__dict__)
2309 self.assertIsNotNone(C.__dict__['__hash__'])
2310
Eric V. Smithea8fc522018-01-27 19:07:40 -05002311 elif result == '':
2312 # __hash__ is not present in our class.
2313 if not with_hash:
2314 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002315
Eric V. Smithea8fc522018-01-27 19:07:40 -05002316 elif result == 'none':
2317 # __hash__ is set to None.
2318 self.assertIn('__hash__', C.__dict__)
2319 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002320
2321 elif result == 'exception':
2322 # Creating the class should cause an exception.
2323 # This only happens with with_hash==True.
2324 assert(with_hash)
2325 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2326 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2327 class C:
2328 def __hash__(self):
2329 return 0
2330
Eric V. Smithea8fc522018-01-27 19:07:40 -05002331 else:
2332 assert False, f'unknown result {result!r}'
2333
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002334 # There are 8 cases of:
2335 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002336 # eq=True/False
2337 # frozen=True/False
2338 # And for each of these, a different result if
2339 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002340 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2341 (False, False, False, '', ''),
2342 (False, False, True, '', ''),
2343 (False, True, False, 'none', ''),
2344 (False, True, True, 'fn', ''),
2345 (True, False, False, 'fn', 'exception'),
2346 (True, False, True, 'fn', 'exception'),
2347 (True, True, False, 'fn', 'exception'),
2348 (True, True, True, 'fn', 'exception'),
2349 ], 1):
2350 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2351 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002352
2353 # Test non-bool truth values, too. This is just to
2354 # make sure the data-driven table in the decorator
2355 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002356 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2357 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002358
2359
2360 def test_eq_only(self):
2361 # If a class defines __eq__, __hash__ is automatically added
2362 # and set to None. This is normal Python behavior, not
2363 # related to dataclasses. Make sure we don't interfere with
2364 # that (see bpo=32546).
2365
2366 @dataclass
2367 class C:
2368 i: int
2369 def __eq__(self, other):
2370 return self.i == other.i
2371 self.assertEqual(C(1), C(1))
2372 self.assertNotEqual(C(1), C(4))
2373
2374 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002375 # unsafe_hash=True.
2376 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002377 class C:
2378 i: int
2379 def __eq__(self, other):
2380 return self.i == other.i
2381 self.assertEqual(C(1), C(1.0))
2382 self.assertEqual(hash(C(1)), hash(C(1.0)))
2383
2384 # And check that the classes __eq__ is being used, despite
2385 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002386 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002387 class C:
2388 i: int
2389 def __eq__(self, other):
2390 return self.i == 3 and self.i == other.i
2391 self.assertEqual(C(3), C(3))
2392 self.assertNotEqual(C(1), C(1))
2393 self.assertEqual(hash(C(1)), hash(C(1.0)))
2394
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002395 def test_0_field_hash(self):
2396 @dataclass(frozen=True)
2397 class C:
2398 pass
2399 self.assertEqual(hash(C()), hash(()))
2400
2401 @dataclass(unsafe_hash=True)
2402 class C:
2403 pass
2404 self.assertEqual(hash(C()), hash(()))
2405
2406 def test_1_field_hash(self):
2407 @dataclass(frozen=True)
2408 class C:
2409 x: int
2410 self.assertEqual(hash(C(4)), hash((4,)))
2411 self.assertEqual(hash(C(42)), hash((42,)))
2412
2413 @dataclass(unsafe_hash=True)
2414 class C:
2415 x: int
2416 self.assertEqual(hash(C(4)), hash((4,)))
2417 self.assertEqual(hash(C(42)), hash((42,)))
2418
Eric V. Smith718070d2018-02-23 13:01:31 -05002419 def test_hash_no_args(self):
2420 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002421 # make sure that if the @dataclass parameter name is changed
2422 # or the non-default hashing behavior changes, the default
2423 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002424
2425 class Base:
2426 def __hash__(self):
2427 return 301
2428
2429 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002430 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002431 for frozen, eq, base, expected in [
2432 (None, None, object, 'unhashable'),
2433 (None, None, Base, 'unhashable'),
2434 (None, False, object, 'object'),
2435 (None, False, Base, 'base'),
2436 (None, True, object, 'unhashable'),
2437 (None, True, Base, 'unhashable'),
2438 (False, None, object, 'unhashable'),
2439 (False, None, Base, 'unhashable'),
2440 (False, False, object, 'object'),
2441 (False, False, Base, 'base'),
2442 (False, True, object, 'unhashable'),
2443 (False, True, Base, 'unhashable'),
2444 (True, None, object, 'tuple'),
2445 (True, None, Base, 'tuple'),
2446 (True, False, object, 'object'),
2447 (True, False, Base, 'base'),
2448 (True, True, object, 'tuple'),
2449 (True, True, Base, 'tuple'),
2450 ]:
2451
2452 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2453 # First, create the class.
2454 if frozen is None and eq is None:
2455 @dataclass
2456 class C(base):
2457 i: int
2458 elif frozen is None:
2459 @dataclass(eq=eq)
2460 class C(base):
2461 i: int
2462 elif eq is None:
2463 @dataclass(frozen=frozen)
2464 class C(base):
2465 i: int
2466 else:
2467 @dataclass(frozen=frozen, eq=eq)
2468 class C(base):
2469 i: int
2470
2471 # Now, make sure it hashes as expected.
2472 if expected == 'unhashable':
2473 c = C(10)
2474 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2475 hash(c)
2476
2477 elif expected == 'base':
2478 self.assertEqual(hash(C(10)), 301)
2479
2480 elif expected == 'object':
2481 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002482 # hash isn't based on id(), so calling hash()
2483 # won't tell us much. So, just check the
2484 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002485 self.assertIs(C.__hash__, object.__hash__)
2486
2487 elif expected == 'tuple':
2488 self.assertEqual(hash(C(42)), hash((42,)))
2489
2490 else:
2491 assert False, f'unknown value for expected={expected!r}'
2492
Eric V. Smithea8fc522018-01-27 19:07:40 -05002493
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002494class TestFrozen(unittest.TestCase):
2495 def test_frozen(self):
2496 @dataclass(frozen=True)
2497 class C:
2498 i: int
2499
2500 c = C(10)
2501 self.assertEqual(c.i, 10)
2502 with self.assertRaises(FrozenInstanceError):
2503 c.i = 5
2504 self.assertEqual(c.i, 10)
2505
2506 def test_inherit(self):
2507 @dataclass(frozen=True)
2508 class C:
2509 i: int
2510
2511 @dataclass(frozen=True)
2512 class D(C):
2513 j: int
2514
2515 d = D(0, 10)
2516 with self.assertRaises(FrozenInstanceError):
2517 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002518 with self.assertRaises(FrozenInstanceError):
2519 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002520 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002521 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002522
Eric V. Smithf199bc62018-03-18 20:40:34 -04002523 # Test both ways: with an intermediate normal (non-dataclass)
2524 # class and without an intermediate class.
2525 def test_inherit_nonfrozen_from_frozen(self):
2526 for intermediate_class in [True, False]:
2527 with self.subTest(intermediate_class=intermediate_class):
2528 @dataclass(frozen=True)
2529 class C:
2530 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002531
Eric V. Smithf199bc62018-03-18 20:40:34 -04002532 if intermediate_class:
2533 class I(C): pass
2534 else:
2535 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002536
Eric V. Smithf199bc62018-03-18 20:40:34 -04002537 with self.assertRaisesRegex(TypeError,
2538 'cannot inherit non-frozen dataclass from a frozen one'):
2539 @dataclass
2540 class D(I):
2541 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002542
Eric V. Smithf199bc62018-03-18 20:40:34 -04002543 def test_inherit_frozen_from_nonfrozen(self):
2544 for intermediate_class in [True, False]:
2545 with self.subTest(intermediate_class=intermediate_class):
2546 @dataclass
2547 class C:
2548 i: int
2549
2550 if intermediate_class:
2551 class I(C): pass
2552 else:
2553 I = C
2554
2555 with self.assertRaisesRegex(TypeError,
2556 'cannot inherit frozen dataclass from a non-frozen one'):
2557 @dataclass(frozen=True)
2558 class D(I):
2559 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002560
2561 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002562 for intermediate_class in [True, False]:
2563 with self.subTest(intermediate_class=intermediate_class):
2564 class C:
2565 pass
2566
2567 if intermediate_class:
2568 class I(C): pass
2569 else:
2570 I = C
2571
2572 @dataclass(frozen=True)
2573 class D(I):
2574 i: int
2575
2576 d = D(10)
2577 with self.assertRaises(FrozenInstanceError):
2578 d.i = 5
2579
2580 def test_non_frozen_normal_derived(self):
2581 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002582
2583 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002584 class D:
2585 x: int
2586 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002587
Eric V. Smithf199bc62018-03-18 20:40:34 -04002588 class S(D):
2589 pass
2590
2591 s = S(3)
2592 self.assertEqual(s.x, 3)
2593 self.assertEqual(s.y, 10)
2594 s.cached = True
2595
2596 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002597 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002598 s.x = 5
2599 with self.assertRaises(FrozenInstanceError):
2600 s.y = 5
2601 self.assertEqual(s.x, 3)
2602 self.assertEqual(s.y, 10)
2603 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002604
Eric V. Smith74940912018-04-05 06:50:18 -04002605 def test_overwriting_frozen(self):
2606 # frozen uses __setattr__ and __delattr__.
2607 with self.assertRaisesRegex(TypeError,
2608 'Cannot overwrite attribute __setattr__'):
2609 @dataclass(frozen=True)
2610 class C:
2611 x: int
2612 def __setattr__(self):
2613 pass
2614
2615 with self.assertRaisesRegex(TypeError,
2616 'Cannot overwrite attribute __delattr__'):
2617 @dataclass(frozen=True)
2618 class C:
2619 x: int
2620 def __delattr__(self):
2621 pass
2622
2623 @dataclass(frozen=False)
2624 class C:
2625 x: int
2626 def __setattr__(self, name, value):
2627 self.__dict__['x'] = value * 2
2628 self.assertEqual(C(10).x, 20)
2629
2630 def test_frozen_hash(self):
2631 @dataclass(frozen=True)
2632 class C:
2633 x: Any
2634
2635 # If x is immutable, we can compute the hash. No exception is
2636 # raised.
2637 hash(C(3))
2638
2639 # If x is mutable, computing the hash is an error.
2640 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2641 hash(C({}))
2642
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002643
Eric V. Smith7389fd92018-03-19 21:07:51 -04002644class TestSlots(unittest.TestCase):
2645 def test_simple(self):
2646 @dataclass
2647 class C:
2648 __slots__ = ('x',)
2649 x: Any
2650
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002651 # There was a bug where a variable in a slot was assumed to
2652 # also have a default value (of type
2653 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002654 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002655 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002656 C()
2657
2658 # We can create an instance, and assign to x.
2659 c = C(10)
2660 self.assertEqual(c.x, 10)
2661 c.x = 5
2662 self.assertEqual(c.x, 5)
2663
2664 # We can't assign to anything else.
2665 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2666 c.y = 5
2667
2668 def test_derived_added_field(self):
2669 # See bpo-33100.
2670 @dataclass
2671 class Base:
2672 __slots__ = ('x',)
2673 x: Any
2674
2675 @dataclass
2676 class Derived(Base):
2677 x: int
2678 y: int
2679
2680 d = Derived(1, 2)
2681 self.assertEqual((d.x, d.y), (1, 2))
2682
2683 # We can add a new field to the derived instance.
2684 d.z = 10
2685
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002686class TestDescriptors(unittest.TestCase):
2687 def test_set_name(self):
2688 # See bpo-33141.
2689
2690 # Create a descriptor.
2691 class D:
2692 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002693 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002694 def __get__(self, instance, owner):
2695 if instance is not None:
2696 return 1
2697 return self
2698
2699 # This is the case of just normal descriptor behavior, no
2700 # dataclass code is involved in initializing the descriptor.
2701 @dataclass
2702 class C:
2703 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002704 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002705
2706 # Now test with a default value and init=False, which is the
2707 # only time this is really meaningful. If not using
2708 # init=False, then the descriptor will be overwritten, anyway.
2709 @dataclass
2710 class C:
2711 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002712 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002713 self.assertEqual(C().c, 1)
2714
2715 def test_non_descriptor(self):
2716 # PEP 487 says __set_name__ should work on non-descriptors.
2717 # Create a descriptor.
2718
2719 class D:
2720 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002721 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002722
2723 @dataclass
2724 class C:
2725 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002726 self.assertEqual(C.c.name, 'cx')
2727
2728 def test_lookup_on_instance(self):
2729 # See bpo-33175.
2730 class D:
2731 pass
2732
2733 d = D()
2734 # Create an attribute on the instance, not type.
2735 d.__set_name__ = Mock()
2736
2737 # Make sure d.__set_name__ is not called.
2738 @dataclass
2739 class C:
2740 i: int=field(default=d, init=False)
2741
2742 self.assertEqual(d.__set_name__.call_count, 0)
2743
2744 def test_lookup_on_class(self):
2745 # See bpo-33175.
2746 class D:
2747 pass
2748 D.__set_name__ = Mock()
2749
2750 # Make sure D.__set_name__ is called.
2751 @dataclass
2752 class C:
2753 i: int=field(default=D(), init=False)
2754
2755 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002756
Eric V. Smith7389fd92018-03-19 21:07:51 -04002757
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002758class TestStringAnnotations(unittest.TestCase):
2759 def test_classvar(self):
2760 # Some expressions recognized as ClassVar really aren't. But
2761 # if you're using string annotations, it's not an exact
2762 # science.
2763 # These tests assume that both "import typing" and "from
2764 # typing import *" have been run in this file.
2765 for typestr in ('ClassVar[int]',
2766 'ClassVar [int]'
2767 ' ClassVar [int]',
2768 'ClassVar',
2769 ' ClassVar ',
2770 'typing.ClassVar[int]',
2771 'typing.ClassVar[str]',
2772 ' typing.ClassVar[str]',
2773 'typing .ClassVar[str]',
2774 'typing. ClassVar[str]',
2775 'typing.ClassVar [str]',
2776 'typing.ClassVar [ str]',
2777
2778 # Not syntactically valid, but these will
2779 # be treated as ClassVars.
2780 'typing.ClassVar.[int]',
2781 'typing.ClassVar+',
2782 ):
2783 with self.subTest(typestr=typestr):
2784 @dataclass
2785 class C:
2786 x: typestr
2787
2788 # x is a ClassVar, so C() takes no args.
2789 C()
2790
2791 # And it won't appear in the class's dict because it doesn't
2792 # have a default.
2793 self.assertNotIn('x', C.__dict__)
2794
2795 def test_isnt_classvar(self):
2796 for typestr in ('CV',
2797 't.ClassVar',
2798 't.ClassVar[int]',
2799 'typing..ClassVar[int]',
2800 'Classvar',
2801 'Classvar[int]',
2802 'typing.ClassVarx[int]',
2803 'typong.ClassVar[int]',
2804 'dataclasses.ClassVar[int]',
2805 'typingxClassVar[str]',
2806 ):
2807 with self.subTest(typestr=typestr):
2808 @dataclass
2809 class C:
2810 x: typestr
2811
2812 # x is not a ClassVar, so C() takes one arg.
2813 self.assertEqual(C(10).x, 10)
2814
2815 def test_initvar(self):
2816 # These tests assume that both "import dataclasses" and "from
2817 # dataclasses import *" have been run in this file.
2818 for typestr in ('InitVar[int]',
2819 'InitVar [int]'
2820 ' InitVar [int]',
2821 'InitVar',
2822 ' InitVar ',
2823 'dataclasses.InitVar[int]',
2824 'dataclasses.InitVar[str]',
2825 ' dataclasses.InitVar[str]',
2826 'dataclasses .InitVar[str]',
2827 'dataclasses. InitVar[str]',
2828 'dataclasses.InitVar [str]',
2829 'dataclasses.InitVar [ str]',
2830
2831 # Not syntactically valid, but these will
2832 # be treated as InitVars.
2833 'dataclasses.InitVar.[int]',
2834 'dataclasses.InitVar+',
2835 ):
2836 with self.subTest(typestr=typestr):
2837 @dataclass
2838 class C:
2839 x: typestr
2840
2841 # x is an InitVar, so doesn't create a member.
2842 with self.assertRaisesRegex(AttributeError,
2843 "object has no attribute 'x'"):
2844 C(1).x
2845
2846 def test_isnt_initvar(self):
2847 for typestr in ('IV',
2848 'dc.InitVar',
2849 'xdataclasses.xInitVar',
2850 'typing.xInitVar[int]',
2851 ):
2852 with self.subTest(typestr=typestr):
2853 @dataclass
2854 class C:
2855 x: typestr
2856
2857 # x is not an InitVar, so there will be a member x.
2858 self.assertEqual(C(10).x, 10)
2859
2860 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002861 from test import dataclass_module_1
2862 from test import dataclass_module_1_str
2863 from test import dataclass_module_2
2864 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002865
2866 for m in (dataclass_module_1, dataclass_module_1_str,
2867 dataclass_module_2, dataclass_module_2_str,
2868 ):
2869 with self.subTest(m=m):
2870 # There's a difference in how the ClassVars are
2871 # interpreted when using string annotations or
2872 # not. See the imported modules for details.
2873 if m.USING_STRINGS:
2874 c = m.CV(10)
2875 else:
2876 c = m.CV()
2877 self.assertEqual(c.cv0, 20)
2878
2879
2880 # There's a difference in how the InitVars are
2881 # interpreted when using string annotations or
2882 # not. See the imported modules for details.
2883 c = m.IV(0, 1, 2, 3, 4)
2884
2885 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2886 with self.subTest(field_name=field_name):
2887 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2888 # Since field_name is an InitVar, it's
2889 # not an instance field.
2890 getattr(c, field_name)
2891
2892 if m.USING_STRINGS:
2893 # iv4 is interpreted as a normal field.
2894 self.assertIn('not_iv4', c.__dict__)
2895 self.assertEqual(c.not_iv4, 4)
2896 else:
2897 # iv4 is interpreted as an InitVar, so it
2898 # won't exist on the instance.
2899 self.assertNotIn('not_iv4', c.__dict__)
2900
2901
Eric V. Smith4e812962018-05-16 11:31:29 -04002902class TestMakeDataclass(unittest.TestCase):
2903 def test_simple(self):
2904 C = make_dataclass('C',
2905 [('x', int),
2906 ('y', int, field(default=5))],
2907 namespace={'add_one': lambda self: self.x + 1})
2908 c = C(10)
2909 self.assertEqual((c.x, c.y), (10, 5))
2910 self.assertEqual(c.add_one(), 11)
2911
2912
2913 def test_no_mutate_namespace(self):
2914 # Make sure a provided namespace isn't mutated.
2915 ns = {}
2916 C = make_dataclass('C',
2917 [('x', int),
2918 ('y', int, field(default=5))],
2919 namespace=ns)
2920 self.assertEqual(ns, {})
2921
2922 def test_base(self):
2923 class Base1:
2924 pass
2925 class Base2:
2926 pass
2927 C = make_dataclass('C',
2928 [('x', int)],
2929 bases=(Base1, Base2))
2930 c = C(2)
2931 self.assertIsInstance(c, C)
2932 self.assertIsInstance(c, Base1)
2933 self.assertIsInstance(c, Base2)
2934
2935 def test_base_dataclass(self):
2936 @dataclass
2937 class Base1:
2938 x: int
2939 class Base2:
2940 pass
2941 C = make_dataclass('C',
2942 [('y', int)],
2943 bases=(Base1, Base2))
2944 with self.assertRaisesRegex(TypeError, 'required positional'):
2945 c = C(2)
2946 c = C(1, 2)
2947 self.assertIsInstance(c, C)
2948 self.assertIsInstance(c, Base1)
2949 self.assertIsInstance(c, Base2)
2950
2951 self.assertEqual((c.x, c.y), (1, 2))
2952
2953 def test_init_var(self):
2954 def post_init(self, y):
2955 self.x *= y
2956
2957 C = make_dataclass('C',
2958 [('x', int),
2959 ('y', InitVar[int]),
2960 ],
2961 namespace={'__post_init__': post_init},
2962 )
2963 c = C(2, 3)
2964 self.assertEqual(vars(c), {'x': 6})
2965 self.assertEqual(len(fields(c)), 1)
2966
2967 def test_class_var(self):
2968 C = make_dataclass('C',
2969 [('x', int),
2970 ('y', ClassVar[int], 10),
2971 ('z', ClassVar[int], field(default=20)),
2972 ])
2973 c = C(1)
2974 self.assertEqual(vars(c), {'x': 1})
2975 self.assertEqual(len(fields(c)), 1)
2976 self.assertEqual(C.y, 10)
2977 self.assertEqual(C.z, 20)
2978
2979 def test_other_params(self):
2980 C = make_dataclass('C',
2981 [('x', int),
2982 ('y', ClassVar[int], 10),
2983 ('z', ClassVar[int], field(default=20)),
2984 ],
2985 init=False)
2986 # Make sure we have a repr, but no init.
2987 self.assertNotIn('__init__', vars(C))
2988 self.assertIn('__repr__', vars(C))
2989
2990 # Make sure random other params don't work.
2991 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2992 C = make_dataclass('C',
2993 [],
2994 xxinit=False)
2995
2996 def test_no_types(self):
2997 C = make_dataclass('Point', ['x', 'y', 'z'])
2998 c = C(1, 2, 3)
2999 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3000 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3001 'y': 'typing.Any',
3002 'z': 'typing.Any'})
3003
3004 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3005 c = C(1, 2, 3)
3006 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3007 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3008 'y': int,
3009 'z': 'typing.Any'})
3010
3011 def test_invalid_type_specification(self):
3012 for bad_field in [(),
3013 (1, 2, 3, 4),
3014 ]:
3015 with self.subTest(bad_field=bad_field):
3016 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3017 make_dataclass('C', ['a', bad_field])
3018
3019 # And test for things with no len().
3020 for bad_field in [float,
3021 lambda x:x,
3022 ]:
3023 with self.subTest(bad_field=bad_field):
3024 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3025 make_dataclass('C', ['a', bad_field])
3026
3027 def test_duplicate_field_names(self):
3028 for field in ['a', 'ab']:
3029 with self.subTest(field=field):
3030 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3031 make_dataclass('C', [field, 'a', field])
3032
3033 def test_keyword_field_names(self):
3034 for field in ['for', 'async', 'await', 'as']:
3035 with self.subTest(field=field):
3036 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3037 make_dataclass('C', ['a', field])
3038 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3039 make_dataclass('C', [field])
3040 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3041 make_dataclass('C', [field, 'a'])
3042
3043 def test_non_identifier_field_names(self):
3044 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3045 with self.subTest(field=field):
3046 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3047 make_dataclass('C', ['a', field])
3048 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3049 make_dataclass('C', [field])
3050 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3051 make_dataclass('C', [field, 'a'])
3052
3053 def test_underscore_field_names(self):
3054 # Unlike namedtuple, it's okay if dataclass field names have
3055 # an underscore.
3056 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3057
3058 def test_funny_class_names_names(self):
3059 # No reason to prevent weird class names, since
3060 # types.new_class allows them.
3061 for classname in ['()', 'x,y', '*', '2@3', '']:
3062 with self.subTest(classname=classname):
3063 C = make_dataclass(classname, ['a', 'b'])
3064 self.assertEqual(C.__name__, classname)
3065
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003066class TestReplace(unittest.TestCase):
3067 def test(self):
3068 @dataclass(frozen=True)
3069 class C:
3070 x: int
3071 y: int
3072
3073 c = C(1, 2)
3074 c1 = replace(c, x=3)
3075 self.assertEqual(c1.x, 3)
3076 self.assertEqual(c1.y, 2)
3077
3078 def test_frozen(self):
3079 @dataclass(frozen=True)
3080 class C:
3081 x: int
3082 y: int
3083 z: int = field(init=False, default=10)
3084 t: int = field(init=False, default=100)
3085
3086 c = C(1, 2)
3087 c1 = replace(c, x=3)
3088 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3089 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3090
3091
3092 with self.assertRaisesRegex(ValueError, 'init=False'):
3093 replace(c, x=3, z=20, t=50)
3094 with self.assertRaisesRegex(ValueError, 'init=False'):
3095 replace(c, z=20)
3096 replace(c, x=3, z=20, t=50)
3097
3098 # Make sure the result is still frozen.
3099 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3100 c1.x = 3
3101
3102 # Make sure we can't replace an attribute that doesn't exist,
3103 # if we're also replacing one that does exist. Test this
3104 # here, because setting attributes on frozen instances is
3105 # handled slightly differently from non-frozen ones.
3106 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3107 "keyword argument 'a'"):
3108 c1 = replace(c, x=20, a=5)
3109
3110 def test_invalid_field_name(self):
3111 @dataclass(frozen=True)
3112 class C:
3113 x: int
3114 y: int
3115
3116 c = C(1, 2)
3117 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3118 "keyword argument 'z'"):
3119 c1 = replace(c, z=3)
3120
3121 def test_invalid_object(self):
3122 @dataclass(frozen=True)
3123 class C:
3124 x: int
3125 y: int
3126
3127 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3128 replace(C, x=3)
3129
3130 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3131 replace(0, x=3)
3132
3133 def test_no_init(self):
3134 @dataclass
3135 class C:
3136 x: int
3137 y: int = field(init=False, default=10)
3138
3139 c = C(1)
3140 c.y = 20
3141
3142 # Make sure y gets the default value.
3143 c1 = replace(c, x=5)
3144 self.assertEqual((c1.x, c1.y), (5, 10))
3145
3146 # Trying to replace y is an error.
3147 with self.assertRaisesRegex(ValueError, 'init=False'):
3148 replace(c, x=2, y=30)
3149
3150 with self.assertRaisesRegex(ValueError, 'init=False'):
3151 replace(c, y=30)
3152
3153 def test_classvar(self):
3154 @dataclass
3155 class C:
3156 x: int
3157 y: ClassVar[int] = 1000
3158
3159 c = C(1)
3160 d = C(2)
3161
3162 self.assertIs(c.y, d.y)
3163 self.assertEqual(c.y, 1000)
3164
3165 # Trying to replace y is an error: can't replace ClassVars.
3166 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3167 "unexpected keyword argument 'y'"):
3168 replace(c, y=30)
3169
3170 replace(c, x=5)
3171
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003172 def test_initvar_is_specified(self):
3173 @dataclass
3174 class C:
3175 x: int
3176 y: InitVar[int]
3177
3178 def __post_init__(self, y):
3179 self.x *= y
3180
3181 c = C(1, 10)
3182 self.assertEqual(c.x, 10)
3183 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3184 "specified with replace()"):
3185 replace(c, x=3)
3186 c = replace(c, x=3, y=5)
3187 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303188
3189 def test_recursive_repr(self):
3190 @dataclass
3191 class C:
3192 f: "C"
3193
3194 c = C(None)
3195 c.f = c
3196 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3197
3198 def test_recursive_repr_two_attrs(self):
3199 @dataclass
3200 class C:
3201 f: "C"
3202 g: "C"
3203
3204 c = C(None, None)
3205 c.f = c
3206 c.g = c
3207 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3208 ".<locals>.C(f=..., g=...)")
3209
3210 def test_recursive_repr_indirection(self):
3211 @dataclass
3212 class C:
3213 f: "D"
3214
3215 @dataclass
3216 class D:
3217 f: "C"
3218
3219 c = C(None)
3220 d = D(None)
3221 c.f = d
3222 d.f = c
3223 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3224 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3225 ".<locals>.D(f=...))")
3226
3227 def test_recursive_repr_indirection_two(self):
3228 @dataclass
3229 class C:
3230 f: "D"
3231
3232 @dataclass
3233 class D:
3234 f: "E"
3235
3236 @dataclass
3237 class E:
3238 f: "C"
3239
3240 c = C(None)
3241 d = D(None)
3242 e = E(None)
3243 c.f = d
3244 d.f = e
3245 e.f = c
3246 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3247 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3248 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3249 ".<locals>.E(f=...)))")
3250
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303251 def test_recursive_repr_misc_attrs(self):
3252 @dataclass
3253 class C:
3254 f: "C"
3255 g: int
3256
3257 c = C(None, 1)
3258 c.f = c
3259 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3260 ".<locals>.C(f=..., g=1)")
3261
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003262 ## def test_initvar(self):
3263 ## @dataclass
3264 ## class C:
3265 ## x: int
3266 ## y: InitVar[int]
3267
3268 ## c = C(1, 10)
3269 ## d = C(2, 20)
3270
3271 ## # In our case, replacing an InitVar is a no-op
3272 ## self.assertEqual(c, replace(c, y=5))
3273
3274 ## replace(c, x=5)
3275
Eric V. Smith4e812962018-05-16 11:31:29 -04003276
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003277if __name__ == '__main__':
3278 unittest.main()