blob: 8f9fb2ce8c169c0b745051d18cf2cc68cd2c1507 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +03009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Yury Selivanovd219cc42019-12-09 09:54:20 -050013from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050015from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050016
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040017import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
18import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
19
Eric V. Smithf0db54a2017-12-04 16:58:55 -050020# Just any custom exception we can catch.
21class CustomError(Exception): pass
22
23class TestCase(unittest.TestCase):
24 def test_no_fields(self):
25 @dataclass
26 class C:
27 pass
28
29 o = C()
30 self.assertEqual(len(fields(C)), 0)
31
Eric V. Smith56970b82018-03-22 16:28:48 -040032 def test_no_fields_but_member_variable(self):
33 @dataclass
34 class C:
35 i = 0
36
37 o = C()
38 self.assertEqual(len(fields(C)), 0)
39
Eric V. Smithf0db54a2017-12-04 16:58:55 -050040 def test_one_field_no_default(self):
41 @dataclass
42 class C:
43 x: int
44
45 o = C(42)
46 self.assertEqual(o.x, 42)
47
48 def test_named_init_params(self):
49 @dataclass
50 class C:
51 x: int
52
53 o = C(x=32)
54 self.assertEqual(o.x, 32)
55
56 def test_two_fields_one_default(self):
57 @dataclass
58 class C:
59 x: int
60 y: int = 0
61
62 o = C(3)
63 self.assertEqual((o.x, o.y), (3, 0))
64
65 # Non-defaults following defaults.
66 with self.assertRaisesRegex(TypeError,
67 "non-default argument 'y' follows "
68 "default argument"):
69 @dataclass
70 class C:
71 x: int = 0
72 y: int
73
74 # A derived class adds a non-default field after a default one.
75 with self.assertRaisesRegex(TypeError,
76 "non-default argument 'y' follows "
77 "default argument"):
78 @dataclass
79 class B:
80 x: int = 0
81
82 @dataclass
83 class C(B):
84 y: int
85
86 # Override a base class field and add a default to
87 # a field which didn't use to have a default.
88 with self.assertRaisesRegex(TypeError,
89 "non-default argument 'y' follows "
90 "default argument"):
91 @dataclass
92 class B:
93 x: int
94 y: int
95
96 @dataclass
97 class C(B):
98 x: int = 0
99
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500100 def test_overwrite_hash(self):
101 # Test that declaring this class isn't an error. It should
102 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500103 @dataclass(frozen=True)
104 class C:
105 x: int
106 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500107 return 301
108 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500109
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500110 # Test that declaring this class isn't an error. It should
111 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500112 @dataclass(frozen=True)
113 class C:
114 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500115 def __eq__(self, other):
116 return False
117 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500118
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500119 # But this one should generate an exception, because with
120 # unsafe_hash=True, it's an error to have a __hash__ defined.
121 with self.assertRaisesRegex(TypeError,
122 'Cannot overwrite attribute __hash__'):
123 @dataclass(unsafe_hash=True)
124 class C:
125 def __hash__(self):
126 pass
127
128 # Creating this class should not generate an exception,
129 # because even though __hash__ exists before @dataclass is
130 # called, (due to __eq__ being defined), since it's None
131 # that's okay.
132 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500133 class C:
134 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500135 def __eq__(self):
136 pass
137 # The generated hash function works as we'd expect.
138 self.assertEqual(hash(C(10)), hash((10,)))
139
140 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400141 # __hash__ exists and is not None, which it would be if it
142 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500143 with self.assertRaisesRegex(TypeError,
144 'Cannot overwrite attribute __hash__'):
145 @dataclass(unsafe_hash=True)
146 class C:
147 x: int
148 def __eq__(self):
149 pass
150 def __hash__(self):
151 pass
152
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500153 def test_overwrite_fields_in_derived_class(self):
154 # Note that x from C1 replaces x in Base, but the order remains
155 # the same as defined in Base.
156 @dataclass
157 class Base:
158 x: Any = 15.0
159 y: int = 0
160
161 @dataclass
162 class C1(Base):
163 z: int = 10
164 x: int = 15
165
166 o = Base()
167 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
168
169 o = C1()
170 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
171
172 o = C1(x=5)
173 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
174
175 def test_field_named_self(self):
176 @dataclass
177 class C:
178 self: str
179 c=C('foo')
180 self.assertEqual(c.self, 'foo')
181
182 # Make sure the first parameter is not named 'self'.
183 sig = inspect.signature(C.__init__)
184 first = next(iter(sig.parameters))
185 self.assertNotEqual('self', first)
186
187 # But we do use 'self' if no field named self.
188 @dataclass
189 class C:
190 selfx: str
191
192 # Make sure the first parameter is named 'self'.
193 sig = inspect.signature(C.__init__)
194 first = next(iter(sig.parameters))
195 self.assertEqual('self', first)
196
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300197 def test_field_named_object(self):
198 @dataclass
199 class C:
200 object: str
201 c = C('foo')
202 self.assertEqual(c.object, 'foo')
203
204 def test_field_named_object_frozen(self):
205 @dataclass(frozen=True)
206 class C:
207 object: str
208 c = C('foo')
209 self.assertEqual(c.object, 'foo')
210
211 def test_field_named_like_builtin(self):
212 # Attribute names can shadow built-in names
213 # since code generation is used.
214 # Ensure that this is not happening.
215 exclusions = {'None', 'True', 'False'}
216 builtins_names = sorted(
217 b for b in builtins.__dict__.keys()
218 if not b.startswith('__') and b not in exclusions
219 )
220 attributes = [(name, str) for name in builtins_names]
221 C = make_dataclass('C', attributes)
222
223 c = C(*[name for name in builtins_names])
224
225 for name in builtins_names:
226 self.assertEqual(getattr(c, name), name)
227
228 def test_field_named_like_builtin_frozen(self):
229 # Attribute names can shadow built-in names
230 # since code generation is used.
231 # Ensure that this is not happening
232 # for frozen data classes.
233 exclusions = {'None', 'True', 'False'}
234 builtins_names = sorted(
235 b for b in builtins.__dict__.keys()
236 if not b.startswith('__') and b not in exclusions
237 )
238 attributes = [(name, str) for name in builtins_names]
239 C = make_dataclass('C', attributes, frozen=True)
240
241 c = C(*[name for name in builtins_names])
242
243 for name in builtins_names:
244 self.assertEqual(getattr(c, name), name)
245
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500246 def test_0_field_compare(self):
247 # Ensure that order=False is the default.
248 @dataclass
249 class C0:
250 pass
251
252 @dataclass(order=False)
253 class C1:
254 pass
255
256 for cls in [C0, C1]:
257 with self.subTest(cls=cls):
258 self.assertEqual(cls(), cls())
259 for idx, fn in enumerate([lambda a, b: a < b,
260 lambda a, b: a <= b,
261 lambda a, b: a > b,
262 lambda a, b: a >= b]):
263 with self.subTest(idx=idx):
264 with self.assertRaisesRegex(TypeError,
265 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
266 fn(cls(), cls())
267
268 @dataclass(order=True)
269 class C:
270 pass
271 self.assertLessEqual(C(), C())
272 self.assertGreaterEqual(C(), C())
273
274 def test_1_field_compare(self):
275 # Ensure that order=False is the default.
276 @dataclass
277 class C0:
278 x: int
279
280 @dataclass(order=False)
281 class C1:
282 x: int
283
284 for cls in [C0, C1]:
285 with self.subTest(cls=cls):
286 self.assertEqual(cls(1), cls(1))
287 self.assertNotEqual(cls(0), cls(1))
288 for idx, fn in enumerate([lambda a, b: a < b,
289 lambda a, b: a <= b,
290 lambda a, b: a > b,
291 lambda a, b: a >= b]):
292 with self.subTest(idx=idx):
293 with self.assertRaisesRegex(TypeError,
294 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
295 fn(cls(0), cls(0))
296
297 @dataclass(order=True)
298 class C:
299 x: int
300 self.assertLess(C(0), C(1))
301 self.assertLessEqual(C(0), C(1))
302 self.assertLessEqual(C(1), C(1))
303 self.assertGreater(C(1), C(0))
304 self.assertGreaterEqual(C(1), C(0))
305 self.assertGreaterEqual(C(1), C(1))
306
307 def test_simple_compare(self):
308 # Ensure that order=False is the default.
309 @dataclass
310 class C0:
311 x: int
312 y: int
313
314 @dataclass(order=False)
315 class C1:
316 x: int
317 y: int
318
319 for cls in [C0, C1]:
320 with self.subTest(cls=cls):
321 self.assertEqual(cls(0, 0), cls(0, 0))
322 self.assertEqual(cls(1, 2), cls(1, 2))
323 self.assertNotEqual(cls(1, 0), cls(0, 0))
324 self.assertNotEqual(cls(1, 0), cls(1, 1))
325 for idx, fn in enumerate([lambda a, b: a < b,
326 lambda a, b: a <= b,
327 lambda a, b: a > b,
328 lambda a, b: a >= b]):
329 with self.subTest(idx=idx):
330 with self.assertRaisesRegex(TypeError,
331 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
332 fn(cls(0, 0), cls(0, 0))
333
334 @dataclass(order=True)
335 class C:
336 x: int
337 y: int
338
339 for idx, fn in enumerate([lambda a, b: a == b,
340 lambda a, b: a <= b,
341 lambda a, b: a >= b]):
342 with self.subTest(idx=idx):
343 self.assertTrue(fn(C(0, 0), C(0, 0)))
344
345 for idx, fn in enumerate([lambda a, b: a < b,
346 lambda a, b: a <= b,
347 lambda a, b: a != b]):
348 with self.subTest(idx=idx):
349 self.assertTrue(fn(C(0, 0), C(0, 1)))
350 self.assertTrue(fn(C(0, 1), C(1, 0)))
351 self.assertTrue(fn(C(1, 0), C(1, 1)))
352
353 for idx, fn in enumerate([lambda a, b: a > b,
354 lambda a, b: a >= b,
355 lambda a, b: a != b]):
356 with self.subTest(idx=idx):
357 self.assertTrue(fn(C(0, 1), C(0, 0)))
358 self.assertTrue(fn(C(1, 0), C(0, 1)))
359 self.assertTrue(fn(C(1, 1), C(1, 0)))
360
361 def test_compare_subclasses(self):
362 # Comparisons fail for subclasses, even if no fields
363 # are added.
364 @dataclass
365 class B:
366 i: int
367
368 @dataclass
369 class C(B):
370 pass
371
372 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
373 (lambda a, b: a != b, True)]):
374 with self.subTest(idx=idx):
375 self.assertEqual(fn(B(0), C(0)), expected)
376
377 for idx, fn in enumerate([lambda a, b: a < b,
378 lambda a, b: a <= b,
379 lambda a, b: a > b,
380 lambda a, b: a >= b]):
381 with self.subTest(idx=idx):
382 with self.assertRaisesRegex(TypeError,
383 "not supported between instances of 'B' and 'C'"):
384 fn(B(0), C(0))
385
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500386 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500387 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500388 for (eq, order, result ) in [
389 (False, False, 'neither'),
390 (False, True, 'exception'),
391 (True, False, 'eq_only'),
392 (True, True, 'both'),
393 ]:
394 with self.subTest(eq=eq, order=order):
395 if result == 'exception':
396 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
397 @dataclass(eq=eq, order=order)
398 class C:
399 pass
400 else:
401 @dataclass(eq=eq, order=order)
402 class C:
403 pass
404
405 if result == 'neither':
406 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500407 self.assertNotIn('__lt__', C.__dict__)
408 self.assertNotIn('__le__', C.__dict__)
409 self.assertNotIn('__gt__', C.__dict__)
410 self.assertNotIn('__ge__', C.__dict__)
411 elif result == 'both':
412 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500413 self.assertIn('__lt__', C.__dict__)
414 self.assertIn('__le__', C.__dict__)
415 self.assertIn('__gt__', C.__dict__)
416 self.assertIn('__ge__', C.__dict__)
417 elif result == 'eq_only':
418 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500419 self.assertNotIn('__lt__', C.__dict__)
420 self.assertNotIn('__le__', C.__dict__)
421 self.assertNotIn('__gt__', C.__dict__)
422 self.assertNotIn('__ge__', C.__dict__)
423 else:
424 assert False, f'unknown result {result!r}'
425
426 def test_field_no_default(self):
427 @dataclass
428 class C:
429 x: int = field()
430
431 self.assertEqual(C(5).x, 5)
432
433 with self.assertRaisesRegex(TypeError,
434 r"__init__\(\) missing 1 required "
435 "positional argument: 'x'"):
436 C()
437
438 def test_field_default(self):
439 default = object()
440 @dataclass
441 class C:
442 x: object = field(default=default)
443
444 self.assertIs(C.x, default)
445 c = C(10)
446 self.assertEqual(c.x, 10)
447
448 # If we delete the instance attribute, we should then see the
449 # class attribute.
450 del c.x
451 self.assertIs(c.x, default)
452
453 self.assertIs(C().x, default)
454
455 def test_not_in_repr(self):
456 @dataclass
457 class C:
458 x: int = field(repr=False)
459 with self.assertRaises(TypeError):
460 C()
461 c = C(10)
462 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
463
464 @dataclass
465 class C:
466 x: int = field(repr=False)
467 y: int
468 c = C(10, 20)
469 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
470
471 def test_not_in_compare(self):
472 @dataclass
473 class C:
474 x: int = 0
475 y: int = field(compare=False, default=4)
476
477 self.assertEqual(C(), C(0, 20))
478 self.assertEqual(C(1, 10), C(1, 20))
479 self.assertNotEqual(C(3), C(4, 10))
480 self.assertNotEqual(C(3, 10), C(4, 10))
481
482 def test_hash_field_rules(self):
483 # Test all 6 cases of:
484 # hash=True/False/None
485 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500486 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500487 (True, False, 'field' ),
488 (True, True, 'field' ),
489 (False, False, 'absent'),
490 (False, True, 'absent'),
491 (None, False, 'absent'),
492 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500493 ]:
494 with self.subTest(hash=hash_, compare=compare):
495 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500496 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500497 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500498
499 if result == 'field':
500 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500501 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500502 elif result == 'absent':
503 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500504 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500505 else:
506 assert False, f'unknown result {result!r}'
507
508 def test_init_false_no_default(self):
509 # If init=False and no default value, then the field won't be
510 # present in the instance.
511 @dataclass
512 class C:
513 x: int = field(init=False)
514
515 self.assertNotIn('x', C().__dict__)
516
517 @dataclass
518 class C:
519 x: int
520 y: int = 0
521 z: int = field(init=False)
522 t: int = 10
523
524 self.assertNotIn('z', C(0).__dict__)
525 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
526
527 def test_class_marker(self):
528 @dataclass
529 class C:
530 x: int
531 y: str = field(init=False, default=None)
532 z: str = field(repr=False)
533
534 the_fields = fields(C)
535 # the_fields is a tuple of 3 items, each value
536 # is in __annotations__.
537 self.assertIsInstance(the_fields, tuple)
538 for f in the_fields:
539 self.assertIs(type(f), Field)
540 self.assertIn(f.name, C.__annotations__)
541
542 self.assertEqual(len(the_fields), 3)
543
544 self.assertEqual(the_fields[0].name, 'x')
545 self.assertEqual(the_fields[0].type, int)
546 self.assertFalse(hasattr(C, 'x'))
547 self.assertTrue (the_fields[0].init)
548 self.assertTrue (the_fields[0].repr)
549 self.assertEqual(the_fields[1].name, 'y')
550 self.assertEqual(the_fields[1].type, str)
551 self.assertIsNone(getattr(C, 'y'))
552 self.assertFalse(the_fields[1].init)
553 self.assertTrue (the_fields[1].repr)
554 self.assertEqual(the_fields[2].name, 'z')
555 self.assertEqual(the_fields[2].type, str)
556 self.assertFalse(hasattr(C, 'z'))
557 self.assertTrue (the_fields[2].init)
558 self.assertFalse(the_fields[2].repr)
559
560 def test_field_order(self):
561 @dataclass
562 class B:
563 a: str = 'B:a'
564 b: str = 'B:b'
565 c: str = 'B:c'
566
567 @dataclass
568 class C(B):
569 b: str = 'C:b'
570
571 self.assertEqual([(f.name, f.default) for f in fields(C)],
572 [('a', 'B:a'),
573 ('b', 'C:b'),
574 ('c', 'B:c')])
575
576 @dataclass
577 class D(B):
578 c: str = 'D:c'
579
580 self.assertEqual([(f.name, f.default) for f in fields(D)],
581 [('a', 'B:a'),
582 ('b', 'B:b'),
583 ('c', 'D:c')])
584
585 @dataclass
586 class E(D):
587 a: str = 'E:a'
588 d: str = 'E:d'
589
590 self.assertEqual([(f.name, f.default) for f in fields(E)],
591 [('a', 'E:a'),
592 ('b', 'B:b'),
593 ('c', 'D:c'),
594 ('d', 'E:d')])
595
596 def test_class_attrs(self):
597 # We only have a class attribute if a default value is
598 # specified, either directly or via a field with a default.
599 default = object()
600 @dataclass
601 class C:
602 x: int
603 y: int = field(repr=False)
604 z: object = default
605 t: int = field(default=100)
606
607 self.assertFalse(hasattr(C, 'x'))
608 self.assertFalse(hasattr(C, 'y'))
609 self.assertIs (C.z, default)
610 self.assertEqual(C.t, 100)
611
612 def test_disallowed_mutable_defaults(self):
613 # For the known types, don't allow mutable default values.
614 for typ, empty, non_empty in [(list, [], [1]),
615 (dict, {}, {0:1}),
616 (set, set(), set([1])),
617 ]:
618 with self.subTest(typ=typ):
619 # Can't use a zero-length value.
620 with self.assertRaisesRegex(ValueError,
621 f'mutable default {typ} for field '
622 'x is not allowed'):
623 @dataclass
624 class Point:
625 x: typ = empty
626
627
628 # Nor a non-zero-length value
629 with self.assertRaisesRegex(ValueError,
630 f'mutable default {typ} for field '
631 'y is not allowed'):
632 @dataclass
633 class Point:
634 y: typ = non_empty
635
636 # Check subtypes also fail.
637 class Subclass(typ): pass
638
639 with self.assertRaisesRegex(ValueError,
640 f"mutable default .*Subclass'>"
641 ' for field z is not allowed'
642 ):
643 @dataclass
644 class Point:
645 z: typ = Subclass()
646
647 # Because this is a ClassVar, it can be mutable.
648 @dataclass
649 class C:
650 z: ClassVar[typ] = typ()
651
652 # Because this is a ClassVar, it can be mutable.
653 @dataclass
654 class C:
655 x: ClassVar[typ] = Subclass()
656
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500657 def test_deliberately_mutable_defaults(self):
658 # If a mutable default isn't in the known list of
659 # (list, dict, set), then it's okay.
660 class Mutable:
661 def __init__(self):
662 self.l = []
663
664 @dataclass
665 class C:
666 x: Mutable
667
668 # These 2 instances will share this value of x.
669 lst = Mutable()
670 o1 = C(lst)
671 o2 = C(lst)
672 self.assertEqual(o1, o2)
673 o1.x.l.extend([1, 2])
674 self.assertEqual(o1, o2)
675 self.assertEqual(o1.x.l, [1, 2])
676 self.assertIs(o1.x, o2.x)
677
678 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400679 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500680 @dataclass()
681 class C:
682 x: int
683
684 self.assertEqual(C(42).x, 42)
685
686 def test_not_tuple(self):
687 # Make sure we can't be compared to a tuple.
688 @dataclass
689 class Point:
690 x: int
691 y: int
692 self.assertNotEqual(Point(1, 2), (1, 2))
693
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400694 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500695 @dataclass
696 class C:
697 x: int
698 y: int
699 self.assertNotEqual(Point(1, 3), C(1, 3))
700
Windson yangbe372d72019-04-23 02:45:34 +0800701 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500702 # Test that some of the problems with namedtuple don't happen
703 # here.
704 @dataclass
705 class Point3D:
706 x: int
707 y: int
708 z: int
709
710 @dataclass
711 class Date:
712 year: int
713 month: int
714 day: int
715
716 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
717 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
718
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400719 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200720 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500721 x, y, z = Point3D(4, 5, 6)
722
Eric V. Smith7c99e932018-01-28 19:18:55 -0500723 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500724 # equal.
725 @dataclass
726 class Point3Dv1:
727 x: int = 0
728 y: int = 0
729 z: int = 0
730 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
731
732 def test_function_annotations(self):
733 # Some dummy class and instance to use as a default.
734 class F:
735 pass
736 f = F()
737
738 def validate_class(cls):
739 # First, check __annotations__, even though they're not
740 # function annotations.
741 self.assertEqual(cls.__annotations__['i'], int)
742 self.assertEqual(cls.__annotations__['j'], str)
743 self.assertEqual(cls.__annotations__['k'], F)
744 self.assertEqual(cls.__annotations__['l'], float)
745 self.assertEqual(cls.__annotations__['z'], complex)
746
747 # Verify __init__.
748
749 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400750 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500751 self.assertIs(signature.return_annotation, None)
752
753 # Check each parameter.
754 params = iter(signature.parameters.values())
755 param = next(params)
756 # This is testing an internal name, and probably shouldn't be tested.
757 self.assertEqual(param.name, 'self')
758 param = next(params)
759 self.assertEqual(param.name, 'i')
760 self.assertIs (param.annotation, int)
761 self.assertEqual(param.default, inspect.Parameter.empty)
762 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
763 param = next(params)
764 self.assertEqual(param.name, 'j')
765 self.assertIs (param.annotation, str)
766 self.assertEqual(param.default, inspect.Parameter.empty)
767 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
768 param = next(params)
769 self.assertEqual(param.name, 'k')
770 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400771 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500772 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
773 param = next(params)
774 self.assertEqual(param.name, 'l')
775 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400776 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500777 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
778 self.assertRaises(StopIteration, next, params)
779
780
781 @dataclass
782 class C:
783 i: int
784 j: str
785 k: F = f
786 l: float=field(default=None)
787 z: complex=field(default=3+4j, init=False)
788
789 validate_class(C)
790
791 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500792 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500793 class C:
794 i: int
795 j: str
796 k: F = f
797 l: float=field(default=None)
798 z: complex=field(default=3+4j, init=False)
799
800 validate_class(C)
801
Eric V. Smith03220fd2017-12-29 13:59:58 -0500802 def test_missing_default(self):
803 # Test that MISSING works the same as a default not being
804 # specified.
805 @dataclass
806 class C:
807 x: int=field(default=MISSING)
808 with self.assertRaisesRegex(TypeError,
809 r'__init__\(\) missing 1 required '
810 'positional argument'):
811 C()
812 self.assertNotIn('x', C.__dict__)
813
814 @dataclass
815 class D:
816 x: int
817 with self.assertRaisesRegex(TypeError,
818 r'__init__\(\) missing 1 required '
819 'positional argument'):
820 D()
821 self.assertNotIn('x', D.__dict__)
822
823 def test_missing_default_factory(self):
824 # Test that MISSING works the same as a default factory not
825 # being specified (which is really the same as a default not
826 # being specified, too).
827 @dataclass
828 class C:
829 x: int=field(default_factory=MISSING)
830 with self.assertRaisesRegex(TypeError,
831 r'__init__\(\) missing 1 required '
832 'positional argument'):
833 C()
834 self.assertNotIn('x', C.__dict__)
835
836 @dataclass
837 class D:
838 x: int=field(default=MISSING, default_factory=MISSING)
839 with self.assertRaisesRegex(TypeError,
840 r'__init__\(\) missing 1 required '
841 'positional argument'):
842 D()
843 self.assertNotIn('x', D.__dict__)
844
845 def test_missing_repr(self):
846 self.assertIn('MISSING_TYPE object', repr(MISSING))
847
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500848 def test_dont_include_other_annotations(self):
849 @dataclass
850 class C:
851 i: int
852 def foo(self) -> int:
853 return 4
854 @property
855 def bar(self) -> int:
856 return 5
857 self.assertEqual(list(C.__annotations__), ['i'])
858 self.assertEqual(C(10).foo(), 4)
859 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400860 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500861
862 def test_post_init(self):
863 # Just make sure it gets called
864 @dataclass
865 class C:
866 def __post_init__(self):
867 raise CustomError()
868 with self.assertRaises(CustomError):
869 C()
870
871 @dataclass
872 class C:
873 i: int = 10
874 def __post_init__(self):
875 if self.i == 10:
876 raise CustomError()
877 with self.assertRaises(CustomError):
878 C()
879 # post-init gets called, but doesn't raise. This is just
880 # checking that self is used correctly.
881 C(5)
882
883 # If there's not an __init__, then post-init won't get called.
884 @dataclass(init=False)
885 class C:
886 def __post_init__(self):
887 raise CustomError()
888 # Creating the class won't raise
889 C()
890
891 @dataclass
892 class C:
893 x: int = 0
894 def __post_init__(self):
895 self.x *= 2
896 self.assertEqual(C().x, 0)
897 self.assertEqual(C(2).x, 4)
898
Mike53f7a7c2017-12-14 14:04:53 +0300899 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500900 # attributes.
901 @dataclass(frozen=True)
902 class C:
903 x: int = 0
904 def __post_init__(self):
905 self.x *= 2
906 with self.assertRaises(FrozenInstanceError):
907 C()
908
909 def test_post_init_super(self):
910 # Make sure super() post-init isn't called by default.
911 class B:
912 def __post_init__(self):
913 raise CustomError()
914
915 @dataclass
916 class C(B):
917 def __post_init__(self):
918 self.x = 5
919
920 self.assertEqual(C().x, 5)
921
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400922 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500923 @dataclass
924 class C(B):
925 def __post_init__(self):
926 super().__post_init__()
927
928 with self.assertRaises(CustomError):
929 C()
930
931 # Make sure post-init is called, even if not defined in our
932 # class.
933 @dataclass
934 class C(B):
935 pass
936
937 with self.assertRaises(CustomError):
938 C()
939
940 def test_post_init_staticmethod(self):
941 flag = False
942 @dataclass
943 class C:
944 x: int
945 y: int
946 @staticmethod
947 def __post_init__():
948 nonlocal flag
949 flag = True
950
951 self.assertFalse(flag)
952 c = C(3, 4)
953 self.assertEqual((c.x, c.y), (3, 4))
954 self.assertTrue(flag)
955
956 def test_post_init_classmethod(self):
957 @dataclass
958 class C:
959 flag = False
960 x: int
961 y: int
962 @classmethod
963 def __post_init__(cls):
964 cls.flag = True
965
966 self.assertFalse(C.flag)
967 c = C(3, 4)
968 self.assertEqual((c.x, c.y), (3, 4))
969 self.assertTrue(C.flag)
970
971 def test_class_var(self):
972 # Make sure ClassVars are ignored in __init__, __repr__, etc.
973 @dataclass
974 class C:
975 x: int
976 y: int = 10
977 z: ClassVar[int] = 1000
978 w: ClassVar[int] = 2000
979 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400980 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500981
982 c = C(5)
983 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400984 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400985 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500986 self.assertEqual(c.z, 1000)
987 self.assertEqual(c.w, 2000)
988 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400989 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500990 C.z += 1
991 self.assertEqual(c.z, 1001)
992 c = C(20)
993 self.assertEqual((c.x, c.y), (20, 10))
994 self.assertEqual(c.z, 1001)
995 self.assertEqual(c.w, 2000)
996 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400997 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500998
999 def test_class_var_no_default(self):
1000 # If a ClassVar has no default value, it should not be set on the class.
1001 @dataclass
1002 class C:
1003 x: ClassVar[int]
1004
1005 self.assertNotIn('x', C.__dict__)
1006
1007 def test_class_var_default_factory(self):
1008 # It makes no sense for a ClassVar to have a default factory. When
1009 # would it be called? Call it yourself, since it's class-wide.
1010 with self.assertRaisesRegex(TypeError,
1011 'cannot have a default factory'):
1012 @dataclass
1013 class C:
1014 x: ClassVar[int] = field(default_factory=int)
1015
1016 self.assertNotIn('x', C.__dict__)
1017
1018 def test_class_var_with_default(self):
1019 # If a ClassVar has a default value, it should be set on the class.
1020 @dataclass
1021 class C:
1022 x: ClassVar[int] = 10
1023 self.assertEqual(C.x, 10)
1024
1025 @dataclass
1026 class C:
1027 x: ClassVar[int] = field(default=10)
1028 self.assertEqual(C.x, 10)
1029
1030 def test_class_var_frozen(self):
1031 # Make sure ClassVars work even if we're frozen.
1032 @dataclass(frozen=True)
1033 class C:
1034 x: int
1035 y: int = 10
1036 z: ClassVar[int] = 1000
1037 w: ClassVar[int] = 2000
1038 t: ClassVar[int] = 3000
1039
1040 c = C(5)
1041 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1042 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1043 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1044 self.assertEqual(c.z, 1000)
1045 self.assertEqual(c.w, 2000)
1046 self.assertEqual(c.t, 3000)
1047 # We can still modify the ClassVar, it's only instances that are
1048 # frozen.
1049 C.z += 1
1050 self.assertEqual(c.z, 1001)
1051 c = C(20)
1052 self.assertEqual((c.x, c.y), (20, 10))
1053 self.assertEqual(c.z, 1001)
1054 self.assertEqual(c.w, 2000)
1055 self.assertEqual(c.t, 3000)
1056
1057 def test_init_var_no_default(self):
1058 # If an InitVar has no default value, it should not be set on the class.
1059 @dataclass
1060 class C:
1061 x: InitVar[int]
1062
1063 self.assertNotIn('x', C.__dict__)
1064
1065 def test_init_var_default_factory(self):
1066 # It makes no sense for an InitVar to have a default factory. When
1067 # would it be called? Call it yourself, since it's class-wide.
1068 with self.assertRaisesRegex(TypeError,
1069 'cannot have a default factory'):
1070 @dataclass
1071 class C:
1072 x: InitVar[int] = field(default_factory=int)
1073
1074 self.assertNotIn('x', C.__dict__)
1075
1076 def test_init_var_with_default(self):
1077 # If an InitVar has a default value, it should be set on the class.
1078 @dataclass
1079 class C:
1080 x: InitVar[int] = 10
1081 self.assertEqual(C.x, 10)
1082
1083 @dataclass
1084 class C:
1085 x: InitVar[int] = field(default=10)
1086 self.assertEqual(C.x, 10)
1087
1088 def test_init_var(self):
1089 @dataclass
1090 class C:
1091 x: int = None
1092 init_param: InitVar[int] = None
1093
1094 def __post_init__(self, init_param):
1095 if self.x is None:
1096 self.x = init_param*2
1097
1098 c = C(init_param=10)
1099 self.assertEqual(c.x, 20)
1100
Augusto Hack01ee12b2019-06-02 23:14:48 -03001101 def test_init_var_preserve_type(self):
1102 self.assertEqual(InitVar[int].type, int)
1103
1104 # Make sure the repr is correct.
1105 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001106 self.assertEqual(repr(InitVar[List[int]]),
1107 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001108
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001109 def test_init_var_inheritance(self):
1110 # Note that this deliberately tests that a dataclass need not
1111 # have a __post_init__ function if it has an InitVar field.
1112 # It could just be used in a derived class, as shown here.
1113 @dataclass
1114 class Base:
1115 x: int
1116 init_base: InitVar[int]
1117
1118 # We can instantiate by passing the InitVar, even though
1119 # it's not used.
1120 b = Base(0, 10)
1121 self.assertEqual(vars(b), {'x': 0})
1122
1123 @dataclass
1124 class C(Base):
1125 y: int
1126 init_derived: InitVar[int]
1127
1128 def __post_init__(self, init_base, init_derived):
1129 self.x = self.x + init_base
1130 self.y = self.y + init_derived
1131
1132 c = C(10, 11, 50, 51)
1133 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1134
1135 def test_default_factory(self):
1136 # Test a factory that returns a new list.
1137 @dataclass
1138 class C:
1139 x: int
1140 y: list = field(default_factory=list)
1141
1142 c0 = C(3)
1143 c1 = C(3)
1144 self.assertEqual(c0.x, 3)
1145 self.assertEqual(c0.y, [])
1146 self.assertEqual(c0, c1)
1147 self.assertIsNot(c0.y, c1.y)
1148 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1149
1150 # Test a factory that returns a shared list.
1151 l = []
1152 @dataclass
1153 class C:
1154 x: int
1155 y: list = field(default_factory=lambda: l)
1156
1157 c0 = C(3)
1158 c1 = C(3)
1159 self.assertEqual(c0.x, 3)
1160 self.assertEqual(c0.y, [])
1161 self.assertEqual(c0, c1)
1162 self.assertIs(c0.y, c1.y)
1163 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1164
1165 # Test various other field flags.
1166 # repr
1167 @dataclass
1168 class C:
1169 x: list = field(default_factory=list, repr=False)
1170 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1171 self.assertEqual(C().x, [])
1172
1173 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001174 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001175 class C:
1176 x: list = field(default_factory=list, hash=False)
1177 self.assertEqual(astuple(C()), ([],))
1178 self.assertEqual(hash(C()), hash(()))
1179
1180 # init (see also test_default_factory_with_no_init)
1181 @dataclass
1182 class C:
1183 x: list = field(default_factory=list, init=False)
1184 self.assertEqual(astuple(C()), ([],))
1185
1186 # compare
1187 @dataclass
1188 class C:
1189 x: list = field(default_factory=list, compare=False)
1190 self.assertEqual(C(), C([1]))
1191
1192 def test_default_factory_with_no_init(self):
1193 # We need a factory with a side effect.
1194 factory = Mock()
1195
1196 @dataclass
1197 class C:
1198 x: list = field(default_factory=factory, init=False)
1199
1200 # Make sure the default factory is called for each new instance.
1201 C().x
1202 self.assertEqual(factory.call_count, 1)
1203 C().x
1204 self.assertEqual(factory.call_count, 2)
1205
1206 def test_default_factory_not_called_if_value_given(self):
1207 # We need a factory that we can test if it's been called.
1208 factory = Mock()
1209
1210 @dataclass
1211 class C:
1212 x: int = field(default_factory=factory)
1213
1214 # Make sure that if a field has a default factory function,
1215 # it's not called if a value is specified.
1216 C().x
1217 self.assertEqual(factory.call_count, 1)
1218 self.assertEqual(C(10).x, 10)
1219 self.assertEqual(factory.call_count, 1)
1220 C().x
1221 self.assertEqual(factory.call_count, 2)
1222
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001223 def test_default_factory_derived(self):
1224 # See bpo-32896.
1225 @dataclass
1226 class Foo:
1227 x: dict = field(default_factory=dict)
1228
1229 @dataclass
1230 class Bar(Foo):
1231 y: int = 1
1232
1233 self.assertEqual(Foo().x, {})
1234 self.assertEqual(Bar().x, {})
1235 self.assertEqual(Bar().y, 1)
1236
1237 @dataclass
1238 class Baz(Foo):
1239 pass
1240 self.assertEqual(Baz().x, {})
1241
1242 def test_intermediate_non_dataclass(self):
1243 # Test that an intermediate class that defines
1244 # annotations does not define fields.
1245
1246 @dataclass
1247 class A:
1248 x: int
1249
1250 class B(A):
1251 y: int
1252
1253 @dataclass
1254 class C(B):
1255 z: int
1256
1257 c = C(1, 3)
1258 self.assertEqual((c.x, c.z), (1, 3))
1259
1260 # .y was not initialized.
1261 with self.assertRaisesRegex(AttributeError,
1262 'object has no attribute'):
1263 c.y
1264
1265 # And if we again derive a non-dataclass, no fields are added.
1266 class D(C):
1267 t: int
1268 d = D(4, 5)
1269 self.assertEqual((d.x, d.z), (4, 5))
1270
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001271 def test_classvar_default_factory(self):
1272 # It's an error for a ClassVar to have a factory function.
1273 with self.assertRaisesRegex(TypeError,
1274 'cannot have a default factory'):
1275 @dataclass
1276 class C:
1277 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001278
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001279 def test_is_dataclass(self):
1280 class NotDataClass:
1281 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001282
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001283 self.assertFalse(is_dataclass(0))
1284 self.assertFalse(is_dataclass(int))
1285 self.assertFalse(is_dataclass(NotDataClass))
1286 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001287
1288 @dataclass
1289 class C:
1290 x: int
1291
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001292 @dataclass
1293 class D:
1294 d: C
1295 e: int
1296
1297 c = C(10)
1298 d = D(c, 4)
1299
1300 self.assertTrue(is_dataclass(C))
1301 self.assertTrue(is_dataclass(c))
1302 self.assertFalse(is_dataclass(c.x))
1303 self.assertTrue(is_dataclass(d.d))
1304 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001305
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001306 def test_is_dataclass_when_getattr_always_returns(self):
1307 # See bpo-37868.
1308 class A:
1309 def __getattr__(self, key):
1310 return 0
1311 self.assertFalse(is_dataclass(A))
1312 a = A()
1313
1314 # Also test for an instance attribute.
1315 class B:
1316 pass
1317 b = B()
1318 b.__dataclass_fields__ = []
1319
1320 for obj in a, b:
1321 with self.subTest(obj=obj):
1322 self.assertFalse(is_dataclass(obj))
1323
1324 # Indirect tests for _is_dataclass_instance().
1325 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1326 asdict(obj)
1327 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1328 astuple(obj)
1329 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1330 replace(obj, x=0)
1331
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001332 def test_helper_fields_with_class_instance(self):
1333 # Check that we can call fields() on either a class or instance,
1334 # and get back the same thing.
1335 @dataclass
1336 class C:
1337 x: int
1338 y: float
1339
1340 self.assertEqual(fields(C), fields(C(0, 0.0)))
1341
1342 def test_helper_fields_exception(self):
1343 # Check that TypeError is raised if not passed a dataclass or
1344 # instance.
1345 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1346 fields(0)
1347
1348 class C: pass
1349 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1350 fields(C)
1351 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1352 fields(C())
1353
1354 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001355 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001356 @dataclass
1357 class C:
1358 x: int
1359 y: int
1360 c = C(1, 2)
1361
1362 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1363 self.assertEqual(asdict(c), asdict(c))
1364 self.assertIsNot(asdict(c), asdict(c))
1365 c.x = 42
1366 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1367 self.assertIs(type(asdict(c)), dict)
1368
1369 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001370 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001371 @dataclass
1372 class C:
1373 x: int
1374 y: int
1375 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1376 asdict(C)
1377 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1378 asdict(int)
1379
1380 def test_helper_asdict_copy_values(self):
1381 @dataclass
1382 class C:
1383 x: int
1384 y: List[int] = field(default_factory=list)
1385 initial = []
1386 c = C(1, initial)
1387 d = asdict(c)
1388 self.assertEqual(d['y'], initial)
1389 self.assertIsNot(d['y'], initial)
1390 c = C(1)
1391 d = asdict(c)
1392 d['y'].append(1)
1393 self.assertEqual(c.y, [])
1394
1395 def test_helper_asdict_nested(self):
1396 @dataclass
1397 class UserId:
1398 token: int
1399 group: int
1400 @dataclass
1401 class User:
1402 name: str
1403 id: UserId
1404 u = User('Joe', UserId(123, 1))
1405 d = asdict(u)
1406 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1407 self.assertIsNot(asdict(u), asdict(u))
1408 u.id.group = 2
1409 self.assertEqual(asdict(u), {'name': 'Joe',
1410 'id': {'token': 123, 'group': 2}})
1411
1412 def test_helper_asdict_builtin_containers(self):
1413 @dataclass
1414 class User:
1415 name: str
1416 id: int
1417 @dataclass
1418 class GroupList:
1419 id: int
1420 users: List[User]
1421 @dataclass
1422 class GroupTuple:
1423 id: int
1424 users: Tuple[User, ...]
1425 @dataclass
1426 class GroupDict:
1427 id: int
1428 users: Dict[str, User]
1429 a = User('Alice', 1)
1430 b = User('Bob', 2)
1431 gl = GroupList(0, [a, b])
1432 gt = GroupTuple(0, (a, b))
1433 gd = GroupDict(0, {'first': a, 'second': b})
1434 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1435 {'name': 'Bob', 'id': 2}]})
1436 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1437 {'name': 'Bob', 'id': 2})})
1438 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1439 'second': {'name': 'Bob', 'id': 2}}})
1440
Windson yangbe372d72019-04-23 02:45:34 +08001441 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001442 @dataclass
1443 class Child:
1444 d: object
1445
1446 @dataclass
1447 class Parent:
1448 child: Child
1449
1450 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1451 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1452
1453 def test_helper_asdict_factory(self):
1454 @dataclass
1455 class C:
1456 x: int
1457 y: int
1458 c = C(1, 2)
1459 d = asdict(c, dict_factory=OrderedDict)
1460 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1461 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1462 c.x = 42
1463 d = asdict(c, dict_factory=OrderedDict)
1464 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1465 self.assertIs(type(d), OrderedDict)
1466
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001467 def test_helper_asdict_namedtuple(self):
1468 T = namedtuple('T', 'a b c')
1469 @dataclass
1470 class C:
1471 x: str
1472 y: T
1473 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1474
1475 d = asdict(c)
1476 self.assertEqual(d, {'x': 'outer',
1477 'y': T(1,
1478 {'x': 'inner',
1479 'y': T(11, 12, 13)},
1480 2),
1481 }
1482 )
1483
1484 # Now with a dict_factory. OrderedDict is convenient, but
1485 # since it compares to dicts, we also need to have separate
1486 # assertIs tests.
1487 d = asdict(c, dict_factory=OrderedDict)
1488 self.assertEqual(d, {'x': 'outer',
1489 'y': T(1,
1490 {'x': 'inner',
1491 'y': T(11, 12, 13)},
1492 2),
1493 }
1494 )
1495
penguindustin96466302019-05-06 14:57:17 -04001496 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001497 self.assertIs(type(d), OrderedDict)
1498 self.assertIs(type(d['y'][1]), OrderedDict)
1499
1500 def test_helper_asdict_namedtuple_key(self):
1501 # Ensure that a field that contains a dict which has a
1502 # namedtuple as a key works with asdict().
1503
1504 @dataclass
1505 class C:
1506 f: dict
1507 T = namedtuple('T', 'a')
1508
1509 c = C({T('an a'): 0})
1510
1511 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1512
1513 def test_helper_asdict_namedtuple_derived(self):
1514 class T(namedtuple('Tbase', 'a')):
1515 def my_a(self):
1516 return self.a
1517
1518 @dataclass
1519 class C:
1520 f: T
1521
1522 t = T(6)
1523 c = C(t)
1524
1525 d = asdict(c)
1526 self.assertEqual(d, {'f': T(a=6)})
1527 # Make sure that t has been copied, not used directly.
1528 self.assertIsNot(d['f'], t)
1529 self.assertEqual(d['f'].my_a(), 6)
1530
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001531 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001532 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001533 @dataclass
1534 class C:
1535 x: int
1536 y: int = 0
1537 c = C(1)
1538
1539 self.assertEqual(astuple(c), (1, 0))
1540 self.assertEqual(astuple(c), astuple(c))
1541 self.assertIsNot(astuple(c), astuple(c))
1542 c.y = 42
1543 self.assertEqual(astuple(c), (1, 42))
1544 self.assertIs(type(astuple(c)), tuple)
1545
1546 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001547 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001548 @dataclass
1549 class C:
1550 x: int
1551 y: int
1552 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1553 astuple(C)
1554 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1555 astuple(int)
1556
1557 def test_helper_astuple_copy_values(self):
1558 @dataclass
1559 class C:
1560 x: int
1561 y: List[int] = field(default_factory=list)
1562 initial = []
1563 c = C(1, initial)
1564 t = astuple(c)
1565 self.assertEqual(t[1], initial)
1566 self.assertIsNot(t[1], initial)
1567 c = C(1)
1568 t = astuple(c)
1569 t[1].append(1)
1570 self.assertEqual(c.y, [])
1571
1572 def test_helper_astuple_nested(self):
1573 @dataclass
1574 class UserId:
1575 token: int
1576 group: int
1577 @dataclass
1578 class User:
1579 name: str
1580 id: UserId
1581 u = User('Joe', UserId(123, 1))
1582 t = astuple(u)
1583 self.assertEqual(t, ('Joe', (123, 1)))
1584 self.assertIsNot(astuple(u), astuple(u))
1585 u.id.group = 2
1586 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1587
1588 def test_helper_astuple_builtin_containers(self):
1589 @dataclass
1590 class User:
1591 name: str
1592 id: int
1593 @dataclass
1594 class GroupList:
1595 id: int
1596 users: List[User]
1597 @dataclass
1598 class GroupTuple:
1599 id: int
1600 users: Tuple[User, ...]
1601 @dataclass
1602 class GroupDict:
1603 id: int
1604 users: Dict[str, User]
1605 a = User('Alice', 1)
1606 b = User('Bob', 2)
1607 gl = GroupList(0, [a, b])
1608 gt = GroupTuple(0, (a, b))
1609 gd = GroupDict(0, {'first': a, 'second': b})
1610 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1611 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1612 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1613
Windson yangbe372d72019-04-23 02:45:34 +08001614 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001615 @dataclass
1616 class Child:
1617 d: object
1618
1619 @dataclass
1620 class Parent:
1621 child: Child
1622
1623 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1624 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1625
1626 def test_helper_astuple_factory(self):
1627 @dataclass
1628 class C:
1629 x: int
1630 y: int
1631 NT = namedtuple('NT', 'x y')
1632 def nt(lst):
1633 return NT(*lst)
1634 c = C(1, 2)
1635 t = astuple(c, tuple_factory=nt)
1636 self.assertEqual(t, NT(1, 2))
1637 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1638 c.x = 42
1639 t = astuple(c, tuple_factory=nt)
1640 self.assertEqual(t, NT(42, 2))
1641 self.assertIs(type(t), NT)
1642
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001643 def test_helper_astuple_namedtuple(self):
1644 T = namedtuple('T', 'a b c')
1645 @dataclass
1646 class C:
1647 x: str
1648 y: T
1649 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1650
1651 t = astuple(c)
1652 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1653
1654 # Now, using a tuple_factory. list is convenient here.
1655 t = astuple(c, tuple_factory=list)
1656 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1657
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001658 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001659 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001660 }
1661
1662 # Create the class.
1663 cls = type('C', (), cls_dict)
1664
1665 # Make it a dataclass.
1666 cls1 = dataclass(cls)
1667
1668 self.assertEqual(cls1, cls)
1669 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1670
1671 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001672 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001673 'y': field(default=5),
1674 }
1675
1676 # Create the class.
1677 cls = type('C', (), cls_dict)
1678
1679 # Make it a dataclass.
1680 cls1 = dataclass(cls)
1681
1682 self.assertEqual(cls1, cls)
1683 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1684
1685 def test_init_in_order(self):
1686 @dataclass
1687 class C:
1688 a: int
1689 b: int = field()
1690 c: list = field(default_factory=list, init=False)
1691 d: list = field(default_factory=list)
1692 e: int = field(default=4, init=False)
1693 f: int = 4
1694
1695 calls = []
1696 def setattr(self, name, value):
1697 calls.append((name, value))
1698
1699 C.__setattr__ = setattr
1700 c = C(0, 1)
1701 self.assertEqual(('a', 0), calls[0])
1702 self.assertEqual(('b', 1), calls[1])
1703 self.assertEqual(('c', []), calls[2])
1704 self.assertEqual(('d', []), calls[3])
1705 self.assertNotIn(('e', 4), calls)
1706 self.assertEqual(('f', 4), calls[4])
1707
1708 def test_items_in_dicts(self):
1709 @dataclass
1710 class C:
1711 a: int
1712 b: list = field(default_factory=list, init=False)
1713 c: list = field(default_factory=list)
1714 d: int = field(default=4, init=False)
1715 e: int = 0
1716
1717 c = C(0)
1718 # Class dict
1719 self.assertNotIn('a', C.__dict__)
1720 self.assertNotIn('b', C.__dict__)
1721 self.assertNotIn('c', C.__dict__)
1722 self.assertIn('d', C.__dict__)
1723 self.assertEqual(C.d, 4)
1724 self.assertIn('e', C.__dict__)
1725 self.assertEqual(C.e, 0)
1726 # Instance dict
1727 self.assertIn('a', c.__dict__)
1728 self.assertEqual(c.a, 0)
1729 self.assertIn('b', c.__dict__)
1730 self.assertEqual(c.b, [])
1731 self.assertIn('c', c.__dict__)
1732 self.assertEqual(c.c, [])
1733 self.assertNotIn('d', c.__dict__)
1734 self.assertIn('e', c.__dict__)
1735 self.assertEqual(c.e, 0)
1736
1737 def test_alternate_classmethod_constructor(self):
1738 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001739 # alternate constructor. This is mostly an example to show
1740 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001741 @dataclass
1742 class C:
1743 x: int
1744 @classmethod
1745 def from_file(cls, filename):
1746 # In a real example, create a new instance
1747 # and populate 'x' from contents of a file.
1748 value_in_file = 20
1749 return cls(value_in_file)
1750
1751 self.assertEqual(C.from_file('filename').x, 20)
1752
1753 def test_field_metadata_default(self):
1754 # Make sure the default metadata is read-only and of
1755 # zero length.
1756 @dataclass
1757 class C:
1758 i: int
1759
1760 self.assertFalse(fields(C)[0].metadata)
1761 self.assertEqual(len(fields(C)[0].metadata), 0)
1762 with self.assertRaisesRegex(TypeError,
1763 'does not support item assignment'):
1764 fields(C)[0].metadata['test'] = 3
1765
1766 def test_field_metadata_mapping(self):
1767 # Make sure only a mapping can be passed as metadata
1768 # zero length.
1769 with self.assertRaises(TypeError):
1770 @dataclass
1771 class C:
1772 i: int = field(metadata=0)
1773
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001774 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001775 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001776 @dataclass
1777 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001778 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001779 self.assertFalse(fields(C)[0].metadata)
1780 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001781 # Update should work (see bpo-35960).
1782 d['foo'] = 1
1783 self.assertEqual(len(fields(C)[0].metadata), 1)
1784 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001785 with self.assertRaisesRegex(TypeError,
1786 'does not support item assignment'):
1787 fields(C)[0].metadata['test'] = 3
1788
1789 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001790 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001791 @dataclass
1792 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001793 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001794 self.assertEqual(len(fields(C)[0].metadata), 3)
1795 self.assertEqual(fields(C)[0].metadata['test'], 10)
1796 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1797 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001798 # Update should work.
1799 d['foo'] = 1
1800 self.assertEqual(len(fields(C)[0].metadata), 4)
1801 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001802 with self.assertRaises(KeyError):
1803 # Non-existent key.
1804 fields(C)[0].metadata['baz']
1805 with self.assertRaisesRegex(TypeError,
1806 'does not support item assignment'):
1807 fields(C)[0].metadata['test'] = 3
1808
1809 def test_field_metadata_custom_mapping(self):
1810 # Try a custom mapping.
1811 class SimpleNameSpace:
1812 def __init__(self, **kw):
1813 self.__dict__.update(kw)
1814
1815 def __getitem__(self, item):
1816 if item == 'xyzzy':
1817 return 'plugh'
1818 return getattr(self, item)
1819
1820 def __len__(self):
1821 return self.__dict__.__len__()
1822
1823 @dataclass
1824 class C:
1825 i: int = field(metadata=SimpleNameSpace(a=10))
1826
1827 self.assertEqual(len(fields(C)[0].metadata), 1)
1828 self.assertEqual(fields(C)[0].metadata['a'], 10)
1829 with self.assertRaises(AttributeError):
1830 fields(C)[0].metadata['b']
1831 # Make sure we're still talking to our custom mapping.
1832 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1833
1834 def test_generic_dataclasses(self):
1835 T = TypeVar('T')
1836
1837 @dataclass
1838 class LabeledBox(Generic[T]):
1839 content: T
1840 label: str = '<unknown>'
1841
1842 box = LabeledBox(42)
1843 self.assertEqual(box.content, 42)
1844 self.assertEqual(box.label, '<unknown>')
1845
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001846 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001847 Alias = List[LabeledBox[int]]
1848
1849 def test_generic_extending(self):
1850 S = TypeVar('S')
1851 T = TypeVar('T')
1852
1853 @dataclass
1854 class Base(Generic[T, S]):
1855 x: T
1856 y: S
1857
1858 @dataclass
1859 class DataDerived(Base[int, T]):
1860 new_field: str
1861 Alias = DataDerived[str]
1862 c = Alias(0, 'test1', 'test2')
1863 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1864
1865 class NonDataDerived(Base[int, T]):
1866 def new_method(self):
1867 return self.y
1868 Alias = NonDataDerived[float]
1869 c = Alias(10, 1.0)
1870 self.assertEqual(c.new_method(), 1.0)
1871
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001872 def test_generic_dynamic(self):
1873 T = TypeVar('T')
1874
1875 @dataclass
1876 class Parent(Generic[T]):
1877 x: T
1878 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1879 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1880 self.assertIs(Child[int](1, 2).z, None)
1881 self.assertEqual(Child[int](1, 2, 3).z, 3)
1882 self.assertEqual(Child[int](1, 2, 3).other, 42)
1883 # Check that type aliases work correctly.
1884 Alias = Child[T]
1885 self.assertEqual(Alias[int](1, 2).x, 1)
1886 # Check MRO resolution.
1887 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1888
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001889 def test_dataclassses_pickleable(self):
1890 global P, Q, R
1891 @dataclass
1892 class P:
1893 x: int
1894 y: int = 0
1895 @dataclass
1896 class Q:
1897 x: int
1898 y: int = field(default=0, init=False)
1899 @dataclass
1900 class R:
1901 x: int
1902 y: List[int] = field(default_factory=list)
1903 q = Q(1)
1904 q.y = 2
1905 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1906 for sample in samples:
1907 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1908 with self.subTest(sample=sample, proto=proto):
1909 new_sample = pickle.loads(pickle.dumps(sample, proto))
1910 self.assertEqual(sample.x, new_sample.x)
1911 self.assertEqual(sample.y, new_sample.y)
1912 self.assertIsNot(sample, new_sample)
1913 new_sample.x = 42
1914 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1915 self.assertEqual(new_sample.x, another_new_sample.x)
1916 self.assertEqual(sample.y, another_new_sample.y)
1917
Eric V. Smithea8fc522018-01-27 19:07:40 -05001918
Eric V. Smith56970b82018-03-22 16:28:48 -04001919class TestFieldNoAnnotation(unittest.TestCase):
1920 def test_field_without_annotation(self):
1921 with self.assertRaisesRegex(TypeError,
1922 "'f' is a field but has no type annotation"):
1923 @dataclass
1924 class C:
1925 f = field()
1926
1927 def test_field_without_annotation_but_annotation_in_base(self):
1928 @dataclass
1929 class B:
1930 f: int
1931
1932 with self.assertRaisesRegex(TypeError,
1933 "'f' is a field but has no type annotation"):
1934 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001935 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001936 @dataclass
1937 class C(B):
1938 f = field()
1939
1940 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1941 # Same test, but with the base class not a dataclass.
1942 class B:
1943 f: int
1944
1945 with self.assertRaisesRegex(TypeError,
1946 "'f' is a field but has no type annotation"):
1947 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001948 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001949 @dataclass
1950 class C(B):
1951 f = field()
1952
1953
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001954class TestDocString(unittest.TestCase):
1955 def assertDocStrEqual(self, a, b):
1956 # Because 3.6 and 3.7 differ in how inspect.signature work
1957 # (see bpo #32108), for the time being just compare them with
1958 # whitespace stripped.
1959 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1960
1961 def test_existing_docstring_not_overridden(self):
1962 @dataclass
1963 class C:
1964 """Lorem ipsum"""
1965 x: int
1966
1967 self.assertEqual(C.__doc__, "Lorem ipsum")
1968
1969 def test_docstring_no_fields(self):
1970 @dataclass
1971 class C:
1972 pass
1973
1974 self.assertDocStrEqual(C.__doc__, "C()")
1975
1976 def test_docstring_one_field(self):
1977 @dataclass
1978 class C:
1979 x: int
1980
1981 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1982
1983 def test_docstring_two_fields(self):
1984 @dataclass
1985 class C:
1986 x: int
1987 y: int
1988
1989 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1990
1991 def test_docstring_three_fields(self):
1992 @dataclass
1993 class C:
1994 x: int
1995 y: int
1996 z: str
1997
1998 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1999
2000 def test_docstring_one_field_with_default(self):
2001 @dataclass
2002 class C:
2003 x: int = 3
2004
2005 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2006
2007 def test_docstring_one_field_with_default_none(self):
2008 @dataclass
2009 class C:
2010 x: Union[int, type(None)] = None
2011
2012 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2013
2014 def test_docstring_list_field(self):
2015 @dataclass
2016 class C:
2017 x: List[int]
2018
2019 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2020
2021 def test_docstring_list_field_with_default_factory(self):
2022 @dataclass
2023 class C:
2024 x: List[int] = field(default_factory=list)
2025
2026 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2027
2028 def test_docstring_deque_field(self):
2029 @dataclass
2030 class C:
2031 x: deque
2032
2033 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2034
2035 def test_docstring_deque_field_with_default_factory(self):
2036 @dataclass
2037 class C:
2038 x: deque = field(default_factory=deque)
2039
2040 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2041
2042
Eric V. Smithea8fc522018-01-27 19:07:40 -05002043class TestInit(unittest.TestCase):
2044 def test_base_has_init(self):
2045 class B:
2046 def __init__(self):
2047 self.z = 100
2048 pass
2049
2050 # Make sure that declaring this class doesn't raise an error.
2051 # The issue is that we can't override __init__ in our class,
2052 # but it should be okay to add __init__ to us if our base has
2053 # an __init__.
2054 @dataclass
2055 class C(B):
2056 x: int = 0
2057 c = C(10)
2058 self.assertEqual(c.x, 10)
2059 self.assertNotIn('z', vars(c))
2060
2061 # Make sure that if we don't add an init, the base __init__
2062 # gets called.
2063 @dataclass(init=False)
2064 class C(B):
2065 x: int = 10
2066 c = C()
2067 self.assertEqual(c.x, 10)
2068 self.assertEqual(c.z, 100)
2069
2070 def test_no_init(self):
2071 dataclass(init=False)
2072 class C:
2073 i: int = 0
2074 self.assertEqual(C().i, 0)
2075
2076 dataclass(init=False)
2077 class C:
2078 i: int = 2
2079 def __init__(self):
2080 self.i = 3
2081 self.assertEqual(C().i, 3)
2082
2083 def test_overwriting_init(self):
2084 # If the class has __init__, use it no matter the value of
2085 # init=.
2086
2087 @dataclass
2088 class C:
2089 x: int
2090 def __init__(self, x):
2091 self.x = 2 * x
2092 self.assertEqual(C(3).x, 6)
2093
2094 @dataclass(init=True)
2095 class C:
2096 x: int
2097 def __init__(self, x):
2098 self.x = 2 * x
2099 self.assertEqual(C(4).x, 8)
2100
2101 @dataclass(init=False)
2102 class C:
2103 x: int
2104 def __init__(self, x):
2105 self.x = 2 * x
2106 self.assertEqual(C(5).x, 10)
2107
2108
2109class TestRepr(unittest.TestCase):
2110 def test_repr(self):
2111 @dataclass
2112 class B:
2113 x: int
2114
2115 @dataclass
2116 class C(B):
2117 y: int = 10
2118
2119 o = C(4)
2120 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2121
2122 @dataclass
2123 class D(C):
2124 x: int = 20
2125 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2126
2127 @dataclass
2128 class C:
2129 @dataclass
2130 class D:
2131 i: int
2132 @dataclass
2133 class E:
2134 pass
2135 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2136 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2137
2138 def test_no_repr(self):
2139 # Test a class with no __repr__ and repr=False.
2140 @dataclass(repr=False)
2141 class C:
2142 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002143 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002144 repr(C(3)))
2145
2146 # Test a class with a __repr__ and repr=False.
2147 @dataclass(repr=False)
2148 class C:
2149 x: int
2150 def __repr__(self):
2151 return 'C-class'
2152 self.assertEqual(repr(C(3)), 'C-class')
2153
2154 def test_overwriting_repr(self):
2155 # If the class has __repr__, use it no matter the value of
2156 # repr=.
2157
2158 @dataclass
2159 class C:
2160 x: int
2161 def __repr__(self):
2162 return 'x'
2163 self.assertEqual(repr(C(0)), 'x')
2164
2165 @dataclass(repr=True)
2166 class C:
2167 x: int
2168 def __repr__(self):
2169 return 'x'
2170 self.assertEqual(repr(C(0)), 'x')
2171
2172 @dataclass(repr=False)
2173 class C:
2174 x: int
2175 def __repr__(self):
2176 return 'x'
2177 self.assertEqual(repr(C(0)), 'x')
2178
2179
Eric V. Smithea8fc522018-01-27 19:07:40 -05002180class TestEq(unittest.TestCase):
2181 def test_no_eq(self):
2182 # Test a class with no __eq__ and eq=False.
2183 @dataclass(eq=False)
2184 class C:
2185 x: int
2186 self.assertNotEqual(C(0), C(0))
2187 c = C(3)
2188 self.assertEqual(c, c)
2189
2190 # Test a class with an __eq__ and eq=False.
2191 @dataclass(eq=False)
2192 class C:
2193 x: int
2194 def __eq__(self, other):
2195 return other == 10
2196 self.assertEqual(C(3), 10)
2197
2198 def test_overwriting_eq(self):
2199 # If the class has __eq__, use it no matter the value of
2200 # eq=.
2201
2202 @dataclass
2203 class C:
2204 x: int
2205 def __eq__(self, other):
2206 return other == 3
2207 self.assertEqual(C(1), 3)
2208 self.assertNotEqual(C(1), 1)
2209
2210 @dataclass(eq=True)
2211 class C:
2212 x: int
2213 def __eq__(self, other):
2214 return other == 4
2215 self.assertEqual(C(1), 4)
2216 self.assertNotEqual(C(1), 1)
2217
2218 @dataclass(eq=False)
2219 class C:
2220 x: int
2221 def __eq__(self, other):
2222 return other == 5
2223 self.assertEqual(C(1), 5)
2224 self.assertNotEqual(C(1), 1)
2225
2226
2227class TestOrdering(unittest.TestCase):
2228 def test_functools_total_ordering(self):
2229 # Test that functools.total_ordering works with this class.
2230 @total_ordering
2231 @dataclass
2232 class C:
2233 x: int
2234 def __lt__(self, other):
2235 # Perform the test "backward", just to make
2236 # sure this is being called.
2237 return self.x >= other
2238
2239 self.assertLess(C(0), -1)
2240 self.assertLessEqual(C(0), -1)
2241 self.assertGreater(C(0), 1)
2242 self.assertGreaterEqual(C(0), 1)
2243
2244 def test_no_order(self):
2245 # Test that no ordering functions are added by default.
2246 @dataclass(order=False)
2247 class C:
2248 x: int
2249 # Make sure no order methods are added.
2250 self.assertNotIn('__le__', C.__dict__)
2251 self.assertNotIn('__lt__', C.__dict__)
2252 self.assertNotIn('__ge__', C.__dict__)
2253 self.assertNotIn('__gt__', C.__dict__)
2254
2255 # Test that __lt__ is still called
2256 @dataclass(order=False)
2257 class C:
2258 x: int
2259 def __lt__(self, other):
2260 return False
2261 # Make sure other methods aren't added.
2262 self.assertNotIn('__le__', C.__dict__)
2263 self.assertNotIn('__ge__', C.__dict__)
2264 self.assertNotIn('__gt__', C.__dict__)
2265
2266 def test_overwriting_order(self):
2267 with self.assertRaisesRegex(TypeError,
2268 'Cannot overwrite attribute __lt__'
2269 '.*using functools.total_ordering'):
2270 @dataclass(order=True)
2271 class C:
2272 x: int
2273 def __lt__(self):
2274 pass
2275
2276 with self.assertRaisesRegex(TypeError,
2277 'Cannot overwrite attribute __le__'
2278 '.*using functools.total_ordering'):
2279 @dataclass(order=True)
2280 class C:
2281 x: int
2282 def __le__(self):
2283 pass
2284
2285 with self.assertRaisesRegex(TypeError,
2286 'Cannot overwrite attribute __gt__'
2287 '.*using functools.total_ordering'):
2288 @dataclass(order=True)
2289 class C:
2290 x: int
2291 def __gt__(self):
2292 pass
2293
2294 with self.assertRaisesRegex(TypeError,
2295 'Cannot overwrite attribute __ge__'
2296 '.*using functools.total_ordering'):
2297 @dataclass(order=True)
2298 class C:
2299 x: int
2300 def __ge__(self):
2301 pass
2302
2303class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002304 def test_unsafe_hash(self):
2305 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002306 class C:
2307 x: int
2308 y: str
2309 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2310
Eric V. Smithea8fc522018-01-27 19:07:40 -05002311 def test_hash_rules(self):
2312 def non_bool(value):
2313 # Map to something else that's True, but not a bool.
2314 if value is None:
2315 return None
2316 if value:
2317 return (3,)
2318 return 0
2319
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002320 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2321 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2322 frozen=frozen):
2323 if result != 'exception':
2324 if with_hash:
2325 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2326 class C:
2327 def __hash__(self):
2328 return 0
2329 else:
2330 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2331 class C:
2332 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002333
2334 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002335 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002336 # __hash__ contains the function we generated.
2337 self.assertIn('__hash__', C.__dict__)
2338 self.assertIsNotNone(C.__dict__['__hash__'])
2339
Eric V. Smithea8fc522018-01-27 19:07:40 -05002340 elif result == '':
2341 # __hash__ is not present in our class.
2342 if not with_hash:
2343 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002344
Eric V. Smithea8fc522018-01-27 19:07:40 -05002345 elif result == 'none':
2346 # __hash__ is set to None.
2347 self.assertIn('__hash__', C.__dict__)
2348 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002349
2350 elif result == 'exception':
2351 # Creating the class should cause an exception.
2352 # This only happens with with_hash==True.
2353 assert(with_hash)
2354 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2355 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2356 class C:
2357 def __hash__(self):
2358 return 0
2359
Eric V. Smithea8fc522018-01-27 19:07:40 -05002360 else:
2361 assert False, f'unknown result {result!r}'
2362
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002363 # There are 8 cases of:
2364 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002365 # eq=True/False
2366 # frozen=True/False
2367 # And for each of these, a different result if
2368 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002369 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2370 (False, False, False, '', ''),
2371 (False, False, True, '', ''),
2372 (False, True, False, 'none', ''),
2373 (False, True, True, 'fn', ''),
2374 (True, False, False, 'fn', 'exception'),
2375 (True, False, True, 'fn', 'exception'),
2376 (True, True, False, 'fn', 'exception'),
2377 (True, True, True, 'fn', 'exception'),
2378 ], 1):
2379 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2380 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002381
2382 # Test non-bool truth values, too. This is just to
2383 # make sure the data-driven table in the decorator
2384 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002385 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2386 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002387
2388
2389 def test_eq_only(self):
2390 # If a class defines __eq__, __hash__ is automatically added
2391 # and set to None. This is normal Python behavior, not
2392 # related to dataclasses. Make sure we don't interfere with
2393 # that (see bpo=32546).
2394
2395 @dataclass
2396 class C:
2397 i: int
2398 def __eq__(self, other):
2399 return self.i == other.i
2400 self.assertEqual(C(1), C(1))
2401 self.assertNotEqual(C(1), C(4))
2402
2403 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002404 # unsafe_hash=True.
2405 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002406 class C:
2407 i: int
2408 def __eq__(self, other):
2409 return self.i == other.i
2410 self.assertEqual(C(1), C(1.0))
2411 self.assertEqual(hash(C(1)), hash(C(1.0)))
2412
2413 # And check that the classes __eq__ is being used, despite
2414 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002415 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002416 class C:
2417 i: int
2418 def __eq__(self, other):
2419 return self.i == 3 and self.i == other.i
2420 self.assertEqual(C(3), C(3))
2421 self.assertNotEqual(C(1), C(1))
2422 self.assertEqual(hash(C(1)), hash(C(1.0)))
2423
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002424 def test_0_field_hash(self):
2425 @dataclass(frozen=True)
2426 class C:
2427 pass
2428 self.assertEqual(hash(C()), hash(()))
2429
2430 @dataclass(unsafe_hash=True)
2431 class C:
2432 pass
2433 self.assertEqual(hash(C()), hash(()))
2434
2435 def test_1_field_hash(self):
2436 @dataclass(frozen=True)
2437 class C:
2438 x: int
2439 self.assertEqual(hash(C(4)), hash((4,)))
2440 self.assertEqual(hash(C(42)), hash((42,)))
2441
2442 @dataclass(unsafe_hash=True)
2443 class C:
2444 x: int
2445 self.assertEqual(hash(C(4)), hash((4,)))
2446 self.assertEqual(hash(C(42)), hash((42,)))
2447
Eric V. Smith718070d2018-02-23 13:01:31 -05002448 def test_hash_no_args(self):
2449 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002450 # make sure that if the @dataclass parameter name is changed
2451 # or the non-default hashing behavior changes, the default
2452 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002453
2454 class Base:
2455 def __hash__(self):
2456 return 301
2457
2458 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002459 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002460 for frozen, eq, base, expected in [
2461 (None, None, object, 'unhashable'),
2462 (None, None, Base, 'unhashable'),
2463 (None, False, object, 'object'),
2464 (None, False, Base, 'base'),
2465 (None, True, object, 'unhashable'),
2466 (None, True, Base, 'unhashable'),
2467 (False, None, object, 'unhashable'),
2468 (False, None, Base, 'unhashable'),
2469 (False, False, object, 'object'),
2470 (False, False, Base, 'base'),
2471 (False, True, object, 'unhashable'),
2472 (False, True, Base, 'unhashable'),
2473 (True, None, object, 'tuple'),
2474 (True, None, Base, 'tuple'),
2475 (True, False, object, 'object'),
2476 (True, False, Base, 'base'),
2477 (True, True, object, 'tuple'),
2478 (True, True, Base, 'tuple'),
2479 ]:
2480
2481 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2482 # First, create the class.
2483 if frozen is None and eq is None:
2484 @dataclass
2485 class C(base):
2486 i: int
2487 elif frozen is None:
2488 @dataclass(eq=eq)
2489 class C(base):
2490 i: int
2491 elif eq is None:
2492 @dataclass(frozen=frozen)
2493 class C(base):
2494 i: int
2495 else:
2496 @dataclass(frozen=frozen, eq=eq)
2497 class C(base):
2498 i: int
2499
2500 # Now, make sure it hashes as expected.
2501 if expected == 'unhashable':
2502 c = C(10)
2503 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2504 hash(c)
2505
2506 elif expected == 'base':
2507 self.assertEqual(hash(C(10)), 301)
2508
2509 elif expected == 'object':
2510 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002511 # hash isn't based on id(), so calling hash()
2512 # won't tell us much. So, just check the
2513 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002514 self.assertIs(C.__hash__, object.__hash__)
2515
2516 elif expected == 'tuple':
2517 self.assertEqual(hash(C(42)), hash((42,)))
2518
2519 else:
2520 assert False, f'unknown value for expected={expected!r}'
2521
Eric V. Smithea8fc522018-01-27 19:07:40 -05002522
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002523class TestFrozen(unittest.TestCase):
2524 def test_frozen(self):
2525 @dataclass(frozen=True)
2526 class C:
2527 i: int
2528
2529 c = C(10)
2530 self.assertEqual(c.i, 10)
2531 with self.assertRaises(FrozenInstanceError):
2532 c.i = 5
2533 self.assertEqual(c.i, 10)
2534
2535 def test_inherit(self):
2536 @dataclass(frozen=True)
2537 class C:
2538 i: int
2539
2540 @dataclass(frozen=True)
2541 class D(C):
2542 j: int
2543
2544 d = D(0, 10)
2545 with self.assertRaises(FrozenInstanceError):
2546 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002547 with self.assertRaises(FrozenInstanceError):
2548 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002549 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002550 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002551
Eric V. Smithf199bc62018-03-18 20:40:34 -04002552 # Test both ways: with an intermediate normal (non-dataclass)
2553 # class and without an intermediate class.
2554 def test_inherit_nonfrozen_from_frozen(self):
2555 for intermediate_class in [True, False]:
2556 with self.subTest(intermediate_class=intermediate_class):
2557 @dataclass(frozen=True)
2558 class C:
2559 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002560
Eric V. Smithf199bc62018-03-18 20:40:34 -04002561 if intermediate_class:
2562 class I(C): pass
2563 else:
2564 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002565
Eric V. Smithf199bc62018-03-18 20:40:34 -04002566 with self.assertRaisesRegex(TypeError,
2567 'cannot inherit non-frozen dataclass from a frozen one'):
2568 @dataclass
2569 class D(I):
2570 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002571
Eric V. Smithf199bc62018-03-18 20:40:34 -04002572 def test_inherit_frozen_from_nonfrozen(self):
2573 for intermediate_class in [True, False]:
2574 with self.subTest(intermediate_class=intermediate_class):
2575 @dataclass
2576 class C:
2577 i: int
2578
2579 if intermediate_class:
2580 class I(C): pass
2581 else:
2582 I = C
2583
2584 with self.assertRaisesRegex(TypeError,
2585 'cannot inherit frozen dataclass from a non-frozen one'):
2586 @dataclass(frozen=True)
2587 class D(I):
2588 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002589
2590 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002591 for intermediate_class in [True, False]:
2592 with self.subTest(intermediate_class=intermediate_class):
2593 class C:
2594 pass
2595
2596 if intermediate_class:
2597 class I(C): pass
2598 else:
2599 I = C
2600
2601 @dataclass(frozen=True)
2602 class D(I):
2603 i: int
2604
2605 d = D(10)
2606 with self.assertRaises(FrozenInstanceError):
2607 d.i = 5
2608
2609 def test_non_frozen_normal_derived(self):
2610 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002611
2612 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002613 class D:
2614 x: int
2615 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002616
Eric V. Smithf199bc62018-03-18 20:40:34 -04002617 class S(D):
2618 pass
2619
2620 s = S(3)
2621 self.assertEqual(s.x, 3)
2622 self.assertEqual(s.y, 10)
2623 s.cached = True
2624
2625 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002626 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002627 s.x = 5
2628 with self.assertRaises(FrozenInstanceError):
2629 s.y = 5
2630 self.assertEqual(s.x, 3)
2631 self.assertEqual(s.y, 10)
2632 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002633
Eric V. Smith74940912018-04-05 06:50:18 -04002634 def test_overwriting_frozen(self):
2635 # frozen uses __setattr__ and __delattr__.
2636 with self.assertRaisesRegex(TypeError,
2637 'Cannot overwrite attribute __setattr__'):
2638 @dataclass(frozen=True)
2639 class C:
2640 x: int
2641 def __setattr__(self):
2642 pass
2643
2644 with self.assertRaisesRegex(TypeError,
2645 'Cannot overwrite attribute __delattr__'):
2646 @dataclass(frozen=True)
2647 class C:
2648 x: int
2649 def __delattr__(self):
2650 pass
2651
2652 @dataclass(frozen=False)
2653 class C:
2654 x: int
2655 def __setattr__(self, name, value):
2656 self.__dict__['x'] = value * 2
2657 self.assertEqual(C(10).x, 20)
2658
2659 def test_frozen_hash(self):
2660 @dataclass(frozen=True)
2661 class C:
2662 x: Any
2663
2664 # If x is immutable, we can compute the hash. No exception is
2665 # raised.
2666 hash(C(3))
2667
2668 # If x is mutable, computing the hash is an error.
2669 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2670 hash(C({}))
2671
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002672
Eric V. Smith7389fd92018-03-19 21:07:51 -04002673class TestSlots(unittest.TestCase):
2674 def test_simple(self):
2675 @dataclass
2676 class C:
2677 __slots__ = ('x',)
2678 x: Any
2679
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002680 # There was a bug where a variable in a slot was assumed to
2681 # also have a default value (of type
2682 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002683 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002684 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002685 C()
2686
2687 # We can create an instance, and assign to x.
2688 c = C(10)
2689 self.assertEqual(c.x, 10)
2690 c.x = 5
2691 self.assertEqual(c.x, 5)
2692
2693 # We can't assign to anything else.
2694 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2695 c.y = 5
2696
2697 def test_derived_added_field(self):
2698 # See bpo-33100.
2699 @dataclass
2700 class Base:
2701 __slots__ = ('x',)
2702 x: Any
2703
2704 @dataclass
2705 class Derived(Base):
2706 x: int
2707 y: int
2708
2709 d = Derived(1, 2)
2710 self.assertEqual((d.x, d.y), (1, 2))
2711
2712 # We can add a new field to the derived instance.
2713 d.z = 10
2714
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002715class TestDescriptors(unittest.TestCase):
2716 def test_set_name(self):
2717 # See bpo-33141.
2718
2719 # Create a descriptor.
2720 class D:
2721 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002722 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002723 def __get__(self, instance, owner):
2724 if instance is not None:
2725 return 1
2726 return self
2727
2728 # This is the case of just normal descriptor behavior, no
2729 # dataclass code is involved in initializing the descriptor.
2730 @dataclass
2731 class C:
2732 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002733 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002734
2735 # Now test with a default value and init=False, which is the
2736 # only time this is really meaningful. If not using
2737 # init=False, then the descriptor will be overwritten, anyway.
2738 @dataclass
2739 class C:
2740 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002741 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002742 self.assertEqual(C().c, 1)
2743
2744 def test_non_descriptor(self):
2745 # PEP 487 says __set_name__ should work on non-descriptors.
2746 # Create a descriptor.
2747
2748 class D:
2749 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002750 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002751
2752 @dataclass
2753 class C:
2754 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002755 self.assertEqual(C.c.name, 'cx')
2756
2757 def test_lookup_on_instance(self):
2758 # See bpo-33175.
2759 class D:
2760 pass
2761
2762 d = D()
2763 # Create an attribute on the instance, not type.
2764 d.__set_name__ = Mock()
2765
2766 # Make sure d.__set_name__ is not called.
2767 @dataclass
2768 class C:
2769 i: int=field(default=d, init=False)
2770
2771 self.assertEqual(d.__set_name__.call_count, 0)
2772
2773 def test_lookup_on_class(self):
2774 # See bpo-33175.
2775 class D:
2776 pass
2777 D.__set_name__ = Mock()
2778
2779 # Make sure D.__set_name__ is called.
2780 @dataclass
2781 class C:
2782 i: int=field(default=D(), init=False)
2783
2784 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002785
Eric V. Smith7389fd92018-03-19 21:07:51 -04002786
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002787class TestStringAnnotations(unittest.TestCase):
2788 def test_classvar(self):
2789 # Some expressions recognized as ClassVar really aren't. But
2790 # if you're using string annotations, it's not an exact
2791 # science.
2792 # These tests assume that both "import typing" and "from
2793 # typing import *" have been run in this file.
2794 for typestr in ('ClassVar[int]',
2795 'ClassVar [int]'
2796 ' ClassVar [int]',
2797 'ClassVar',
2798 ' ClassVar ',
2799 'typing.ClassVar[int]',
2800 'typing.ClassVar[str]',
2801 ' typing.ClassVar[str]',
2802 'typing .ClassVar[str]',
2803 'typing. ClassVar[str]',
2804 'typing.ClassVar [str]',
2805 'typing.ClassVar [ str]',
2806
2807 # Not syntactically valid, but these will
2808 # be treated as ClassVars.
2809 'typing.ClassVar.[int]',
2810 'typing.ClassVar+',
2811 ):
2812 with self.subTest(typestr=typestr):
2813 @dataclass
2814 class C:
2815 x: typestr
2816
2817 # x is a ClassVar, so C() takes no args.
2818 C()
2819
2820 # And it won't appear in the class's dict because it doesn't
2821 # have a default.
2822 self.assertNotIn('x', C.__dict__)
2823
2824 def test_isnt_classvar(self):
2825 for typestr in ('CV',
2826 't.ClassVar',
2827 't.ClassVar[int]',
2828 'typing..ClassVar[int]',
2829 'Classvar',
2830 'Classvar[int]',
2831 'typing.ClassVarx[int]',
2832 'typong.ClassVar[int]',
2833 'dataclasses.ClassVar[int]',
2834 'typingxClassVar[str]',
2835 ):
2836 with self.subTest(typestr=typestr):
2837 @dataclass
2838 class C:
2839 x: typestr
2840
2841 # x is not a ClassVar, so C() takes one arg.
2842 self.assertEqual(C(10).x, 10)
2843
2844 def test_initvar(self):
2845 # These tests assume that both "import dataclasses" and "from
2846 # dataclasses import *" have been run in this file.
2847 for typestr in ('InitVar[int]',
2848 'InitVar [int]'
2849 ' InitVar [int]',
2850 'InitVar',
2851 ' InitVar ',
2852 'dataclasses.InitVar[int]',
2853 'dataclasses.InitVar[str]',
2854 ' dataclasses.InitVar[str]',
2855 'dataclasses .InitVar[str]',
2856 'dataclasses. InitVar[str]',
2857 'dataclasses.InitVar [str]',
2858 'dataclasses.InitVar [ str]',
2859
2860 # Not syntactically valid, but these will
2861 # be treated as InitVars.
2862 'dataclasses.InitVar.[int]',
2863 'dataclasses.InitVar+',
2864 ):
2865 with self.subTest(typestr=typestr):
2866 @dataclass
2867 class C:
2868 x: typestr
2869
2870 # x is an InitVar, so doesn't create a member.
2871 with self.assertRaisesRegex(AttributeError,
2872 "object has no attribute 'x'"):
2873 C(1).x
2874
2875 def test_isnt_initvar(self):
2876 for typestr in ('IV',
2877 'dc.InitVar',
2878 'xdataclasses.xInitVar',
2879 'typing.xInitVar[int]',
2880 ):
2881 with self.subTest(typestr=typestr):
2882 @dataclass
2883 class C:
2884 x: typestr
2885
2886 # x is not an InitVar, so there will be a member x.
2887 self.assertEqual(C(10).x, 10)
2888
2889 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002890 from test import dataclass_module_1
2891 from test import dataclass_module_1_str
2892 from test import dataclass_module_2
2893 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002894
2895 for m in (dataclass_module_1, dataclass_module_1_str,
2896 dataclass_module_2, dataclass_module_2_str,
2897 ):
2898 with self.subTest(m=m):
2899 # There's a difference in how the ClassVars are
2900 # interpreted when using string annotations or
2901 # not. See the imported modules for details.
2902 if m.USING_STRINGS:
2903 c = m.CV(10)
2904 else:
2905 c = m.CV()
2906 self.assertEqual(c.cv0, 20)
2907
2908
2909 # There's a difference in how the InitVars are
2910 # interpreted when using string annotations or
2911 # not. See the imported modules for details.
2912 c = m.IV(0, 1, 2, 3, 4)
2913
2914 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2915 with self.subTest(field_name=field_name):
2916 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2917 # Since field_name is an InitVar, it's
2918 # not an instance field.
2919 getattr(c, field_name)
2920
2921 if m.USING_STRINGS:
2922 # iv4 is interpreted as a normal field.
2923 self.assertIn('not_iv4', c.__dict__)
2924 self.assertEqual(c.not_iv4, 4)
2925 else:
2926 # iv4 is interpreted as an InitVar, so it
2927 # won't exist on the instance.
2928 self.assertNotIn('not_iv4', c.__dict__)
2929
Yury Selivanovd219cc42019-12-09 09:54:20 -05002930 def test_text_annotations(self):
2931 from test import dataclass_textanno
2932
2933 self.assertEqual(
2934 get_type_hints(dataclass_textanno.Bar),
2935 {'foo': dataclass_textanno.Foo})
2936 self.assertEqual(
2937 get_type_hints(dataclass_textanno.Bar.__init__),
2938 {'foo': dataclass_textanno.Foo,
2939 'return': type(None)})
2940
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002941
Eric V. Smith4e812962018-05-16 11:31:29 -04002942class TestMakeDataclass(unittest.TestCase):
2943 def test_simple(self):
2944 C = make_dataclass('C',
2945 [('x', int),
2946 ('y', int, field(default=5))],
2947 namespace={'add_one': lambda self: self.x + 1})
2948 c = C(10)
2949 self.assertEqual((c.x, c.y), (10, 5))
2950 self.assertEqual(c.add_one(), 11)
2951
2952
2953 def test_no_mutate_namespace(self):
2954 # Make sure a provided namespace isn't mutated.
2955 ns = {}
2956 C = make_dataclass('C',
2957 [('x', int),
2958 ('y', int, field(default=5))],
2959 namespace=ns)
2960 self.assertEqual(ns, {})
2961
2962 def test_base(self):
2963 class Base1:
2964 pass
2965 class Base2:
2966 pass
2967 C = make_dataclass('C',
2968 [('x', int)],
2969 bases=(Base1, Base2))
2970 c = C(2)
2971 self.assertIsInstance(c, C)
2972 self.assertIsInstance(c, Base1)
2973 self.assertIsInstance(c, Base2)
2974
2975 def test_base_dataclass(self):
2976 @dataclass
2977 class Base1:
2978 x: int
2979 class Base2:
2980 pass
2981 C = make_dataclass('C',
2982 [('y', int)],
2983 bases=(Base1, Base2))
2984 with self.assertRaisesRegex(TypeError, 'required positional'):
2985 c = C(2)
2986 c = C(1, 2)
2987 self.assertIsInstance(c, C)
2988 self.assertIsInstance(c, Base1)
2989 self.assertIsInstance(c, Base2)
2990
2991 self.assertEqual((c.x, c.y), (1, 2))
2992
2993 def test_init_var(self):
2994 def post_init(self, y):
2995 self.x *= y
2996
2997 C = make_dataclass('C',
2998 [('x', int),
2999 ('y', InitVar[int]),
3000 ],
3001 namespace={'__post_init__': post_init},
3002 )
3003 c = C(2, 3)
3004 self.assertEqual(vars(c), {'x': 6})
3005 self.assertEqual(len(fields(c)), 1)
3006
3007 def test_class_var(self):
3008 C = make_dataclass('C',
3009 [('x', int),
3010 ('y', ClassVar[int], 10),
3011 ('z', ClassVar[int], field(default=20)),
3012 ])
3013 c = C(1)
3014 self.assertEqual(vars(c), {'x': 1})
3015 self.assertEqual(len(fields(c)), 1)
3016 self.assertEqual(C.y, 10)
3017 self.assertEqual(C.z, 20)
3018
3019 def test_other_params(self):
3020 C = make_dataclass('C',
3021 [('x', int),
3022 ('y', ClassVar[int], 10),
3023 ('z', ClassVar[int], field(default=20)),
3024 ],
3025 init=False)
3026 # Make sure we have a repr, but no init.
3027 self.assertNotIn('__init__', vars(C))
3028 self.assertIn('__repr__', vars(C))
3029
3030 # Make sure random other params don't work.
3031 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3032 C = make_dataclass('C',
3033 [],
3034 xxinit=False)
3035
3036 def test_no_types(self):
3037 C = make_dataclass('Point', ['x', 'y', 'z'])
3038 c = C(1, 2, 3)
3039 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3040 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3041 'y': 'typing.Any',
3042 'z': 'typing.Any'})
3043
3044 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3045 c = C(1, 2, 3)
3046 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3047 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3048 'y': int,
3049 'z': 'typing.Any'})
3050
3051 def test_invalid_type_specification(self):
3052 for bad_field in [(),
3053 (1, 2, 3, 4),
3054 ]:
3055 with self.subTest(bad_field=bad_field):
3056 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3057 make_dataclass('C', ['a', bad_field])
3058
3059 # And test for things with no len().
3060 for bad_field in [float,
3061 lambda x:x,
3062 ]:
3063 with self.subTest(bad_field=bad_field):
3064 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3065 make_dataclass('C', ['a', bad_field])
3066
3067 def test_duplicate_field_names(self):
3068 for field in ['a', 'ab']:
3069 with self.subTest(field=field):
3070 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3071 make_dataclass('C', [field, 'a', field])
3072
3073 def test_keyword_field_names(self):
3074 for field in ['for', 'async', 'await', 'as']:
3075 with self.subTest(field=field):
3076 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3077 make_dataclass('C', ['a', field])
3078 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3079 make_dataclass('C', [field])
3080 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3081 make_dataclass('C', [field, 'a'])
3082
3083 def test_non_identifier_field_names(self):
3084 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3085 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003086 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003087 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003088 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003089 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003090 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003091 make_dataclass('C', [field, 'a'])
3092
3093 def test_underscore_field_names(self):
3094 # Unlike namedtuple, it's okay if dataclass field names have
3095 # an underscore.
3096 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3097
3098 def test_funny_class_names_names(self):
3099 # No reason to prevent weird class names, since
3100 # types.new_class allows them.
3101 for classname in ['()', 'x,y', '*', '2@3', '']:
3102 with self.subTest(classname=classname):
3103 C = make_dataclass(classname, ['a', 'b'])
3104 self.assertEqual(C.__name__, classname)
3105
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003106class TestReplace(unittest.TestCase):
3107 def test(self):
3108 @dataclass(frozen=True)
3109 class C:
3110 x: int
3111 y: int
3112
3113 c = C(1, 2)
3114 c1 = replace(c, x=3)
3115 self.assertEqual(c1.x, 3)
3116 self.assertEqual(c1.y, 2)
3117
3118 def test_frozen(self):
3119 @dataclass(frozen=True)
3120 class C:
3121 x: int
3122 y: int
3123 z: int = field(init=False, default=10)
3124 t: int = field(init=False, default=100)
3125
3126 c = C(1, 2)
3127 c1 = replace(c, x=3)
3128 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3129 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3130
3131
3132 with self.assertRaisesRegex(ValueError, 'init=False'):
3133 replace(c, x=3, z=20, t=50)
3134 with self.assertRaisesRegex(ValueError, 'init=False'):
3135 replace(c, z=20)
3136 replace(c, x=3, z=20, t=50)
3137
3138 # Make sure the result is still frozen.
3139 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3140 c1.x = 3
3141
3142 # Make sure we can't replace an attribute that doesn't exist,
3143 # if we're also replacing one that does exist. Test this
3144 # here, because setting attributes on frozen instances is
3145 # handled slightly differently from non-frozen ones.
3146 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3147 "keyword argument 'a'"):
3148 c1 = replace(c, x=20, a=5)
3149
3150 def test_invalid_field_name(self):
3151 @dataclass(frozen=True)
3152 class C:
3153 x: int
3154 y: int
3155
3156 c = C(1, 2)
3157 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3158 "keyword argument 'z'"):
3159 c1 = replace(c, z=3)
3160
3161 def test_invalid_object(self):
3162 @dataclass(frozen=True)
3163 class C:
3164 x: int
3165 y: int
3166
3167 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3168 replace(C, x=3)
3169
3170 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3171 replace(0, x=3)
3172
3173 def test_no_init(self):
3174 @dataclass
3175 class C:
3176 x: int
3177 y: int = field(init=False, default=10)
3178
3179 c = C(1)
3180 c.y = 20
3181
3182 # Make sure y gets the default value.
3183 c1 = replace(c, x=5)
3184 self.assertEqual((c1.x, c1.y), (5, 10))
3185
3186 # Trying to replace y is an error.
3187 with self.assertRaisesRegex(ValueError, 'init=False'):
3188 replace(c, x=2, y=30)
3189
3190 with self.assertRaisesRegex(ValueError, 'init=False'):
3191 replace(c, y=30)
3192
3193 def test_classvar(self):
3194 @dataclass
3195 class C:
3196 x: int
3197 y: ClassVar[int] = 1000
3198
3199 c = C(1)
3200 d = C(2)
3201
3202 self.assertIs(c.y, d.y)
3203 self.assertEqual(c.y, 1000)
3204
3205 # Trying to replace y is an error: can't replace ClassVars.
3206 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3207 "unexpected keyword argument 'y'"):
3208 replace(c, y=30)
3209
3210 replace(c, x=5)
3211
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003212 def test_initvar_is_specified(self):
3213 @dataclass
3214 class C:
3215 x: int
3216 y: InitVar[int]
3217
3218 def __post_init__(self, y):
3219 self.x *= y
3220
3221 c = C(1, 10)
3222 self.assertEqual(c.x, 10)
3223 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3224 "specified with replace()"):
3225 replace(c, x=3)
3226 c = replace(c, x=3, y=5)
3227 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303228
3229 def test_recursive_repr(self):
3230 @dataclass
3231 class C:
3232 f: "C"
3233
3234 c = C(None)
3235 c.f = c
3236 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3237
3238 def test_recursive_repr_two_attrs(self):
3239 @dataclass
3240 class C:
3241 f: "C"
3242 g: "C"
3243
3244 c = C(None, None)
3245 c.f = c
3246 c.g = c
3247 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3248 ".<locals>.C(f=..., g=...)")
3249
3250 def test_recursive_repr_indirection(self):
3251 @dataclass
3252 class C:
3253 f: "D"
3254
3255 @dataclass
3256 class D:
3257 f: "C"
3258
3259 c = C(None)
3260 d = D(None)
3261 c.f = d
3262 d.f = c
3263 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3264 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3265 ".<locals>.D(f=...))")
3266
3267 def test_recursive_repr_indirection_two(self):
3268 @dataclass
3269 class C:
3270 f: "D"
3271
3272 @dataclass
3273 class D:
3274 f: "E"
3275
3276 @dataclass
3277 class E:
3278 f: "C"
3279
3280 c = C(None)
3281 d = D(None)
3282 e = E(None)
3283 c.f = d
3284 d.f = e
3285 e.f = c
3286 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3287 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3288 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3289 ".<locals>.E(f=...)))")
3290
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303291 def test_recursive_repr_misc_attrs(self):
3292 @dataclass
3293 class C:
3294 f: "C"
3295 g: int
3296
3297 c = C(None, 1)
3298 c.f = c
3299 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3300 ".<locals>.C(f=..., g=1)")
3301
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003302 ## def test_initvar(self):
3303 ## @dataclass
3304 ## class C:
3305 ## x: int
3306 ## y: InitVar[int]
3307
3308 ## c = C(1, 10)
3309 ## d = C(2, 20)
3310
3311 ## # In our case, replacing an InitVar is a no-op
3312 ## self.assertEqual(c, replace(c, y=5))
3313
3314 ## replace(c, x=5)
3315
Eric V. Smith4e812962018-05-16 11:31:29 -04003316
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003317if __name__ == '__main__':
3318 unittest.main()