blob: ff6060c6d2838acf09e3a0519104a920aadf924a [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +03009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050014from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040016import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
17import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
18
Eric V. Smithf0db54a2017-12-04 16:58:55 -050019# Just any custom exception we can catch.
20class CustomError(Exception): pass
21
22class TestCase(unittest.TestCase):
23 def test_no_fields(self):
24 @dataclass
25 class C:
26 pass
27
28 o = C()
29 self.assertEqual(len(fields(C)), 0)
30
Eric V. Smith56970b82018-03-22 16:28:48 -040031 def test_no_fields_but_member_variable(self):
32 @dataclass
33 class C:
34 i = 0
35
36 o = C()
37 self.assertEqual(len(fields(C)), 0)
38
Eric V. Smithf0db54a2017-12-04 16:58:55 -050039 def test_one_field_no_default(self):
40 @dataclass
41 class C:
42 x: int
43
44 o = C(42)
45 self.assertEqual(o.x, 42)
46
47 def test_named_init_params(self):
48 @dataclass
49 class C:
50 x: int
51
52 o = C(x=32)
53 self.assertEqual(o.x, 32)
54
55 def test_two_fields_one_default(self):
56 @dataclass
57 class C:
58 x: int
59 y: int = 0
60
61 o = C(3)
62 self.assertEqual((o.x, o.y), (3, 0))
63
64 # Non-defaults following defaults.
65 with self.assertRaisesRegex(TypeError,
66 "non-default argument 'y' follows "
67 "default argument"):
68 @dataclass
69 class C:
70 x: int = 0
71 y: int
72
73 # A derived class adds a non-default field after a default one.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int = 0
80
81 @dataclass
82 class C(B):
83 y: int
84
85 # Override a base class field and add a default to
86 # a field which didn't use to have a default.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class B:
92 x: int
93 y: int
94
95 @dataclass
96 class C(B):
97 x: int = 0
98
Eric V. Smithdbf9cff2018-02-25 21:30:17 -050099 def test_overwrite_hash(self):
100 # Test that declaring this class isn't an error. It should
101 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500102 @dataclass(frozen=True)
103 class C:
104 x: int
105 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500106 return 301
107 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500108
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500109 # Test that declaring this class isn't an error. It should
110 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500111 @dataclass(frozen=True)
112 class C:
113 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500114 def __eq__(self, other):
115 return False
116 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500117
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500118 # But this one should generate an exception, because with
119 # unsafe_hash=True, it's an error to have a __hash__ defined.
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __hash__'):
122 @dataclass(unsafe_hash=True)
123 class C:
124 def __hash__(self):
125 pass
126
127 # Creating this class should not generate an exception,
128 # because even though __hash__ exists before @dataclass is
129 # called, (due to __eq__ being defined), since it's None
130 # that's okay.
131 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500132 class C:
133 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500134 def __eq__(self):
135 pass
136 # The generated hash function works as we'd expect.
137 self.assertEqual(hash(C(10)), hash((10,)))
138
139 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400140 # __hash__ exists and is not None, which it would be if it
141 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 x: int
147 def __eq__(self):
148 pass
149 def __hash__(self):
150 pass
151
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500152 def test_overwrite_fields_in_derived_class(self):
153 # Note that x from C1 replaces x in Base, but the order remains
154 # the same as defined in Base.
155 @dataclass
156 class Base:
157 x: Any = 15.0
158 y: int = 0
159
160 @dataclass
161 class C1(Base):
162 z: int = 10
163 x: int = 15
164
165 o = Base()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
167
168 o = C1()
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
170
171 o = C1(x=5)
172 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
173
174 def test_field_named_self(self):
175 @dataclass
176 class C:
177 self: str
178 c=C('foo')
179 self.assertEqual(c.self, 'foo')
180
181 # Make sure the first parameter is not named 'self'.
182 sig = inspect.signature(C.__init__)
183 first = next(iter(sig.parameters))
184 self.assertNotEqual('self', first)
185
186 # But we do use 'self' if no field named self.
187 @dataclass
188 class C:
189 selfx: str
190
191 # Make sure the first parameter is named 'self'.
192 sig = inspect.signature(C.__init__)
193 first = next(iter(sig.parameters))
194 self.assertEqual('self', first)
195
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300196 def test_field_named_object(self):
197 @dataclass
198 class C:
199 object: str
200 c = C('foo')
201 self.assertEqual(c.object, 'foo')
202
203 def test_field_named_object_frozen(self):
204 @dataclass(frozen=True)
205 class C:
206 object: str
207 c = C('foo')
208 self.assertEqual(c.object, 'foo')
209
210 def test_field_named_like_builtin(self):
211 # Attribute names can shadow built-in names
212 # since code generation is used.
213 # Ensure that this is not happening.
214 exclusions = {'None', 'True', 'False'}
215 builtins_names = sorted(
216 b for b in builtins.__dict__.keys()
217 if not b.startswith('__') and b not in exclusions
218 )
219 attributes = [(name, str) for name in builtins_names]
220 C = make_dataclass('C', attributes)
221
222 c = C(*[name for name in builtins_names])
223
224 for name in builtins_names:
225 self.assertEqual(getattr(c, name), name)
226
227 def test_field_named_like_builtin_frozen(self):
228 # Attribute names can shadow built-in names
229 # since code generation is used.
230 # Ensure that this is not happening
231 # for frozen data classes.
232 exclusions = {'None', 'True', 'False'}
233 builtins_names = sorted(
234 b for b in builtins.__dict__.keys()
235 if not b.startswith('__') and b not in exclusions
236 )
237 attributes = [(name, str) for name in builtins_names]
238 C = make_dataclass('C', attributes, frozen=True)
239
240 c = C(*[name for name in builtins_names])
241
242 for name in builtins_names:
243 self.assertEqual(getattr(c, name), name)
244
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500245 def test_0_field_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 pass
250
251 @dataclass(order=False)
252 class C1:
253 pass
254
255 for cls in [C0, C1]:
256 with self.subTest(cls=cls):
257 self.assertEqual(cls(), cls())
258 for idx, fn in enumerate([lambda a, b: a < b,
259 lambda a, b: a <= b,
260 lambda a, b: a > b,
261 lambda a, b: a >= b]):
262 with self.subTest(idx=idx):
263 with self.assertRaisesRegex(TypeError,
264 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
265 fn(cls(), cls())
266
267 @dataclass(order=True)
268 class C:
269 pass
270 self.assertLessEqual(C(), C())
271 self.assertGreaterEqual(C(), C())
272
273 def test_1_field_compare(self):
274 # Ensure that order=False is the default.
275 @dataclass
276 class C0:
277 x: int
278
279 @dataclass(order=False)
280 class C1:
281 x: int
282
283 for cls in [C0, C1]:
284 with self.subTest(cls=cls):
285 self.assertEqual(cls(1), cls(1))
286 self.assertNotEqual(cls(0), cls(1))
287 for idx, fn in enumerate([lambda a, b: a < b,
288 lambda a, b: a <= b,
289 lambda a, b: a > b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 with self.assertRaisesRegex(TypeError,
293 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
294 fn(cls(0), cls(0))
295
296 @dataclass(order=True)
297 class C:
298 x: int
299 self.assertLess(C(0), C(1))
300 self.assertLessEqual(C(0), C(1))
301 self.assertLessEqual(C(1), C(1))
302 self.assertGreater(C(1), C(0))
303 self.assertGreaterEqual(C(1), C(0))
304 self.assertGreaterEqual(C(1), C(1))
305
306 def test_simple_compare(self):
307 # Ensure that order=False is the default.
308 @dataclass
309 class C0:
310 x: int
311 y: int
312
313 @dataclass(order=False)
314 class C1:
315 x: int
316 y: int
317
318 for cls in [C0, C1]:
319 with self.subTest(cls=cls):
320 self.assertEqual(cls(0, 0), cls(0, 0))
321 self.assertEqual(cls(1, 2), cls(1, 2))
322 self.assertNotEqual(cls(1, 0), cls(0, 0))
323 self.assertNotEqual(cls(1, 0), cls(1, 1))
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
331 fn(cls(0, 0), cls(0, 0))
332
333 @dataclass(order=True)
334 class C:
335 x: int
336 y: int
337
338 for idx, fn in enumerate([lambda a, b: a == b,
339 lambda a, b: a <= b,
340 lambda a, b: a >= b]):
341 with self.subTest(idx=idx):
342 self.assertTrue(fn(C(0, 0), C(0, 0)))
343
344 for idx, fn in enumerate([lambda a, b: a < b,
345 lambda a, b: a <= b,
346 lambda a, b: a != b]):
347 with self.subTest(idx=idx):
348 self.assertTrue(fn(C(0, 0), C(0, 1)))
349 self.assertTrue(fn(C(0, 1), C(1, 0)))
350 self.assertTrue(fn(C(1, 0), C(1, 1)))
351
352 for idx, fn in enumerate([lambda a, b: a > b,
353 lambda a, b: a >= b,
354 lambda a, b: a != b]):
355 with self.subTest(idx=idx):
356 self.assertTrue(fn(C(0, 1), C(0, 0)))
357 self.assertTrue(fn(C(1, 0), C(0, 1)))
358 self.assertTrue(fn(C(1, 1), C(1, 0)))
359
360 def test_compare_subclasses(self):
361 # Comparisons fail for subclasses, even if no fields
362 # are added.
363 @dataclass
364 class B:
365 i: int
366
367 @dataclass
368 class C(B):
369 pass
370
371 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
372 (lambda a, b: a != b, True)]):
373 with self.subTest(idx=idx):
374 self.assertEqual(fn(B(0), C(0)), expected)
375
376 for idx, fn in enumerate([lambda a, b: a < b,
377 lambda a, b: a <= b,
378 lambda a, b: a > b,
379 lambda a, b: a >= b]):
380 with self.subTest(idx=idx):
381 with self.assertRaisesRegex(TypeError,
382 "not supported between instances of 'B' and 'C'"):
383 fn(B(0), C(0))
384
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500385 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500386 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500387 for (eq, order, result ) in [
388 (False, False, 'neither'),
389 (False, True, 'exception'),
390 (True, False, 'eq_only'),
391 (True, True, 'both'),
392 ]:
393 with self.subTest(eq=eq, order=order):
394 if result == 'exception':
395 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
396 @dataclass(eq=eq, order=order)
397 class C:
398 pass
399 else:
400 @dataclass(eq=eq, order=order)
401 class C:
402 pass
403
404 if result == 'neither':
405 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 self.assertNotIn('__lt__', C.__dict__)
407 self.assertNotIn('__le__', C.__dict__)
408 self.assertNotIn('__gt__', C.__dict__)
409 self.assertNotIn('__ge__', C.__dict__)
410 elif result == 'both':
411 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500412 self.assertIn('__lt__', C.__dict__)
413 self.assertIn('__le__', C.__dict__)
414 self.assertIn('__gt__', C.__dict__)
415 self.assertIn('__ge__', C.__dict__)
416 elif result == 'eq_only':
417 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500418 self.assertNotIn('__lt__', C.__dict__)
419 self.assertNotIn('__le__', C.__dict__)
420 self.assertNotIn('__gt__', C.__dict__)
421 self.assertNotIn('__ge__', C.__dict__)
422 else:
423 assert False, f'unknown result {result!r}'
424
425 def test_field_no_default(self):
426 @dataclass
427 class C:
428 x: int = field()
429
430 self.assertEqual(C(5).x, 5)
431
432 with self.assertRaisesRegex(TypeError,
433 r"__init__\(\) missing 1 required "
434 "positional argument: 'x'"):
435 C()
436
437 def test_field_default(self):
438 default = object()
439 @dataclass
440 class C:
441 x: object = field(default=default)
442
443 self.assertIs(C.x, default)
444 c = C(10)
445 self.assertEqual(c.x, 10)
446
447 # If we delete the instance attribute, we should then see the
448 # class attribute.
449 del c.x
450 self.assertIs(c.x, default)
451
452 self.assertIs(C().x, default)
453
454 def test_not_in_repr(self):
455 @dataclass
456 class C:
457 x: int = field(repr=False)
458 with self.assertRaises(TypeError):
459 C()
460 c = C(10)
461 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
462
463 @dataclass
464 class C:
465 x: int = field(repr=False)
466 y: int
467 c = C(10, 20)
468 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
469
470 def test_not_in_compare(self):
471 @dataclass
472 class C:
473 x: int = 0
474 y: int = field(compare=False, default=4)
475
476 self.assertEqual(C(), C(0, 20))
477 self.assertEqual(C(1, 10), C(1, 20))
478 self.assertNotEqual(C(3), C(4, 10))
479 self.assertNotEqual(C(3, 10), C(4, 10))
480
481 def test_hash_field_rules(self):
482 # Test all 6 cases of:
483 # hash=True/False/None
484 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500485 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500486 (True, False, 'field' ),
487 (True, True, 'field' ),
488 (False, False, 'absent'),
489 (False, True, 'absent'),
490 (None, False, 'absent'),
491 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500492 ]:
493 with self.subTest(hash=hash_, compare=compare):
494 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500495 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500496 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500497
498 if result == 'field':
499 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500500 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500501 elif result == 'absent':
502 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500503 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500504 else:
505 assert False, f'unknown result {result!r}'
506
507 def test_init_false_no_default(self):
508 # If init=False and no default value, then the field won't be
509 # present in the instance.
510 @dataclass
511 class C:
512 x: int = field(init=False)
513
514 self.assertNotIn('x', C().__dict__)
515
516 @dataclass
517 class C:
518 x: int
519 y: int = 0
520 z: int = field(init=False)
521 t: int = 10
522
523 self.assertNotIn('z', C(0).__dict__)
524 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
525
526 def test_class_marker(self):
527 @dataclass
528 class C:
529 x: int
530 y: str = field(init=False, default=None)
531 z: str = field(repr=False)
532
533 the_fields = fields(C)
534 # the_fields is a tuple of 3 items, each value
535 # is in __annotations__.
536 self.assertIsInstance(the_fields, tuple)
537 for f in the_fields:
538 self.assertIs(type(f), Field)
539 self.assertIn(f.name, C.__annotations__)
540
541 self.assertEqual(len(the_fields), 3)
542
543 self.assertEqual(the_fields[0].name, 'x')
544 self.assertEqual(the_fields[0].type, int)
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertTrue (the_fields[0].init)
547 self.assertTrue (the_fields[0].repr)
548 self.assertEqual(the_fields[1].name, 'y')
549 self.assertEqual(the_fields[1].type, str)
550 self.assertIsNone(getattr(C, 'y'))
551 self.assertFalse(the_fields[1].init)
552 self.assertTrue (the_fields[1].repr)
553 self.assertEqual(the_fields[2].name, 'z')
554 self.assertEqual(the_fields[2].type, str)
555 self.assertFalse(hasattr(C, 'z'))
556 self.assertTrue (the_fields[2].init)
557 self.assertFalse(the_fields[2].repr)
558
559 def test_field_order(self):
560 @dataclass
561 class B:
562 a: str = 'B:a'
563 b: str = 'B:b'
564 c: str = 'B:c'
565
566 @dataclass
567 class C(B):
568 b: str = 'C:b'
569
570 self.assertEqual([(f.name, f.default) for f in fields(C)],
571 [('a', 'B:a'),
572 ('b', 'C:b'),
573 ('c', 'B:c')])
574
575 @dataclass
576 class D(B):
577 c: str = 'D:c'
578
579 self.assertEqual([(f.name, f.default) for f in fields(D)],
580 [('a', 'B:a'),
581 ('b', 'B:b'),
582 ('c', 'D:c')])
583
584 @dataclass
585 class E(D):
586 a: str = 'E:a'
587 d: str = 'E:d'
588
589 self.assertEqual([(f.name, f.default) for f in fields(E)],
590 [('a', 'E:a'),
591 ('b', 'B:b'),
592 ('c', 'D:c'),
593 ('d', 'E:d')])
594
595 def test_class_attrs(self):
596 # We only have a class attribute if a default value is
597 # specified, either directly or via a field with a default.
598 default = object()
599 @dataclass
600 class C:
601 x: int
602 y: int = field(repr=False)
603 z: object = default
604 t: int = field(default=100)
605
606 self.assertFalse(hasattr(C, 'x'))
607 self.assertFalse(hasattr(C, 'y'))
608 self.assertIs (C.z, default)
609 self.assertEqual(C.t, 100)
610
611 def test_disallowed_mutable_defaults(self):
612 # For the known types, don't allow mutable default values.
613 for typ, empty, non_empty in [(list, [], [1]),
614 (dict, {}, {0:1}),
615 (set, set(), set([1])),
616 ]:
617 with self.subTest(typ=typ):
618 # Can't use a zero-length value.
619 with self.assertRaisesRegex(ValueError,
620 f'mutable default {typ} for field '
621 'x is not allowed'):
622 @dataclass
623 class Point:
624 x: typ = empty
625
626
627 # Nor a non-zero-length value
628 with self.assertRaisesRegex(ValueError,
629 f'mutable default {typ} for field '
630 'y is not allowed'):
631 @dataclass
632 class Point:
633 y: typ = non_empty
634
635 # Check subtypes also fail.
636 class Subclass(typ): pass
637
638 with self.assertRaisesRegex(ValueError,
639 f"mutable default .*Subclass'>"
640 ' for field z is not allowed'
641 ):
642 @dataclass
643 class Point:
644 z: typ = Subclass()
645
646 # Because this is a ClassVar, it can be mutable.
647 @dataclass
648 class C:
649 z: ClassVar[typ] = typ()
650
651 # Because this is a ClassVar, it can be mutable.
652 @dataclass
653 class C:
654 x: ClassVar[typ] = Subclass()
655
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500656 def test_deliberately_mutable_defaults(self):
657 # If a mutable default isn't in the known list of
658 # (list, dict, set), then it's okay.
659 class Mutable:
660 def __init__(self):
661 self.l = []
662
663 @dataclass
664 class C:
665 x: Mutable
666
667 # These 2 instances will share this value of x.
668 lst = Mutable()
669 o1 = C(lst)
670 o2 = C(lst)
671 self.assertEqual(o1, o2)
672 o1.x.l.extend([1, 2])
673 self.assertEqual(o1, o2)
674 self.assertEqual(o1.x.l, [1, 2])
675 self.assertIs(o1.x, o2.x)
676
677 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400678 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500679 @dataclass()
680 class C:
681 x: int
682
683 self.assertEqual(C(42).x, 42)
684
685 def test_not_tuple(self):
686 # Make sure we can't be compared to a tuple.
687 @dataclass
688 class Point:
689 x: int
690 y: int
691 self.assertNotEqual(Point(1, 2), (1, 2))
692
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400693 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 @dataclass
695 class C:
696 x: int
697 y: int
698 self.assertNotEqual(Point(1, 3), C(1, 3))
699
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500700 def test_not_tuple(self):
701 # Test that some of the problems with namedtuple don't happen
702 # here.
703 @dataclass
704 class Point3D:
705 x: int
706 y: int
707 z: int
708
709 @dataclass
710 class Date:
711 year: int
712 month: int
713 day: int
714
715 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
716 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
717
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400718 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200719 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 x, y, z = Point3D(4, 5, 6)
721
Eric V. Smith7c99e932018-01-28 19:18:55 -0500722 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # equal.
724 @dataclass
725 class Point3Dv1:
726 x: int = 0
727 y: int = 0
728 z: int = 0
729 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
730
731 def test_function_annotations(self):
732 # Some dummy class and instance to use as a default.
733 class F:
734 pass
735 f = F()
736
737 def validate_class(cls):
738 # First, check __annotations__, even though they're not
739 # function annotations.
740 self.assertEqual(cls.__annotations__['i'], int)
741 self.assertEqual(cls.__annotations__['j'], str)
742 self.assertEqual(cls.__annotations__['k'], F)
743 self.assertEqual(cls.__annotations__['l'], float)
744 self.assertEqual(cls.__annotations__['z'], complex)
745
746 # Verify __init__.
747
748 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400749 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500750 self.assertIs(signature.return_annotation, None)
751
752 # Check each parameter.
753 params = iter(signature.parameters.values())
754 param = next(params)
755 # This is testing an internal name, and probably shouldn't be tested.
756 self.assertEqual(param.name, 'self')
757 param = next(params)
758 self.assertEqual(param.name, 'i')
759 self.assertIs (param.annotation, int)
760 self.assertEqual(param.default, inspect.Parameter.empty)
761 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
762 param = next(params)
763 self.assertEqual(param.name, 'j')
764 self.assertIs (param.annotation, str)
765 self.assertEqual(param.default, inspect.Parameter.empty)
766 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
767 param = next(params)
768 self.assertEqual(param.name, 'k')
769 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400770 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
772 param = next(params)
773 self.assertEqual(param.name, 'l')
774 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400775 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500776 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
777 self.assertRaises(StopIteration, next, params)
778
779
780 @dataclass
781 class C:
782 i: int
783 j: str
784 k: F = f
785 l: float=field(default=None)
786 z: complex=field(default=3+4j, init=False)
787
788 validate_class(C)
789
790 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500791 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 class C:
793 i: int
794 j: str
795 k: F = f
796 l: float=field(default=None)
797 z: complex=field(default=3+4j, init=False)
798
799 validate_class(C)
800
Eric V. Smith03220fd2017-12-29 13:59:58 -0500801 def test_missing_default(self):
802 # Test that MISSING works the same as a default not being
803 # specified.
804 @dataclass
805 class C:
806 x: int=field(default=MISSING)
807 with self.assertRaisesRegex(TypeError,
808 r'__init__\(\) missing 1 required '
809 'positional argument'):
810 C()
811 self.assertNotIn('x', C.__dict__)
812
813 @dataclass
814 class D:
815 x: int
816 with self.assertRaisesRegex(TypeError,
817 r'__init__\(\) missing 1 required '
818 'positional argument'):
819 D()
820 self.assertNotIn('x', D.__dict__)
821
822 def test_missing_default_factory(self):
823 # Test that MISSING works the same as a default factory not
824 # being specified (which is really the same as a default not
825 # being specified, too).
826 @dataclass
827 class C:
828 x: int=field(default_factory=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int=field(default=MISSING, default_factory=MISSING)
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_repr(self):
845 self.assertIn('MISSING_TYPE object', repr(MISSING))
846
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 def test_dont_include_other_annotations(self):
848 @dataclass
849 class C:
850 i: int
851 def foo(self) -> int:
852 return 4
853 @property
854 def bar(self) -> int:
855 return 5
856 self.assertEqual(list(C.__annotations__), ['i'])
857 self.assertEqual(C(10).foo(), 4)
858 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400859 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500860
861 def test_post_init(self):
862 # Just make sure it gets called
863 @dataclass
864 class C:
865 def __post_init__(self):
866 raise CustomError()
867 with self.assertRaises(CustomError):
868 C()
869
870 @dataclass
871 class C:
872 i: int = 10
873 def __post_init__(self):
874 if self.i == 10:
875 raise CustomError()
876 with self.assertRaises(CustomError):
877 C()
878 # post-init gets called, but doesn't raise. This is just
879 # checking that self is used correctly.
880 C(5)
881
882 # If there's not an __init__, then post-init won't get called.
883 @dataclass(init=False)
884 class C:
885 def __post_init__(self):
886 raise CustomError()
887 # Creating the class won't raise
888 C()
889
890 @dataclass
891 class C:
892 x: int = 0
893 def __post_init__(self):
894 self.x *= 2
895 self.assertEqual(C().x, 0)
896 self.assertEqual(C(2).x, 4)
897
Mike53f7a7c2017-12-14 14:04:53 +0300898 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500899 # attributes.
900 @dataclass(frozen=True)
901 class C:
902 x: int = 0
903 def __post_init__(self):
904 self.x *= 2
905 with self.assertRaises(FrozenInstanceError):
906 C()
907
908 def test_post_init_super(self):
909 # Make sure super() post-init isn't called by default.
910 class B:
911 def __post_init__(self):
912 raise CustomError()
913
914 @dataclass
915 class C(B):
916 def __post_init__(self):
917 self.x = 5
918
919 self.assertEqual(C().x, 5)
920
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400921 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500922 @dataclass
923 class C(B):
924 def __post_init__(self):
925 super().__post_init__()
926
927 with self.assertRaises(CustomError):
928 C()
929
930 # Make sure post-init is called, even if not defined in our
931 # class.
932 @dataclass
933 class C(B):
934 pass
935
936 with self.assertRaises(CustomError):
937 C()
938
939 def test_post_init_staticmethod(self):
940 flag = False
941 @dataclass
942 class C:
943 x: int
944 y: int
945 @staticmethod
946 def __post_init__():
947 nonlocal flag
948 flag = True
949
950 self.assertFalse(flag)
951 c = C(3, 4)
952 self.assertEqual((c.x, c.y), (3, 4))
953 self.assertTrue(flag)
954
955 def test_post_init_classmethod(self):
956 @dataclass
957 class C:
958 flag = False
959 x: int
960 y: int
961 @classmethod
962 def __post_init__(cls):
963 cls.flag = True
964
965 self.assertFalse(C.flag)
966 c = C(3, 4)
967 self.assertEqual((c.x, c.y), (3, 4))
968 self.assertTrue(C.flag)
969
970 def test_class_var(self):
971 # Make sure ClassVars are ignored in __init__, __repr__, etc.
972 @dataclass
973 class C:
974 x: int
975 y: int = 10
976 z: ClassVar[int] = 1000
977 w: ClassVar[int] = 2000
978 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400979 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500980
981 c = C(5)
982 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400983 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400984 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500985 self.assertEqual(c.z, 1000)
986 self.assertEqual(c.w, 2000)
987 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400988 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500989 C.z += 1
990 self.assertEqual(c.z, 1001)
991 c = C(20)
992 self.assertEqual((c.x, c.y), (20, 10))
993 self.assertEqual(c.z, 1001)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400996 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500997
998 def test_class_var_no_default(self):
999 # If a ClassVar has no default value, it should not be set on the class.
1000 @dataclass
1001 class C:
1002 x: ClassVar[int]
1003
1004 self.assertNotIn('x', C.__dict__)
1005
1006 def test_class_var_default_factory(self):
1007 # It makes no sense for a ClassVar to have a default factory. When
1008 # would it be called? Call it yourself, since it's class-wide.
1009 with self.assertRaisesRegex(TypeError,
1010 'cannot have a default factory'):
1011 @dataclass
1012 class C:
1013 x: ClassVar[int] = field(default_factory=int)
1014
1015 self.assertNotIn('x', C.__dict__)
1016
1017 def test_class_var_with_default(self):
1018 # If a ClassVar has a default value, it should be set on the class.
1019 @dataclass
1020 class C:
1021 x: ClassVar[int] = 10
1022 self.assertEqual(C.x, 10)
1023
1024 @dataclass
1025 class C:
1026 x: ClassVar[int] = field(default=10)
1027 self.assertEqual(C.x, 10)
1028
1029 def test_class_var_frozen(self):
1030 # Make sure ClassVars work even if we're frozen.
1031 @dataclass(frozen=True)
1032 class C:
1033 x: int
1034 y: int = 10
1035 z: ClassVar[int] = 1000
1036 w: ClassVar[int] = 2000
1037 t: ClassVar[int] = 3000
1038
1039 c = C(5)
1040 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1041 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1042 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1043 self.assertEqual(c.z, 1000)
1044 self.assertEqual(c.w, 2000)
1045 self.assertEqual(c.t, 3000)
1046 # We can still modify the ClassVar, it's only instances that are
1047 # frozen.
1048 C.z += 1
1049 self.assertEqual(c.z, 1001)
1050 c = C(20)
1051 self.assertEqual((c.x, c.y), (20, 10))
1052 self.assertEqual(c.z, 1001)
1053 self.assertEqual(c.w, 2000)
1054 self.assertEqual(c.t, 3000)
1055
1056 def test_init_var_no_default(self):
1057 # If an InitVar has no default value, it should not be set on the class.
1058 @dataclass
1059 class C:
1060 x: InitVar[int]
1061
1062 self.assertNotIn('x', C.__dict__)
1063
1064 def test_init_var_default_factory(self):
1065 # It makes no sense for an InitVar to have a default factory. When
1066 # would it be called? Call it yourself, since it's class-wide.
1067 with self.assertRaisesRegex(TypeError,
1068 'cannot have a default factory'):
1069 @dataclass
1070 class C:
1071 x: InitVar[int] = field(default_factory=int)
1072
1073 self.assertNotIn('x', C.__dict__)
1074
1075 def test_init_var_with_default(self):
1076 # If an InitVar has a default value, it should be set on the class.
1077 @dataclass
1078 class C:
1079 x: InitVar[int] = 10
1080 self.assertEqual(C.x, 10)
1081
1082 @dataclass
1083 class C:
1084 x: InitVar[int] = field(default=10)
1085 self.assertEqual(C.x, 10)
1086
1087 def test_init_var(self):
1088 @dataclass
1089 class C:
1090 x: int = None
1091 init_param: InitVar[int] = None
1092
1093 def __post_init__(self, init_param):
1094 if self.x is None:
1095 self.x = init_param*2
1096
1097 c = C(init_param=10)
1098 self.assertEqual(c.x, 20)
1099
1100 def test_init_var_inheritance(self):
1101 # Note that this deliberately tests that a dataclass need not
1102 # have a __post_init__ function if it has an InitVar field.
1103 # It could just be used in a derived class, as shown here.
1104 @dataclass
1105 class Base:
1106 x: int
1107 init_base: InitVar[int]
1108
1109 # We can instantiate by passing the InitVar, even though
1110 # it's not used.
1111 b = Base(0, 10)
1112 self.assertEqual(vars(b), {'x': 0})
1113
1114 @dataclass
1115 class C(Base):
1116 y: int
1117 init_derived: InitVar[int]
1118
1119 def __post_init__(self, init_base, init_derived):
1120 self.x = self.x + init_base
1121 self.y = self.y + init_derived
1122
1123 c = C(10, 11, 50, 51)
1124 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1125
1126 def test_default_factory(self):
1127 # Test a factory that returns a new list.
1128 @dataclass
1129 class C:
1130 x: int
1131 y: list = field(default_factory=list)
1132
1133 c0 = C(3)
1134 c1 = C(3)
1135 self.assertEqual(c0.x, 3)
1136 self.assertEqual(c0.y, [])
1137 self.assertEqual(c0, c1)
1138 self.assertIsNot(c0.y, c1.y)
1139 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1140
1141 # Test a factory that returns a shared list.
1142 l = []
1143 @dataclass
1144 class C:
1145 x: int
1146 y: list = field(default_factory=lambda: l)
1147
1148 c0 = C(3)
1149 c1 = C(3)
1150 self.assertEqual(c0.x, 3)
1151 self.assertEqual(c0.y, [])
1152 self.assertEqual(c0, c1)
1153 self.assertIs(c0.y, c1.y)
1154 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1155
1156 # Test various other field flags.
1157 # repr
1158 @dataclass
1159 class C:
1160 x: list = field(default_factory=list, repr=False)
1161 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1162 self.assertEqual(C().x, [])
1163
1164 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001165 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001166 class C:
1167 x: list = field(default_factory=list, hash=False)
1168 self.assertEqual(astuple(C()), ([],))
1169 self.assertEqual(hash(C()), hash(()))
1170
1171 # init (see also test_default_factory_with_no_init)
1172 @dataclass
1173 class C:
1174 x: list = field(default_factory=list, init=False)
1175 self.assertEqual(astuple(C()), ([],))
1176
1177 # compare
1178 @dataclass
1179 class C:
1180 x: list = field(default_factory=list, compare=False)
1181 self.assertEqual(C(), C([1]))
1182
1183 def test_default_factory_with_no_init(self):
1184 # We need a factory with a side effect.
1185 factory = Mock()
1186
1187 @dataclass
1188 class C:
1189 x: list = field(default_factory=factory, init=False)
1190
1191 # Make sure the default factory is called for each new instance.
1192 C().x
1193 self.assertEqual(factory.call_count, 1)
1194 C().x
1195 self.assertEqual(factory.call_count, 2)
1196
1197 def test_default_factory_not_called_if_value_given(self):
1198 # We need a factory that we can test if it's been called.
1199 factory = Mock()
1200
1201 @dataclass
1202 class C:
1203 x: int = field(default_factory=factory)
1204
1205 # Make sure that if a field has a default factory function,
1206 # it's not called if a value is specified.
1207 C().x
1208 self.assertEqual(factory.call_count, 1)
1209 self.assertEqual(C(10).x, 10)
1210 self.assertEqual(factory.call_count, 1)
1211 C().x
1212 self.assertEqual(factory.call_count, 2)
1213
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001214 def test_default_factory_derived(self):
1215 # See bpo-32896.
1216 @dataclass
1217 class Foo:
1218 x: dict = field(default_factory=dict)
1219
1220 @dataclass
1221 class Bar(Foo):
1222 y: int = 1
1223
1224 self.assertEqual(Foo().x, {})
1225 self.assertEqual(Bar().x, {})
1226 self.assertEqual(Bar().y, 1)
1227
1228 @dataclass
1229 class Baz(Foo):
1230 pass
1231 self.assertEqual(Baz().x, {})
1232
1233 def test_intermediate_non_dataclass(self):
1234 # Test that an intermediate class that defines
1235 # annotations does not define fields.
1236
1237 @dataclass
1238 class A:
1239 x: int
1240
1241 class B(A):
1242 y: int
1243
1244 @dataclass
1245 class C(B):
1246 z: int
1247
1248 c = C(1, 3)
1249 self.assertEqual((c.x, c.z), (1, 3))
1250
1251 # .y was not initialized.
1252 with self.assertRaisesRegex(AttributeError,
1253 'object has no attribute'):
1254 c.y
1255
1256 # And if we again derive a non-dataclass, no fields are added.
1257 class D(C):
1258 t: int
1259 d = D(4, 5)
1260 self.assertEqual((d.x, d.z), (4, 5))
1261
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001262 def test_classvar_default_factory(self):
1263 # It's an error for a ClassVar to have a factory function.
1264 with self.assertRaisesRegex(TypeError,
1265 'cannot have a default factory'):
1266 @dataclass
1267 class C:
1268 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001269
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001270 def test_is_dataclass(self):
1271 class NotDataClass:
1272 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001273
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001274 self.assertFalse(is_dataclass(0))
1275 self.assertFalse(is_dataclass(int))
1276 self.assertFalse(is_dataclass(NotDataClass))
1277 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001278
1279 @dataclass
1280 class C:
1281 x: int
1282
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001283 @dataclass
1284 class D:
1285 d: C
1286 e: int
1287
1288 c = C(10)
1289 d = D(c, 4)
1290
1291 self.assertTrue(is_dataclass(C))
1292 self.assertTrue(is_dataclass(c))
1293 self.assertFalse(is_dataclass(c.x))
1294 self.assertTrue(is_dataclass(d.d))
1295 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001296
1297 def test_helper_fields_with_class_instance(self):
1298 # Check that we can call fields() on either a class or instance,
1299 # and get back the same thing.
1300 @dataclass
1301 class C:
1302 x: int
1303 y: float
1304
1305 self.assertEqual(fields(C), fields(C(0, 0.0)))
1306
1307 def test_helper_fields_exception(self):
1308 # Check that TypeError is raised if not passed a dataclass or
1309 # instance.
1310 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1311 fields(0)
1312
1313 class C: pass
1314 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1315 fields(C)
1316 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1317 fields(C())
1318
1319 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001320 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001321 @dataclass
1322 class C:
1323 x: int
1324 y: int
1325 c = C(1, 2)
1326
1327 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1328 self.assertEqual(asdict(c), asdict(c))
1329 self.assertIsNot(asdict(c), asdict(c))
1330 c.x = 42
1331 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1332 self.assertIs(type(asdict(c)), dict)
1333
1334 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001335 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001336 @dataclass
1337 class C:
1338 x: int
1339 y: int
1340 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1341 asdict(C)
1342 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1343 asdict(int)
1344
1345 def test_helper_asdict_copy_values(self):
1346 @dataclass
1347 class C:
1348 x: int
1349 y: List[int] = field(default_factory=list)
1350 initial = []
1351 c = C(1, initial)
1352 d = asdict(c)
1353 self.assertEqual(d['y'], initial)
1354 self.assertIsNot(d['y'], initial)
1355 c = C(1)
1356 d = asdict(c)
1357 d['y'].append(1)
1358 self.assertEqual(c.y, [])
1359
1360 def test_helper_asdict_nested(self):
1361 @dataclass
1362 class UserId:
1363 token: int
1364 group: int
1365 @dataclass
1366 class User:
1367 name: str
1368 id: UserId
1369 u = User('Joe', UserId(123, 1))
1370 d = asdict(u)
1371 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1372 self.assertIsNot(asdict(u), asdict(u))
1373 u.id.group = 2
1374 self.assertEqual(asdict(u), {'name': 'Joe',
1375 'id': {'token': 123, 'group': 2}})
1376
1377 def test_helper_asdict_builtin_containers(self):
1378 @dataclass
1379 class User:
1380 name: str
1381 id: int
1382 @dataclass
1383 class GroupList:
1384 id: int
1385 users: List[User]
1386 @dataclass
1387 class GroupTuple:
1388 id: int
1389 users: Tuple[User, ...]
1390 @dataclass
1391 class GroupDict:
1392 id: int
1393 users: Dict[str, User]
1394 a = User('Alice', 1)
1395 b = User('Bob', 2)
1396 gl = GroupList(0, [a, b])
1397 gt = GroupTuple(0, (a, b))
1398 gd = GroupDict(0, {'first': a, 'second': b})
1399 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1400 {'name': 'Bob', 'id': 2}]})
1401 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1402 {'name': 'Bob', 'id': 2})})
1403 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1404 'second': {'name': 'Bob', 'id': 2}}})
1405
1406 def test_helper_asdict_builtin_containers(self):
1407 @dataclass
1408 class Child:
1409 d: object
1410
1411 @dataclass
1412 class Parent:
1413 child: Child
1414
1415 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1416 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1417
1418 def test_helper_asdict_factory(self):
1419 @dataclass
1420 class C:
1421 x: int
1422 y: int
1423 c = C(1, 2)
1424 d = asdict(c, dict_factory=OrderedDict)
1425 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1426 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1427 c.x = 42
1428 d = asdict(c, dict_factory=OrderedDict)
1429 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1430 self.assertIs(type(d), OrderedDict)
1431
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001432 def test_helper_asdict_namedtuple(self):
1433 T = namedtuple('T', 'a b c')
1434 @dataclass
1435 class C:
1436 x: str
1437 y: T
1438 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1439
1440 d = asdict(c)
1441 self.assertEqual(d, {'x': 'outer',
1442 'y': T(1,
1443 {'x': 'inner',
1444 'y': T(11, 12, 13)},
1445 2),
1446 }
1447 )
1448
1449 # Now with a dict_factory. OrderedDict is convenient, but
1450 # since it compares to dicts, we also need to have separate
1451 # assertIs tests.
1452 d = asdict(c, dict_factory=OrderedDict)
1453 self.assertEqual(d, {'x': 'outer',
1454 'y': T(1,
1455 {'x': 'inner',
1456 'y': T(11, 12, 13)},
1457 2),
1458 }
1459 )
1460
1461 # Make sure that the returned dicts are actuall OrderedDicts.
1462 self.assertIs(type(d), OrderedDict)
1463 self.assertIs(type(d['y'][1]), OrderedDict)
1464
1465 def test_helper_asdict_namedtuple_key(self):
1466 # Ensure that a field that contains a dict which has a
1467 # namedtuple as a key works with asdict().
1468
1469 @dataclass
1470 class C:
1471 f: dict
1472 T = namedtuple('T', 'a')
1473
1474 c = C({T('an a'): 0})
1475
1476 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1477
1478 def test_helper_asdict_namedtuple_derived(self):
1479 class T(namedtuple('Tbase', 'a')):
1480 def my_a(self):
1481 return self.a
1482
1483 @dataclass
1484 class C:
1485 f: T
1486
1487 t = T(6)
1488 c = C(t)
1489
1490 d = asdict(c)
1491 self.assertEqual(d, {'f': T(a=6)})
1492 # Make sure that t has been copied, not used directly.
1493 self.assertIsNot(d['f'], t)
1494 self.assertEqual(d['f'].my_a(), 6)
1495
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001496 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001497 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001498 @dataclass
1499 class C:
1500 x: int
1501 y: int = 0
1502 c = C(1)
1503
1504 self.assertEqual(astuple(c), (1, 0))
1505 self.assertEqual(astuple(c), astuple(c))
1506 self.assertIsNot(astuple(c), astuple(c))
1507 c.y = 42
1508 self.assertEqual(astuple(c), (1, 42))
1509 self.assertIs(type(astuple(c)), tuple)
1510
1511 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001512 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001513 @dataclass
1514 class C:
1515 x: int
1516 y: int
1517 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1518 astuple(C)
1519 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1520 astuple(int)
1521
1522 def test_helper_astuple_copy_values(self):
1523 @dataclass
1524 class C:
1525 x: int
1526 y: List[int] = field(default_factory=list)
1527 initial = []
1528 c = C(1, initial)
1529 t = astuple(c)
1530 self.assertEqual(t[1], initial)
1531 self.assertIsNot(t[1], initial)
1532 c = C(1)
1533 t = astuple(c)
1534 t[1].append(1)
1535 self.assertEqual(c.y, [])
1536
1537 def test_helper_astuple_nested(self):
1538 @dataclass
1539 class UserId:
1540 token: int
1541 group: int
1542 @dataclass
1543 class User:
1544 name: str
1545 id: UserId
1546 u = User('Joe', UserId(123, 1))
1547 t = astuple(u)
1548 self.assertEqual(t, ('Joe', (123, 1)))
1549 self.assertIsNot(astuple(u), astuple(u))
1550 u.id.group = 2
1551 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1552
1553 def test_helper_astuple_builtin_containers(self):
1554 @dataclass
1555 class User:
1556 name: str
1557 id: int
1558 @dataclass
1559 class GroupList:
1560 id: int
1561 users: List[User]
1562 @dataclass
1563 class GroupTuple:
1564 id: int
1565 users: Tuple[User, ...]
1566 @dataclass
1567 class GroupDict:
1568 id: int
1569 users: Dict[str, User]
1570 a = User('Alice', 1)
1571 b = User('Bob', 2)
1572 gl = GroupList(0, [a, b])
1573 gt = GroupTuple(0, (a, b))
1574 gd = GroupDict(0, {'first': a, 'second': b})
1575 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1576 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1577 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1578
1579 def test_helper_astuple_builtin_containers(self):
1580 @dataclass
1581 class Child:
1582 d: object
1583
1584 @dataclass
1585 class Parent:
1586 child: Child
1587
1588 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1589 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1590
1591 def test_helper_astuple_factory(self):
1592 @dataclass
1593 class C:
1594 x: int
1595 y: int
1596 NT = namedtuple('NT', 'x y')
1597 def nt(lst):
1598 return NT(*lst)
1599 c = C(1, 2)
1600 t = astuple(c, tuple_factory=nt)
1601 self.assertEqual(t, NT(1, 2))
1602 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1603 c.x = 42
1604 t = astuple(c, tuple_factory=nt)
1605 self.assertEqual(t, NT(42, 2))
1606 self.assertIs(type(t), NT)
1607
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001608 def test_helper_astuple_namedtuple(self):
1609 T = namedtuple('T', 'a b c')
1610 @dataclass
1611 class C:
1612 x: str
1613 y: T
1614 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1615
1616 t = astuple(c)
1617 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1618
1619 # Now, using a tuple_factory. list is convenient here.
1620 t = astuple(c, tuple_factory=list)
1621 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1622
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001623 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001624 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001625 }
1626
1627 # Create the class.
1628 cls = type('C', (), cls_dict)
1629
1630 # Make it a dataclass.
1631 cls1 = dataclass(cls)
1632
1633 self.assertEqual(cls1, cls)
1634 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1635
1636 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001637 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001638 'y': field(default=5),
1639 }
1640
1641 # Create the class.
1642 cls = type('C', (), cls_dict)
1643
1644 # Make it a dataclass.
1645 cls1 = dataclass(cls)
1646
1647 self.assertEqual(cls1, cls)
1648 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1649
1650 def test_init_in_order(self):
1651 @dataclass
1652 class C:
1653 a: int
1654 b: int = field()
1655 c: list = field(default_factory=list, init=False)
1656 d: list = field(default_factory=list)
1657 e: int = field(default=4, init=False)
1658 f: int = 4
1659
1660 calls = []
1661 def setattr(self, name, value):
1662 calls.append((name, value))
1663
1664 C.__setattr__ = setattr
1665 c = C(0, 1)
1666 self.assertEqual(('a', 0), calls[0])
1667 self.assertEqual(('b', 1), calls[1])
1668 self.assertEqual(('c', []), calls[2])
1669 self.assertEqual(('d', []), calls[3])
1670 self.assertNotIn(('e', 4), calls)
1671 self.assertEqual(('f', 4), calls[4])
1672
1673 def test_items_in_dicts(self):
1674 @dataclass
1675 class C:
1676 a: int
1677 b: list = field(default_factory=list, init=False)
1678 c: list = field(default_factory=list)
1679 d: int = field(default=4, init=False)
1680 e: int = 0
1681
1682 c = C(0)
1683 # Class dict
1684 self.assertNotIn('a', C.__dict__)
1685 self.assertNotIn('b', C.__dict__)
1686 self.assertNotIn('c', C.__dict__)
1687 self.assertIn('d', C.__dict__)
1688 self.assertEqual(C.d, 4)
1689 self.assertIn('e', C.__dict__)
1690 self.assertEqual(C.e, 0)
1691 # Instance dict
1692 self.assertIn('a', c.__dict__)
1693 self.assertEqual(c.a, 0)
1694 self.assertIn('b', c.__dict__)
1695 self.assertEqual(c.b, [])
1696 self.assertIn('c', c.__dict__)
1697 self.assertEqual(c.c, [])
1698 self.assertNotIn('d', c.__dict__)
1699 self.assertIn('e', c.__dict__)
1700 self.assertEqual(c.e, 0)
1701
1702 def test_alternate_classmethod_constructor(self):
1703 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001704 # alternate constructor. This is mostly an example to show
1705 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001706 @dataclass
1707 class C:
1708 x: int
1709 @classmethod
1710 def from_file(cls, filename):
1711 # In a real example, create a new instance
1712 # and populate 'x' from contents of a file.
1713 value_in_file = 20
1714 return cls(value_in_file)
1715
1716 self.assertEqual(C.from_file('filename').x, 20)
1717
1718 def test_field_metadata_default(self):
1719 # Make sure the default metadata is read-only and of
1720 # zero length.
1721 @dataclass
1722 class C:
1723 i: int
1724
1725 self.assertFalse(fields(C)[0].metadata)
1726 self.assertEqual(len(fields(C)[0].metadata), 0)
1727 with self.assertRaisesRegex(TypeError,
1728 'does not support item assignment'):
1729 fields(C)[0].metadata['test'] = 3
1730
1731 def test_field_metadata_mapping(self):
1732 # Make sure only a mapping can be passed as metadata
1733 # zero length.
1734 with self.assertRaises(TypeError):
1735 @dataclass
1736 class C:
1737 i: int = field(metadata=0)
1738
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001739 # Make sure an empty dict works.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001740 @dataclass
1741 class C:
1742 i: int = field(metadata={})
1743 self.assertFalse(fields(C)[0].metadata)
1744 self.assertEqual(len(fields(C)[0].metadata), 0)
1745 with self.assertRaisesRegex(TypeError,
1746 'does not support item assignment'):
1747 fields(C)[0].metadata['test'] = 3
1748
1749 # Make sure a non-empty dict works.
1750 @dataclass
1751 class C:
1752 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1753 self.assertEqual(len(fields(C)[0].metadata), 3)
1754 self.assertEqual(fields(C)[0].metadata['test'], 10)
1755 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1756 self.assertEqual(fields(C)[0].metadata[3], 'three')
1757 with self.assertRaises(KeyError):
1758 # Non-existent key.
1759 fields(C)[0].metadata['baz']
1760 with self.assertRaisesRegex(TypeError,
1761 'does not support item assignment'):
1762 fields(C)[0].metadata['test'] = 3
1763
1764 def test_field_metadata_custom_mapping(self):
1765 # Try a custom mapping.
1766 class SimpleNameSpace:
1767 def __init__(self, **kw):
1768 self.__dict__.update(kw)
1769
1770 def __getitem__(self, item):
1771 if item == 'xyzzy':
1772 return 'plugh'
1773 return getattr(self, item)
1774
1775 def __len__(self):
1776 return self.__dict__.__len__()
1777
1778 @dataclass
1779 class C:
1780 i: int = field(metadata=SimpleNameSpace(a=10))
1781
1782 self.assertEqual(len(fields(C)[0].metadata), 1)
1783 self.assertEqual(fields(C)[0].metadata['a'], 10)
1784 with self.assertRaises(AttributeError):
1785 fields(C)[0].metadata['b']
1786 # Make sure we're still talking to our custom mapping.
1787 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1788
1789 def test_generic_dataclasses(self):
1790 T = TypeVar('T')
1791
1792 @dataclass
1793 class LabeledBox(Generic[T]):
1794 content: T
1795 label: str = '<unknown>'
1796
1797 box = LabeledBox(42)
1798 self.assertEqual(box.content, 42)
1799 self.assertEqual(box.label, '<unknown>')
1800
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001801 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001802 Alias = List[LabeledBox[int]]
1803
1804 def test_generic_extending(self):
1805 S = TypeVar('S')
1806 T = TypeVar('T')
1807
1808 @dataclass
1809 class Base(Generic[T, S]):
1810 x: T
1811 y: S
1812
1813 @dataclass
1814 class DataDerived(Base[int, T]):
1815 new_field: str
1816 Alias = DataDerived[str]
1817 c = Alias(0, 'test1', 'test2')
1818 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1819
1820 class NonDataDerived(Base[int, T]):
1821 def new_method(self):
1822 return self.y
1823 Alias = NonDataDerived[float]
1824 c = Alias(10, 1.0)
1825 self.assertEqual(c.new_method(), 1.0)
1826
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001827 def test_generic_dynamic(self):
1828 T = TypeVar('T')
1829
1830 @dataclass
1831 class Parent(Generic[T]):
1832 x: T
1833 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1834 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1835 self.assertIs(Child[int](1, 2).z, None)
1836 self.assertEqual(Child[int](1, 2, 3).z, 3)
1837 self.assertEqual(Child[int](1, 2, 3).other, 42)
1838 # Check that type aliases work correctly.
1839 Alias = Child[T]
1840 self.assertEqual(Alias[int](1, 2).x, 1)
1841 # Check MRO resolution.
1842 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1843
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001844 def test_dataclassses_pickleable(self):
1845 global P, Q, R
1846 @dataclass
1847 class P:
1848 x: int
1849 y: int = 0
1850 @dataclass
1851 class Q:
1852 x: int
1853 y: int = field(default=0, init=False)
1854 @dataclass
1855 class R:
1856 x: int
1857 y: List[int] = field(default_factory=list)
1858 q = Q(1)
1859 q.y = 2
1860 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1861 for sample in samples:
1862 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1863 with self.subTest(sample=sample, proto=proto):
1864 new_sample = pickle.loads(pickle.dumps(sample, proto))
1865 self.assertEqual(sample.x, new_sample.x)
1866 self.assertEqual(sample.y, new_sample.y)
1867 self.assertIsNot(sample, new_sample)
1868 new_sample.x = 42
1869 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1870 self.assertEqual(new_sample.x, another_new_sample.x)
1871 self.assertEqual(sample.y, another_new_sample.y)
1872
Eric V. Smithea8fc522018-01-27 19:07:40 -05001873
Eric V. Smith56970b82018-03-22 16:28:48 -04001874class TestFieldNoAnnotation(unittest.TestCase):
1875 def test_field_without_annotation(self):
1876 with self.assertRaisesRegex(TypeError,
1877 "'f' is a field but has no type annotation"):
1878 @dataclass
1879 class C:
1880 f = field()
1881
1882 def test_field_without_annotation_but_annotation_in_base(self):
1883 @dataclass
1884 class B:
1885 f: int
1886
1887 with self.assertRaisesRegex(TypeError,
1888 "'f' is a field but has no type annotation"):
1889 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001890 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001891 @dataclass
1892 class C(B):
1893 f = field()
1894
1895 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1896 # Same test, but with the base class not a dataclass.
1897 class B:
1898 f: int
1899
1900 with self.assertRaisesRegex(TypeError,
1901 "'f' is a field but has no type annotation"):
1902 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001903 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001904 @dataclass
1905 class C(B):
1906 f = field()
1907
1908
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001909class TestDocString(unittest.TestCase):
1910 def assertDocStrEqual(self, a, b):
1911 # Because 3.6 and 3.7 differ in how inspect.signature work
1912 # (see bpo #32108), for the time being just compare them with
1913 # whitespace stripped.
1914 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1915
1916 def test_existing_docstring_not_overridden(self):
1917 @dataclass
1918 class C:
1919 """Lorem ipsum"""
1920 x: int
1921
1922 self.assertEqual(C.__doc__, "Lorem ipsum")
1923
1924 def test_docstring_no_fields(self):
1925 @dataclass
1926 class C:
1927 pass
1928
1929 self.assertDocStrEqual(C.__doc__, "C()")
1930
1931 def test_docstring_one_field(self):
1932 @dataclass
1933 class C:
1934 x: int
1935
1936 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1937
1938 def test_docstring_two_fields(self):
1939 @dataclass
1940 class C:
1941 x: int
1942 y: int
1943
1944 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1945
1946 def test_docstring_three_fields(self):
1947 @dataclass
1948 class C:
1949 x: int
1950 y: int
1951 z: str
1952
1953 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1954
1955 def test_docstring_one_field_with_default(self):
1956 @dataclass
1957 class C:
1958 x: int = 3
1959
1960 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1961
1962 def test_docstring_one_field_with_default_none(self):
1963 @dataclass
1964 class C:
1965 x: Union[int, type(None)] = None
1966
1967 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1968
1969 def test_docstring_list_field(self):
1970 @dataclass
1971 class C:
1972 x: List[int]
1973
1974 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1975
1976 def test_docstring_list_field_with_default_factory(self):
1977 @dataclass
1978 class C:
1979 x: List[int] = field(default_factory=list)
1980
1981 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1982
1983 def test_docstring_deque_field(self):
1984 @dataclass
1985 class C:
1986 x: deque
1987
1988 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1989
1990 def test_docstring_deque_field_with_default_factory(self):
1991 @dataclass
1992 class C:
1993 x: deque = field(default_factory=deque)
1994
1995 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1996
1997
Eric V. Smithea8fc522018-01-27 19:07:40 -05001998class TestInit(unittest.TestCase):
1999 def test_base_has_init(self):
2000 class B:
2001 def __init__(self):
2002 self.z = 100
2003 pass
2004
2005 # Make sure that declaring this class doesn't raise an error.
2006 # The issue is that we can't override __init__ in our class,
2007 # but it should be okay to add __init__ to us if our base has
2008 # an __init__.
2009 @dataclass
2010 class C(B):
2011 x: int = 0
2012 c = C(10)
2013 self.assertEqual(c.x, 10)
2014 self.assertNotIn('z', vars(c))
2015
2016 # Make sure that if we don't add an init, the base __init__
2017 # gets called.
2018 @dataclass(init=False)
2019 class C(B):
2020 x: int = 10
2021 c = C()
2022 self.assertEqual(c.x, 10)
2023 self.assertEqual(c.z, 100)
2024
2025 def test_no_init(self):
2026 dataclass(init=False)
2027 class C:
2028 i: int = 0
2029 self.assertEqual(C().i, 0)
2030
2031 dataclass(init=False)
2032 class C:
2033 i: int = 2
2034 def __init__(self):
2035 self.i = 3
2036 self.assertEqual(C().i, 3)
2037
2038 def test_overwriting_init(self):
2039 # If the class has __init__, use it no matter the value of
2040 # init=.
2041
2042 @dataclass
2043 class C:
2044 x: int
2045 def __init__(self, x):
2046 self.x = 2 * x
2047 self.assertEqual(C(3).x, 6)
2048
2049 @dataclass(init=True)
2050 class C:
2051 x: int
2052 def __init__(self, x):
2053 self.x = 2 * x
2054 self.assertEqual(C(4).x, 8)
2055
2056 @dataclass(init=False)
2057 class C:
2058 x: int
2059 def __init__(self, x):
2060 self.x = 2 * x
2061 self.assertEqual(C(5).x, 10)
2062
2063
2064class TestRepr(unittest.TestCase):
2065 def test_repr(self):
2066 @dataclass
2067 class B:
2068 x: int
2069
2070 @dataclass
2071 class C(B):
2072 y: int = 10
2073
2074 o = C(4)
2075 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2076
2077 @dataclass
2078 class D(C):
2079 x: int = 20
2080 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2081
2082 @dataclass
2083 class C:
2084 @dataclass
2085 class D:
2086 i: int
2087 @dataclass
2088 class E:
2089 pass
2090 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2091 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2092
2093 def test_no_repr(self):
2094 # Test a class with no __repr__ and repr=False.
2095 @dataclass(repr=False)
2096 class C:
2097 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002098 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002099 repr(C(3)))
2100
2101 # Test a class with a __repr__ and repr=False.
2102 @dataclass(repr=False)
2103 class C:
2104 x: int
2105 def __repr__(self):
2106 return 'C-class'
2107 self.assertEqual(repr(C(3)), 'C-class')
2108
2109 def test_overwriting_repr(self):
2110 # If the class has __repr__, use it no matter the value of
2111 # repr=.
2112
2113 @dataclass
2114 class C:
2115 x: int
2116 def __repr__(self):
2117 return 'x'
2118 self.assertEqual(repr(C(0)), 'x')
2119
2120 @dataclass(repr=True)
2121 class C:
2122 x: int
2123 def __repr__(self):
2124 return 'x'
2125 self.assertEqual(repr(C(0)), 'x')
2126
2127 @dataclass(repr=False)
2128 class C:
2129 x: int
2130 def __repr__(self):
2131 return 'x'
2132 self.assertEqual(repr(C(0)), 'x')
2133
2134
Eric V. Smithea8fc522018-01-27 19:07:40 -05002135class TestEq(unittest.TestCase):
2136 def test_no_eq(self):
2137 # Test a class with no __eq__ and eq=False.
2138 @dataclass(eq=False)
2139 class C:
2140 x: int
2141 self.assertNotEqual(C(0), C(0))
2142 c = C(3)
2143 self.assertEqual(c, c)
2144
2145 # Test a class with an __eq__ and eq=False.
2146 @dataclass(eq=False)
2147 class C:
2148 x: int
2149 def __eq__(self, other):
2150 return other == 10
2151 self.assertEqual(C(3), 10)
2152
2153 def test_overwriting_eq(self):
2154 # If the class has __eq__, use it no matter the value of
2155 # eq=.
2156
2157 @dataclass
2158 class C:
2159 x: int
2160 def __eq__(self, other):
2161 return other == 3
2162 self.assertEqual(C(1), 3)
2163 self.assertNotEqual(C(1), 1)
2164
2165 @dataclass(eq=True)
2166 class C:
2167 x: int
2168 def __eq__(self, other):
2169 return other == 4
2170 self.assertEqual(C(1), 4)
2171 self.assertNotEqual(C(1), 1)
2172
2173 @dataclass(eq=False)
2174 class C:
2175 x: int
2176 def __eq__(self, other):
2177 return other == 5
2178 self.assertEqual(C(1), 5)
2179 self.assertNotEqual(C(1), 1)
2180
2181
2182class TestOrdering(unittest.TestCase):
2183 def test_functools_total_ordering(self):
2184 # Test that functools.total_ordering works with this class.
2185 @total_ordering
2186 @dataclass
2187 class C:
2188 x: int
2189 def __lt__(self, other):
2190 # Perform the test "backward", just to make
2191 # sure this is being called.
2192 return self.x >= other
2193
2194 self.assertLess(C(0), -1)
2195 self.assertLessEqual(C(0), -1)
2196 self.assertGreater(C(0), 1)
2197 self.assertGreaterEqual(C(0), 1)
2198
2199 def test_no_order(self):
2200 # Test that no ordering functions are added by default.
2201 @dataclass(order=False)
2202 class C:
2203 x: int
2204 # Make sure no order methods are added.
2205 self.assertNotIn('__le__', C.__dict__)
2206 self.assertNotIn('__lt__', C.__dict__)
2207 self.assertNotIn('__ge__', C.__dict__)
2208 self.assertNotIn('__gt__', C.__dict__)
2209
2210 # Test that __lt__ is still called
2211 @dataclass(order=False)
2212 class C:
2213 x: int
2214 def __lt__(self, other):
2215 return False
2216 # Make sure other methods aren't added.
2217 self.assertNotIn('__le__', C.__dict__)
2218 self.assertNotIn('__ge__', C.__dict__)
2219 self.assertNotIn('__gt__', C.__dict__)
2220
2221 def test_overwriting_order(self):
2222 with self.assertRaisesRegex(TypeError,
2223 'Cannot overwrite attribute __lt__'
2224 '.*using functools.total_ordering'):
2225 @dataclass(order=True)
2226 class C:
2227 x: int
2228 def __lt__(self):
2229 pass
2230
2231 with self.assertRaisesRegex(TypeError,
2232 'Cannot overwrite attribute __le__'
2233 '.*using functools.total_ordering'):
2234 @dataclass(order=True)
2235 class C:
2236 x: int
2237 def __le__(self):
2238 pass
2239
2240 with self.assertRaisesRegex(TypeError,
2241 'Cannot overwrite attribute __gt__'
2242 '.*using functools.total_ordering'):
2243 @dataclass(order=True)
2244 class C:
2245 x: int
2246 def __gt__(self):
2247 pass
2248
2249 with self.assertRaisesRegex(TypeError,
2250 'Cannot overwrite attribute __ge__'
2251 '.*using functools.total_ordering'):
2252 @dataclass(order=True)
2253 class C:
2254 x: int
2255 def __ge__(self):
2256 pass
2257
2258class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002259 def test_unsafe_hash(self):
2260 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002261 class C:
2262 x: int
2263 y: str
2264 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2265
Eric V. Smithea8fc522018-01-27 19:07:40 -05002266 def test_hash_rules(self):
2267 def non_bool(value):
2268 # Map to something else that's True, but not a bool.
2269 if value is None:
2270 return None
2271 if value:
2272 return (3,)
2273 return 0
2274
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002275 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2276 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2277 frozen=frozen):
2278 if result != 'exception':
2279 if with_hash:
2280 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2281 class C:
2282 def __hash__(self):
2283 return 0
2284 else:
2285 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2286 class C:
2287 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002288
2289 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002290 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002291 # __hash__ contains the function we generated.
2292 self.assertIn('__hash__', C.__dict__)
2293 self.assertIsNotNone(C.__dict__['__hash__'])
2294
Eric V. Smithea8fc522018-01-27 19:07:40 -05002295 elif result == '':
2296 # __hash__ is not present in our class.
2297 if not with_hash:
2298 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002299
Eric V. Smithea8fc522018-01-27 19:07:40 -05002300 elif result == 'none':
2301 # __hash__ is set to None.
2302 self.assertIn('__hash__', C.__dict__)
2303 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002304
2305 elif result == 'exception':
2306 # Creating the class should cause an exception.
2307 # This only happens with with_hash==True.
2308 assert(with_hash)
2309 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2310 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2311 class C:
2312 def __hash__(self):
2313 return 0
2314
Eric V. Smithea8fc522018-01-27 19:07:40 -05002315 else:
2316 assert False, f'unknown result {result!r}'
2317
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002318 # There are 8 cases of:
2319 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002320 # eq=True/False
2321 # frozen=True/False
2322 # And for each of these, a different result if
2323 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002324 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2325 (False, False, False, '', ''),
2326 (False, False, True, '', ''),
2327 (False, True, False, 'none', ''),
2328 (False, True, True, 'fn', ''),
2329 (True, False, False, 'fn', 'exception'),
2330 (True, False, True, 'fn', 'exception'),
2331 (True, True, False, 'fn', 'exception'),
2332 (True, True, True, 'fn', 'exception'),
2333 ], 1):
2334 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2335 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002336
2337 # Test non-bool truth values, too. This is just to
2338 # make sure the data-driven table in the decorator
2339 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002340 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2341 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002342
2343
2344 def test_eq_only(self):
2345 # If a class defines __eq__, __hash__ is automatically added
2346 # and set to None. This is normal Python behavior, not
2347 # related to dataclasses. Make sure we don't interfere with
2348 # that (see bpo=32546).
2349
2350 @dataclass
2351 class C:
2352 i: int
2353 def __eq__(self, other):
2354 return self.i == other.i
2355 self.assertEqual(C(1), C(1))
2356 self.assertNotEqual(C(1), C(4))
2357
2358 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002359 # unsafe_hash=True.
2360 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002361 class C:
2362 i: int
2363 def __eq__(self, other):
2364 return self.i == other.i
2365 self.assertEqual(C(1), C(1.0))
2366 self.assertEqual(hash(C(1)), hash(C(1.0)))
2367
2368 # And check that the classes __eq__ is being used, despite
2369 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002370 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002371 class C:
2372 i: int
2373 def __eq__(self, other):
2374 return self.i == 3 and self.i == other.i
2375 self.assertEqual(C(3), C(3))
2376 self.assertNotEqual(C(1), C(1))
2377 self.assertEqual(hash(C(1)), hash(C(1.0)))
2378
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002379 def test_0_field_hash(self):
2380 @dataclass(frozen=True)
2381 class C:
2382 pass
2383 self.assertEqual(hash(C()), hash(()))
2384
2385 @dataclass(unsafe_hash=True)
2386 class C:
2387 pass
2388 self.assertEqual(hash(C()), hash(()))
2389
2390 def test_1_field_hash(self):
2391 @dataclass(frozen=True)
2392 class C:
2393 x: int
2394 self.assertEqual(hash(C(4)), hash((4,)))
2395 self.assertEqual(hash(C(42)), hash((42,)))
2396
2397 @dataclass(unsafe_hash=True)
2398 class C:
2399 x: int
2400 self.assertEqual(hash(C(4)), hash((4,)))
2401 self.assertEqual(hash(C(42)), hash((42,)))
2402
Eric V. Smith718070d2018-02-23 13:01:31 -05002403 def test_hash_no_args(self):
2404 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002405 # make sure that if the @dataclass parameter name is changed
2406 # or the non-default hashing behavior changes, the default
2407 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002408
2409 class Base:
2410 def __hash__(self):
2411 return 301
2412
2413 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002414 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002415 for frozen, eq, base, expected in [
2416 (None, None, object, 'unhashable'),
2417 (None, None, Base, 'unhashable'),
2418 (None, False, object, 'object'),
2419 (None, False, Base, 'base'),
2420 (None, True, object, 'unhashable'),
2421 (None, True, Base, 'unhashable'),
2422 (False, None, object, 'unhashable'),
2423 (False, None, Base, 'unhashable'),
2424 (False, False, object, 'object'),
2425 (False, False, Base, 'base'),
2426 (False, True, object, 'unhashable'),
2427 (False, True, Base, 'unhashable'),
2428 (True, None, object, 'tuple'),
2429 (True, None, Base, 'tuple'),
2430 (True, False, object, 'object'),
2431 (True, False, Base, 'base'),
2432 (True, True, object, 'tuple'),
2433 (True, True, Base, 'tuple'),
2434 ]:
2435
2436 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2437 # First, create the class.
2438 if frozen is None and eq is None:
2439 @dataclass
2440 class C(base):
2441 i: int
2442 elif frozen is None:
2443 @dataclass(eq=eq)
2444 class C(base):
2445 i: int
2446 elif eq is None:
2447 @dataclass(frozen=frozen)
2448 class C(base):
2449 i: int
2450 else:
2451 @dataclass(frozen=frozen, eq=eq)
2452 class C(base):
2453 i: int
2454
2455 # Now, make sure it hashes as expected.
2456 if expected == 'unhashable':
2457 c = C(10)
2458 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2459 hash(c)
2460
2461 elif expected == 'base':
2462 self.assertEqual(hash(C(10)), 301)
2463
2464 elif expected == 'object':
2465 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002466 # hash isn't based on id(), so calling hash()
2467 # won't tell us much. So, just check the
2468 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002469 self.assertIs(C.__hash__, object.__hash__)
2470
2471 elif expected == 'tuple':
2472 self.assertEqual(hash(C(42)), hash((42,)))
2473
2474 else:
2475 assert False, f'unknown value for expected={expected!r}'
2476
Eric V. Smithea8fc522018-01-27 19:07:40 -05002477
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002478class TestFrozen(unittest.TestCase):
2479 def test_frozen(self):
2480 @dataclass(frozen=True)
2481 class C:
2482 i: int
2483
2484 c = C(10)
2485 self.assertEqual(c.i, 10)
2486 with self.assertRaises(FrozenInstanceError):
2487 c.i = 5
2488 self.assertEqual(c.i, 10)
2489
2490 def test_inherit(self):
2491 @dataclass(frozen=True)
2492 class C:
2493 i: int
2494
2495 @dataclass(frozen=True)
2496 class D(C):
2497 j: int
2498
2499 d = D(0, 10)
2500 with self.assertRaises(FrozenInstanceError):
2501 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002502 with self.assertRaises(FrozenInstanceError):
2503 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002504 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002505 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002506
Eric V. Smithf199bc62018-03-18 20:40:34 -04002507 # Test both ways: with an intermediate normal (non-dataclass)
2508 # class and without an intermediate class.
2509 def test_inherit_nonfrozen_from_frozen(self):
2510 for intermediate_class in [True, False]:
2511 with self.subTest(intermediate_class=intermediate_class):
2512 @dataclass(frozen=True)
2513 class C:
2514 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002515
Eric V. Smithf199bc62018-03-18 20:40:34 -04002516 if intermediate_class:
2517 class I(C): pass
2518 else:
2519 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002520
Eric V. Smithf199bc62018-03-18 20:40:34 -04002521 with self.assertRaisesRegex(TypeError,
2522 'cannot inherit non-frozen dataclass from a frozen one'):
2523 @dataclass
2524 class D(I):
2525 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002526
Eric V. Smithf199bc62018-03-18 20:40:34 -04002527 def test_inherit_frozen_from_nonfrozen(self):
2528 for intermediate_class in [True, False]:
2529 with self.subTest(intermediate_class=intermediate_class):
2530 @dataclass
2531 class C:
2532 i: int
2533
2534 if intermediate_class:
2535 class I(C): pass
2536 else:
2537 I = C
2538
2539 with self.assertRaisesRegex(TypeError,
2540 'cannot inherit frozen dataclass from a non-frozen one'):
2541 @dataclass(frozen=True)
2542 class D(I):
2543 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002544
2545 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002546 for intermediate_class in [True, False]:
2547 with self.subTest(intermediate_class=intermediate_class):
2548 class C:
2549 pass
2550
2551 if intermediate_class:
2552 class I(C): pass
2553 else:
2554 I = C
2555
2556 @dataclass(frozen=True)
2557 class D(I):
2558 i: int
2559
2560 d = D(10)
2561 with self.assertRaises(FrozenInstanceError):
2562 d.i = 5
2563
2564 def test_non_frozen_normal_derived(self):
2565 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002566
2567 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002568 class D:
2569 x: int
2570 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002571
Eric V. Smithf199bc62018-03-18 20:40:34 -04002572 class S(D):
2573 pass
2574
2575 s = S(3)
2576 self.assertEqual(s.x, 3)
2577 self.assertEqual(s.y, 10)
2578 s.cached = True
2579
2580 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002581 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002582 s.x = 5
2583 with self.assertRaises(FrozenInstanceError):
2584 s.y = 5
2585 self.assertEqual(s.x, 3)
2586 self.assertEqual(s.y, 10)
2587 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002588
Eric V. Smith74940912018-04-05 06:50:18 -04002589 def test_overwriting_frozen(self):
2590 # frozen uses __setattr__ and __delattr__.
2591 with self.assertRaisesRegex(TypeError,
2592 'Cannot overwrite attribute __setattr__'):
2593 @dataclass(frozen=True)
2594 class C:
2595 x: int
2596 def __setattr__(self):
2597 pass
2598
2599 with self.assertRaisesRegex(TypeError,
2600 'Cannot overwrite attribute __delattr__'):
2601 @dataclass(frozen=True)
2602 class C:
2603 x: int
2604 def __delattr__(self):
2605 pass
2606
2607 @dataclass(frozen=False)
2608 class C:
2609 x: int
2610 def __setattr__(self, name, value):
2611 self.__dict__['x'] = value * 2
2612 self.assertEqual(C(10).x, 20)
2613
2614 def test_frozen_hash(self):
2615 @dataclass(frozen=True)
2616 class C:
2617 x: Any
2618
2619 # If x is immutable, we can compute the hash. No exception is
2620 # raised.
2621 hash(C(3))
2622
2623 # If x is mutable, computing the hash is an error.
2624 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2625 hash(C({}))
2626
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002627
Eric V. Smith7389fd92018-03-19 21:07:51 -04002628class TestSlots(unittest.TestCase):
2629 def test_simple(self):
2630 @dataclass
2631 class C:
2632 __slots__ = ('x',)
2633 x: Any
2634
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002635 # There was a bug where a variable in a slot was assumed to
2636 # also have a default value (of type
2637 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002638 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002639 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002640 C()
2641
2642 # We can create an instance, and assign to x.
2643 c = C(10)
2644 self.assertEqual(c.x, 10)
2645 c.x = 5
2646 self.assertEqual(c.x, 5)
2647
2648 # We can't assign to anything else.
2649 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2650 c.y = 5
2651
2652 def test_derived_added_field(self):
2653 # See bpo-33100.
2654 @dataclass
2655 class Base:
2656 __slots__ = ('x',)
2657 x: Any
2658
2659 @dataclass
2660 class Derived(Base):
2661 x: int
2662 y: int
2663
2664 d = Derived(1, 2)
2665 self.assertEqual((d.x, d.y), (1, 2))
2666
2667 # We can add a new field to the derived instance.
2668 d.z = 10
2669
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002670class TestDescriptors(unittest.TestCase):
2671 def test_set_name(self):
2672 # See bpo-33141.
2673
2674 # Create a descriptor.
2675 class D:
2676 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002677 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002678 def __get__(self, instance, owner):
2679 if instance is not None:
2680 return 1
2681 return self
2682
2683 # This is the case of just normal descriptor behavior, no
2684 # dataclass code is involved in initializing the descriptor.
2685 @dataclass
2686 class C:
2687 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002688 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002689
2690 # Now test with a default value and init=False, which is the
2691 # only time this is really meaningful. If not using
2692 # init=False, then the descriptor will be overwritten, anyway.
2693 @dataclass
2694 class C:
2695 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002696 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002697 self.assertEqual(C().c, 1)
2698
2699 def test_non_descriptor(self):
2700 # PEP 487 says __set_name__ should work on non-descriptors.
2701 # Create a descriptor.
2702
2703 class D:
2704 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002705 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002706
2707 @dataclass
2708 class C:
2709 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002710 self.assertEqual(C.c.name, 'cx')
2711
2712 def test_lookup_on_instance(self):
2713 # See bpo-33175.
2714 class D:
2715 pass
2716
2717 d = D()
2718 # Create an attribute on the instance, not type.
2719 d.__set_name__ = Mock()
2720
2721 # Make sure d.__set_name__ is not called.
2722 @dataclass
2723 class C:
2724 i: int=field(default=d, init=False)
2725
2726 self.assertEqual(d.__set_name__.call_count, 0)
2727
2728 def test_lookup_on_class(self):
2729 # See bpo-33175.
2730 class D:
2731 pass
2732 D.__set_name__ = Mock()
2733
2734 # Make sure D.__set_name__ is called.
2735 @dataclass
2736 class C:
2737 i: int=field(default=D(), init=False)
2738
2739 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002740
Eric V. Smith7389fd92018-03-19 21:07:51 -04002741
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002742class TestStringAnnotations(unittest.TestCase):
2743 def test_classvar(self):
2744 # Some expressions recognized as ClassVar really aren't. But
2745 # if you're using string annotations, it's not an exact
2746 # science.
2747 # These tests assume that both "import typing" and "from
2748 # typing import *" have been run in this file.
2749 for typestr in ('ClassVar[int]',
2750 'ClassVar [int]'
2751 ' ClassVar [int]',
2752 'ClassVar',
2753 ' ClassVar ',
2754 'typing.ClassVar[int]',
2755 'typing.ClassVar[str]',
2756 ' typing.ClassVar[str]',
2757 'typing .ClassVar[str]',
2758 'typing. ClassVar[str]',
2759 'typing.ClassVar [str]',
2760 'typing.ClassVar [ str]',
2761
2762 # Not syntactically valid, but these will
2763 # be treated as ClassVars.
2764 'typing.ClassVar.[int]',
2765 'typing.ClassVar+',
2766 ):
2767 with self.subTest(typestr=typestr):
2768 @dataclass
2769 class C:
2770 x: typestr
2771
2772 # x is a ClassVar, so C() takes no args.
2773 C()
2774
2775 # And it won't appear in the class's dict because it doesn't
2776 # have a default.
2777 self.assertNotIn('x', C.__dict__)
2778
2779 def test_isnt_classvar(self):
2780 for typestr in ('CV',
2781 't.ClassVar',
2782 't.ClassVar[int]',
2783 'typing..ClassVar[int]',
2784 'Classvar',
2785 'Classvar[int]',
2786 'typing.ClassVarx[int]',
2787 'typong.ClassVar[int]',
2788 'dataclasses.ClassVar[int]',
2789 'typingxClassVar[str]',
2790 ):
2791 with self.subTest(typestr=typestr):
2792 @dataclass
2793 class C:
2794 x: typestr
2795
2796 # x is not a ClassVar, so C() takes one arg.
2797 self.assertEqual(C(10).x, 10)
2798
2799 def test_initvar(self):
2800 # These tests assume that both "import dataclasses" and "from
2801 # dataclasses import *" have been run in this file.
2802 for typestr in ('InitVar[int]',
2803 'InitVar [int]'
2804 ' InitVar [int]',
2805 'InitVar',
2806 ' InitVar ',
2807 'dataclasses.InitVar[int]',
2808 'dataclasses.InitVar[str]',
2809 ' dataclasses.InitVar[str]',
2810 'dataclasses .InitVar[str]',
2811 'dataclasses. InitVar[str]',
2812 'dataclasses.InitVar [str]',
2813 'dataclasses.InitVar [ str]',
2814
2815 # Not syntactically valid, but these will
2816 # be treated as InitVars.
2817 'dataclasses.InitVar.[int]',
2818 'dataclasses.InitVar+',
2819 ):
2820 with self.subTest(typestr=typestr):
2821 @dataclass
2822 class C:
2823 x: typestr
2824
2825 # x is an InitVar, so doesn't create a member.
2826 with self.assertRaisesRegex(AttributeError,
2827 "object has no attribute 'x'"):
2828 C(1).x
2829
2830 def test_isnt_initvar(self):
2831 for typestr in ('IV',
2832 'dc.InitVar',
2833 'xdataclasses.xInitVar',
2834 'typing.xInitVar[int]',
2835 ):
2836 with self.subTest(typestr=typestr):
2837 @dataclass
2838 class C:
2839 x: typestr
2840
2841 # x is not an InitVar, so there will be a member x.
2842 self.assertEqual(C(10).x, 10)
2843
2844 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002845 from test import dataclass_module_1
2846 from test import dataclass_module_1_str
2847 from test import dataclass_module_2
2848 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002849
2850 for m in (dataclass_module_1, dataclass_module_1_str,
2851 dataclass_module_2, dataclass_module_2_str,
2852 ):
2853 with self.subTest(m=m):
2854 # There's a difference in how the ClassVars are
2855 # interpreted when using string annotations or
2856 # not. See the imported modules for details.
2857 if m.USING_STRINGS:
2858 c = m.CV(10)
2859 else:
2860 c = m.CV()
2861 self.assertEqual(c.cv0, 20)
2862
2863
2864 # There's a difference in how the InitVars are
2865 # interpreted when using string annotations or
2866 # not. See the imported modules for details.
2867 c = m.IV(0, 1, 2, 3, 4)
2868
2869 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2870 with self.subTest(field_name=field_name):
2871 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2872 # Since field_name is an InitVar, it's
2873 # not an instance field.
2874 getattr(c, field_name)
2875
2876 if m.USING_STRINGS:
2877 # iv4 is interpreted as a normal field.
2878 self.assertIn('not_iv4', c.__dict__)
2879 self.assertEqual(c.not_iv4, 4)
2880 else:
2881 # iv4 is interpreted as an InitVar, so it
2882 # won't exist on the instance.
2883 self.assertNotIn('not_iv4', c.__dict__)
2884
2885
Eric V. Smith4e812962018-05-16 11:31:29 -04002886class TestMakeDataclass(unittest.TestCase):
2887 def test_simple(self):
2888 C = make_dataclass('C',
2889 [('x', int),
2890 ('y', int, field(default=5))],
2891 namespace={'add_one': lambda self: self.x + 1})
2892 c = C(10)
2893 self.assertEqual((c.x, c.y), (10, 5))
2894 self.assertEqual(c.add_one(), 11)
2895
2896
2897 def test_no_mutate_namespace(self):
2898 # Make sure a provided namespace isn't mutated.
2899 ns = {}
2900 C = make_dataclass('C',
2901 [('x', int),
2902 ('y', int, field(default=5))],
2903 namespace=ns)
2904 self.assertEqual(ns, {})
2905
2906 def test_base(self):
2907 class Base1:
2908 pass
2909 class Base2:
2910 pass
2911 C = make_dataclass('C',
2912 [('x', int)],
2913 bases=(Base1, Base2))
2914 c = C(2)
2915 self.assertIsInstance(c, C)
2916 self.assertIsInstance(c, Base1)
2917 self.assertIsInstance(c, Base2)
2918
2919 def test_base_dataclass(self):
2920 @dataclass
2921 class Base1:
2922 x: int
2923 class Base2:
2924 pass
2925 C = make_dataclass('C',
2926 [('y', int)],
2927 bases=(Base1, Base2))
2928 with self.assertRaisesRegex(TypeError, 'required positional'):
2929 c = C(2)
2930 c = C(1, 2)
2931 self.assertIsInstance(c, C)
2932 self.assertIsInstance(c, Base1)
2933 self.assertIsInstance(c, Base2)
2934
2935 self.assertEqual((c.x, c.y), (1, 2))
2936
2937 def test_init_var(self):
2938 def post_init(self, y):
2939 self.x *= y
2940
2941 C = make_dataclass('C',
2942 [('x', int),
2943 ('y', InitVar[int]),
2944 ],
2945 namespace={'__post_init__': post_init},
2946 )
2947 c = C(2, 3)
2948 self.assertEqual(vars(c), {'x': 6})
2949 self.assertEqual(len(fields(c)), 1)
2950
2951 def test_class_var(self):
2952 C = make_dataclass('C',
2953 [('x', int),
2954 ('y', ClassVar[int], 10),
2955 ('z', ClassVar[int], field(default=20)),
2956 ])
2957 c = C(1)
2958 self.assertEqual(vars(c), {'x': 1})
2959 self.assertEqual(len(fields(c)), 1)
2960 self.assertEqual(C.y, 10)
2961 self.assertEqual(C.z, 20)
2962
2963 def test_other_params(self):
2964 C = make_dataclass('C',
2965 [('x', int),
2966 ('y', ClassVar[int], 10),
2967 ('z', ClassVar[int], field(default=20)),
2968 ],
2969 init=False)
2970 # Make sure we have a repr, but no init.
2971 self.assertNotIn('__init__', vars(C))
2972 self.assertIn('__repr__', vars(C))
2973
2974 # Make sure random other params don't work.
2975 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2976 C = make_dataclass('C',
2977 [],
2978 xxinit=False)
2979
2980 def test_no_types(self):
2981 C = make_dataclass('Point', ['x', 'y', 'z'])
2982 c = C(1, 2, 3)
2983 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2984 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2985 'y': 'typing.Any',
2986 'z': 'typing.Any'})
2987
2988 C = make_dataclass('Point', ['x', ('y', int), 'z'])
2989 c = C(1, 2, 3)
2990 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2991 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2992 'y': int,
2993 'z': 'typing.Any'})
2994
2995 def test_invalid_type_specification(self):
2996 for bad_field in [(),
2997 (1, 2, 3, 4),
2998 ]:
2999 with self.subTest(bad_field=bad_field):
3000 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3001 make_dataclass('C', ['a', bad_field])
3002
3003 # And test for things with no len().
3004 for bad_field in [float,
3005 lambda x:x,
3006 ]:
3007 with self.subTest(bad_field=bad_field):
3008 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3009 make_dataclass('C', ['a', bad_field])
3010
3011 def test_duplicate_field_names(self):
3012 for field in ['a', 'ab']:
3013 with self.subTest(field=field):
3014 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3015 make_dataclass('C', [field, 'a', field])
3016
3017 def test_keyword_field_names(self):
3018 for field in ['for', 'async', 'await', 'as']:
3019 with self.subTest(field=field):
3020 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3021 make_dataclass('C', ['a', field])
3022 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3023 make_dataclass('C', [field])
3024 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3025 make_dataclass('C', [field, 'a'])
3026
3027 def test_non_identifier_field_names(self):
3028 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3029 with self.subTest(field=field):
3030 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3031 make_dataclass('C', ['a', field])
3032 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3033 make_dataclass('C', [field])
3034 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
3035 make_dataclass('C', [field, 'a'])
3036
3037 def test_underscore_field_names(self):
3038 # Unlike namedtuple, it's okay if dataclass field names have
3039 # an underscore.
3040 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3041
3042 def test_funny_class_names_names(self):
3043 # No reason to prevent weird class names, since
3044 # types.new_class allows them.
3045 for classname in ['()', 'x,y', '*', '2@3', '']:
3046 with self.subTest(classname=classname):
3047 C = make_dataclass(classname, ['a', 'b'])
3048 self.assertEqual(C.__name__, classname)
3049
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003050class TestReplace(unittest.TestCase):
3051 def test(self):
3052 @dataclass(frozen=True)
3053 class C:
3054 x: int
3055 y: int
3056
3057 c = C(1, 2)
3058 c1 = replace(c, x=3)
3059 self.assertEqual(c1.x, 3)
3060 self.assertEqual(c1.y, 2)
3061
3062 def test_frozen(self):
3063 @dataclass(frozen=True)
3064 class C:
3065 x: int
3066 y: int
3067 z: int = field(init=False, default=10)
3068 t: int = field(init=False, default=100)
3069
3070 c = C(1, 2)
3071 c1 = replace(c, x=3)
3072 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3073 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3074
3075
3076 with self.assertRaisesRegex(ValueError, 'init=False'):
3077 replace(c, x=3, z=20, t=50)
3078 with self.assertRaisesRegex(ValueError, 'init=False'):
3079 replace(c, z=20)
3080 replace(c, x=3, z=20, t=50)
3081
3082 # Make sure the result is still frozen.
3083 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3084 c1.x = 3
3085
3086 # Make sure we can't replace an attribute that doesn't exist,
3087 # if we're also replacing one that does exist. Test this
3088 # here, because setting attributes on frozen instances is
3089 # handled slightly differently from non-frozen ones.
3090 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3091 "keyword argument 'a'"):
3092 c1 = replace(c, x=20, a=5)
3093
3094 def test_invalid_field_name(self):
3095 @dataclass(frozen=True)
3096 class C:
3097 x: int
3098 y: int
3099
3100 c = C(1, 2)
3101 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3102 "keyword argument 'z'"):
3103 c1 = replace(c, z=3)
3104
3105 def test_invalid_object(self):
3106 @dataclass(frozen=True)
3107 class C:
3108 x: int
3109 y: int
3110
3111 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3112 replace(C, x=3)
3113
3114 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3115 replace(0, x=3)
3116
3117 def test_no_init(self):
3118 @dataclass
3119 class C:
3120 x: int
3121 y: int = field(init=False, default=10)
3122
3123 c = C(1)
3124 c.y = 20
3125
3126 # Make sure y gets the default value.
3127 c1 = replace(c, x=5)
3128 self.assertEqual((c1.x, c1.y), (5, 10))
3129
3130 # Trying to replace y is an error.
3131 with self.assertRaisesRegex(ValueError, 'init=False'):
3132 replace(c, x=2, y=30)
3133
3134 with self.assertRaisesRegex(ValueError, 'init=False'):
3135 replace(c, y=30)
3136
3137 def test_classvar(self):
3138 @dataclass
3139 class C:
3140 x: int
3141 y: ClassVar[int] = 1000
3142
3143 c = C(1)
3144 d = C(2)
3145
3146 self.assertIs(c.y, d.y)
3147 self.assertEqual(c.y, 1000)
3148
3149 # Trying to replace y is an error: can't replace ClassVars.
3150 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3151 "unexpected keyword argument 'y'"):
3152 replace(c, y=30)
3153
3154 replace(c, x=5)
3155
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003156 def test_initvar_is_specified(self):
3157 @dataclass
3158 class C:
3159 x: int
3160 y: InitVar[int]
3161
3162 def __post_init__(self, y):
3163 self.x *= y
3164
3165 c = C(1, 10)
3166 self.assertEqual(c.x, 10)
3167 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3168 "specified with replace()"):
3169 replace(c, x=3)
3170 c = replace(c, x=3, y=5)
3171 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303172
3173 def test_recursive_repr(self):
3174 @dataclass
3175 class C:
3176 f: "C"
3177
3178 c = C(None)
3179 c.f = c
3180 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3181
3182 def test_recursive_repr_two_attrs(self):
3183 @dataclass
3184 class C:
3185 f: "C"
3186 g: "C"
3187
3188 c = C(None, None)
3189 c.f = c
3190 c.g = c
3191 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3192 ".<locals>.C(f=..., g=...)")
3193
3194 def test_recursive_repr_indirection(self):
3195 @dataclass
3196 class C:
3197 f: "D"
3198
3199 @dataclass
3200 class D:
3201 f: "C"
3202
3203 c = C(None)
3204 d = D(None)
3205 c.f = d
3206 d.f = c
3207 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3208 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3209 ".<locals>.D(f=...))")
3210
3211 def test_recursive_repr_indirection_two(self):
3212 @dataclass
3213 class C:
3214 f: "D"
3215
3216 @dataclass
3217 class D:
3218 f: "E"
3219
3220 @dataclass
3221 class E:
3222 f: "C"
3223
3224 c = C(None)
3225 d = D(None)
3226 e = E(None)
3227 c.f = d
3228 d.f = e
3229 e.f = c
3230 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3231 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3232 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3233 ".<locals>.E(f=...)))")
3234
3235 def test_recursive_repr_two_attrs(self):
3236 @dataclass
3237 class C:
3238 f: "C"
3239 g: "C"
3240
3241 c = C(None, None)
3242 c.f = c
3243 c.g = c
3244 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3245 ".<locals>.C(f=..., g=...)")
3246
3247 def test_recursive_repr_misc_attrs(self):
3248 @dataclass
3249 class C:
3250 f: "C"
3251 g: int
3252
3253 c = C(None, 1)
3254 c.f = c
3255 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3256 ".<locals>.C(f=..., g=1)")
3257
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003258 ## def test_initvar(self):
3259 ## @dataclass
3260 ## class C:
3261 ## x: int
3262 ## y: InitVar[int]
3263
3264 ## c = C(1, 10)
3265 ## d = C(2, 20)
3266
3267 ## # In our case, replacing an InitVar is a no-op
3268 ## self.assertEqual(c, replace(c, y=5))
3269
3270 ## replace(c, x=5)
3271
Eric V. Smith4e812962018-05-16 11:31:29 -04003272
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003273if __name__ == '__main__':
3274 unittest.main()