blob: 238335e7d9edad0650de2719d12849ea601ca9a0 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +03009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050014from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040016import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
17import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
18
Eric V. Smithf0db54a2017-12-04 16:58:55 -050019# Just any custom exception we can catch.
20class CustomError(Exception): pass
21
22class TestCase(unittest.TestCase):
23 def test_no_fields(self):
24 @dataclass
25 class C:
26 pass
27
28 o = C()
29 self.assertEqual(len(fields(C)), 0)
30
Eric V. Smith56970b82018-03-22 16:28:48 -040031 def test_no_fields_but_member_variable(self):
32 @dataclass
33 class C:
34 i = 0
35
36 o = C()
37 self.assertEqual(len(fields(C)), 0)
38
Eric V. Smithf0db54a2017-12-04 16:58:55 -050039 def test_one_field_no_default(self):
40 @dataclass
41 class C:
42 x: int
43
44 o = C(42)
45 self.assertEqual(o.x, 42)
46
47 def test_named_init_params(self):
48 @dataclass
49 class C:
50 x: int
51
52 o = C(x=32)
53 self.assertEqual(o.x, 32)
54
55 def test_two_fields_one_default(self):
56 @dataclass
57 class C:
58 x: int
59 y: int = 0
60
61 o = C(3)
62 self.assertEqual((o.x, o.y), (3, 0))
63
64 # Non-defaults following defaults.
65 with self.assertRaisesRegex(TypeError,
66 "non-default argument 'y' follows "
67 "default argument"):
68 @dataclass
69 class C:
70 x: int = 0
71 y: int
72
73 # A derived class adds a non-default field after a default one.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int = 0
80
81 @dataclass
82 class C(B):
83 y: int
84
85 # Override a base class field and add a default to
86 # a field which didn't use to have a default.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class B:
92 x: int
93 y: int
94
95 @dataclass
96 class C(B):
97 x: int = 0
98
Eric V. Smithdbf9cff2018-02-25 21:30:17 -050099 def test_overwrite_hash(self):
100 # Test that declaring this class isn't an error. It should
101 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500102 @dataclass(frozen=True)
103 class C:
104 x: int
105 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500106 return 301
107 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500108
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500109 # Test that declaring this class isn't an error. It should
110 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500111 @dataclass(frozen=True)
112 class C:
113 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500114 def __eq__(self, other):
115 return False
116 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500117
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500118 # But this one should generate an exception, because with
119 # unsafe_hash=True, it's an error to have a __hash__ defined.
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __hash__'):
122 @dataclass(unsafe_hash=True)
123 class C:
124 def __hash__(self):
125 pass
126
127 # Creating this class should not generate an exception,
128 # because even though __hash__ exists before @dataclass is
129 # called, (due to __eq__ being defined), since it's None
130 # that's okay.
131 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500132 class C:
133 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500134 def __eq__(self):
135 pass
136 # The generated hash function works as we'd expect.
137 self.assertEqual(hash(C(10)), hash((10,)))
138
139 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400140 # __hash__ exists and is not None, which it would be if it
141 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 x: int
147 def __eq__(self):
148 pass
149 def __hash__(self):
150 pass
151
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500152 def test_overwrite_fields_in_derived_class(self):
153 # Note that x from C1 replaces x in Base, but the order remains
154 # the same as defined in Base.
155 @dataclass
156 class Base:
157 x: Any = 15.0
158 y: int = 0
159
160 @dataclass
161 class C1(Base):
162 z: int = 10
163 x: int = 15
164
165 o = Base()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
167
168 o = C1()
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
170
171 o = C1(x=5)
172 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
173
174 def test_field_named_self(self):
175 @dataclass
176 class C:
177 self: str
178 c=C('foo')
179 self.assertEqual(c.self, 'foo')
180
181 # Make sure the first parameter is not named 'self'.
182 sig = inspect.signature(C.__init__)
183 first = next(iter(sig.parameters))
184 self.assertNotEqual('self', first)
185
186 # But we do use 'self' if no field named self.
187 @dataclass
188 class C:
189 selfx: str
190
191 # Make sure the first parameter is named 'self'.
192 sig = inspect.signature(C.__init__)
193 first = next(iter(sig.parameters))
194 self.assertEqual('self', first)
195
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300196 def test_field_named_object(self):
197 @dataclass
198 class C:
199 object: str
200 c = C('foo')
201 self.assertEqual(c.object, 'foo')
202
203 def test_field_named_object_frozen(self):
204 @dataclass(frozen=True)
205 class C:
206 object: str
207 c = C('foo')
208 self.assertEqual(c.object, 'foo')
209
210 def test_field_named_like_builtin(self):
211 # Attribute names can shadow built-in names
212 # since code generation is used.
213 # Ensure that this is not happening.
214 exclusions = {'None', 'True', 'False'}
215 builtins_names = sorted(
216 b for b in builtins.__dict__.keys()
217 if not b.startswith('__') and b not in exclusions
218 )
219 attributes = [(name, str) for name in builtins_names]
220 C = make_dataclass('C', attributes)
221
222 c = C(*[name for name in builtins_names])
223
224 for name in builtins_names:
225 self.assertEqual(getattr(c, name), name)
226
227 def test_field_named_like_builtin_frozen(self):
228 # Attribute names can shadow built-in names
229 # since code generation is used.
230 # Ensure that this is not happening
231 # for frozen data classes.
232 exclusions = {'None', 'True', 'False'}
233 builtins_names = sorted(
234 b for b in builtins.__dict__.keys()
235 if not b.startswith('__') and b not in exclusions
236 )
237 attributes = [(name, str) for name in builtins_names]
238 C = make_dataclass('C', attributes, frozen=True)
239
240 c = C(*[name for name in builtins_names])
241
242 for name in builtins_names:
243 self.assertEqual(getattr(c, name), name)
244
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500245 def test_0_field_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 pass
250
251 @dataclass(order=False)
252 class C1:
253 pass
254
255 for cls in [C0, C1]:
256 with self.subTest(cls=cls):
257 self.assertEqual(cls(), cls())
258 for idx, fn in enumerate([lambda a, b: a < b,
259 lambda a, b: a <= b,
260 lambda a, b: a > b,
261 lambda a, b: a >= b]):
262 with self.subTest(idx=idx):
263 with self.assertRaisesRegex(TypeError,
264 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
265 fn(cls(), cls())
266
267 @dataclass(order=True)
268 class C:
269 pass
270 self.assertLessEqual(C(), C())
271 self.assertGreaterEqual(C(), C())
272
273 def test_1_field_compare(self):
274 # Ensure that order=False is the default.
275 @dataclass
276 class C0:
277 x: int
278
279 @dataclass(order=False)
280 class C1:
281 x: int
282
283 for cls in [C0, C1]:
284 with self.subTest(cls=cls):
285 self.assertEqual(cls(1), cls(1))
286 self.assertNotEqual(cls(0), cls(1))
287 for idx, fn in enumerate([lambda a, b: a < b,
288 lambda a, b: a <= b,
289 lambda a, b: a > b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 with self.assertRaisesRegex(TypeError,
293 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
294 fn(cls(0), cls(0))
295
296 @dataclass(order=True)
297 class C:
298 x: int
299 self.assertLess(C(0), C(1))
300 self.assertLessEqual(C(0), C(1))
301 self.assertLessEqual(C(1), C(1))
302 self.assertGreater(C(1), C(0))
303 self.assertGreaterEqual(C(1), C(0))
304 self.assertGreaterEqual(C(1), C(1))
305
306 def test_simple_compare(self):
307 # Ensure that order=False is the default.
308 @dataclass
309 class C0:
310 x: int
311 y: int
312
313 @dataclass(order=False)
314 class C1:
315 x: int
316 y: int
317
318 for cls in [C0, C1]:
319 with self.subTest(cls=cls):
320 self.assertEqual(cls(0, 0), cls(0, 0))
321 self.assertEqual(cls(1, 2), cls(1, 2))
322 self.assertNotEqual(cls(1, 0), cls(0, 0))
323 self.assertNotEqual(cls(1, 0), cls(1, 1))
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
331 fn(cls(0, 0), cls(0, 0))
332
333 @dataclass(order=True)
334 class C:
335 x: int
336 y: int
337
338 for idx, fn in enumerate([lambda a, b: a == b,
339 lambda a, b: a <= b,
340 lambda a, b: a >= b]):
341 with self.subTest(idx=idx):
342 self.assertTrue(fn(C(0, 0), C(0, 0)))
343
344 for idx, fn in enumerate([lambda a, b: a < b,
345 lambda a, b: a <= b,
346 lambda a, b: a != b]):
347 with self.subTest(idx=idx):
348 self.assertTrue(fn(C(0, 0), C(0, 1)))
349 self.assertTrue(fn(C(0, 1), C(1, 0)))
350 self.assertTrue(fn(C(1, 0), C(1, 1)))
351
352 for idx, fn in enumerate([lambda a, b: a > b,
353 lambda a, b: a >= b,
354 lambda a, b: a != b]):
355 with self.subTest(idx=idx):
356 self.assertTrue(fn(C(0, 1), C(0, 0)))
357 self.assertTrue(fn(C(1, 0), C(0, 1)))
358 self.assertTrue(fn(C(1, 1), C(1, 0)))
359
360 def test_compare_subclasses(self):
361 # Comparisons fail for subclasses, even if no fields
362 # are added.
363 @dataclass
364 class B:
365 i: int
366
367 @dataclass
368 class C(B):
369 pass
370
371 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
372 (lambda a, b: a != b, True)]):
373 with self.subTest(idx=idx):
374 self.assertEqual(fn(B(0), C(0)), expected)
375
376 for idx, fn in enumerate([lambda a, b: a < b,
377 lambda a, b: a <= b,
378 lambda a, b: a > b,
379 lambda a, b: a >= b]):
380 with self.subTest(idx=idx):
381 with self.assertRaisesRegex(TypeError,
382 "not supported between instances of 'B' and 'C'"):
383 fn(B(0), C(0))
384
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500385 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500386 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500387 for (eq, order, result ) in [
388 (False, False, 'neither'),
389 (False, True, 'exception'),
390 (True, False, 'eq_only'),
391 (True, True, 'both'),
392 ]:
393 with self.subTest(eq=eq, order=order):
394 if result == 'exception':
395 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
396 @dataclass(eq=eq, order=order)
397 class C:
398 pass
399 else:
400 @dataclass(eq=eq, order=order)
401 class C:
402 pass
403
404 if result == 'neither':
405 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 self.assertNotIn('__lt__', C.__dict__)
407 self.assertNotIn('__le__', C.__dict__)
408 self.assertNotIn('__gt__', C.__dict__)
409 self.assertNotIn('__ge__', C.__dict__)
410 elif result == 'both':
411 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500412 self.assertIn('__lt__', C.__dict__)
413 self.assertIn('__le__', C.__dict__)
414 self.assertIn('__gt__', C.__dict__)
415 self.assertIn('__ge__', C.__dict__)
416 elif result == 'eq_only':
417 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500418 self.assertNotIn('__lt__', C.__dict__)
419 self.assertNotIn('__le__', C.__dict__)
420 self.assertNotIn('__gt__', C.__dict__)
421 self.assertNotIn('__ge__', C.__dict__)
422 else:
423 assert False, f'unknown result {result!r}'
424
425 def test_field_no_default(self):
426 @dataclass
427 class C:
428 x: int = field()
429
430 self.assertEqual(C(5).x, 5)
431
432 with self.assertRaisesRegex(TypeError,
433 r"__init__\(\) missing 1 required "
434 "positional argument: 'x'"):
435 C()
436
437 def test_field_default(self):
438 default = object()
439 @dataclass
440 class C:
441 x: object = field(default=default)
442
443 self.assertIs(C.x, default)
444 c = C(10)
445 self.assertEqual(c.x, 10)
446
447 # If we delete the instance attribute, we should then see the
448 # class attribute.
449 del c.x
450 self.assertIs(c.x, default)
451
452 self.assertIs(C().x, default)
453
454 def test_not_in_repr(self):
455 @dataclass
456 class C:
457 x: int = field(repr=False)
458 with self.assertRaises(TypeError):
459 C()
460 c = C(10)
461 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
462
463 @dataclass
464 class C:
465 x: int = field(repr=False)
466 y: int
467 c = C(10, 20)
468 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
469
470 def test_not_in_compare(self):
471 @dataclass
472 class C:
473 x: int = 0
474 y: int = field(compare=False, default=4)
475
476 self.assertEqual(C(), C(0, 20))
477 self.assertEqual(C(1, 10), C(1, 20))
478 self.assertNotEqual(C(3), C(4, 10))
479 self.assertNotEqual(C(3, 10), C(4, 10))
480
481 def test_hash_field_rules(self):
482 # Test all 6 cases of:
483 # hash=True/False/None
484 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500485 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500486 (True, False, 'field' ),
487 (True, True, 'field' ),
488 (False, False, 'absent'),
489 (False, True, 'absent'),
490 (None, False, 'absent'),
491 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500492 ]:
493 with self.subTest(hash=hash_, compare=compare):
494 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500495 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500496 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500497
498 if result == 'field':
499 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500500 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500501 elif result == 'absent':
502 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500503 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500504 else:
505 assert False, f'unknown result {result!r}'
506
507 def test_init_false_no_default(self):
508 # If init=False and no default value, then the field won't be
509 # present in the instance.
510 @dataclass
511 class C:
512 x: int = field(init=False)
513
514 self.assertNotIn('x', C().__dict__)
515
516 @dataclass
517 class C:
518 x: int
519 y: int = 0
520 z: int = field(init=False)
521 t: int = 10
522
523 self.assertNotIn('z', C(0).__dict__)
524 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
525
526 def test_class_marker(self):
527 @dataclass
528 class C:
529 x: int
530 y: str = field(init=False, default=None)
531 z: str = field(repr=False)
532
533 the_fields = fields(C)
534 # the_fields is a tuple of 3 items, each value
535 # is in __annotations__.
536 self.assertIsInstance(the_fields, tuple)
537 for f in the_fields:
538 self.assertIs(type(f), Field)
539 self.assertIn(f.name, C.__annotations__)
540
541 self.assertEqual(len(the_fields), 3)
542
543 self.assertEqual(the_fields[0].name, 'x')
544 self.assertEqual(the_fields[0].type, int)
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertTrue (the_fields[0].init)
547 self.assertTrue (the_fields[0].repr)
548 self.assertEqual(the_fields[1].name, 'y')
549 self.assertEqual(the_fields[1].type, str)
550 self.assertIsNone(getattr(C, 'y'))
551 self.assertFalse(the_fields[1].init)
552 self.assertTrue (the_fields[1].repr)
553 self.assertEqual(the_fields[2].name, 'z')
554 self.assertEqual(the_fields[2].type, str)
555 self.assertFalse(hasattr(C, 'z'))
556 self.assertTrue (the_fields[2].init)
557 self.assertFalse(the_fields[2].repr)
558
559 def test_field_order(self):
560 @dataclass
561 class B:
562 a: str = 'B:a'
563 b: str = 'B:b'
564 c: str = 'B:c'
565
566 @dataclass
567 class C(B):
568 b: str = 'C:b'
569
570 self.assertEqual([(f.name, f.default) for f in fields(C)],
571 [('a', 'B:a'),
572 ('b', 'C:b'),
573 ('c', 'B:c')])
574
575 @dataclass
576 class D(B):
577 c: str = 'D:c'
578
579 self.assertEqual([(f.name, f.default) for f in fields(D)],
580 [('a', 'B:a'),
581 ('b', 'B:b'),
582 ('c', 'D:c')])
583
584 @dataclass
585 class E(D):
586 a: str = 'E:a'
587 d: str = 'E:d'
588
589 self.assertEqual([(f.name, f.default) for f in fields(E)],
590 [('a', 'E:a'),
591 ('b', 'B:b'),
592 ('c', 'D:c'),
593 ('d', 'E:d')])
594
595 def test_class_attrs(self):
596 # We only have a class attribute if a default value is
597 # specified, either directly or via a field with a default.
598 default = object()
599 @dataclass
600 class C:
601 x: int
602 y: int = field(repr=False)
603 z: object = default
604 t: int = field(default=100)
605
606 self.assertFalse(hasattr(C, 'x'))
607 self.assertFalse(hasattr(C, 'y'))
608 self.assertIs (C.z, default)
609 self.assertEqual(C.t, 100)
610
611 def test_disallowed_mutable_defaults(self):
612 # For the known types, don't allow mutable default values.
613 for typ, empty, non_empty in [(list, [], [1]),
614 (dict, {}, {0:1}),
615 (set, set(), set([1])),
616 ]:
617 with self.subTest(typ=typ):
618 # Can't use a zero-length value.
619 with self.assertRaisesRegex(ValueError,
620 f'mutable default {typ} for field '
621 'x is not allowed'):
622 @dataclass
623 class Point:
624 x: typ = empty
625
626
627 # Nor a non-zero-length value
628 with self.assertRaisesRegex(ValueError,
629 f'mutable default {typ} for field '
630 'y is not allowed'):
631 @dataclass
632 class Point:
633 y: typ = non_empty
634
635 # Check subtypes also fail.
636 class Subclass(typ): pass
637
638 with self.assertRaisesRegex(ValueError,
639 f"mutable default .*Subclass'>"
640 ' for field z is not allowed'
641 ):
642 @dataclass
643 class Point:
644 z: typ = Subclass()
645
646 # Because this is a ClassVar, it can be mutable.
647 @dataclass
648 class C:
649 z: ClassVar[typ] = typ()
650
651 # Because this is a ClassVar, it can be mutable.
652 @dataclass
653 class C:
654 x: ClassVar[typ] = Subclass()
655
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500656 def test_deliberately_mutable_defaults(self):
657 # If a mutable default isn't in the known list of
658 # (list, dict, set), then it's okay.
659 class Mutable:
660 def __init__(self):
661 self.l = []
662
663 @dataclass
664 class C:
665 x: Mutable
666
667 # These 2 instances will share this value of x.
668 lst = Mutable()
669 o1 = C(lst)
670 o2 = C(lst)
671 self.assertEqual(o1, o2)
672 o1.x.l.extend([1, 2])
673 self.assertEqual(o1, o2)
674 self.assertEqual(o1.x.l, [1, 2])
675 self.assertIs(o1.x, o2.x)
676
677 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400678 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500679 @dataclass()
680 class C:
681 x: int
682
683 self.assertEqual(C(42).x, 42)
684
685 def test_not_tuple(self):
686 # Make sure we can't be compared to a tuple.
687 @dataclass
688 class Point:
689 x: int
690 y: int
691 self.assertNotEqual(Point(1, 2), (1, 2))
692
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400693 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 @dataclass
695 class C:
696 x: int
697 y: int
698 self.assertNotEqual(Point(1, 3), C(1, 3))
699
Windson yangbe372d72019-04-23 02:45:34 +0800700 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 # Test that some of the problems with namedtuple don't happen
702 # here.
703 @dataclass
704 class Point3D:
705 x: int
706 y: int
707 z: int
708
709 @dataclass
710 class Date:
711 year: int
712 month: int
713 day: int
714
715 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
716 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
717
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400718 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200719 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 x, y, z = Point3D(4, 5, 6)
721
Eric V. Smith7c99e932018-01-28 19:18:55 -0500722 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # equal.
724 @dataclass
725 class Point3Dv1:
726 x: int = 0
727 y: int = 0
728 z: int = 0
729 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
730
731 def test_function_annotations(self):
732 # Some dummy class and instance to use as a default.
733 class F:
734 pass
735 f = F()
736
737 def validate_class(cls):
738 # First, check __annotations__, even though they're not
739 # function annotations.
740 self.assertEqual(cls.__annotations__['i'], int)
741 self.assertEqual(cls.__annotations__['j'], str)
742 self.assertEqual(cls.__annotations__['k'], F)
743 self.assertEqual(cls.__annotations__['l'], float)
744 self.assertEqual(cls.__annotations__['z'], complex)
745
746 # Verify __init__.
747
748 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400749 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500750 self.assertIs(signature.return_annotation, None)
751
752 # Check each parameter.
753 params = iter(signature.parameters.values())
754 param = next(params)
755 # This is testing an internal name, and probably shouldn't be tested.
756 self.assertEqual(param.name, 'self')
757 param = next(params)
758 self.assertEqual(param.name, 'i')
759 self.assertIs (param.annotation, int)
760 self.assertEqual(param.default, inspect.Parameter.empty)
761 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
762 param = next(params)
763 self.assertEqual(param.name, 'j')
764 self.assertIs (param.annotation, str)
765 self.assertEqual(param.default, inspect.Parameter.empty)
766 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
767 param = next(params)
768 self.assertEqual(param.name, 'k')
769 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400770 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
772 param = next(params)
773 self.assertEqual(param.name, 'l')
774 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400775 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500776 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
777 self.assertRaises(StopIteration, next, params)
778
779
780 @dataclass
781 class C:
782 i: int
783 j: str
784 k: F = f
785 l: float=field(default=None)
786 z: complex=field(default=3+4j, init=False)
787
788 validate_class(C)
789
790 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500791 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 class C:
793 i: int
794 j: str
795 k: F = f
796 l: float=field(default=None)
797 z: complex=field(default=3+4j, init=False)
798
799 validate_class(C)
800
Eric V. Smith03220fd2017-12-29 13:59:58 -0500801 def test_missing_default(self):
802 # Test that MISSING works the same as a default not being
803 # specified.
804 @dataclass
805 class C:
806 x: int=field(default=MISSING)
807 with self.assertRaisesRegex(TypeError,
808 r'__init__\(\) missing 1 required '
809 'positional argument'):
810 C()
811 self.assertNotIn('x', C.__dict__)
812
813 @dataclass
814 class D:
815 x: int
816 with self.assertRaisesRegex(TypeError,
817 r'__init__\(\) missing 1 required '
818 'positional argument'):
819 D()
820 self.assertNotIn('x', D.__dict__)
821
822 def test_missing_default_factory(self):
823 # Test that MISSING works the same as a default factory not
824 # being specified (which is really the same as a default not
825 # being specified, too).
826 @dataclass
827 class C:
828 x: int=field(default_factory=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int=field(default=MISSING, default_factory=MISSING)
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_repr(self):
845 self.assertIn('MISSING_TYPE object', repr(MISSING))
846
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 def test_dont_include_other_annotations(self):
848 @dataclass
849 class C:
850 i: int
851 def foo(self) -> int:
852 return 4
853 @property
854 def bar(self) -> int:
855 return 5
856 self.assertEqual(list(C.__annotations__), ['i'])
857 self.assertEqual(C(10).foo(), 4)
858 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400859 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500860
861 def test_post_init(self):
862 # Just make sure it gets called
863 @dataclass
864 class C:
865 def __post_init__(self):
866 raise CustomError()
867 with self.assertRaises(CustomError):
868 C()
869
870 @dataclass
871 class C:
872 i: int = 10
873 def __post_init__(self):
874 if self.i == 10:
875 raise CustomError()
876 with self.assertRaises(CustomError):
877 C()
878 # post-init gets called, but doesn't raise. This is just
879 # checking that self is used correctly.
880 C(5)
881
882 # If there's not an __init__, then post-init won't get called.
883 @dataclass(init=False)
884 class C:
885 def __post_init__(self):
886 raise CustomError()
887 # Creating the class won't raise
888 C()
889
890 @dataclass
891 class C:
892 x: int = 0
893 def __post_init__(self):
894 self.x *= 2
895 self.assertEqual(C().x, 0)
896 self.assertEqual(C(2).x, 4)
897
Mike53f7a7c2017-12-14 14:04:53 +0300898 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500899 # attributes.
900 @dataclass(frozen=True)
901 class C:
902 x: int = 0
903 def __post_init__(self):
904 self.x *= 2
905 with self.assertRaises(FrozenInstanceError):
906 C()
907
908 def test_post_init_super(self):
909 # Make sure super() post-init isn't called by default.
910 class B:
911 def __post_init__(self):
912 raise CustomError()
913
914 @dataclass
915 class C(B):
916 def __post_init__(self):
917 self.x = 5
918
919 self.assertEqual(C().x, 5)
920
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400921 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500922 @dataclass
923 class C(B):
924 def __post_init__(self):
925 super().__post_init__()
926
927 with self.assertRaises(CustomError):
928 C()
929
930 # Make sure post-init is called, even if not defined in our
931 # class.
932 @dataclass
933 class C(B):
934 pass
935
936 with self.assertRaises(CustomError):
937 C()
938
939 def test_post_init_staticmethod(self):
940 flag = False
941 @dataclass
942 class C:
943 x: int
944 y: int
945 @staticmethod
946 def __post_init__():
947 nonlocal flag
948 flag = True
949
950 self.assertFalse(flag)
951 c = C(3, 4)
952 self.assertEqual((c.x, c.y), (3, 4))
953 self.assertTrue(flag)
954
955 def test_post_init_classmethod(self):
956 @dataclass
957 class C:
958 flag = False
959 x: int
960 y: int
961 @classmethod
962 def __post_init__(cls):
963 cls.flag = True
964
965 self.assertFalse(C.flag)
966 c = C(3, 4)
967 self.assertEqual((c.x, c.y), (3, 4))
968 self.assertTrue(C.flag)
969
970 def test_class_var(self):
971 # Make sure ClassVars are ignored in __init__, __repr__, etc.
972 @dataclass
973 class C:
974 x: int
975 y: int = 10
976 z: ClassVar[int] = 1000
977 w: ClassVar[int] = 2000
978 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400979 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500980
981 c = C(5)
982 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400983 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400984 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500985 self.assertEqual(c.z, 1000)
986 self.assertEqual(c.w, 2000)
987 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400988 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500989 C.z += 1
990 self.assertEqual(c.z, 1001)
991 c = C(20)
992 self.assertEqual((c.x, c.y), (20, 10))
993 self.assertEqual(c.z, 1001)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400996 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500997
998 def test_class_var_no_default(self):
999 # If a ClassVar has no default value, it should not be set on the class.
1000 @dataclass
1001 class C:
1002 x: ClassVar[int]
1003
1004 self.assertNotIn('x', C.__dict__)
1005
1006 def test_class_var_default_factory(self):
1007 # It makes no sense for a ClassVar to have a default factory. When
1008 # would it be called? Call it yourself, since it's class-wide.
1009 with self.assertRaisesRegex(TypeError,
1010 'cannot have a default factory'):
1011 @dataclass
1012 class C:
1013 x: ClassVar[int] = field(default_factory=int)
1014
1015 self.assertNotIn('x', C.__dict__)
1016
1017 def test_class_var_with_default(self):
1018 # If a ClassVar has a default value, it should be set on the class.
1019 @dataclass
1020 class C:
1021 x: ClassVar[int] = 10
1022 self.assertEqual(C.x, 10)
1023
1024 @dataclass
1025 class C:
1026 x: ClassVar[int] = field(default=10)
1027 self.assertEqual(C.x, 10)
1028
1029 def test_class_var_frozen(self):
1030 # Make sure ClassVars work even if we're frozen.
1031 @dataclass(frozen=True)
1032 class C:
1033 x: int
1034 y: int = 10
1035 z: ClassVar[int] = 1000
1036 w: ClassVar[int] = 2000
1037 t: ClassVar[int] = 3000
1038
1039 c = C(5)
1040 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1041 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1042 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1043 self.assertEqual(c.z, 1000)
1044 self.assertEqual(c.w, 2000)
1045 self.assertEqual(c.t, 3000)
1046 # We can still modify the ClassVar, it's only instances that are
1047 # frozen.
1048 C.z += 1
1049 self.assertEqual(c.z, 1001)
1050 c = C(20)
1051 self.assertEqual((c.x, c.y), (20, 10))
1052 self.assertEqual(c.z, 1001)
1053 self.assertEqual(c.w, 2000)
1054 self.assertEqual(c.t, 3000)
1055
1056 def test_init_var_no_default(self):
1057 # If an InitVar has no default value, it should not be set on the class.
1058 @dataclass
1059 class C:
1060 x: InitVar[int]
1061
1062 self.assertNotIn('x', C.__dict__)
1063
1064 def test_init_var_default_factory(self):
1065 # It makes no sense for an InitVar to have a default factory. When
1066 # would it be called? Call it yourself, since it's class-wide.
1067 with self.assertRaisesRegex(TypeError,
1068 'cannot have a default factory'):
1069 @dataclass
1070 class C:
1071 x: InitVar[int] = field(default_factory=int)
1072
1073 self.assertNotIn('x', C.__dict__)
1074
1075 def test_init_var_with_default(self):
1076 # If an InitVar has a default value, it should be set on the class.
1077 @dataclass
1078 class C:
1079 x: InitVar[int] = 10
1080 self.assertEqual(C.x, 10)
1081
1082 @dataclass
1083 class C:
1084 x: InitVar[int] = field(default=10)
1085 self.assertEqual(C.x, 10)
1086
1087 def test_init_var(self):
1088 @dataclass
1089 class C:
1090 x: int = None
1091 init_param: InitVar[int] = None
1092
1093 def __post_init__(self, init_param):
1094 if self.x is None:
1095 self.x = init_param*2
1096
1097 c = C(init_param=10)
1098 self.assertEqual(c.x, 20)
1099
Augusto Hack01ee12b2019-06-02 23:14:48 -03001100 def test_init_var_preserve_type(self):
1101 self.assertEqual(InitVar[int].type, int)
1102
1103 # Make sure the repr is correct.
1104 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001105 self.assertEqual(repr(InitVar[List[int]]),
1106 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001107
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001108 def test_init_var_inheritance(self):
1109 # Note that this deliberately tests that a dataclass need not
1110 # have a __post_init__ function if it has an InitVar field.
1111 # It could just be used in a derived class, as shown here.
1112 @dataclass
1113 class Base:
1114 x: int
1115 init_base: InitVar[int]
1116
1117 # We can instantiate by passing the InitVar, even though
1118 # it's not used.
1119 b = Base(0, 10)
1120 self.assertEqual(vars(b), {'x': 0})
1121
1122 @dataclass
1123 class C(Base):
1124 y: int
1125 init_derived: InitVar[int]
1126
1127 def __post_init__(self, init_base, init_derived):
1128 self.x = self.x + init_base
1129 self.y = self.y + init_derived
1130
1131 c = C(10, 11, 50, 51)
1132 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1133
1134 def test_default_factory(self):
1135 # Test a factory that returns a new list.
1136 @dataclass
1137 class C:
1138 x: int
1139 y: list = field(default_factory=list)
1140
1141 c0 = C(3)
1142 c1 = C(3)
1143 self.assertEqual(c0.x, 3)
1144 self.assertEqual(c0.y, [])
1145 self.assertEqual(c0, c1)
1146 self.assertIsNot(c0.y, c1.y)
1147 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1148
1149 # Test a factory that returns a shared list.
1150 l = []
1151 @dataclass
1152 class C:
1153 x: int
1154 y: list = field(default_factory=lambda: l)
1155
1156 c0 = C(3)
1157 c1 = C(3)
1158 self.assertEqual(c0.x, 3)
1159 self.assertEqual(c0.y, [])
1160 self.assertEqual(c0, c1)
1161 self.assertIs(c0.y, c1.y)
1162 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1163
1164 # Test various other field flags.
1165 # repr
1166 @dataclass
1167 class C:
1168 x: list = field(default_factory=list, repr=False)
1169 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1170 self.assertEqual(C().x, [])
1171
1172 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001173 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001174 class C:
1175 x: list = field(default_factory=list, hash=False)
1176 self.assertEqual(astuple(C()), ([],))
1177 self.assertEqual(hash(C()), hash(()))
1178
1179 # init (see also test_default_factory_with_no_init)
1180 @dataclass
1181 class C:
1182 x: list = field(default_factory=list, init=False)
1183 self.assertEqual(astuple(C()), ([],))
1184
1185 # compare
1186 @dataclass
1187 class C:
1188 x: list = field(default_factory=list, compare=False)
1189 self.assertEqual(C(), C([1]))
1190
1191 def test_default_factory_with_no_init(self):
1192 # We need a factory with a side effect.
1193 factory = Mock()
1194
1195 @dataclass
1196 class C:
1197 x: list = field(default_factory=factory, init=False)
1198
1199 # Make sure the default factory is called for each new instance.
1200 C().x
1201 self.assertEqual(factory.call_count, 1)
1202 C().x
1203 self.assertEqual(factory.call_count, 2)
1204
1205 def test_default_factory_not_called_if_value_given(self):
1206 # We need a factory that we can test if it's been called.
1207 factory = Mock()
1208
1209 @dataclass
1210 class C:
1211 x: int = field(default_factory=factory)
1212
1213 # Make sure that if a field has a default factory function,
1214 # it's not called if a value is specified.
1215 C().x
1216 self.assertEqual(factory.call_count, 1)
1217 self.assertEqual(C(10).x, 10)
1218 self.assertEqual(factory.call_count, 1)
1219 C().x
1220 self.assertEqual(factory.call_count, 2)
1221
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001222 def test_default_factory_derived(self):
1223 # See bpo-32896.
1224 @dataclass
1225 class Foo:
1226 x: dict = field(default_factory=dict)
1227
1228 @dataclass
1229 class Bar(Foo):
1230 y: int = 1
1231
1232 self.assertEqual(Foo().x, {})
1233 self.assertEqual(Bar().x, {})
1234 self.assertEqual(Bar().y, 1)
1235
1236 @dataclass
1237 class Baz(Foo):
1238 pass
1239 self.assertEqual(Baz().x, {})
1240
1241 def test_intermediate_non_dataclass(self):
1242 # Test that an intermediate class that defines
1243 # annotations does not define fields.
1244
1245 @dataclass
1246 class A:
1247 x: int
1248
1249 class B(A):
1250 y: int
1251
1252 @dataclass
1253 class C(B):
1254 z: int
1255
1256 c = C(1, 3)
1257 self.assertEqual((c.x, c.z), (1, 3))
1258
1259 # .y was not initialized.
1260 with self.assertRaisesRegex(AttributeError,
1261 'object has no attribute'):
1262 c.y
1263
1264 # And if we again derive a non-dataclass, no fields are added.
1265 class D(C):
1266 t: int
1267 d = D(4, 5)
1268 self.assertEqual((d.x, d.z), (4, 5))
1269
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001270 def test_classvar_default_factory(self):
1271 # It's an error for a ClassVar to have a factory function.
1272 with self.assertRaisesRegex(TypeError,
1273 'cannot have a default factory'):
1274 @dataclass
1275 class C:
1276 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001277
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001278 def test_is_dataclass(self):
1279 class NotDataClass:
1280 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001281
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001282 self.assertFalse(is_dataclass(0))
1283 self.assertFalse(is_dataclass(int))
1284 self.assertFalse(is_dataclass(NotDataClass))
1285 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001286
1287 @dataclass
1288 class C:
1289 x: int
1290
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001291 @dataclass
1292 class D:
1293 d: C
1294 e: int
1295
1296 c = C(10)
1297 d = D(c, 4)
1298
1299 self.assertTrue(is_dataclass(C))
1300 self.assertTrue(is_dataclass(c))
1301 self.assertFalse(is_dataclass(c.x))
1302 self.assertTrue(is_dataclass(d.d))
1303 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001304
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001305 def test_is_dataclass_when_getattr_always_returns(self):
1306 # See bpo-37868.
1307 class A:
1308 def __getattr__(self, key):
1309 return 0
1310 self.assertFalse(is_dataclass(A))
1311 a = A()
1312
1313 # Also test for an instance attribute.
1314 class B:
1315 pass
1316 b = B()
1317 b.__dataclass_fields__ = []
1318
1319 for obj in a, b:
1320 with self.subTest(obj=obj):
1321 self.assertFalse(is_dataclass(obj))
1322
1323 # Indirect tests for _is_dataclass_instance().
1324 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1325 asdict(obj)
1326 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1327 astuple(obj)
1328 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1329 replace(obj, x=0)
1330
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001331 def test_helper_fields_with_class_instance(self):
1332 # Check that we can call fields() on either a class or instance,
1333 # and get back the same thing.
1334 @dataclass
1335 class C:
1336 x: int
1337 y: float
1338
1339 self.assertEqual(fields(C), fields(C(0, 0.0)))
1340
1341 def test_helper_fields_exception(self):
1342 # Check that TypeError is raised if not passed a dataclass or
1343 # instance.
1344 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1345 fields(0)
1346
1347 class C: pass
1348 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1349 fields(C)
1350 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1351 fields(C())
1352
1353 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001354 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001355 @dataclass
1356 class C:
1357 x: int
1358 y: int
1359 c = C(1, 2)
1360
1361 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1362 self.assertEqual(asdict(c), asdict(c))
1363 self.assertIsNot(asdict(c), asdict(c))
1364 c.x = 42
1365 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1366 self.assertIs(type(asdict(c)), dict)
1367
1368 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001369 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001370 @dataclass
1371 class C:
1372 x: int
1373 y: int
1374 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1375 asdict(C)
1376 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1377 asdict(int)
1378
1379 def test_helper_asdict_copy_values(self):
1380 @dataclass
1381 class C:
1382 x: int
1383 y: List[int] = field(default_factory=list)
1384 initial = []
1385 c = C(1, initial)
1386 d = asdict(c)
1387 self.assertEqual(d['y'], initial)
1388 self.assertIsNot(d['y'], initial)
1389 c = C(1)
1390 d = asdict(c)
1391 d['y'].append(1)
1392 self.assertEqual(c.y, [])
1393
1394 def test_helper_asdict_nested(self):
1395 @dataclass
1396 class UserId:
1397 token: int
1398 group: int
1399 @dataclass
1400 class User:
1401 name: str
1402 id: UserId
1403 u = User('Joe', UserId(123, 1))
1404 d = asdict(u)
1405 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1406 self.assertIsNot(asdict(u), asdict(u))
1407 u.id.group = 2
1408 self.assertEqual(asdict(u), {'name': 'Joe',
1409 'id': {'token': 123, 'group': 2}})
1410
1411 def test_helper_asdict_builtin_containers(self):
1412 @dataclass
1413 class User:
1414 name: str
1415 id: int
1416 @dataclass
1417 class GroupList:
1418 id: int
1419 users: List[User]
1420 @dataclass
1421 class GroupTuple:
1422 id: int
1423 users: Tuple[User, ...]
1424 @dataclass
1425 class GroupDict:
1426 id: int
1427 users: Dict[str, User]
1428 a = User('Alice', 1)
1429 b = User('Bob', 2)
1430 gl = GroupList(0, [a, b])
1431 gt = GroupTuple(0, (a, b))
1432 gd = GroupDict(0, {'first': a, 'second': b})
1433 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1434 {'name': 'Bob', 'id': 2}]})
1435 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1436 {'name': 'Bob', 'id': 2})})
1437 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1438 'second': {'name': 'Bob', 'id': 2}}})
1439
Windson yangbe372d72019-04-23 02:45:34 +08001440 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001441 @dataclass
1442 class Child:
1443 d: object
1444
1445 @dataclass
1446 class Parent:
1447 child: Child
1448
1449 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1450 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1451
1452 def test_helper_asdict_factory(self):
1453 @dataclass
1454 class C:
1455 x: int
1456 y: int
1457 c = C(1, 2)
1458 d = asdict(c, dict_factory=OrderedDict)
1459 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1460 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1461 c.x = 42
1462 d = asdict(c, dict_factory=OrderedDict)
1463 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1464 self.assertIs(type(d), OrderedDict)
1465
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001466 def test_helper_asdict_namedtuple(self):
1467 T = namedtuple('T', 'a b c')
1468 @dataclass
1469 class C:
1470 x: str
1471 y: T
1472 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1473
1474 d = asdict(c)
1475 self.assertEqual(d, {'x': 'outer',
1476 'y': T(1,
1477 {'x': 'inner',
1478 'y': T(11, 12, 13)},
1479 2),
1480 }
1481 )
1482
1483 # Now with a dict_factory. OrderedDict is convenient, but
1484 # since it compares to dicts, we also need to have separate
1485 # assertIs tests.
1486 d = asdict(c, dict_factory=OrderedDict)
1487 self.assertEqual(d, {'x': 'outer',
1488 'y': T(1,
1489 {'x': 'inner',
1490 'y': T(11, 12, 13)},
1491 2),
1492 }
1493 )
1494
penguindustin96466302019-05-06 14:57:17 -04001495 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001496 self.assertIs(type(d), OrderedDict)
1497 self.assertIs(type(d['y'][1]), OrderedDict)
1498
1499 def test_helper_asdict_namedtuple_key(self):
1500 # Ensure that a field that contains a dict which has a
1501 # namedtuple as a key works with asdict().
1502
1503 @dataclass
1504 class C:
1505 f: dict
1506 T = namedtuple('T', 'a')
1507
1508 c = C({T('an a'): 0})
1509
1510 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1511
1512 def test_helper_asdict_namedtuple_derived(self):
1513 class T(namedtuple('Tbase', 'a')):
1514 def my_a(self):
1515 return self.a
1516
1517 @dataclass
1518 class C:
1519 f: T
1520
1521 t = T(6)
1522 c = C(t)
1523
1524 d = asdict(c)
1525 self.assertEqual(d, {'f': T(a=6)})
1526 # Make sure that t has been copied, not used directly.
1527 self.assertIsNot(d['f'], t)
1528 self.assertEqual(d['f'].my_a(), 6)
1529
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001530 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001531 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001532 @dataclass
1533 class C:
1534 x: int
1535 y: int = 0
1536 c = C(1)
1537
1538 self.assertEqual(astuple(c), (1, 0))
1539 self.assertEqual(astuple(c), astuple(c))
1540 self.assertIsNot(astuple(c), astuple(c))
1541 c.y = 42
1542 self.assertEqual(astuple(c), (1, 42))
1543 self.assertIs(type(astuple(c)), tuple)
1544
1545 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001546 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001547 @dataclass
1548 class C:
1549 x: int
1550 y: int
1551 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1552 astuple(C)
1553 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1554 astuple(int)
1555
1556 def test_helper_astuple_copy_values(self):
1557 @dataclass
1558 class C:
1559 x: int
1560 y: List[int] = field(default_factory=list)
1561 initial = []
1562 c = C(1, initial)
1563 t = astuple(c)
1564 self.assertEqual(t[1], initial)
1565 self.assertIsNot(t[1], initial)
1566 c = C(1)
1567 t = astuple(c)
1568 t[1].append(1)
1569 self.assertEqual(c.y, [])
1570
1571 def test_helper_astuple_nested(self):
1572 @dataclass
1573 class UserId:
1574 token: int
1575 group: int
1576 @dataclass
1577 class User:
1578 name: str
1579 id: UserId
1580 u = User('Joe', UserId(123, 1))
1581 t = astuple(u)
1582 self.assertEqual(t, ('Joe', (123, 1)))
1583 self.assertIsNot(astuple(u), astuple(u))
1584 u.id.group = 2
1585 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1586
1587 def test_helper_astuple_builtin_containers(self):
1588 @dataclass
1589 class User:
1590 name: str
1591 id: int
1592 @dataclass
1593 class GroupList:
1594 id: int
1595 users: List[User]
1596 @dataclass
1597 class GroupTuple:
1598 id: int
1599 users: Tuple[User, ...]
1600 @dataclass
1601 class GroupDict:
1602 id: int
1603 users: Dict[str, User]
1604 a = User('Alice', 1)
1605 b = User('Bob', 2)
1606 gl = GroupList(0, [a, b])
1607 gt = GroupTuple(0, (a, b))
1608 gd = GroupDict(0, {'first': a, 'second': b})
1609 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1610 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1611 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1612
Windson yangbe372d72019-04-23 02:45:34 +08001613 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001614 @dataclass
1615 class Child:
1616 d: object
1617
1618 @dataclass
1619 class Parent:
1620 child: Child
1621
1622 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1623 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1624
1625 def test_helper_astuple_factory(self):
1626 @dataclass
1627 class C:
1628 x: int
1629 y: int
1630 NT = namedtuple('NT', 'x y')
1631 def nt(lst):
1632 return NT(*lst)
1633 c = C(1, 2)
1634 t = astuple(c, tuple_factory=nt)
1635 self.assertEqual(t, NT(1, 2))
1636 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1637 c.x = 42
1638 t = astuple(c, tuple_factory=nt)
1639 self.assertEqual(t, NT(42, 2))
1640 self.assertIs(type(t), NT)
1641
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001642 def test_helper_astuple_namedtuple(self):
1643 T = namedtuple('T', 'a b c')
1644 @dataclass
1645 class C:
1646 x: str
1647 y: T
1648 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1649
1650 t = astuple(c)
1651 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1652
1653 # Now, using a tuple_factory. list is convenient here.
1654 t = astuple(c, tuple_factory=list)
1655 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1656
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001657 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001658 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001659 }
1660
1661 # Create the class.
1662 cls = type('C', (), cls_dict)
1663
1664 # Make it a dataclass.
1665 cls1 = dataclass(cls)
1666
1667 self.assertEqual(cls1, cls)
1668 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1669
1670 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001671 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001672 'y': field(default=5),
1673 }
1674
1675 # Create the class.
1676 cls = type('C', (), cls_dict)
1677
1678 # Make it a dataclass.
1679 cls1 = dataclass(cls)
1680
1681 self.assertEqual(cls1, cls)
1682 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1683
1684 def test_init_in_order(self):
1685 @dataclass
1686 class C:
1687 a: int
1688 b: int = field()
1689 c: list = field(default_factory=list, init=False)
1690 d: list = field(default_factory=list)
1691 e: int = field(default=4, init=False)
1692 f: int = 4
1693
1694 calls = []
1695 def setattr(self, name, value):
1696 calls.append((name, value))
1697
1698 C.__setattr__ = setattr
1699 c = C(0, 1)
1700 self.assertEqual(('a', 0), calls[0])
1701 self.assertEqual(('b', 1), calls[1])
1702 self.assertEqual(('c', []), calls[2])
1703 self.assertEqual(('d', []), calls[3])
1704 self.assertNotIn(('e', 4), calls)
1705 self.assertEqual(('f', 4), calls[4])
1706
1707 def test_items_in_dicts(self):
1708 @dataclass
1709 class C:
1710 a: int
1711 b: list = field(default_factory=list, init=False)
1712 c: list = field(default_factory=list)
1713 d: int = field(default=4, init=False)
1714 e: int = 0
1715
1716 c = C(0)
1717 # Class dict
1718 self.assertNotIn('a', C.__dict__)
1719 self.assertNotIn('b', C.__dict__)
1720 self.assertNotIn('c', C.__dict__)
1721 self.assertIn('d', C.__dict__)
1722 self.assertEqual(C.d, 4)
1723 self.assertIn('e', C.__dict__)
1724 self.assertEqual(C.e, 0)
1725 # Instance dict
1726 self.assertIn('a', c.__dict__)
1727 self.assertEqual(c.a, 0)
1728 self.assertIn('b', c.__dict__)
1729 self.assertEqual(c.b, [])
1730 self.assertIn('c', c.__dict__)
1731 self.assertEqual(c.c, [])
1732 self.assertNotIn('d', c.__dict__)
1733 self.assertIn('e', c.__dict__)
1734 self.assertEqual(c.e, 0)
1735
1736 def test_alternate_classmethod_constructor(self):
1737 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001738 # alternate constructor. This is mostly an example to show
1739 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001740 @dataclass
1741 class C:
1742 x: int
1743 @classmethod
1744 def from_file(cls, filename):
1745 # In a real example, create a new instance
1746 # and populate 'x' from contents of a file.
1747 value_in_file = 20
1748 return cls(value_in_file)
1749
1750 self.assertEqual(C.from_file('filename').x, 20)
1751
1752 def test_field_metadata_default(self):
1753 # Make sure the default metadata is read-only and of
1754 # zero length.
1755 @dataclass
1756 class C:
1757 i: int
1758
1759 self.assertFalse(fields(C)[0].metadata)
1760 self.assertEqual(len(fields(C)[0].metadata), 0)
1761 with self.assertRaisesRegex(TypeError,
1762 'does not support item assignment'):
1763 fields(C)[0].metadata['test'] = 3
1764
1765 def test_field_metadata_mapping(self):
1766 # Make sure only a mapping can be passed as metadata
1767 # zero length.
1768 with self.assertRaises(TypeError):
1769 @dataclass
1770 class C:
1771 i: int = field(metadata=0)
1772
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001773 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001774 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001775 @dataclass
1776 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001777 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001778 self.assertFalse(fields(C)[0].metadata)
1779 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001780 # Update should work (see bpo-35960).
1781 d['foo'] = 1
1782 self.assertEqual(len(fields(C)[0].metadata), 1)
1783 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001784 with self.assertRaisesRegex(TypeError,
1785 'does not support item assignment'):
1786 fields(C)[0].metadata['test'] = 3
1787
1788 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001789 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001790 @dataclass
1791 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001792 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001793 self.assertEqual(len(fields(C)[0].metadata), 3)
1794 self.assertEqual(fields(C)[0].metadata['test'], 10)
1795 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1796 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001797 # Update should work.
1798 d['foo'] = 1
1799 self.assertEqual(len(fields(C)[0].metadata), 4)
1800 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001801 with self.assertRaises(KeyError):
1802 # Non-existent key.
1803 fields(C)[0].metadata['baz']
1804 with self.assertRaisesRegex(TypeError,
1805 'does not support item assignment'):
1806 fields(C)[0].metadata['test'] = 3
1807
1808 def test_field_metadata_custom_mapping(self):
1809 # Try a custom mapping.
1810 class SimpleNameSpace:
1811 def __init__(self, **kw):
1812 self.__dict__.update(kw)
1813
1814 def __getitem__(self, item):
1815 if item == 'xyzzy':
1816 return 'plugh'
1817 return getattr(self, item)
1818
1819 def __len__(self):
1820 return self.__dict__.__len__()
1821
1822 @dataclass
1823 class C:
1824 i: int = field(metadata=SimpleNameSpace(a=10))
1825
1826 self.assertEqual(len(fields(C)[0].metadata), 1)
1827 self.assertEqual(fields(C)[0].metadata['a'], 10)
1828 with self.assertRaises(AttributeError):
1829 fields(C)[0].metadata['b']
1830 # Make sure we're still talking to our custom mapping.
1831 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1832
1833 def test_generic_dataclasses(self):
1834 T = TypeVar('T')
1835
1836 @dataclass
1837 class LabeledBox(Generic[T]):
1838 content: T
1839 label: str = '<unknown>'
1840
1841 box = LabeledBox(42)
1842 self.assertEqual(box.content, 42)
1843 self.assertEqual(box.label, '<unknown>')
1844
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001845 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001846 Alias = List[LabeledBox[int]]
1847
1848 def test_generic_extending(self):
1849 S = TypeVar('S')
1850 T = TypeVar('T')
1851
1852 @dataclass
1853 class Base(Generic[T, S]):
1854 x: T
1855 y: S
1856
1857 @dataclass
1858 class DataDerived(Base[int, T]):
1859 new_field: str
1860 Alias = DataDerived[str]
1861 c = Alias(0, 'test1', 'test2')
1862 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1863
1864 class NonDataDerived(Base[int, T]):
1865 def new_method(self):
1866 return self.y
1867 Alias = NonDataDerived[float]
1868 c = Alias(10, 1.0)
1869 self.assertEqual(c.new_method(), 1.0)
1870
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001871 def test_generic_dynamic(self):
1872 T = TypeVar('T')
1873
1874 @dataclass
1875 class Parent(Generic[T]):
1876 x: T
1877 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1878 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1879 self.assertIs(Child[int](1, 2).z, None)
1880 self.assertEqual(Child[int](1, 2, 3).z, 3)
1881 self.assertEqual(Child[int](1, 2, 3).other, 42)
1882 # Check that type aliases work correctly.
1883 Alias = Child[T]
1884 self.assertEqual(Alias[int](1, 2).x, 1)
1885 # Check MRO resolution.
1886 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1887
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001888 def test_dataclassses_pickleable(self):
1889 global P, Q, R
1890 @dataclass
1891 class P:
1892 x: int
1893 y: int = 0
1894 @dataclass
1895 class Q:
1896 x: int
1897 y: int = field(default=0, init=False)
1898 @dataclass
1899 class R:
1900 x: int
1901 y: List[int] = field(default_factory=list)
1902 q = Q(1)
1903 q.y = 2
1904 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1905 for sample in samples:
1906 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1907 with self.subTest(sample=sample, proto=proto):
1908 new_sample = pickle.loads(pickle.dumps(sample, proto))
1909 self.assertEqual(sample.x, new_sample.x)
1910 self.assertEqual(sample.y, new_sample.y)
1911 self.assertIsNot(sample, new_sample)
1912 new_sample.x = 42
1913 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1914 self.assertEqual(new_sample.x, another_new_sample.x)
1915 self.assertEqual(sample.y, another_new_sample.y)
1916
Eric V. Smithea8fc522018-01-27 19:07:40 -05001917
Eric V. Smith56970b82018-03-22 16:28:48 -04001918class TestFieldNoAnnotation(unittest.TestCase):
1919 def test_field_without_annotation(self):
1920 with self.assertRaisesRegex(TypeError,
1921 "'f' is a field but has no type annotation"):
1922 @dataclass
1923 class C:
1924 f = field()
1925
1926 def test_field_without_annotation_but_annotation_in_base(self):
1927 @dataclass
1928 class B:
1929 f: int
1930
1931 with self.assertRaisesRegex(TypeError,
1932 "'f' is a field but has no type annotation"):
1933 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001934 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001935 @dataclass
1936 class C(B):
1937 f = field()
1938
1939 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1940 # Same test, but with the base class not a dataclass.
1941 class B:
1942 f: int
1943
1944 with self.assertRaisesRegex(TypeError,
1945 "'f' is a field but has no type annotation"):
1946 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001947 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001948 @dataclass
1949 class C(B):
1950 f = field()
1951
1952
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001953class TestDocString(unittest.TestCase):
1954 def assertDocStrEqual(self, a, b):
1955 # Because 3.6 and 3.7 differ in how inspect.signature work
1956 # (see bpo #32108), for the time being just compare them with
1957 # whitespace stripped.
1958 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1959
1960 def test_existing_docstring_not_overridden(self):
1961 @dataclass
1962 class C:
1963 """Lorem ipsum"""
1964 x: int
1965
1966 self.assertEqual(C.__doc__, "Lorem ipsum")
1967
1968 def test_docstring_no_fields(self):
1969 @dataclass
1970 class C:
1971 pass
1972
1973 self.assertDocStrEqual(C.__doc__, "C()")
1974
1975 def test_docstring_one_field(self):
1976 @dataclass
1977 class C:
1978 x: int
1979
1980 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1981
1982 def test_docstring_two_fields(self):
1983 @dataclass
1984 class C:
1985 x: int
1986 y: int
1987
1988 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1989
1990 def test_docstring_three_fields(self):
1991 @dataclass
1992 class C:
1993 x: int
1994 y: int
1995 z: str
1996
1997 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1998
1999 def test_docstring_one_field_with_default(self):
2000 @dataclass
2001 class C:
2002 x: int = 3
2003
2004 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2005
2006 def test_docstring_one_field_with_default_none(self):
2007 @dataclass
2008 class C:
2009 x: Union[int, type(None)] = None
2010
2011 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2012
2013 def test_docstring_list_field(self):
2014 @dataclass
2015 class C:
2016 x: List[int]
2017
2018 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2019
2020 def test_docstring_list_field_with_default_factory(self):
2021 @dataclass
2022 class C:
2023 x: List[int] = field(default_factory=list)
2024
2025 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2026
2027 def test_docstring_deque_field(self):
2028 @dataclass
2029 class C:
2030 x: deque
2031
2032 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2033
2034 def test_docstring_deque_field_with_default_factory(self):
2035 @dataclass
2036 class C:
2037 x: deque = field(default_factory=deque)
2038
2039 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2040
2041
Eric V. Smithea8fc522018-01-27 19:07:40 -05002042class TestInit(unittest.TestCase):
2043 def test_base_has_init(self):
2044 class B:
2045 def __init__(self):
2046 self.z = 100
2047 pass
2048
2049 # Make sure that declaring this class doesn't raise an error.
2050 # The issue is that we can't override __init__ in our class,
2051 # but it should be okay to add __init__ to us if our base has
2052 # an __init__.
2053 @dataclass
2054 class C(B):
2055 x: int = 0
2056 c = C(10)
2057 self.assertEqual(c.x, 10)
2058 self.assertNotIn('z', vars(c))
2059
2060 # Make sure that if we don't add an init, the base __init__
2061 # gets called.
2062 @dataclass(init=False)
2063 class C(B):
2064 x: int = 10
2065 c = C()
2066 self.assertEqual(c.x, 10)
2067 self.assertEqual(c.z, 100)
2068
2069 def test_no_init(self):
2070 dataclass(init=False)
2071 class C:
2072 i: int = 0
2073 self.assertEqual(C().i, 0)
2074
2075 dataclass(init=False)
2076 class C:
2077 i: int = 2
2078 def __init__(self):
2079 self.i = 3
2080 self.assertEqual(C().i, 3)
2081
2082 def test_overwriting_init(self):
2083 # If the class has __init__, use it no matter the value of
2084 # init=.
2085
2086 @dataclass
2087 class C:
2088 x: int
2089 def __init__(self, x):
2090 self.x = 2 * x
2091 self.assertEqual(C(3).x, 6)
2092
2093 @dataclass(init=True)
2094 class C:
2095 x: int
2096 def __init__(self, x):
2097 self.x = 2 * x
2098 self.assertEqual(C(4).x, 8)
2099
2100 @dataclass(init=False)
2101 class C:
2102 x: int
2103 def __init__(self, x):
2104 self.x = 2 * x
2105 self.assertEqual(C(5).x, 10)
2106
2107
2108class TestRepr(unittest.TestCase):
2109 def test_repr(self):
2110 @dataclass
2111 class B:
2112 x: int
2113
2114 @dataclass
2115 class C(B):
2116 y: int = 10
2117
2118 o = C(4)
2119 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2120
2121 @dataclass
2122 class D(C):
2123 x: int = 20
2124 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2125
2126 @dataclass
2127 class C:
2128 @dataclass
2129 class D:
2130 i: int
2131 @dataclass
2132 class E:
2133 pass
2134 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2135 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2136
2137 def test_no_repr(self):
2138 # Test a class with no __repr__ and repr=False.
2139 @dataclass(repr=False)
2140 class C:
2141 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002142 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002143 repr(C(3)))
2144
2145 # Test a class with a __repr__ and repr=False.
2146 @dataclass(repr=False)
2147 class C:
2148 x: int
2149 def __repr__(self):
2150 return 'C-class'
2151 self.assertEqual(repr(C(3)), 'C-class')
2152
2153 def test_overwriting_repr(self):
2154 # If the class has __repr__, use it no matter the value of
2155 # repr=.
2156
2157 @dataclass
2158 class C:
2159 x: int
2160 def __repr__(self):
2161 return 'x'
2162 self.assertEqual(repr(C(0)), 'x')
2163
2164 @dataclass(repr=True)
2165 class C:
2166 x: int
2167 def __repr__(self):
2168 return 'x'
2169 self.assertEqual(repr(C(0)), 'x')
2170
2171 @dataclass(repr=False)
2172 class C:
2173 x: int
2174 def __repr__(self):
2175 return 'x'
2176 self.assertEqual(repr(C(0)), 'x')
2177
2178
Eric V. Smithea8fc522018-01-27 19:07:40 -05002179class TestEq(unittest.TestCase):
2180 def test_no_eq(self):
2181 # Test a class with no __eq__ and eq=False.
2182 @dataclass(eq=False)
2183 class C:
2184 x: int
2185 self.assertNotEqual(C(0), C(0))
2186 c = C(3)
2187 self.assertEqual(c, c)
2188
2189 # Test a class with an __eq__ and eq=False.
2190 @dataclass(eq=False)
2191 class C:
2192 x: int
2193 def __eq__(self, other):
2194 return other == 10
2195 self.assertEqual(C(3), 10)
2196
2197 def test_overwriting_eq(self):
2198 # If the class has __eq__, use it no matter the value of
2199 # eq=.
2200
2201 @dataclass
2202 class C:
2203 x: int
2204 def __eq__(self, other):
2205 return other == 3
2206 self.assertEqual(C(1), 3)
2207 self.assertNotEqual(C(1), 1)
2208
2209 @dataclass(eq=True)
2210 class C:
2211 x: int
2212 def __eq__(self, other):
2213 return other == 4
2214 self.assertEqual(C(1), 4)
2215 self.assertNotEqual(C(1), 1)
2216
2217 @dataclass(eq=False)
2218 class C:
2219 x: int
2220 def __eq__(self, other):
2221 return other == 5
2222 self.assertEqual(C(1), 5)
2223 self.assertNotEqual(C(1), 1)
2224
2225
2226class TestOrdering(unittest.TestCase):
2227 def test_functools_total_ordering(self):
2228 # Test that functools.total_ordering works with this class.
2229 @total_ordering
2230 @dataclass
2231 class C:
2232 x: int
2233 def __lt__(self, other):
2234 # Perform the test "backward", just to make
2235 # sure this is being called.
2236 return self.x >= other
2237
2238 self.assertLess(C(0), -1)
2239 self.assertLessEqual(C(0), -1)
2240 self.assertGreater(C(0), 1)
2241 self.assertGreaterEqual(C(0), 1)
2242
2243 def test_no_order(self):
2244 # Test that no ordering functions are added by default.
2245 @dataclass(order=False)
2246 class C:
2247 x: int
2248 # Make sure no order methods are added.
2249 self.assertNotIn('__le__', C.__dict__)
2250 self.assertNotIn('__lt__', C.__dict__)
2251 self.assertNotIn('__ge__', C.__dict__)
2252 self.assertNotIn('__gt__', C.__dict__)
2253
2254 # Test that __lt__ is still called
2255 @dataclass(order=False)
2256 class C:
2257 x: int
2258 def __lt__(self, other):
2259 return False
2260 # Make sure other methods aren't added.
2261 self.assertNotIn('__le__', C.__dict__)
2262 self.assertNotIn('__ge__', C.__dict__)
2263 self.assertNotIn('__gt__', C.__dict__)
2264
2265 def test_overwriting_order(self):
2266 with self.assertRaisesRegex(TypeError,
2267 'Cannot overwrite attribute __lt__'
2268 '.*using functools.total_ordering'):
2269 @dataclass(order=True)
2270 class C:
2271 x: int
2272 def __lt__(self):
2273 pass
2274
2275 with self.assertRaisesRegex(TypeError,
2276 'Cannot overwrite attribute __le__'
2277 '.*using functools.total_ordering'):
2278 @dataclass(order=True)
2279 class C:
2280 x: int
2281 def __le__(self):
2282 pass
2283
2284 with self.assertRaisesRegex(TypeError,
2285 'Cannot overwrite attribute __gt__'
2286 '.*using functools.total_ordering'):
2287 @dataclass(order=True)
2288 class C:
2289 x: int
2290 def __gt__(self):
2291 pass
2292
2293 with self.assertRaisesRegex(TypeError,
2294 'Cannot overwrite attribute __ge__'
2295 '.*using functools.total_ordering'):
2296 @dataclass(order=True)
2297 class C:
2298 x: int
2299 def __ge__(self):
2300 pass
2301
2302class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002303 def test_unsafe_hash(self):
2304 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002305 class C:
2306 x: int
2307 y: str
2308 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2309
Eric V. Smithea8fc522018-01-27 19:07:40 -05002310 def test_hash_rules(self):
2311 def non_bool(value):
2312 # Map to something else that's True, but not a bool.
2313 if value is None:
2314 return None
2315 if value:
2316 return (3,)
2317 return 0
2318
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002319 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2320 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2321 frozen=frozen):
2322 if result != 'exception':
2323 if with_hash:
2324 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2325 class C:
2326 def __hash__(self):
2327 return 0
2328 else:
2329 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2330 class C:
2331 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002332
2333 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002334 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002335 # __hash__ contains the function we generated.
2336 self.assertIn('__hash__', C.__dict__)
2337 self.assertIsNotNone(C.__dict__['__hash__'])
2338
Eric V. Smithea8fc522018-01-27 19:07:40 -05002339 elif result == '':
2340 # __hash__ is not present in our class.
2341 if not with_hash:
2342 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002343
Eric V. Smithea8fc522018-01-27 19:07:40 -05002344 elif result == 'none':
2345 # __hash__ is set to None.
2346 self.assertIn('__hash__', C.__dict__)
2347 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002348
2349 elif result == 'exception':
2350 # Creating the class should cause an exception.
2351 # This only happens with with_hash==True.
2352 assert(with_hash)
2353 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2354 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2355 class C:
2356 def __hash__(self):
2357 return 0
2358
Eric V. Smithea8fc522018-01-27 19:07:40 -05002359 else:
2360 assert False, f'unknown result {result!r}'
2361
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002362 # There are 8 cases of:
2363 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002364 # eq=True/False
2365 # frozen=True/False
2366 # And for each of these, a different result if
2367 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002368 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2369 (False, False, False, '', ''),
2370 (False, False, True, '', ''),
2371 (False, True, False, 'none', ''),
2372 (False, True, True, 'fn', ''),
2373 (True, False, False, 'fn', 'exception'),
2374 (True, False, True, 'fn', 'exception'),
2375 (True, True, False, 'fn', 'exception'),
2376 (True, True, True, 'fn', 'exception'),
2377 ], 1):
2378 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2379 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002380
2381 # Test non-bool truth values, too. This is just to
2382 # make sure the data-driven table in the decorator
2383 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002384 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2385 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002386
2387
2388 def test_eq_only(self):
2389 # If a class defines __eq__, __hash__ is automatically added
2390 # and set to None. This is normal Python behavior, not
2391 # related to dataclasses. Make sure we don't interfere with
2392 # that (see bpo=32546).
2393
2394 @dataclass
2395 class C:
2396 i: int
2397 def __eq__(self, other):
2398 return self.i == other.i
2399 self.assertEqual(C(1), C(1))
2400 self.assertNotEqual(C(1), C(4))
2401
2402 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002403 # unsafe_hash=True.
2404 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002405 class C:
2406 i: int
2407 def __eq__(self, other):
2408 return self.i == other.i
2409 self.assertEqual(C(1), C(1.0))
2410 self.assertEqual(hash(C(1)), hash(C(1.0)))
2411
2412 # And check that the classes __eq__ is being used, despite
2413 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002414 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002415 class C:
2416 i: int
2417 def __eq__(self, other):
2418 return self.i == 3 and self.i == other.i
2419 self.assertEqual(C(3), C(3))
2420 self.assertNotEqual(C(1), C(1))
2421 self.assertEqual(hash(C(1)), hash(C(1.0)))
2422
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002423 def test_0_field_hash(self):
2424 @dataclass(frozen=True)
2425 class C:
2426 pass
2427 self.assertEqual(hash(C()), hash(()))
2428
2429 @dataclass(unsafe_hash=True)
2430 class C:
2431 pass
2432 self.assertEqual(hash(C()), hash(()))
2433
2434 def test_1_field_hash(self):
2435 @dataclass(frozen=True)
2436 class C:
2437 x: int
2438 self.assertEqual(hash(C(4)), hash((4,)))
2439 self.assertEqual(hash(C(42)), hash((42,)))
2440
2441 @dataclass(unsafe_hash=True)
2442 class C:
2443 x: int
2444 self.assertEqual(hash(C(4)), hash((4,)))
2445 self.assertEqual(hash(C(42)), hash((42,)))
2446
Eric V. Smith718070d2018-02-23 13:01:31 -05002447 def test_hash_no_args(self):
2448 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002449 # make sure that if the @dataclass parameter name is changed
2450 # or the non-default hashing behavior changes, the default
2451 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002452
2453 class Base:
2454 def __hash__(self):
2455 return 301
2456
2457 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002458 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002459 for frozen, eq, base, expected in [
2460 (None, None, object, 'unhashable'),
2461 (None, None, Base, 'unhashable'),
2462 (None, False, object, 'object'),
2463 (None, False, Base, 'base'),
2464 (None, True, object, 'unhashable'),
2465 (None, True, Base, 'unhashable'),
2466 (False, None, object, 'unhashable'),
2467 (False, None, Base, 'unhashable'),
2468 (False, False, object, 'object'),
2469 (False, False, Base, 'base'),
2470 (False, True, object, 'unhashable'),
2471 (False, True, Base, 'unhashable'),
2472 (True, None, object, 'tuple'),
2473 (True, None, Base, 'tuple'),
2474 (True, False, object, 'object'),
2475 (True, False, Base, 'base'),
2476 (True, True, object, 'tuple'),
2477 (True, True, Base, 'tuple'),
2478 ]:
2479
2480 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2481 # First, create the class.
2482 if frozen is None and eq is None:
2483 @dataclass
2484 class C(base):
2485 i: int
2486 elif frozen is None:
2487 @dataclass(eq=eq)
2488 class C(base):
2489 i: int
2490 elif eq is None:
2491 @dataclass(frozen=frozen)
2492 class C(base):
2493 i: int
2494 else:
2495 @dataclass(frozen=frozen, eq=eq)
2496 class C(base):
2497 i: int
2498
2499 # Now, make sure it hashes as expected.
2500 if expected == 'unhashable':
2501 c = C(10)
2502 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2503 hash(c)
2504
2505 elif expected == 'base':
2506 self.assertEqual(hash(C(10)), 301)
2507
2508 elif expected == 'object':
2509 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002510 # hash isn't based on id(), so calling hash()
2511 # won't tell us much. So, just check the
2512 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002513 self.assertIs(C.__hash__, object.__hash__)
2514
2515 elif expected == 'tuple':
2516 self.assertEqual(hash(C(42)), hash((42,)))
2517
2518 else:
2519 assert False, f'unknown value for expected={expected!r}'
2520
Eric V. Smithea8fc522018-01-27 19:07:40 -05002521
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002522class TestFrozen(unittest.TestCase):
2523 def test_frozen(self):
2524 @dataclass(frozen=True)
2525 class C:
2526 i: int
2527
2528 c = C(10)
2529 self.assertEqual(c.i, 10)
2530 with self.assertRaises(FrozenInstanceError):
2531 c.i = 5
2532 self.assertEqual(c.i, 10)
2533
2534 def test_inherit(self):
2535 @dataclass(frozen=True)
2536 class C:
2537 i: int
2538
2539 @dataclass(frozen=True)
2540 class D(C):
2541 j: int
2542
2543 d = D(0, 10)
2544 with self.assertRaises(FrozenInstanceError):
2545 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002546 with self.assertRaises(FrozenInstanceError):
2547 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002548 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002549 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002550
Eric V. Smithf199bc62018-03-18 20:40:34 -04002551 # Test both ways: with an intermediate normal (non-dataclass)
2552 # class and without an intermediate class.
2553 def test_inherit_nonfrozen_from_frozen(self):
2554 for intermediate_class in [True, False]:
2555 with self.subTest(intermediate_class=intermediate_class):
2556 @dataclass(frozen=True)
2557 class C:
2558 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002559
Eric V. Smithf199bc62018-03-18 20:40:34 -04002560 if intermediate_class:
2561 class I(C): pass
2562 else:
2563 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002564
Eric V. Smithf199bc62018-03-18 20:40:34 -04002565 with self.assertRaisesRegex(TypeError,
2566 'cannot inherit non-frozen dataclass from a frozen one'):
2567 @dataclass
2568 class D(I):
2569 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002570
Eric V. Smithf199bc62018-03-18 20:40:34 -04002571 def test_inherit_frozen_from_nonfrozen(self):
2572 for intermediate_class in [True, False]:
2573 with self.subTest(intermediate_class=intermediate_class):
2574 @dataclass
2575 class C:
2576 i: int
2577
2578 if intermediate_class:
2579 class I(C): pass
2580 else:
2581 I = C
2582
2583 with self.assertRaisesRegex(TypeError,
2584 'cannot inherit frozen dataclass from a non-frozen one'):
2585 @dataclass(frozen=True)
2586 class D(I):
2587 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002588
2589 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002590 for intermediate_class in [True, False]:
2591 with self.subTest(intermediate_class=intermediate_class):
2592 class C:
2593 pass
2594
2595 if intermediate_class:
2596 class I(C): pass
2597 else:
2598 I = C
2599
2600 @dataclass(frozen=True)
2601 class D(I):
2602 i: int
2603
2604 d = D(10)
2605 with self.assertRaises(FrozenInstanceError):
2606 d.i = 5
2607
2608 def test_non_frozen_normal_derived(self):
2609 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002610
2611 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002612 class D:
2613 x: int
2614 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002615
Eric V. Smithf199bc62018-03-18 20:40:34 -04002616 class S(D):
2617 pass
2618
2619 s = S(3)
2620 self.assertEqual(s.x, 3)
2621 self.assertEqual(s.y, 10)
2622 s.cached = True
2623
2624 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002625 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002626 s.x = 5
2627 with self.assertRaises(FrozenInstanceError):
2628 s.y = 5
2629 self.assertEqual(s.x, 3)
2630 self.assertEqual(s.y, 10)
2631 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002632
Eric V. Smith74940912018-04-05 06:50:18 -04002633 def test_overwriting_frozen(self):
2634 # frozen uses __setattr__ and __delattr__.
2635 with self.assertRaisesRegex(TypeError,
2636 'Cannot overwrite attribute __setattr__'):
2637 @dataclass(frozen=True)
2638 class C:
2639 x: int
2640 def __setattr__(self):
2641 pass
2642
2643 with self.assertRaisesRegex(TypeError,
2644 'Cannot overwrite attribute __delattr__'):
2645 @dataclass(frozen=True)
2646 class C:
2647 x: int
2648 def __delattr__(self):
2649 pass
2650
2651 @dataclass(frozen=False)
2652 class C:
2653 x: int
2654 def __setattr__(self, name, value):
2655 self.__dict__['x'] = value * 2
2656 self.assertEqual(C(10).x, 20)
2657
2658 def test_frozen_hash(self):
2659 @dataclass(frozen=True)
2660 class C:
2661 x: Any
2662
2663 # If x is immutable, we can compute the hash. No exception is
2664 # raised.
2665 hash(C(3))
2666
2667 # If x is mutable, computing the hash is an error.
2668 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2669 hash(C({}))
2670
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002671
Eric V. Smith7389fd92018-03-19 21:07:51 -04002672class TestSlots(unittest.TestCase):
2673 def test_simple(self):
2674 @dataclass
2675 class C:
2676 __slots__ = ('x',)
2677 x: Any
2678
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002679 # There was a bug where a variable in a slot was assumed to
2680 # also have a default value (of type
2681 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002682 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002683 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002684 C()
2685
2686 # We can create an instance, and assign to x.
2687 c = C(10)
2688 self.assertEqual(c.x, 10)
2689 c.x = 5
2690 self.assertEqual(c.x, 5)
2691
2692 # We can't assign to anything else.
2693 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2694 c.y = 5
2695
2696 def test_derived_added_field(self):
2697 # See bpo-33100.
2698 @dataclass
2699 class Base:
2700 __slots__ = ('x',)
2701 x: Any
2702
2703 @dataclass
2704 class Derived(Base):
2705 x: int
2706 y: int
2707
2708 d = Derived(1, 2)
2709 self.assertEqual((d.x, d.y), (1, 2))
2710
2711 # We can add a new field to the derived instance.
2712 d.z = 10
2713
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002714class TestDescriptors(unittest.TestCase):
2715 def test_set_name(self):
2716 # See bpo-33141.
2717
2718 # Create a descriptor.
2719 class D:
2720 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002721 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002722 def __get__(self, instance, owner):
2723 if instance is not None:
2724 return 1
2725 return self
2726
2727 # This is the case of just normal descriptor behavior, no
2728 # dataclass code is involved in initializing the descriptor.
2729 @dataclass
2730 class C:
2731 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002732 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002733
2734 # Now test with a default value and init=False, which is the
2735 # only time this is really meaningful. If not using
2736 # init=False, then the descriptor will be overwritten, anyway.
2737 @dataclass
2738 class C:
2739 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002740 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002741 self.assertEqual(C().c, 1)
2742
2743 def test_non_descriptor(self):
2744 # PEP 487 says __set_name__ should work on non-descriptors.
2745 # Create a descriptor.
2746
2747 class D:
2748 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002749 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002750
2751 @dataclass
2752 class C:
2753 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002754 self.assertEqual(C.c.name, 'cx')
2755
2756 def test_lookup_on_instance(self):
2757 # See bpo-33175.
2758 class D:
2759 pass
2760
2761 d = D()
2762 # Create an attribute on the instance, not type.
2763 d.__set_name__ = Mock()
2764
2765 # Make sure d.__set_name__ is not called.
2766 @dataclass
2767 class C:
2768 i: int=field(default=d, init=False)
2769
2770 self.assertEqual(d.__set_name__.call_count, 0)
2771
2772 def test_lookup_on_class(self):
2773 # See bpo-33175.
2774 class D:
2775 pass
2776 D.__set_name__ = Mock()
2777
2778 # Make sure D.__set_name__ is called.
2779 @dataclass
2780 class C:
2781 i: int=field(default=D(), init=False)
2782
2783 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002784
Eric V. Smith7389fd92018-03-19 21:07:51 -04002785
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002786class TestStringAnnotations(unittest.TestCase):
2787 def test_classvar(self):
2788 # Some expressions recognized as ClassVar really aren't. But
2789 # if you're using string annotations, it's not an exact
2790 # science.
2791 # These tests assume that both "import typing" and "from
2792 # typing import *" have been run in this file.
2793 for typestr in ('ClassVar[int]',
2794 'ClassVar [int]'
2795 ' ClassVar [int]',
2796 'ClassVar',
2797 ' ClassVar ',
2798 'typing.ClassVar[int]',
2799 'typing.ClassVar[str]',
2800 ' typing.ClassVar[str]',
2801 'typing .ClassVar[str]',
2802 'typing. ClassVar[str]',
2803 'typing.ClassVar [str]',
2804 'typing.ClassVar [ str]',
2805
2806 # Not syntactically valid, but these will
2807 # be treated as ClassVars.
2808 'typing.ClassVar.[int]',
2809 'typing.ClassVar+',
2810 ):
2811 with self.subTest(typestr=typestr):
2812 @dataclass
2813 class C:
2814 x: typestr
2815
2816 # x is a ClassVar, so C() takes no args.
2817 C()
2818
2819 # And it won't appear in the class's dict because it doesn't
2820 # have a default.
2821 self.assertNotIn('x', C.__dict__)
2822
2823 def test_isnt_classvar(self):
2824 for typestr in ('CV',
2825 't.ClassVar',
2826 't.ClassVar[int]',
2827 'typing..ClassVar[int]',
2828 'Classvar',
2829 'Classvar[int]',
2830 'typing.ClassVarx[int]',
2831 'typong.ClassVar[int]',
2832 'dataclasses.ClassVar[int]',
2833 'typingxClassVar[str]',
2834 ):
2835 with self.subTest(typestr=typestr):
2836 @dataclass
2837 class C:
2838 x: typestr
2839
2840 # x is not a ClassVar, so C() takes one arg.
2841 self.assertEqual(C(10).x, 10)
2842
2843 def test_initvar(self):
2844 # These tests assume that both "import dataclasses" and "from
2845 # dataclasses import *" have been run in this file.
2846 for typestr in ('InitVar[int]',
2847 'InitVar [int]'
2848 ' InitVar [int]',
2849 'InitVar',
2850 ' InitVar ',
2851 'dataclasses.InitVar[int]',
2852 'dataclasses.InitVar[str]',
2853 ' dataclasses.InitVar[str]',
2854 'dataclasses .InitVar[str]',
2855 'dataclasses. InitVar[str]',
2856 'dataclasses.InitVar [str]',
2857 'dataclasses.InitVar [ str]',
2858
2859 # Not syntactically valid, but these will
2860 # be treated as InitVars.
2861 'dataclasses.InitVar.[int]',
2862 'dataclasses.InitVar+',
2863 ):
2864 with self.subTest(typestr=typestr):
2865 @dataclass
2866 class C:
2867 x: typestr
2868
2869 # x is an InitVar, so doesn't create a member.
2870 with self.assertRaisesRegex(AttributeError,
2871 "object has no attribute 'x'"):
2872 C(1).x
2873
2874 def test_isnt_initvar(self):
2875 for typestr in ('IV',
2876 'dc.InitVar',
2877 'xdataclasses.xInitVar',
2878 'typing.xInitVar[int]',
2879 ):
2880 with self.subTest(typestr=typestr):
2881 @dataclass
2882 class C:
2883 x: typestr
2884
2885 # x is not an InitVar, so there will be a member x.
2886 self.assertEqual(C(10).x, 10)
2887
2888 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002889 from test import dataclass_module_1
2890 from test import dataclass_module_1_str
2891 from test import dataclass_module_2
2892 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002893
2894 for m in (dataclass_module_1, dataclass_module_1_str,
2895 dataclass_module_2, dataclass_module_2_str,
2896 ):
2897 with self.subTest(m=m):
2898 # There's a difference in how the ClassVars are
2899 # interpreted when using string annotations or
2900 # not. See the imported modules for details.
2901 if m.USING_STRINGS:
2902 c = m.CV(10)
2903 else:
2904 c = m.CV()
2905 self.assertEqual(c.cv0, 20)
2906
2907
2908 # There's a difference in how the InitVars are
2909 # interpreted when using string annotations or
2910 # not. See the imported modules for details.
2911 c = m.IV(0, 1, 2, 3, 4)
2912
2913 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2914 with self.subTest(field_name=field_name):
2915 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2916 # Since field_name is an InitVar, it's
2917 # not an instance field.
2918 getattr(c, field_name)
2919
2920 if m.USING_STRINGS:
2921 # iv4 is interpreted as a normal field.
2922 self.assertIn('not_iv4', c.__dict__)
2923 self.assertEqual(c.not_iv4, 4)
2924 else:
2925 # iv4 is interpreted as an InitVar, so it
2926 # won't exist on the instance.
2927 self.assertNotIn('not_iv4', c.__dict__)
2928
2929
Eric V. Smith4e812962018-05-16 11:31:29 -04002930class TestMakeDataclass(unittest.TestCase):
2931 def test_simple(self):
2932 C = make_dataclass('C',
2933 [('x', int),
2934 ('y', int, field(default=5))],
2935 namespace={'add_one': lambda self: self.x + 1})
2936 c = C(10)
2937 self.assertEqual((c.x, c.y), (10, 5))
2938 self.assertEqual(c.add_one(), 11)
2939
2940
2941 def test_no_mutate_namespace(self):
2942 # Make sure a provided namespace isn't mutated.
2943 ns = {}
2944 C = make_dataclass('C',
2945 [('x', int),
2946 ('y', int, field(default=5))],
2947 namespace=ns)
2948 self.assertEqual(ns, {})
2949
2950 def test_base(self):
2951 class Base1:
2952 pass
2953 class Base2:
2954 pass
2955 C = make_dataclass('C',
2956 [('x', int)],
2957 bases=(Base1, Base2))
2958 c = C(2)
2959 self.assertIsInstance(c, C)
2960 self.assertIsInstance(c, Base1)
2961 self.assertIsInstance(c, Base2)
2962
2963 def test_base_dataclass(self):
2964 @dataclass
2965 class Base1:
2966 x: int
2967 class Base2:
2968 pass
2969 C = make_dataclass('C',
2970 [('y', int)],
2971 bases=(Base1, Base2))
2972 with self.assertRaisesRegex(TypeError, 'required positional'):
2973 c = C(2)
2974 c = C(1, 2)
2975 self.assertIsInstance(c, C)
2976 self.assertIsInstance(c, Base1)
2977 self.assertIsInstance(c, Base2)
2978
2979 self.assertEqual((c.x, c.y), (1, 2))
2980
2981 def test_init_var(self):
2982 def post_init(self, y):
2983 self.x *= y
2984
2985 C = make_dataclass('C',
2986 [('x', int),
2987 ('y', InitVar[int]),
2988 ],
2989 namespace={'__post_init__': post_init},
2990 )
2991 c = C(2, 3)
2992 self.assertEqual(vars(c), {'x': 6})
2993 self.assertEqual(len(fields(c)), 1)
2994
2995 def test_class_var(self):
2996 C = make_dataclass('C',
2997 [('x', int),
2998 ('y', ClassVar[int], 10),
2999 ('z', ClassVar[int], field(default=20)),
3000 ])
3001 c = C(1)
3002 self.assertEqual(vars(c), {'x': 1})
3003 self.assertEqual(len(fields(c)), 1)
3004 self.assertEqual(C.y, 10)
3005 self.assertEqual(C.z, 20)
3006
3007 def test_other_params(self):
3008 C = make_dataclass('C',
3009 [('x', int),
3010 ('y', ClassVar[int], 10),
3011 ('z', ClassVar[int], field(default=20)),
3012 ],
3013 init=False)
3014 # Make sure we have a repr, but no init.
3015 self.assertNotIn('__init__', vars(C))
3016 self.assertIn('__repr__', vars(C))
3017
3018 # Make sure random other params don't work.
3019 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3020 C = make_dataclass('C',
3021 [],
3022 xxinit=False)
3023
3024 def test_no_types(self):
3025 C = make_dataclass('Point', ['x', 'y', 'z'])
3026 c = C(1, 2, 3)
3027 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3028 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3029 'y': 'typing.Any',
3030 'z': 'typing.Any'})
3031
3032 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3033 c = C(1, 2, 3)
3034 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3035 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3036 'y': int,
3037 'z': 'typing.Any'})
3038
3039 def test_invalid_type_specification(self):
3040 for bad_field in [(),
3041 (1, 2, 3, 4),
3042 ]:
3043 with self.subTest(bad_field=bad_field):
3044 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3045 make_dataclass('C', ['a', bad_field])
3046
3047 # And test for things with no len().
3048 for bad_field in [float,
3049 lambda x:x,
3050 ]:
3051 with self.subTest(bad_field=bad_field):
3052 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3053 make_dataclass('C', ['a', bad_field])
3054
3055 def test_duplicate_field_names(self):
3056 for field in ['a', 'ab']:
3057 with self.subTest(field=field):
3058 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3059 make_dataclass('C', [field, 'a', field])
3060
3061 def test_keyword_field_names(self):
3062 for field in ['for', 'async', 'await', 'as']:
3063 with self.subTest(field=field):
3064 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3065 make_dataclass('C', ['a', field])
3066 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3067 make_dataclass('C', [field])
3068 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3069 make_dataclass('C', [field, 'a'])
3070
3071 def test_non_identifier_field_names(self):
3072 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3073 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003074 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003075 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003076 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003077 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003078 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003079 make_dataclass('C', [field, 'a'])
3080
3081 def test_underscore_field_names(self):
3082 # Unlike namedtuple, it's okay if dataclass field names have
3083 # an underscore.
3084 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3085
3086 def test_funny_class_names_names(self):
3087 # No reason to prevent weird class names, since
3088 # types.new_class allows them.
3089 for classname in ['()', 'x,y', '*', '2@3', '']:
3090 with self.subTest(classname=classname):
3091 C = make_dataclass(classname, ['a', 'b'])
3092 self.assertEqual(C.__name__, classname)
3093
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003094class TestReplace(unittest.TestCase):
3095 def test(self):
3096 @dataclass(frozen=True)
3097 class C:
3098 x: int
3099 y: int
3100
3101 c = C(1, 2)
3102 c1 = replace(c, x=3)
3103 self.assertEqual(c1.x, 3)
3104 self.assertEqual(c1.y, 2)
3105
3106 def test_frozen(self):
3107 @dataclass(frozen=True)
3108 class C:
3109 x: int
3110 y: int
3111 z: int = field(init=False, default=10)
3112 t: int = field(init=False, default=100)
3113
3114 c = C(1, 2)
3115 c1 = replace(c, x=3)
3116 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3117 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3118
3119
3120 with self.assertRaisesRegex(ValueError, 'init=False'):
3121 replace(c, x=3, z=20, t=50)
3122 with self.assertRaisesRegex(ValueError, 'init=False'):
3123 replace(c, z=20)
3124 replace(c, x=3, z=20, t=50)
3125
3126 # Make sure the result is still frozen.
3127 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3128 c1.x = 3
3129
3130 # Make sure we can't replace an attribute that doesn't exist,
3131 # if we're also replacing one that does exist. Test this
3132 # here, because setting attributes on frozen instances is
3133 # handled slightly differently from non-frozen ones.
3134 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3135 "keyword argument 'a'"):
3136 c1 = replace(c, x=20, a=5)
3137
3138 def test_invalid_field_name(self):
3139 @dataclass(frozen=True)
3140 class C:
3141 x: int
3142 y: int
3143
3144 c = C(1, 2)
3145 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3146 "keyword argument 'z'"):
3147 c1 = replace(c, z=3)
3148
3149 def test_invalid_object(self):
3150 @dataclass(frozen=True)
3151 class C:
3152 x: int
3153 y: int
3154
3155 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3156 replace(C, x=3)
3157
3158 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3159 replace(0, x=3)
3160
3161 def test_no_init(self):
3162 @dataclass
3163 class C:
3164 x: int
3165 y: int = field(init=False, default=10)
3166
3167 c = C(1)
3168 c.y = 20
3169
3170 # Make sure y gets the default value.
3171 c1 = replace(c, x=5)
3172 self.assertEqual((c1.x, c1.y), (5, 10))
3173
3174 # Trying to replace y is an error.
3175 with self.assertRaisesRegex(ValueError, 'init=False'):
3176 replace(c, x=2, y=30)
3177
3178 with self.assertRaisesRegex(ValueError, 'init=False'):
3179 replace(c, y=30)
3180
3181 def test_classvar(self):
3182 @dataclass
3183 class C:
3184 x: int
3185 y: ClassVar[int] = 1000
3186
3187 c = C(1)
3188 d = C(2)
3189
3190 self.assertIs(c.y, d.y)
3191 self.assertEqual(c.y, 1000)
3192
3193 # Trying to replace y is an error: can't replace ClassVars.
3194 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3195 "unexpected keyword argument 'y'"):
3196 replace(c, y=30)
3197
3198 replace(c, x=5)
3199
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003200 def test_initvar_is_specified(self):
3201 @dataclass
3202 class C:
3203 x: int
3204 y: InitVar[int]
3205
3206 def __post_init__(self, y):
3207 self.x *= y
3208
3209 c = C(1, 10)
3210 self.assertEqual(c.x, 10)
3211 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3212 "specified with replace()"):
3213 replace(c, x=3)
3214 c = replace(c, x=3, y=5)
3215 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303216
3217 def test_recursive_repr(self):
3218 @dataclass
3219 class C:
3220 f: "C"
3221
3222 c = C(None)
3223 c.f = c
3224 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3225
3226 def test_recursive_repr_two_attrs(self):
3227 @dataclass
3228 class C:
3229 f: "C"
3230 g: "C"
3231
3232 c = C(None, None)
3233 c.f = c
3234 c.g = c
3235 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3236 ".<locals>.C(f=..., g=...)")
3237
3238 def test_recursive_repr_indirection(self):
3239 @dataclass
3240 class C:
3241 f: "D"
3242
3243 @dataclass
3244 class D:
3245 f: "C"
3246
3247 c = C(None)
3248 d = D(None)
3249 c.f = d
3250 d.f = c
3251 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3252 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3253 ".<locals>.D(f=...))")
3254
3255 def test_recursive_repr_indirection_two(self):
3256 @dataclass
3257 class C:
3258 f: "D"
3259
3260 @dataclass
3261 class D:
3262 f: "E"
3263
3264 @dataclass
3265 class E:
3266 f: "C"
3267
3268 c = C(None)
3269 d = D(None)
3270 e = E(None)
3271 c.f = d
3272 d.f = e
3273 e.f = c
3274 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3275 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3276 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3277 ".<locals>.E(f=...)))")
3278
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303279 def test_recursive_repr_misc_attrs(self):
3280 @dataclass
3281 class C:
3282 f: "C"
3283 g: int
3284
3285 c = C(None, 1)
3286 c.f = c
3287 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3288 ".<locals>.C(f=..., g=1)")
3289
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003290 ## def test_initvar(self):
3291 ## @dataclass
3292 ## class C:
3293 ## x: int
3294 ## y: InitVar[int]
3295
3296 ## c = C(1, 10)
3297 ## d = C(2, 20)
3298
3299 ## # In our case, replacing an InitVar is a no-op
3300 ## self.assertEqual(c, replace(c, y=5))
3301
3302 ## replace(c, x=5)
3303
Eric V. Smith4e812962018-05-16 11:31:29 -04003304
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003305if __name__ == '__main__':
3306 unittest.main()