blob: b20103bdce51cb78d904a97007ee98b8a6853a00 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +03009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Yury Selivanovd219cc42019-12-09 09:54:20 -050013from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050015from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050016
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040017import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
18import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
19
Eric V. Smithf0db54a2017-12-04 16:58:55 -050020# Just any custom exception we can catch.
21class CustomError(Exception): pass
22
23class TestCase(unittest.TestCase):
24 def test_no_fields(self):
25 @dataclass
26 class C:
27 pass
28
29 o = C()
30 self.assertEqual(len(fields(C)), 0)
31
Eric V. Smith56970b82018-03-22 16:28:48 -040032 def test_no_fields_but_member_variable(self):
33 @dataclass
34 class C:
35 i = 0
36
37 o = C()
38 self.assertEqual(len(fields(C)), 0)
39
Eric V. Smithf0db54a2017-12-04 16:58:55 -050040 def test_one_field_no_default(self):
41 @dataclass
42 class C:
43 x: int
44
45 o = C(42)
46 self.assertEqual(o.x, 42)
47
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053048 def test_field_default_default_factory_error(self):
49 msg = "cannot specify both default and default_factory"
50 with self.assertRaisesRegex(ValueError, msg):
51 @dataclass
52 class C:
53 x: int = field(default=1, default_factory=int)
54
55 def test_field_repr(self):
56 int_field = field(default=1, init=True, repr=False)
57 int_field.name = "id"
58 repr_output = repr(int_field)
59 expected_output = "Field(name='id',type=None," \
60 f"default=1,default_factory={MISSING!r}," \
61 "init=True,repr=False,hash=None," \
62 "compare=True,metadata=mappingproxy({})," \
63 "_field_type=None)"
64
65 self.assertEqual(repr_output, expected_output)
66
Eric V. Smithf0db54a2017-12-04 16:58:55 -050067 def test_named_init_params(self):
68 @dataclass
69 class C:
70 x: int
71
72 o = C(x=32)
73 self.assertEqual(o.x, 32)
74
75 def test_two_fields_one_default(self):
76 @dataclass
77 class C:
78 x: int
79 y: int = 0
80
81 o = C(3)
82 self.assertEqual((o.x, o.y), (3, 0))
83
84 # Non-defaults following defaults.
85 with self.assertRaisesRegex(TypeError,
86 "non-default argument 'y' follows "
87 "default argument"):
88 @dataclass
89 class C:
90 x: int = 0
91 y: int
92
93 # A derived class adds a non-default field after a default one.
94 with self.assertRaisesRegex(TypeError,
95 "non-default argument 'y' follows "
96 "default argument"):
97 @dataclass
98 class B:
99 x: int = 0
100
101 @dataclass
102 class C(B):
103 y: int
104
105 # Override a base class field and add a default to
106 # a field which didn't use to have a default.
107 with self.assertRaisesRegex(TypeError,
108 "non-default argument 'y' follows "
109 "default argument"):
110 @dataclass
111 class B:
112 x: int
113 y: int
114
115 @dataclass
116 class C(B):
117 x: int = 0
118
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500119 def test_overwrite_hash(self):
120 # Test that declaring this class isn't an error. It should
121 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500122 @dataclass(frozen=True)
123 class C:
124 x: int
125 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500126 return 301
127 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500128
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500129 # Test that declaring this class isn't an error. It should
130 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500131 @dataclass(frozen=True)
132 class C:
133 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500134 def __eq__(self, other):
135 return False
136 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500137
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500138 # But this one should generate an exception, because with
139 # unsafe_hash=True, it's an error to have a __hash__ defined.
140 with self.assertRaisesRegex(TypeError,
141 'Cannot overwrite attribute __hash__'):
142 @dataclass(unsafe_hash=True)
143 class C:
144 def __hash__(self):
145 pass
146
147 # Creating this class should not generate an exception,
148 # because even though __hash__ exists before @dataclass is
149 # called, (due to __eq__ being defined), since it's None
150 # that's okay.
151 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500152 class C:
153 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500154 def __eq__(self):
155 pass
156 # The generated hash function works as we'd expect.
157 self.assertEqual(hash(C(10)), hash((10,)))
158
159 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400160 # __hash__ exists and is not None, which it would be if it
161 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500162 with self.assertRaisesRegex(TypeError,
163 'Cannot overwrite attribute __hash__'):
164 @dataclass(unsafe_hash=True)
165 class C:
166 x: int
167 def __eq__(self):
168 pass
169 def __hash__(self):
170 pass
171
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500172 def test_overwrite_fields_in_derived_class(self):
173 # Note that x from C1 replaces x in Base, but the order remains
174 # the same as defined in Base.
175 @dataclass
176 class Base:
177 x: Any = 15.0
178 y: int = 0
179
180 @dataclass
181 class C1(Base):
182 z: int = 10
183 x: int = 15
184
185 o = Base()
186 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
187
188 o = C1()
189 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
190
191 o = C1(x=5)
192 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
193
194 def test_field_named_self(self):
195 @dataclass
196 class C:
197 self: str
198 c=C('foo')
199 self.assertEqual(c.self, 'foo')
200
201 # Make sure the first parameter is not named 'self'.
202 sig = inspect.signature(C.__init__)
203 first = next(iter(sig.parameters))
204 self.assertNotEqual('self', first)
205
206 # But we do use 'self' if no field named self.
207 @dataclass
208 class C:
209 selfx: str
210
211 # Make sure the first parameter is named 'self'.
212 sig = inspect.signature(C.__init__)
213 first = next(iter(sig.parameters))
214 self.assertEqual('self', first)
215
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300216 def test_field_named_object(self):
217 @dataclass
218 class C:
219 object: str
220 c = C('foo')
221 self.assertEqual(c.object, 'foo')
222
223 def test_field_named_object_frozen(self):
224 @dataclass(frozen=True)
225 class C:
226 object: str
227 c = C('foo')
228 self.assertEqual(c.object, 'foo')
229
230 def test_field_named_like_builtin(self):
231 # Attribute names can shadow built-in names
232 # since code generation is used.
233 # Ensure that this is not happening.
234 exclusions = {'None', 'True', 'False'}
235 builtins_names = sorted(
236 b for b in builtins.__dict__.keys()
237 if not b.startswith('__') and b not in exclusions
238 )
239 attributes = [(name, str) for name in builtins_names]
240 C = make_dataclass('C', attributes)
241
242 c = C(*[name for name in builtins_names])
243
244 for name in builtins_names:
245 self.assertEqual(getattr(c, name), name)
246
247 def test_field_named_like_builtin_frozen(self):
248 # Attribute names can shadow built-in names
249 # since code generation is used.
250 # Ensure that this is not happening
251 # for frozen data classes.
252 exclusions = {'None', 'True', 'False'}
253 builtins_names = sorted(
254 b for b in builtins.__dict__.keys()
255 if not b.startswith('__') and b not in exclusions
256 )
257 attributes = [(name, str) for name in builtins_names]
258 C = make_dataclass('C', attributes, frozen=True)
259
260 c = C(*[name for name in builtins_names])
261
262 for name in builtins_names:
263 self.assertEqual(getattr(c, name), name)
264
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500265 def test_0_field_compare(self):
266 # Ensure that order=False is the default.
267 @dataclass
268 class C0:
269 pass
270
271 @dataclass(order=False)
272 class C1:
273 pass
274
275 for cls in [C0, C1]:
276 with self.subTest(cls=cls):
277 self.assertEqual(cls(), cls())
278 for idx, fn in enumerate([lambda a, b: a < b,
279 lambda a, b: a <= b,
280 lambda a, b: a > b,
281 lambda a, b: a >= b]):
282 with self.subTest(idx=idx):
283 with self.assertRaisesRegex(TypeError,
284 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
285 fn(cls(), cls())
286
287 @dataclass(order=True)
288 class C:
289 pass
290 self.assertLessEqual(C(), C())
291 self.assertGreaterEqual(C(), C())
292
293 def test_1_field_compare(self):
294 # Ensure that order=False is the default.
295 @dataclass
296 class C0:
297 x: int
298
299 @dataclass(order=False)
300 class C1:
301 x: int
302
303 for cls in [C0, C1]:
304 with self.subTest(cls=cls):
305 self.assertEqual(cls(1), cls(1))
306 self.assertNotEqual(cls(0), cls(1))
307 for idx, fn in enumerate([lambda a, b: a < b,
308 lambda a, b: a <= b,
309 lambda a, b: a > b,
310 lambda a, b: a >= b]):
311 with self.subTest(idx=idx):
312 with self.assertRaisesRegex(TypeError,
313 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
314 fn(cls(0), cls(0))
315
316 @dataclass(order=True)
317 class C:
318 x: int
319 self.assertLess(C(0), C(1))
320 self.assertLessEqual(C(0), C(1))
321 self.assertLessEqual(C(1), C(1))
322 self.assertGreater(C(1), C(0))
323 self.assertGreaterEqual(C(1), C(0))
324 self.assertGreaterEqual(C(1), C(1))
325
326 def test_simple_compare(self):
327 # Ensure that order=False is the default.
328 @dataclass
329 class C0:
330 x: int
331 y: int
332
333 @dataclass(order=False)
334 class C1:
335 x: int
336 y: int
337
338 for cls in [C0, C1]:
339 with self.subTest(cls=cls):
340 self.assertEqual(cls(0, 0), cls(0, 0))
341 self.assertEqual(cls(1, 2), cls(1, 2))
342 self.assertNotEqual(cls(1, 0), cls(0, 0))
343 self.assertNotEqual(cls(1, 0), cls(1, 1))
344 for idx, fn in enumerate([lambda a, b: a < b,
345 lambda a, b: a <= b,
346 lambda a, b: a > b,
347 lambda a, b: a >= b]):
348 with self.subTest(idx=idx):
349 with self.assertRaisesRegex(TypeError,
350 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
351 fn(cls(0, 0), cls(0, 0))
352
353 @dataclass(order=True)
354 class C:
355 x: int
356 y: int
357
358 for idx, fn in enumerate([lambda a, b: a == b,
359 lambda a, b: a <= b,
360 lambda a, b: a >= b]):
361 with self.subTest(idx=idx):
362 self.assertTrue(fn(C(0, 0), C(0, 0)))
363
364 for idx, fn in enumerate([lambda a, b: a < b,
365 lambda a, b: a <= b,
366 lambda a, b: a != b]):
367 with self.subTest(idx=idx):
368 self.assertTrue(fn(C(0, 0), C(0, 1)))
369 self.assertTrue(fn(C(0, 1), C(1, 0)))
370 self.assertTrue(fn(C(1, 0), C(1, 1)))
371
372 for idx, fn in enumerate([lambda a, b: a > b,
373 lambda a, b: a >= b,
374 lambda a, b: a != b]):
375 with self.subTest(idx=idx):
376 self.assertTrue(fn(C(0, 1), C(0, 0)))
377 self.assertTrue(fn(C(1, 0), C(0, 1)))
378 self.assertTrue(fn(C(1, 1), C(1, 0)))
379
380 def test_compare_subclasses(self):
381 # Comparisons fail for subclasses, even if no fields
382 # are added.
383 @dataclass
384 class B:
385 i: int
386
387 @dataclass
388 class C(B):
389 pass
390
391 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
392 (lambda a, b: a != b, True)]):
393 with self.subTest(idx=idx):
394 self.assertEqual(fn(B(0), C(0)), expected)
395
396 for idx, fn in enumerate([lambda a, b: a < b,
397 lambda a, b: a <= b,
398 lambda a, b: a > b,
399 lambda a, b: a >= b]):
400 with self.subTest(idx=idx):
401 with self.assertRaisesRegex(TypeError,
402 "not supported between instances of 'B' and 'C'"):
403 fn(B(0), C(0))
404
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500405 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500406 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500407 for (eq, order, result ) in [
408 (False, False, 'neither'),
409 (False, True, 'exception'),
410 (True, False, 'eq_only'),
411 (True, True, 'both'),
412 ]:
413 with self.subTest(eq=eq, order=order):
414 if result == 'exception':
415 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
416 @dataclass(eq=eq, order=order)
417 class C:
418 pass
419 else:
420 @dataclass(eq=eq, order=order)
421 class C:
422 pass
423
424 if result == 'neither':
425 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500426 self.assertNotIn('__lt__', C.__dict__)
427 self.assertNotIn('__le__', C.__dict__)
428 self.assertNotIn('__gt__', C.__dict__)
429 self.assertNotIn('__ge__', C.__dict__)
430 elif result == 'both':
431 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500432 self.assertIn('__lt__', C.__dict__)
433 self.assertIn('__le__', C.__dict__)
434 self.assertIn('__gt__', C.__dict__)
435 self.assertIn('__ge__', C.__dict__)
436 elif result == 'eq_only':
437 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500438 self.assertNotIn('__lt__', C.__dict__)
439 self.assertNotIn('__le__', C.__dict__)
440 self.assertNotIn('__gt__', C.__dict__)
441 self.assertNotIn('__ge__', C.__dict__)
442 else:
443 assert False, f'unknown result {result!r}'
444
445 def test_field_no_default(self):
446 @dataclass
447 class C:
448 x: int = field()
449
450 self.assertEqual(C(5).x, 5)
451
452 with self.assertRaisesRegex(TypeError,
453 r"__init__\(\) missing 1 required "
454 "positional argument: 'x'"):
455 C()
456
457 def test_field_default(self):
458 default = object()
459 @dataclass
460 class C:
461 x: object = field(default=default)
462
463 self.assertIs(C.x, default)
464 c = C(10)
465 self.assertEqual(c.x, 10)
466
467 # If we delete the instance attribute, we should then see the
468 # class attribute.
469 del c.x
470 self.assertIs(c.x, default)
471
472 self.assertIs(C().x, default)
473
474 def test_not_in_repr(self):
475 @dataclass
476 class C:
477 x: int = field(repr=False)
478 with self.assertRaises(TypeError):
479 C()
480 c = C(10)
481 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
482
483 @dataclass
484 class C:
485 x: int = field(repr=False)
486 y: int
487 c = C(10, 20)
488 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
489
490 def test_not_in_compare(self):
491 @dataclass
492 class C:
493 x: int = 0
494 y: int = field(compare=False, default=4)
495
496 self.assertEqual(C(), C(0, 20))
497 self.assertEqual(C(1, 10), C(1, 20))
498 self.assertNotEqual(C(3), C(4, 10))
499 self.assertNotEqual(C(3, 10), C(4, 10))
500
501 def test_hash_field_rules(self):
502 # Test all 6 cases of:
503 # hash=True/False/None
504 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500505 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500506 (True, False, 'field' ),
507 (True, True, 'field' ),
508 (False, False, 'absent'),
509 (False, True, 'absent'),
510 (None, False, 'absent'),
511 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500512 ]:
513 with self.subTest(hash=hash_, compare=compare):
514 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500515 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500516 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500517
518 if result == 'field':
519 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500520 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500521 elif result == 'absent':
522 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500523 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500524 else:
525 assert False, f'unknown result {result!r}'
526
527 def test_init_false_no_default(self):
528 # If init=False and no default value, then the field won't be
529 # present in the instance.
530 @dataclass
531 class C:
532 x: int = field(init=False)
533
534 self.assertNotIn('x', C().__dict__)
535
536 @dataclass
537 class C:
538 x: int
539 y: int = 0
540 z: int = field(init=False)
541 t: int = 10
542
543 self.assertNotIn('z', C(0).__dict__)
544 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
545
546 def test_class_marker(self):
547 @dataclass
548 class C:
549 x: int
550 y: str = field(init=False, default=None)
551 z: str = field(repr=False)
552
553 the_fields = fields(C)
554 # the_fields is a tuple of 3 items, each value
555 # is in __annotations__.
556 self.assertIsInstance(the_fields, tuple)
557 for f in the_fields:
558 self.assertIs(type(f), Field)
559 self.assertIn(f.name, C.__annotations__)
560
561 self.assertEqual(len(the_fields), 3)
562
563 self.assertEqual(the_fields[0].name, 'x')
564 self.assertEqual(the_fields[0].type, int)
565 self.assertFalse(hasattr(C, 'x'))
566 self.assertTrue (the_fields[0].init)
567 self.assertTrue (the_fields[0].repr)
568 self.assertEqual(the_fields[1].name, 'y')
569 self.assertEqual(the_fields[1].type, str)
570 self.assertIsNone(getattr(C, 'y'))
571 self.assertFalse(the_fields[1].init)
572 self.assertTrue (the_fields[1].repr)
573 self.assertEqual(the_fields[2].name, 'z')
574 self.assertEqual(the_fields[2].type, str)
575 self.assertFalse(hasattr(C, 'z'))
576 self.assertTrue (the_fields[2].init)
577 self.assertFalse(the_fields[2].repr)
578
579 def test_field_order(self):
580 @dataclass
581 class B:
582 a: str = 'B:a'
583 b: str = 'B:b'
584 c: str = 'B:c'
585
586 @dataclass
587 class C(B):
588 b: str = 'C:b'
589
590 self.assertEqual([(f.name, f.default) for f in fields(C)],
591 [('a', 'B:a'),
592 ('b', 'C:b'),
593 ('c', 'B:c')])
594
595 @dataclass
596 class D(B):
597 c: str = 'D:c'
598
599 self.assertEqual([(f.name, f.default) for f in fields(D)],
600 [('a', 'B:a'),
601 ('b', 'B:b'),
602 ('c', 'D:c')])
603
604 @dataclass
605 class E(D):
606 a: str = 'E:a'
607 d: str = 'E:d'
608
609 self.assertEqual([(f.name, f.default) for f in fields(E)],
610 [('a', 'E:a'),
611 ('b', 'B:b'),
612 ('c', 'D:c'),
613 ('d', 'E:d')])
614
615 def test_class_attrs(self):
616 # We only have a class attribute if a default value is
617 # specified, either directly or via a field with a default.
618 default = object()
619 @dataclass
620 class C:
621 x: int
622 y: int = field(repr=False)
623 z: object = default
624 t: int = field(default=100)
625
626 self.assertFalse(hasattr(C, 'x'))
627 self.assertFalse(hasattr(C, 'y'))
628 self.assertIs (C.z, default)
629 self.assertEqual(C.t, 100)
630
631 def test_disallowed_mutable_defaults(self):
632 # For the known types, don't allow mutable default values.
633 for typ, empty, non_empty in [(list, [], [1]),
634 (dict, {}, {0:1}),
635 (set, set(), set([1])),
636 ]:
637 with self.subTest(typ=typ):
638 # Can't use a zero-length value.
639 with self.assertRaisesRegex(ValueError,
640 f'mutable default {typ} for field '
641 'x is not allowed'):
642 @dataclass
643 class Point:
644 x: typ = empty
645
646
647 # Nor a non-zero-length value
648 with self.assertRaisesRegex(ValueError,
649 f'mutable default {typ} for field '
650 'y is not allowed'):
651 @dataclass
652 class Point:
653 y: typ = non_empty
654
655 # Check subtypes also fail.
656 class Subclass(typ): pass
657
658 with self.assertRaisesRegex(ValueError,
659 f"mutable default .*Subclass'>"
660 ' for field z is not allowed'
661 ):
662 @dataclass
663 class Point:
664 z: typ = Subclass()
665
666 # Because this is a ClassVar, it can be mutable.
667 @dataclass
668 class C:
669 z: ClassVar[typ] = typ()
670
671 # Because this is a ClassVar, it can be mutable.
672 @dataclass
673 class C:
674 x: ClassVar[typ] = Subclass()
675
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500676 def test_deliberately_mutable_defaults(self):
677 # If a mutable default isn't in the known list of
678 # (list, dict, set), then it's okay.
679 class Mutable:
680 def __init__(self):
681 self.l = []
682
683 @dataclass
684 class C:
685 x: Mutable
686
687 # These 2 instances will share this value of x.
688 lst = Mutable()
689 o1 = C(lst)
690 o2 = C(lst)
691 self.assertEqual(o1, o2)
692 o1.x.l.extend([1, 2])
693 self.assertEqual(o1, o2)
694 self.assertEqual(o1.x.l, [1, 2])
695 self.assertIs(o1.x, o2.x)
696
697 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400698 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500699 @dataclass()
700 class C:
701 x: int
702
703 self.assertEqual(C(42).x, 42)
704
705 def test_not_tuple(self):
706 # Make sure we can't be compared to a tuple.
707 @dataclass
708 class Point:
709 x: int
710 y: int
711 self.assertNotEqual(Point(1, 2), (1, 2))
712
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400713 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500714 @dataclass
715 class C:
716 x: int
717 y: int
718 self.assertNotEqual(Point(1, 3), C(1, 3))
719
Windson yangbe372d72019-04-23 02:45:34 +0800720 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500721 # Test that some of the problems with namedtuple don't happen
722 # here.
723 @dataclass
724 class Point3D:
725 x: int
726 y: int
727 z: int
728
729 @dataclass
730 class Date:
731 year: int
732 month: int
733 day: int
734
735 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
736 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
737
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400738 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200739 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500740 x, y, z = Point3D(4, 5, 6)
741
Eric V. Smith7c99e932018-01-28 19:18:55 -0500742 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500743 # equal.
744 @dataclass
745 class Point3Dv1:
746 x: int = 0
747 y: int = 0
748 z: int = 0
749 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
750
751 def test_function_annotations(self):
752 # Some dummy class and instance to use as a default.
753 class F:
754 pass
755 f = F()
756
757 def validate_class(cls):
758 # First, check __annotations__, even though they're not
759 # function annotations.
760 self.assertEqual(cls.__annotations__['i'], int)
761 self.assertEqual(cls.__annotations__['j'], str)
762 self.assertEqual(cls.__annotations__['k'], F)
763 self.assertEqual(cls.__annotations__['l'], float)
764 self.assertEqual(cls.__annotations__['z'], complex)
765
766 # Verify __init__.
767
768 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400769 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500770 self.assertIs(signature.return_annotation, None)
771
772 # Check each parameter.
773 params = iter(signature.parameters.values())
774 param = next(params)
775 # This is testing an internal name, and probably shouldn't be tested.
776 self.assertEqual(param.name, 'self')
777 param = next(params)
778 self.assertEqual(param.name, 'i')
779 self.assertIs (param.annotation, int)
780 self.assertEqual(param.default, inspect.Parameter.empty)
781 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
782 param = next(params)
783 self.assertEqual(param.name, 'j')
784 self.assertIs (param.annotation, str)
785 self.assertEqual(param.default, inspect.Parameter.empty)
786 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
787 param = next(params)
788 self.assertEqual(param.name, 'k')
789 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400790 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500791 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
792 param = next(params)
793 self.assertEqual(param.name, 'l')
794 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400795 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500796 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
797 self.assertRaises(StopIteration, next, params)
798
799
800 @dataclass
801 class C:
802 i: int
803 j: str
804 k: F = f
805 l: float=field(default=None)
806 z: complex=field(default=3+4j, init=False)
807
808 validate_class(C)
809
810 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500811 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500812 class C:
813 i: int
814 j: str
815 k: F = f
816 l: float=field(default=None)
817 z: complex=field(default=3+4j, init=False)
818
819 validate_class(C)
820
Eric V. Smith03220fd2017-12-29 13:59:58 -0500821 def test_missing_default(self):
822 # Test that MISSING works the same as a default not being
823 # specified.
824 @dataclass
825 class C:
826 x: int=field(default=MISSING)
827 with self.assertRaisesRegex(TypeError,
828 r'__init__\(\) missing 1 required '
829 'positional argument'):
830 C()
831 self.assertNotIn('x', C.__dict__)
832
833 @dataclass
834 class D:
835 x: int
836 with self.assertRaisesRegex(TypeError,
837 r'__init__\(\) missing 1 required '
838 'positional argument'):
839 D()
840 self.assertNotIn('x', D.__dict__)
841
842 def test_missing_default_factory(self):
843 # Test that MISSING works the same as a default factory not
844 # being specified (which is really the same as a default not
845 # being specified, too).
846 @dataclass
847 class C:
848 x: int=field(default_factory=MISSING)
849 with self.assertRaisesRegex(TypeError,
850 r'__init__\(\) missing 1 required '
851 'positional argument'):
852 C()
853 self.assertNotIn('x', C.__dict__)
854
855 @dataclass
856 class D:
857 x: int=field(default=MISSING, default_factory=MISSING)
858 with self.assertRaisesRegex(TypeError,
859 r'__init__\(\) missing 1 required '
860 'positional argument'):
861 D()
862 self.assertNotIn('x', D.__dict__)
863
864 def test_missing_repr(self):
865 self.assertIn('MISSING_TYPE object', repr(MISSING))
866
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500867 def test_dont_include_other_annotations(self):
868 @dataclass
869 class C:
870 i: int
871 def foo(self) -> int:
872 return 4
873 @property
874 def bar(self) -> int:
875 return 5
876 self.assertEqual(list(C.__annotations__), ['i'])
877 self.assertEqual(C(10).foo(), 4)
878 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400879 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500880
881 def test_post_init(self):
882 # Just make sure it gets called
883 @dataclass
884 class C:
885 def __post_init__(self):
886 raise CustomError()
887 with self.assertRaises(CustomError):
888 C()
889
890 @dataclass
891 class C:
892 i: int = 10
893 def __post_init__(self):
894 if self.i == 10:
895 raise CustomError()
896 with self.assertRaises(CustomError):
897 C()
898 # post-init gets called, but doesn't raise. This is just
899 # checking that self is used correctly.
900 C(5)
901
902 # If there's not an __init__, then post-init won't get called.
903 @dataclass(init=False)
904 class C:
905 def __post_init__(self):
906 raise CustomError()
907 # Creating the class won't raise
908 C()
909
910 @dataclass
911 class C:
912 x: int = 0
913 def __post_init__(self):
914 self.x *= 2
915 self.assertEqual(C().x, 0)
916 self.assertEqual(C(2).x, 4)
917
Mike53f7a7c2017-12-14 14:04:53 +0300918 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500919 # attributes.
920 @dataclass(frozen=True)
921 class C:
922 x: int = 0
923 def __post_init__(self):
924 self.x *= 2
925 with self.assertRaises(FrozenInstanceError):
926 C()
927
928 def test_post_init_super(self):
929 # Make sure super() post-init isn't called by default.
930 class B:
931 def __post_init__(self):
932 raise CustomError()
933
934 @dataclass
935 class C(B):
936 def __post_init__(self):
937 self.x = 5
938
939 self.assertEqual(C().x, 5)
940
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400941 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500942 @dataclass
943 class C(B):
944 def __post_init__(self):
945 super().__post_init__()
946
947 with self.assertRaises(CustomError):
948 C()
949
950 # Make sure post-init is called, even if not defined in our
951 # class.
952 @dataclass
953 class C(B):
954 pass
955
956 with self.assertRaises(CustomError):
957 C()
958
959 def test_post_init_staticmethod(self):
960 flag = False
961 @dataclass
962 class C:
963 x: int
964 y: int
965 @staticmethod
966 def __post_init__():
967 nonlocal flag
968 flag = True
969
970 self.assertFalse(flag)
971 c = C(3, 4)
972 self.assertEqual((c.x, c.y), (3, 4))
973 self.assertTrue(flag)
974
975 def test_post_init_classmethod(self):
976 @dataclass
977 class C:
978 flag = False
979 x: int
980 y: int
981 @classmethod
982 def __post_init__(cls):
983 cls.flag = True
984
985 self.assertFalse(C.flag)
986 c = C(3, 4)
987 self.assertEqual((c.x, c.y), (3, 4))
988 self.assertTrue(C.flag)
989
990 def test_class_var(self):
991 # Make sure ClassVars are ignored in __init__, __repr__, etc.
992 @dataclass
993 class C:
994 x: int
995 y: int = 10
996 z: ClassVar[int] = 1000
997 w: ClassVar[int] = 2000
998 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400999 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001000
1001 c = C(5)
1002 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001003 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001004 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001005 self.assertEqual(c.z, 1000)
1006 self.assertEqual(c.w, 2000)
1007 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001008 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001009 C.z += 1
1010 self.assertEqual(c.z, 1001)
1011 c = C(20)
1012 self.assertEqual((c.x, c.y), (20, 10))
1013 self.assertEqual(c.z, 1001)
1014 self.assertEqual(c.w, 2000)
1015 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001016 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001017
1018 def test_class_var_no_default(self):
1019 # If a ClassVar has no default value, it should not be set on the class.
1020 @dataclass
1021 class C:
1022 x: ClassVar[int]
1023
1024 self.assertNotIn('x', C.__dict__)
1025
1026 def test_class_var_default_factory(self):
1027 # It makes no sense for a ClassVar to have a default factory. When
1028 # would it be called? Call it yourself, since it's class-wide.
1029 with self.assertRaisesRegex(TypeError,
1030 'cannot have a default factory'):
1031 @dataclass
1032 class C:
1033 x: ClassVar[int] = field(default_factory=int)
1034
1035 self.assertNotIn('x', C.__dict__)
1036
1037 def test_class_var_with_default(self):
1038 # If a ClassVar has a default value, it should be set on the class.
1039 @dataclass
1040 class C:
1041 x: ClassVar[int] = 10
1042 self.assertEqual(C.x, 10)
1043
1044 @dataclass
1045 class C:
1046 x: ClassVar[int] = field(default=10)
1047 self.assertEqual(C.x, 10)
1048
1049 def test_class_var_frozen(self):
1050 # Make sure ClassVars work even if we're frozen.
1051 @dataclass(frozen=True)
1052 class C:
1053 x: int
1054 y: int = 10
1055 z: ClassVar[int] = 1000
1056 w: ClassVar[int] = 2000
1057 t: ClassVar[int] = 3000
1058
1059 c = C(5)
1060 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1061 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1062 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1063 self.assertEqual(c.z, 1000)
1064 self.assertEqual(c.w, 2000)
1065 self.assertEqual(c.t, 3000)
1066 # We can still modify the ClassVar, it's only instances that are
1067 # frozen.
1068 C.z += 1
1069 self.assertEqual(c.z, 1001)
1070 c = C(20)
1071 self.assertEqual((c.x, c.y), (20, 10))
1072 self.assertEqual(c.z, 1001)
1073 self.assertEqual(c.w, 2000)
1074 self.assertEqual(c.t, 3000)
1075
1076 def test_init_var_no_default(self):
1077 # If an InitVar has no default value, it should not be set on the class.
1078 @dataclass
1079 class C:
1080 x: InitVar[int]
1081
1082 self.assertNotIn('x', C.__dict__)
1083
1084 def test_init_var_default_factory(self):
1085 # It makes no sense for an InitVar to have a default factory. When
1086 # would it be called? Call it yourself, since it's class-wide.
1087 with self.assertRaisesRegex(TypeError,
1088 'cannot have a default factory'):
1089 @dataclass
1090 class C:
1091 x: InitVar[int] = field(default_factory=int)
1092
1093 self.assertNotIn('x', C.__dict__)
1094
1095 def test_init_var_with_default(self):
1096 # If an InitVar has a default value, it should be set on the class.
1097 @dataclass
1098 class C:
1099 x: InitVar[int] = 10
1100 self.assertEqual(C.x, 10)
1101
1102 @dataclass
1103 class C:
1104 x: InitVar[int] = field(default=10)
1105 self.assertEqual(C.x, 10)
1106
1107 def test_init_var(self):
1108 @dataclass
1109 class C:
1110 x: int = None
1111 init_param: InitVar[int] = None
1112
1113 def __post_init__(self, init_param):
1114 if self.x is None:
1115 self.x = init_param*2
1116
1117 c = C(init_param=10)
1118 self.assertEqual(c.x, 20)
1119
Augusto Hack01ee12b2019-06-02 23:14:48 -03001120 def test_init_var_preserve_type(self):
1121 self.assertEqual(InitVar[int].type, int)
1122
1123 # Make sure the repr is correct.
1124 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001125 self.assertEqual(repr(InitVar[List[int]]),
1126 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001127
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001128 def test_init_var_inheritance(self):
1129 # Note that this deliberately tests that a dataclass need not
1130 # have a __post_init__ function if it has an InitVar field.
1131 # It could just be used in a derived class, as shown here.
1132 @dataclass
1133 class Base:
1134 x: int
1135 init_base: InitVar[int]
1136
1137 # We can instantiate by passing the InitVar, even though
1138 # it's not used.
1139 b = Base(0, 10)
1140 self.assertEqual(vars(b), {'x': 0})
1141
1142 @dataclass
1143 class C(Base):
1144 y: int
1145 init_derived: InitVar[int]
1146
1147 def __post_init__(self, init_base, init_derived):
1148 self.x = self.x + init_base
1149 self.y = self.y + init_derived
1150
1151 c = C(10, 11, 50, 51)
1152 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1153
1154 def test_default_factory(self):
1155 # Test a factory that returns a new list.
1156 @dataclass
1157 class C:
1158 x: int
1159 y: list = field(default_factory=list)
1160
1161 c0 = C(3)
1162 c1 = C(3)
1163 self.assertEqual(c0.x, 3)
1164 self.assertEqual(c0.y, [])
1165 self.assertEqual(c0, c1)
1166 self.assertIsNot(c0.y, c1.y)
1167 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1168
1169 # Test a factory that returns a shared list.
1170 l = []
1171 @dataclass
1172 class C:
1173 x: int
1174 y: list = field(default_factory=lambda: l)
1175
1176 c0 = C(3)
1177 c1 = C(3)
1178 self.assertEqual(c0.x, 3)
1179 self.assertEqual(c0.y, [])
1180 self.assertEqual(c0, c1)
1181 self.assertIs(c0.y, c1.y)
1182 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1183
1184 # Test various other field flags.
1185 # repr
1186 @dataclass
1187 class C:
1188 x: list = field(default_factory=list, repr=False)
1189 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1190 self.assertEqual(C().x, [])
1191
1192 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001193 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001194 class C:
1195 x: list = field(default_factory=list, hash=False)
1196 self.assertEqual(astuple(C()), ([],))
1197 self.assertEqual(hash(C()), hash(()))
1198
1199 # init (see also test_default_factory_with_no_init)
1200 @dataclass
1201 class C:
1202 x: list = field(default_factory=list, init=False)
1203 self.assertEqual(astuple(C()), ([],))
1204
1205 # compare
1206 @dataclass
1207 class C:
1208 x: list = field(default_factory=list, compare=False)
1209 self.assertEqual(C(), C([1]))
1210
1211 def test_default_factory_with_no_init(self):
1212 # We need a factory with a side effect.
1213 factory = Mock()
1214
1215 @dataclass
1216 class C:
1217 x: list = field(default_factory=factory, init=False)
1218
1219 # Make sure the default factory is called for each new instance.
1220 C().x
1221 self.assertEqual(factory.call_count, 1)
1222 C().x
1223 self.assertEqual(factory.call_count, 2)
1224
1225 def test_default_factory_not_called_if_value_given(self):
1226 # We need a factory that we can test if it's been called.
1227 factory = Mock()
1228
1229 @dataclass
1230 class C:
1231 x: int = field(default_factory=factory)
1232
1233 # Make sure that if a field has a default factory function,
1234 # it's not called if a value is specified.
1235 C().x
1236 self.assertEqual(factory.call_count, 1)
1237 self.assertEqual(C(10).x, 10)
1238 self.assertEqual(factory.call_count, 1)
1239 C().x
1240 self.assertEqual(factory.call_count, 2)
1241
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001242 def test_default_factory_derived(self):
1243 # See bpo-32896.
1244 @dataclass
1245 class Foo:
1246 x: dict = field(default_factory=dict)
1247
1248 @dataclass
1249 class Bar(Foo):
1250 y: int = 1
1251
1252 self.assertEqual(Foo().x, {})
1253 self.assertEqual(Bar().x, {})
1254 self.assertEqual(Bar().y, 1)
1255
1256 @dataclass
1257 class Baz(Foo):
1258 pass
1259 self.assertEqual(Baz().x, {})
1260
1261 def test_intermediate_non_dataclass(self):
1262 # Test that an intermediate class that defines
1263 # annotations does not define fields.
1264
1265 @dataclass
1266 class A:
1267 x: int
1268
1269 class B(A):
1270 y: int
1271
1272 @dataclass
1273 class C(B):
1274 z: int
1275
1276 c = C(1, 3)
1277 self.assertEqual((c.x, c.z), (1, 3))
1278
1279 # .y was not initialized.
1280 with self.assertRaisesRegex(AttributeError,
1281 'object has no attribute'):
1282 c.y
1283
1284 # And if we again derive a non-dataclass, no fields are added.
1285 class D(C):
1286 t: int
1287 d = D(4, 5)
1288 self.assertEqual((d.x, d.z), (4, 5))
1289
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001290 def test_classvar_default_factory(self):
1291 # It's an error for a ClassVar to have a factory function.
1292 with self.assertRaisesRegex(TypeError,
1293 'cannot have a default factory'):
1294 @dataclass
1295 class C:
1296 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001297
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001298 def test_is_dataclass(self):
1299 class NotDataClass:
1300 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001301
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001302 self.assertFalse(is_dataclass(0))
1303 self.assertFalse(is_dataclass(int))
1304 self.assertFalse(is_dataclass(NotDataClass))
1305 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001306
1307 @dataclass
1308 class C:
1309 x: int
1310
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001311 @dataclass
1312 class D:
1313 d: C
1314 e: int
1315
1316 c = C(10)
1317 d = D(c, 4)
1318
1319 self.assertTrue(is_dataclass(C))
1320 self.assertTrue(is_dataclass(c))
1321 self.assertFalse(is_dataclass(c.x))
1322 self.assertTrue(is_dataclass(d.d))
1323 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001324
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001325 def test_is_dataclass_when_getattr_always_returns(self):
1326 # See bpo-37868.
1327 class A:
1328 def __getattr__(self, key):
1329 return 0
1330 self.assertFalse(is_dataclass(A))
1331 a = A()
1332
1333 # Also test for an instance attribute.
1334 class B:
1335 pass
1336 b = B()
1337 b.__dataclass_fields__ = []
1338
1339 for obj in a, b:
1340 with self.subTest(obj=obj):
1341 self.assertFalse(is_dataclass(obj))
1342
1343 # Indirect tests for _is_dataclass_instance().
1344 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1345 asdict(obj)
1346 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1347 astuple(obj)
1348 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1349 replace(obj, x=0)
1350
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001351 def test_helper_fields_with_class_instance(self):
1352 # Check that we can call fields() on either a class or instance,
1353 # and get back the same thing.
1354 @dataclass
1355 class C:
1356 x: int
1357 y: float
1358
1359 self.assertEqual(fields(C), fields(C(0, 0.0)))
1360
1361 def test_helper_fields_exception(self):
1362 # Check that TypeError is raised if not passed a dataclass or
1363 # instance.
1364 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1365 fields(0)
1366
1367 class C: pass
1368 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1369 fields(C)
1370 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1371 fields(C())
1372
1373 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001374 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001375 @dataclass
1376 class C:
1377 x: int
1378 y: int
1379 c = C(1, 2)
1380
1381 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1382 self.assertEqual(asdict(c), asdict(c))
1383 self.assertIsNot(asdict(c), asdict(c))
1384 c.x = 42
1385 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1386 self.assertIs(type(asdict(c)), dict)
1387
1388 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001389 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001390 @dataclass
1391 class C:
1392 x: int
1393 y: int
1394 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1395 asdict(C)
1396 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1397 asdict(int)
1398
1399 def test_helper_asdict_copy_values(self):
1400 @dataclass
1401 class C:
1402 x: int
1403 y: List[int] = field(default_factory=list)
1404 initial = []
1405 c = C(1, initial)
1406 d = asdict(c)
1407 self.assertEqual(d['y'], initial)
1408 self.assertIsNot(d['y'], initial)
1409 c = C(1)
1410 d = asdict(c)
1411 d['y'].append(1)
1412 self.assertEqual(c.y, [])
1413
1414 def test_helper_asdict_nested(self):
1415 @dataclass
1416 class UserId:
1417 token: int
1418 group: int
1419 @dataclass
1420 class User:
1421 name: str
1422 id: UserId
1423 u = User('Joe', UserId(123, 1))
1424 d = asdict(u)
1425 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1426 self.assertIsNot(asdict(u), asdict(u))
1427 u.id.group = 2
1428 self.assertEqual(asdict(u), {'name': 'Joe',
1429 'id': {'token': 123, 'group': 2}})
1430
1431 def test_helper_asdict_builtin_containers(self):
1432 @dataclass
1433 class User:
1434 name: str
1435 id: int
1436 @dataclass
1437 class GroupList:
1438 id: int
1439 users: List[User]
1440 @dataclass
1441 class GroupTuple:
1442 id: int
1443 users: Tuple[User, ...]
1444 @dataclass
1445 class GroupDict:
1446 id: int
1447 users: Dict[str, User]
1448 a = User('Alice', 1)
1449 b = User('Bob', 2)
1450 gl = GroupList(0, [a, b])
1451 gt = GroupTuple(0, (a, b))
1452 gd = GroupDict(0, {'first': a, 'second': b})
1453 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1454 {'name': 'Bob', 'id': 2}]})
1455 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1456 {'name': 'Bob', 'id': 2})})
1457 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1458 'second': {'name': 'Bob', 'id': 2}}})
1459
Windson yangbe372d72019-04-23 02:45:34 +08001460 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001461 @dataclass
1462 class Child:
1463 d: object
1464
1465 @dataclass
1466 class Parent:
1467 child: Child
1468
1469 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1470 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1471
1472 def test_helper_asdict_factory(self):
1473 @dataclass
1474 class C:
1475 x: int
1476 y: int
1477 c = C(1, 2)
1478 d = asdict(c, dict_factory=OrderedDict)
1479 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1480 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1481 c.x = 42
1482 d = asdict(c, dict_factory=OrderedDict)
1483 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1484 self.assertIs(type(d), OrderedDict)
1485
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001486 def test_helper_asdict_namedtuple(self):
1487 T = namedtuple('T', 'a b c')
1488 @dataclass
1489 class C:
1490 x: str
1491 y: T
1492 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1493
1494 d = asdict(c)
1495 self.assertEqual(d, {'x': 'outer',
1496 'y': T(1,
1497 {'x': 'inner',
1498 'y': T(11, 12, 13)},
1499 2),
1500 }
1501 )
1502
1503 # Now with a dict_factory. OrderedDict is convenient, but
1504 # since it compares to dicts, we also need to have separate
1505 # assertIs tests.
1506 d = asdict(c, dict_factory=OrderedDict)
1507 self.assertEqual(d, {'x': 'outer',
1508 'y': T(1,
1509 {'x': 'inner',
1510 'y': T(11, 12, 13)},
1511 2),
1512 }
1513 )
1514
penguindustin96466302019-05-06 14:57:17 -04001515 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001516 self.assertIs(type(d), OrderedDict)
1517 self.assertIs(type(d['y'][1]), OrderedDict)
1518
1519 def test_helper_asdict_namedtuple_key(self):
1520 # Ensure that a field that contains a dict which has a
1521 # namedtuple as a key works with asdict().
1522
1523 @dataclass
1524 class C:
1525 f: dict
1526 T = namedtuple('T', 'a')
1527
1528 c = C({T('an a'): 0})
1529
1530 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1531
1532 def test_helper_asdict_namedtuple_derived(self):
1533 class T(namedtuple('Tbase', 'a')):
1534 def my_a(self):
1535 return self.a
1536
1537 @dataclass
1538 class C:
1539 f: T
1540
1541 t = T(6)
1542 c = C(t)
1543
1544 d = asdict(c)
1545 self.assertEqual(d, {'f': T(a=6)})
1546 # Make sure that t has been copied, not used directly.
1547 self.assertIsNot(d['f'], t)
1548 self.assertEqual(d['f'].my_a(), 6)
1549
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001550 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001551 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001552 @dataclass
1553 class C:
1554 x: int
1555 y: int = 0
1556 c = C(1)
1557
1558 self.assertEqual(astuple(c), (1, 0))
1559 self.assertEqual(astuple(c), astuple(c))
1560 self.assertIsNot(astuple(c), astuple(c))
1561 c.y = 42
1562 self.assertEqual(astuple(c), (1, 42))
1563 self.assertIs(type(astuple(c)), tuple)
1564
1565 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001566 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001567 @dataclass
1568 class C:
1569 x: int
1570 y: int
1571 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1572 astuple(C)
1573 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1574 astuple(int)
1575
1576 def test_helper_astuple_copy_values(self):
1577 @dataclass
1578 class C:
1579 x: int
1580 y: List[int] = field(default_factory=list)
1581 initial = []
1582 c = C(1, initial)
1583 t = astuple(c)
1584 self.assertEqual(t[1], initial)
1585 self.assertIsNot(t[1], initial)
1586 c = C(1)
1587 t = astuple(c)
1588 t[1].append(1)
1589 self.assertEqual(c.y, [])
1590
1591 def test_helper_astuple_nested(self):
1592 @dataclass
1593 class UserId:
1594 token: int
1595 group: int
1596 @dataclass
1597 class User:
1598 name: str
1599 id: UserId
1600 u = User('Joe', UserId(123, 1))
1601 t = astuple(u)
1602 self.assertEqual(t, ('Joe', (123, 1)))
1603 self.assertIsNot(astuple(u), astuple(u))
1604 u.id.group = 2
1605 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1606
1607 def test_helper_astuple_builtin_containers(self):
1608 @dataclass
1609 class User:
1610 name: str
1611 id: int
1612 @dataclass
1613 class GroupList:
1614 id: int
1615 users: List[User]
1616 @dataclass
1617 class GroupTuple:
1618 id: int
1619 users: Tuple[User, ...]
1620 @dataclass
1621 class GroupDict:
1622 id: int
1623 users: Dict[str, User]
1624 a = User('Alice', 1)
1625 b = User('Bob', 2)
1626 gl = GroupList(0, [a, b])
1627 gt = GroupTuple(0, (a, b))
1628 gd = GroupDict(0, {'first': a, 'second': b})
1629 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1630 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1631 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1632
Windson yangbe372d72019-04-23 02:45:34 +08001633 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001634 @dataclass
1635 class Child:
1636 d: object
1637
1638 @dataclass
1639 class Parent:
1640 child: Child
1641
1642 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1643 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1644
1645 def test_helper_astuple_factory(self):
1646 @dataclass
1647 class C:
1648 x: int
1649 y: int
1650 NT = namedtuple('NT', 'x y')
1651 def nt(lst):
1652 return NT(*lst)
1653 c = C(1, 2)
1654 t = astuple(c, tuple_factory=nt)
1655 self.assertEqual(t, NT(1, 2))
1656 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1657 c.x = 42
1658 t = astuple(c, tuple_factory=nt)
1659 self.assertEqual(t, NT(42, 2))
1660 self.assertIs(type(t), NT)
1661
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001662 def test_helper_astuple_namedtuple(self):
1663 T = namedtuple('T', 'a b c')
1664 @dataclass
1665 class C:
1666 x: str
1667 y: T
1668 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1669
1670 t = astuple(c)
1671 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1672
1673 # Now, using a tuple_factory. list is convenient here.
1674 t = astuple(c, tuple_factory=list)
1675 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1676
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001677 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001678 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001679 }
1680
1681 # Create the class.
1682 cls = type('C', (), cls_dict)
1683
1684 # Make it a dataclass.
1685 cls1 = dataclass(cls)
1686
1687 self.assertEqual(cls1, cls)
1688 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1689
1690 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001691 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001692 'y': field(default=5),
1693 }
1694
1695 # Create the class.
1696 cls = type('C', (), cls_dict)
1697
1698 # Make it a dataclass.
1699 cls1 = dataclass(cls)
1700
1701 self.assertEqual(cls1, cls)
1702 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1703
1704 def test_init_in_order(self):
1705 @dataclass
1706 class C:
1707 a: int
1708 b: int = field()
1709 c: list = field(default_factory=list, init=False)
1710 d: list = field(default_factory=list)
1711 e: int = field(default=4, init=False)
1712 f: int = 4
1713
1714 calls = []
1715 def setattr(self, name, value):
1716 calls.append((name, value))
1717
1718 C.__setattr__ = setattr
1719 c = C(0, 1)
1720 self.assertEqual(('a', 0), calls[0])
1721 self.assertEqual(('b', 1), calls[1])
1722 self.assertEqual(('c', []), calls[2])
1723 self.assertEqual(('d', []), calls[3])
1724 self.assertNotIn(('e', 4), calls)
1725 self.assertEqual(('f', 4), calls[4])
1726
1727 def test_items_in_dicts(self):
1728 @dataclass
1729 class C:
1730 a: int
1731 b: list = field(default_factory=list, init=False)
1732 c: list = field(default_factory=list)
1733 d: int = field(default=4, init=False)
1734 e: int = 0
1735
1736 c = C(0)
1737 # Class dict
1738 self.assertNotIn('a', C.__dict__)
1739 self.assertNotIn('b', C.__dict__)
1740 self.assertNotIn('c', C.__dict__)
1741 self.assertIn('d', C.__dict__)
1742 self.assertEqual(C.d, 4)
1743 self.assertIn('e', C.__dict__)
1744 self.assertEqual(C.e, 0)
1745 # Instance dict
1746 self.assertIn('a', c.__dict__)
1747 self.assertEqual(c.a, 0)
1748 self.assertIn('b', c.__dict__)
1749 self.assertEqual(c.b, [])
1750 self.assertIn('c', c.__dict__)
1751 self.assertEqual(c.c, [])
1752 self.assertNotIn('d', c.__dict__)
1753 self.assertIn('e', c.__dict__)
1754 self.assertEqual(c.e, 0)
1755
1756 def test_alternate_classmethod_constructor(self):
1757 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001758 # alternate constructor. This is mostly an example to show
1759 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001760 @dataclass
1761 class C:
1762 x: int
1763 @classmethod
1764 def from_file(cls, filename):
1765 # In a real example, create a new instance
1766 # and populate 'x' from contents of a file.
1767 value_in_file = 20
1768 return cls(value_in_file)
1769
1770 self.assertEqual(C.from_file('filename').x, 20)
1771
1772 def test_field_metadata_default(self):
1773 # Make sure the default metadata is read-only and of
1774 # zero length.
1775 @dataclass
1776 class C:
1777 i: int
1778
1779 self.assertFalse(fields(C)[0].metadata)
1780 self.assertEqual(len(fields(C)[0].metadata), 0)
1781 with self.assertRaisesRegex(TypeError,
1782 'does not support item assignment'):
1783 fields(C)[0].metadata['test'] = 3
1784
1785 def test_field_metadata_mapping(self):
1786 # Make sure only a mapping can be passed as metadata
1787 # zero length.
1788 with self.assertRaises(TypeError):
1789 @dataclass
1790 class C:
1791 i: int = field(metadata=0)
1792
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001793 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001794 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001795 @dataclass
1796 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001797 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001798 self.assertFalse(fields(C)[0].metadata)
1799 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001800 # Update should work (see bpo-35960).
1801 d['foo'] = 1
1802 self.assertEqual(len(fields(C)[0].metadata), 1)
1803 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001804 with self.assertRaisesRegex(TypeError,
1805 'does not support item assignment'):
1806 fields(C)[0].metadata['test'] = 3
1807
1808 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001809 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001810 @dataclass
1811 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001812 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001813 self.assertEqual(len(fields(C)[0].metadata), 3)
1814 self.assertEqual(fields(C)[0].metadata['test'], 10)
1815 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1816 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001817 # Update should work.
1818 d['foo'] = 1
1819 self.assertEqual(len(fields(C)[0].metadata), 4)
1820 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001821 with self.assertRaises(KeyError):
1822 # Non-existent key.
1823 fields(C)[0].metadata['baz']
1824 with self.assertRaisesRegex(TypeError,
1825 'does not support item assignment'):
1826 fields(C)[0].metadata['test'] = 3
1827
1828 def test_field_metadata_custom_mapping(self):
1829 # Try a custom mapping.
1830 class SimpleNameSpace:
1831 def __init__(self, **kw):
1832 self.__dict__.update(kw)
1833
1834 def __getitem__(self, item):
1835 if item == 'xyzzy':
1836 return 'plugh'
1837 return getattr(self, item)
1838
1839 def __len__(self):
1840 return self.__dict__.__len__()
1841
1842 @dataclass
1843 class C:
1844 i: int = field(metadata=SimpleNameSpace(a=10))
1845
1846 self.assertEqual(len(fields(C)[0].metadata), 1)
1847 self.assertEqual(fields(C)[0].metadata['a'], 10)
1848 with self.assertRaises(AttributeError):
1849 fields(C)[0].metadata['b']
1850 # Make sure we're still talking to our custom mapping.
1851 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1852
1853 def test_generic_dataclasses(self):
1854 T = TypeVar('T')
1855
1856 @dataclass
1857 class LabeledBox(Generic[T]):
1858 content: T
1859 label: str = '<unknown>'
1860
1861 box = LabeledBox(42)
1862 self.assertEqual(box.content, 42)
1863 self.assertEqual(box.label, '<unknown>')
1864
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001865 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001866 Alias = List[LabeledBox[int]]
1867
1868 def test_generic_extending(self):
1869 S = TypeVar('S')
1870 T = TypeVar('T')
1871
1872 @dataclass
1873 class Base(Generic[T, S]):
1874 x: T
1875 y: S
1876
1877 @dataclass
1878 class DataDerived(Base[int, T]):
1879 new_field: str
1880 Alias = DataDerived[str]
1881 c = Alias(0, 'test1', 'test2')
1882 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1883
1884 class NonDataDerived(Base[int, T]):
1885 def new_method(self):
1886 return self.y
1887 Alias = NonDataDerived[float]
1888 c = Alias(10, 1.0)
1889 self.assertEqual(c.new_method(), 1.0)
1890
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001891 def test_generic_dynamic(self):
1892 T = TypeVar('T')
1893
1894 @dataclass
1895 class Parent(Generic[T]):
1896 x: T
1897 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1898 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1899 self.assertIs(Child[int](1, 2).z, None)
1900 self.assertEqual(Child[int](1, 2, 3).z, 3)
1901 self.assertEqual(Child[int](1, 2, 3).other, 42)
1902 # Check that type aliases work correctly.
1903 Alias = Child[T]
1904 self.assertEqual(Alias[int](1, 2).x, 1)
1905 # Check MRO resolution.
1906 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1907
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001908 def test_dataclassses_pickleable(self):
1909 global P, Q, R
1910 @dataclass
1911 class P:
1912 x: int
1913 y: int = 0
1914 @dataclass
1915 class Q:
1916 x: int
1917 y: int = field(default=0, init=False)
1918 @dataclass
1919 class R:
1920 x: int
1921 y: List[int] = field(default_factory=list)
1922 q = Q(1)
1923 q.y = 2
1924 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1925 for sample in samples:
1926 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1927 with self.subTest(sample=sample, proto=proto):
1928 new_sample = pickle.loads(pickle.dumps(sample, proto))
1929 self.assertEqual(sample.x, new_sample.x)
1930 self.assertEqual(sample.y, new_sample.y)
1931 self.assertIsNot(sample, new_sample)
1932 new_sample.x = 42
1933 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1934 self.assertEqual(new_sample.x, another_new_sample.x)
1935 self.assertEqual(sample.y, another_new_sample.y)
1936
Eric V. Smithea8fc522018-01-27 19:07:40 -05001937
Eric V. Smith56970b82018-03-22 16:28:48 -04001938class TestFieldNoAnnotation(unittest.TestCase):
1939 def test_field_without_annotation(self):
1940 with self.assertRaisesRegex(TypeError,
1941 "'f' is a field but has no type annotation"):
1942 @dataclass
1943 class C:
1944 f = field()
1945
1946 def test_field_without_annotation_but_annotation_in_base(self):
1947 @dataclass
1948 class B:
1949 f: int
1950
1951 with self.assertRaisesRegex(TypeError,
1952 "'f' is a field but has no type annotation"):
1953 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001954 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001955 @dataclass
1956 class C(B):
1957 f = field()
1958
1959 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1960 # Same test, but with the base class not a dataclass.
1961 class B:
1962 f: int
1963
1964 with self.assertRaisesRegex(TypeError,
1965 "'f' is a field but has no type annotation"):
1966 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001967 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001968 @dataclass
1969 class C(B):
1970 f = field()
1971
1972
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001973class TestDocString(unittest.TestCase):
1974 def assertDocStrEqual(self, a, b):
1975 # Because 3.6 and 3.7 differ in how inspect.signature work
1976 # (see bpo #32108), for the time being just compare them with
1977 # whitespace stripped.
1978 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1979
1980 def test_existing_docstring_not_overridden(self):
1981 @dataclass
1982 class C:
1983 """Lorem ipsum"""
1984 x: int
1985
1986 self.assertEqual(C.__doc__, "Lorem ipsum")
1987
1988 def test_docstring_no_fields(self):
1989 @dataclass
1990 class C:
1991 pass
1992
1993 self.assertDocStrEqual(C.__doc__, "C()")
1994
1995 def test_docstring_one_field(self):
1996 @dataclass
1997 class C:
1998 x: int
1999
2000 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2001
2002 def test_docstring_two_fields(self):
2003 @dataclass
2004 class C:
2005 x: int
2006 y: int
2007
2008 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2009
2010 def test_docstring_three_fields(self):
2011 @dataclass
2012 class C:
2013 x: int
2014 y: int
2015 z: str
2016
2017 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2018
2019 def test_docstring_one_field_with_default(self):
2020 @dataclass
2021 class C:
2022 x: int = 3
2023
2024 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2025
2026 def test_docstring_one_field_with_default_none(self):
2027 @dataclass
2028 class C:
2029 x: Union[int, type(None)] = None
2030
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002031 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002032
2033 def test_docstring_list_field(self):
2034 @dataclass
2035 class C:
2036 x: List[int]
2037
2038 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2039
2040 def test_docstring_list_field_with_default_factory(self):
2041 @dataclass
2042 class C:
2043 x: List[int] = field(default_factory=list)
2044
2045 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2046
2047 def test_docstring_deque_field(self):
2048 @dataclass
2049 class C:
2050 x: deque
2051
2052 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2053
2054 def test_docstring_deque_field_with_default_factory(self):
2055 @dataclass
2056 class C:
2057 x: deque = field(default_factory=deque)
2058
2059 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2060
2061
Eric V. Smithea8fc522018-01-27 19:07:40 -05002062class TestInit(unittest.TestCase):
2063 def test_base_has_init(self):
2064 class B:
2065 def __init__(self):
2066 self.z = 100
2067 pass
2068
2069 # Make sure that declaring this class doesn't raise an error.
2070 # The issue is that we can't override __init__ in our class,
2071 # but it should be okay to add __init__ to us if our base has
2072 # an __init__.
2073 @dataclass
2074 class C(B):
2075 x: int = 0
2076 c = C(10)
2077 self.assertEqual(c.x, 10)
2078 self.assertNotIn('z', vars(c))
2079
2080 # Make sure that if we don't add an init, the base __init__
2081 # gets called.
2082 @dataclass(init=False)
2083 class C(B):
2084 x: int = 10
2085 c = C()
2086 self.assertEqual(c.x, 10)
2087 self.assertEqual(c.z, 100)
2088
2089 def test_no_init(self):
2090 dataclass(init=False)
2091 class C:
2092 i: int = 0
2093 self.assertEqual(C().i, 0)
2094
2095 dataclass(init=False)
2096 class C:
2097 i: int = 2
2098 def __init__(self):
2099 self.i = 3
2100 self.assertEqual(C().i, 3)
2101
2102 def test_overwriting_init(self):
2103 # If the class has __init__, use it no matter the value of
2104 # init=.
2105
2106 @dataclass
2107 class C:
2108 x: int
2109 def __init__(self, x):
2110 self.x = 2 * x
2111 self.assertEqual(C(3).x, 6)
2112
2113 @dataclass(init=True)
2114 class C:
2115 x: int
2116 def __init__(self, x):
2117 self.x = 2 * x
2118 self.assertEqual(C(4).x, 8)
2119
2120 @dataclass(init=False)
2121 class C:
2122 x: int
2123 def __init__(self, x):
2124 self.x = 2 * x
2125 self.assertEqual(C(5).x, 10)
2126
2127
2128class TestRepr(unittest.TestCase):
2129 def test_repr(self):
2130 @dataclass
2131 class B:
2132 x: int
2133
2134 @dataclass
2135 class C(B):
2136 y: int = 10
2137
2138 o = C(4)
2139 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2140
2141 @dataclass
2142 class D(C):
2143 x: int = 20
2144 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2145
2146 @dataclass
2147 class C:
2148 @dataclass
2149 class D:
2150 i: int
2151 @dataclass
2152 class E:
2153 pass
2154 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2155 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2156
2157 def test_no_repr(self):
2158 # Test a class with no __repr__ and repr=False.
2159 @dataclass(repr=False)
2160 class C:
2161 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002162 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002163 repr(C(3)))
2164
2165 # Test a class with a __repr__ and repr=False.
2166 @dataclass(repr=False)
2167 class C:
2168 x: int
2169 def __repr__(self):
2170 return 'C-class'
2171 self.assertEqual(repr(C(3)), 'C-class')
2172
2173 def test_overwriting_repr(self):
2174 # If the class has __repr__, use it no matter the value of
2175 # repr=.
2176
2177 @dataclass
2178 class C:
2179 x: int
2180 def __repr__(self):
2181 return 'x'
2182 self.assertEqual(repr(C(0)), 'x')
2183
2184 @dataclass(repr=True)
2185 class C:
2186 x: int
2187 def __repr__(self):
2188 return 'x'
2189 self.assertEqual(repr(C(0)), 'x')
2190
2191 @dataclass(repr=False)
2192 class C:
2193 x: int
2194 def __repr__(self):
2195 return 'x'
2196 self.assertEqual(repr(C(0)), 'x')
2197
2198
Eric V. Smithea8fc522018-01-27 19:07:40 -05002199class TestEq(unittest.TestCase):
2200 def test_no_eq(self):
2201 # Test a class with no __eq__ and eq=False.
2202 @dataclass(eq=False)
2203 class C:
2204 x: int
2205 self.assertNotEqual(C(0), C(0))
2206 c = C(3)
2207 self.assertEqual(c, c)
2208
2209 # Test a class with an __eq__ and eq=False.
2210 @dataclass(eq=False)
2211 class C:
2212 x: int
2213 def __eq__(self, other):
2214 return other == 10
2215 self.assertEqual(C(3), 10)
2216
2217 def test_overwriting_eq(self):
2218 # If the class has __eq__, use it no matter the value of
2219 # eq=.
2220
2221 @dataclass
2222 class C:
2223 x: int
2224 def __eq__(self, other):
2225 return other == 3
2226 self.assertEqual(C(1), 3)
2227 self.assertNotEqual(C(1), 1)
2228
2229 @dataclass(eq=True)
2230 class C:
2231 x: int
2232 def __eq__(self, other):
2233 return other == 4
2234 self.assertEqual(C(1), 4)
2235 self.assertNotEqual(C(1), 1)
2236
2237 @dataclass(eq=False)
2238 class C:
2239 x: int
2240 def __eq__(self, other):
2241 return other == 5
2242 self.assertEqual(C(1), 5)
2243 self.assertNotEqual(C(1), 1)
2244
2245
2246class TestOrdering(unittest.TestCase):
2247 def test_functools_total_ordering(self):
2248 # Test that functools.total_ordering works with this class.
2249 @total_ordering
2250 @dataclass
2251 class C:
2252 x: int
2253 def __lt__(self, other):
2254 # Perform the test "backward", just to make
2255 # sure this is being called.
2256 return self.x >= other
2257
2258 self.assertLess(C(0), -1)
2259 self.assertLessEqual(C(0), -1)
2260 self.assertGreater(C(0), 1)
2261 self.assertGreaterEqual(C(0), 1)
2262
2263 def test_no_order(self):
2264 # Test that no ordering functions are added by default.
2265 @dataclass(order=False)
2266 class C:
2267 x: int
2268 # Make sure no order methods are added.
2269 self.assertNotIn('__le__', C.__dict__)
2270 self.assertNotIn('__lt__', C.__dict__)
2271 self.assertNotIn('__ge__', C.__dict__)
2272 self.assertNotIn('__gt__', C.__dict__)
2273
2274 # Test that __lt__ is still called
2275 @dataclass(order=False)
2276 class C:
2277 x: int
2278 def __lt__(self, other):
2279 return False
2280 # Make sure other methods aren't added.
2281 self.assertNotIn('__le__', C.__dict__)
2282 self.assertNotIn('__ge__', C.__dict__)
2283 self.assertNotIn('__gt__', C.__dict__)
2284
2285 def test_overwriting_order(self):
2286 with self.assertRaisesRegex(TypeError,
2287 'Cannot overwrite attribute __lt__'
2288 '.*using functools.total_ordering'):
2289 @dataclass(order=True)
2290 class C:
2291 x: int
2292 def __lt__(self):
2293 pass
2294
2295 with self.assertRaisesRegex(TypeError,
2296 'Cannot overwrite attribute __le__'
2297 '.*using functools.total_ordering'):
2298 @dataclass(order=True)
2299 class C:
2300 x: int
2301 def __le__(self):
2302 pass
2303
2304 with self.assertRaisesRegex(TypeError,
2305 'Cannot overwrite attribute __gt__'
2306 '.*using functools.total_ordering'):
2307 @dataclass(order=True)
2308 class C:
2309 x: int
2310 def __gt__(self):
2311 pass
2312
2313 with self.assertRaisesRegex(TypeError,
2314 'Cannot overwrite attribute __ge__'
2315 '.*using functools.total_ordering'):
2316 @dataclass(order=True)
2317 class C:
2318 x: int
2319 def __ge__(self):
2320 pass
2321
2322class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002323 def test_unsafe_hash(self):
2324 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002325 class C:
2326 x: int
2327 y: str
2328 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2329
Eric V. Smithea8fc522018-01-27 19:07:40 -05002330 def test_hash_rules(self):
2331 def non_bool(value):
2332 # Map to something else that's True, but not a bool.
2333 if value is None:
2334 return None
2335 if value:
2336 return (3,)
2337 return 0
2338
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002339 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2340 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2341 frozen=frozen):
2342 if result != 'exception':
2343 if with_hash:
2344 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2345 class C:
2346 def __hash__(self):
2347 return 0
2348 else:
2349 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2350 class C:
2351 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002352
2353 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002354 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002355 # __hash__ contains the function we generated.
2356 self.assertIn('__hash__', C.__dict__)
2357 self.assertIsNotNone(C.__dict__['__hash__'])
2358
Eric V. Smithea8fc522018-01-27 19:07:40 -05002359 elif result == '':
2360 # __hash__ is not present in our class.
2361 if not with_hash:
2362 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002363
Eric V. Smithea8fc522018-01-27 19:07:40 -05002364 elif result == 'none':
2365 # __hash__ is set to None.
2366 self.assertIn('__hash__', C.__dict__)
2367 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002368
2369 elif result == 'exception':
2370 # Creating the class should cause an exception.
2371 # This only happens with with_hash==True.
2372 assert(with_hash)
2373 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2374 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2375 class C:
2376 def __hash__(self):
2377 return 0
2378
Eric V. Smithea8fc522018-01-27 19:07:40 -05002379 else:
2380 assert False, f'unknown result {result!r}'
2381
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002382 # There are 8 cases of:
2383 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002384 # eq=True/False
2385 # frozen=True/False
2386 # And for each of these, a different result if
2387 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002388 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2389 (False, False, False, '', ''),
2390 (False, False, True, '', ''),
2391 (False, True, False, 'none', ''),
2392 (False, True, True, 'fn', ''),
2393 (True, False, False, 'fn', 'exception'),
2394 (True, False, True, 'fn', 'exception'),
2395 (True, True, False, 'fn', 'exception'),
2396 (True, True, True, 'fn', 'exception'),
2397 ], 1):
2398 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2399 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002400
2401 # Test non-bool truth values, too. This is just to
2402 # make sure the data-driven table in the decorator
2403 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002404 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2405 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002406
2407
2408 def test_eq_only(self):
2409 # If a class defines __eq__, __hash__ is automatically added
2410 # and set to None. This is normal Python behavior, not
2411 # related to dataclasses. Make sure we don't interfere with
2412 # that (see bpo=32546).
2413
2414 @dataclass
2415 class C:
2416 i: int
2417 def __eq__(self, other):
2418 return self.i == other.i
2419 self.assertEqual(C(1), C(1))
2420 self.assertNotEqual(C(1), C(4))
2421
2422 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002423 # unsafe_hash=True.
2424 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002425 class C:
2426 i: int
2427 def __eq__(self, other):
2428 return self.i == other.i
2429 self.assertEqual(C(1), C(1.0))
2430 self.assertEqual(hash(C(1)), hash(C(1.0)))
2431
2432 # And check that the classes __eq__ is being used, despite
2433 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002434 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002435 class C:
2436 i: int
2437 def __eq__(self, other):
2438 return self.i == 3 and self.i == other.i
2439 self.assertEqual(C(3), C(3))
2440 self.assertNotEqual(C(1), C(1))
2441 self.assertEqual(hash(C(1)), hash(C(1.0)))
2442
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002443 def test_0_field_hash(self):
2444 @dataclass(frozen=True)
2445 class C:
2446 pass
2447 self.assertEqual(hash(C()), hash(()))
2448
2449 @dataclass(unsafe_hash=True)
2450 class C:
2451 pass
2452 self.assertEqual(hash(C()), hash(()))
2453
2454 def test_1_field_hash(self):
2455 @dataclass(frozen=True)
2456 class C:
2457 x: int
2458 self.assertEqual(hash(C(4)), hash((4,)))
2459 self.assertEqual(hash(C(42)), hash((42,)))
2460
2461 @dataclass(unsafe_hash=True)
2462 class C:
2463 x: int
2464 self.assertEqual(hash(C(4)), hash((4,)))
2465 self.assertEqual(hash(C(42)), hash((42,)))
2466
Eric V. Smith718070d2018-02-23 13:01:31 -05002467 def test_hash_no_args(self):
2468 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002469 # make sure that if the @dataclass parameter name is changed
2470 # or the non-default hashing behavior changes, the default
2471 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002472
2473 class Base:
2474 def __hash__(self):
2475 return 301
2476
2477 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002478 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002479 for frozen, eq, base, expected in [
2480 (None, None, object, 'unhashable'),
2481 (None, None, Base, 'unhashable'),
2482 (None, False, object, 'object'),
2483 (None, False, Base, 'base'),
2484 (None, True, object, 'unhashable'),
2485 (None, True, Base, 'unhashable'),
2486 (False, None, object, 'unhashable'),
2487 (False, None, Base, 'unhashable'),
2488 (False, False, object, 'object'),
2489 (False, False, Base, 'base'),
2490 (False, True, object, 'unhashable'),
2491 (False, True, Base, 'unhashable'),
2492 (True, None, object, 'tuple'),
2493 (True, None, Base, 'tuple'),
2494 (True, False, object, 'object'),
2495 (True, False, Base, 'base'),
2496 (True, True, object, 'tuple'),
2497 (True, True, Base, 'tuple'),
2498 ]:
2499
2500 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2501 # First, create the class.
2502 if frozen is None and eq is None:
2503 @dataclass
2504 class C(base):
2505 i: int
2506 elif frozen is None:
2507 @dataclass(eq=eq)
2508 class C(base):
2509 i: int
2510 elif eq is None:
2511 @dataclass(frozen=frozen)
2512 class C(base):
2513 i: int
2514 else:
2515 @dataclass(frozen=frozen, eq=eq)
2516 class C(base):
2517 i: int
2518
2519 # Now, make sure it hashes as expected.
2520 if expected == 'unhashable':
2521 c = C(10)
2522 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2523 hash(c)
2524
2525 elif expected == 'base':
2526 self.assertEqual(hash(C(10)), 301)
2527
2528 elif expected == 'object':
2529 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002530 # hash isn't based on id(), so calling hash()
2531 # won't tell us much. So, just check the
2532 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002533 self.assertIs(C.__hash__, object.__hash__)
2534
2535 elif expected == 'tuple':
2536 self.assertEqual(hash(C(42)), hash((42,)))
2537
2538 else:
2539 assert False, f'unknown value for expected={expected!r}'
2540
Eric V. Smithea8fc522018-01-27 19:07:40 -05002541
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002542class TestFrozen(unittest.TestCase):
2543 def test_frozen(self):
2544 @dataclass(frozen=True)
2545 class C:
2546 i: int
2547
2548 c = C(10)
2549 self.assertEqual(c.i, 10)
2550 with self.assertRaises(FrozenInstanceError):
2551 c.i = 5
2552 self.assertEqual(c.i, 10)
2553
2554 def test_inherit(self):
2555 @dataclass(frozen=True)
2556 class C:
2557 i: int
2558
2559 @dataclass(frozen=True)
2560 class D(C):
2561 j: int
2562
2563 d = D(0, 10)
2564 with self.assertRaises(FrozenInstanceError):
2565 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002566 with self.assertRaises(FrozenInstanceError):
2567 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002568 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002569 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002570
Eric V. Smithf199bc62018-03-18 20:40:34 -04002571 # Test both ways: with an intermediate normal (non-dataclass)
2572 # class and without an intermediate class.
2573 def test_inherit_nonfrozen_from_frozen(self):
2574 for intermediate_class in [True, False]:
2575 with self.subTest(intermediate_class=intermediate_class):
2576 @dataclass(frozen=True)
2577 class C:
2578 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002579
Eric V. Smithf199bc62018-03-18 20:40:34 -04002580 if intermediate_class:
2581 class I(C): pass
2582 else:
2583 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002584
Eric V. Smithf199bc62018-03-18 20:40:34 -04002585 with self.assertRaisesRegex(TypeError,
2586 'cannot inherit non-frozen dataclass from a frozen one'):
2587 @dataclass
2588 class D(I):
2589 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002590
Eric V. Smithf199bc62018-03-18 20:40:34 -04002591 def test_inherit_frozen_from_nonfrozen(self):
2592 for intermediate_class in [True, False]:
2593 with self.subTest(intermediate_class=intermediate_class):
2594 @dataclass
2595 class C:
2596 i: int
2597
2598 if intermediate_class:
2599 class I(C): pass
2600 else:
2601 I = C
2602
2603 with self.assertRaisesRegex(TypeError,
2604 'cannot inherit frozen dataclass from a non-frozen one'):
2605 @dataclass(frozen=True)
2606 class D(I):
2607 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002608
2609 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002610 for intermediate_class in [True, False]:
2611 with self.subTest(intermediate_class=intermediate_class):
2612 class C:
2613 pass
2614
2615 if intermediate_class:
2616 class I(C): pass
2617 else:
2618 I = C
2619
2620 @dataclass(frozen=True)
2621 class D(I):
2622 i: int
2623
2624 d = D(10)
2625 with self.assertRaises(FrozenInstanceError):
2626 d.i = 5
2627
2628 def test_non_frozen_normal_derived(self):
2629 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002630
2631 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002632 class D:
2633 x: int
2634 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002635
Eric V. Smithf199bc62018-03-18 20:40:34 -04002636 class S(D):
2637 pass
2638
2639 s = S(3)
2640 self.assertEqual(s.x, 3)
2641 self.assertEqual(s.y, 10)
2642 s.cached = True
2643
2644 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002645 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002646 s.x = 5
2647 with self.assertRaises(FrozenInstanceError):
2648 s.y = 5
2649 self.assertEqual(s.x, 3)
2650 self.assertEqual(s.y, 10)
2651 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002652
Eric V. Smith74940912018-04-05 06:50:18 -04002653 def test_overwriting_frozen(self):
2654 # frozen uses __setattr__ and __delattr__.
2655 with self.assertRaisesRegex(TypeError,
2656 'Cannot overwrite attribute __setattr__'):
2657 @dataclass(frozen=True)
2658 class C:
2659 x: int
2660 def __setattr__(self):
2661 pass
2662
2663 with self.assertRaisesRegex(TypeError,
2664 'Cannot overwrite attribute __delattr__'):
2665 @dataclass(frozen=True)
2666 class C:
2667 x: int
2668 def __delattr__(self):
2669 pass
2670
2671 @dataclass(frozen=False)
2672 class C:
2673 x: int
2674 def __setattr__(self, name, value):
2675 self.__dict__['x'] = value * 2
2676 self.assertEqual(C(10).x, 20)
2677
2678 def test_frozen_hash(self):
2679 @dataclass(frozen=True)
2680 class C:
2681 x: Any
2682
2683 # If x is immutable, we can compute the hash. No exception is
2684 # raised.
2685 hash(C(3))
2686
2687 # If x is mutable, computing the hash is an error.
2688 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2689 hash(C({}))
2690
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002691
Eric V. Smith7389fd92018-03-19 21:07:51 -04002692class TestSlots(unittest.TestCase):
2693 def test_simple(self):
2694 @dataclass
2695 class C:
2696 __slots__ = ('x',)
2697 x: Any
2698
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002699 # There was a bug where a variable in a slot was assumed to
2700 # also have a default value (of type
2701 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002702 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002703 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002704 C()
2705
2706 # We can create an instance, and assign to x.
2707 c = C(10)
2708 self.assertEqual(c.x, 10)
2709 c.x = 5
2710 self.assertEqual(c.x, 5)
2711
2712 # We can't assign to anything else.
2713 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2714 c.y = 5
2715
2716 def test_derived_added_field(self):
2717 # See bpo-33100.
2718 @dataclass
2719 class Base:
2720 __slots__ = ('x',)
2721 x: Any
2722
2723 @dataclass
2724 class Derived(Base):
2725 x: int
2726 y: int
2727
2728 d = Derived(1, 2)
2729 self.assertEqual((d.x, d.y), (1, 2))
2730
2731 # We can add a new field to the derived instance.
2732 d.z = 10
2733
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002734class TestDescriptors(unittest.TestCase):
2735 def test_set_name(self):
2736 # See bpo-33141.
2737
2738 # Create a descriptor.
2739 class D:
2740 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002741 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002742 def __get__(self, instance, owner):
2743 if instance is not None:
2744 return 1
2745 return self
2746
2747 # This is the case of just normal descriptor behavior, no
2748 # dataclass code is involved in initializing the descriptor.
2749 @dataclass
2750 class C:
2751 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002752 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002753
2754 # Now test with a default value and init=False, which is the
2755 # only time this is really meaningful. If not using
2756 # init=False, then the descriptor will be overwritten, anyway.
2757 @dataclass
2758 class C:
2759 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002760 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002761 self.assertEqual(C().c, 1)
2762
2763 def test_non_descriptor(self):
2764 # PEP 487 says __set_name__ should work on non-descriptors.
2765 # Create a descriptor.
2766
2767 class D:
2768 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002769 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002770
2771 @dataclass
2772 class C:
2773 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002774 self.assertEqual(C.c.name, 'cx')
2775
2776 def test_lookup_on_instance(self):
2777 # See bpo-33175.
2778 class D:
2779 pass
2780
2781 d = D()
2782 # Create an attribute on the instance, not type.
2783 d.__set_name__ = Mock()
2784
2785 # Make sure d.__set_name__ is not called.
2786 @dataclass
2787 class C:
2788 i: int=field(default=d, init=False)
2789
2790 self.assertEqual(d.__set_name__.call_count, 0)
2791
2792 def test_lookup_on_class(self):
2793 # See bpo-33175.
2794 class D:
2795 pass
2796 D.__set_name__ = Mock()
2797
2798 # Make sure D.__set_name__ is called.
2799 @dataclass
2800 class C:
2801 i: int=field(default=D(), init=False)
2802
2803 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002804
Eric V. Smith7389fd92018-03-19 21:07:51 -04002805
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002806class TestStringAnnotations(unittest.TestCase):
2807 def test_classvar(self):
2808 # Some expressions recognized as ClassVar really aren't. But
2809 # if you're using string annotations, it's not an exact
2810 # science.
2811 # These tests assume that both "import typing" and "from
2812 # typing import *" have been run in this file.
2813 for typestr in ('ClassVar[int]',
2814 'ClassVar [int]'
2815 ' ClassVar [int]',
2816 'ClassVar',
2817 ' ClassVar ',
2818 'typing.ClassVar[int]',
2819 'typing.ClassVar[str]',
2820 ' typing.ClassVar[str]',
2821 'typing .ClassVar[str]',
2822 'typing. ClassVar[str]',
2823 'typing.ClassVar [str]',
2824 'typing.ClassVar [ str]',
2825
2826 # Not syntactically valid, but these will
2827 # be treated as ClassVars.
2828 'typing.ClassVar.[int]',
2829 'typing.ClassVar+',
2830 ):
2831 with self.subTest(typestr=typestr):
2832 @dataclass
2833 class C:
2834 x: typestr
2835
2836 # x is a ClassVar, so C() takes no args.
2837 C()
2838
2839 # And it won't appear in the class's dict because it doesn't
2840 # have a default.
2841 self.assertNotIn('x', C.__dict__)
2842
2843 def test_isnt_classvar(self):
2844 for typestr in ('CV',
2845 't.ClassVar',
2846 't.ClassVar[int]',
2847 'typing..ClassVar[int]',
2848 'Classvar',
2849 'Classvar[int]',
2850 'typing.ClassVarx[int]',
2851 'typong.ClassVar[int]',
2852 'dataclasses.ClassVar[int]',
2853 'typingxClassVar[str]',
2854 ):
2855 with self.subTest(typestr=typestr):
2856 @dataclass
2857 class C:
2858 x: typestr
2859
2860 # x is not a ClassVar, so C() takes one arg.
2861 self.assertEqual(C(10).x, 10)
2862
2863 def test_initvar(self):
2864 # These tests assume that both "import dataclasses" and "from
2865 # dataclasses import *" have been run in this file.
2866 for typestr in ('InitVar[int]',
2867 'InitVar [int]'
2868 ' InitVar [int]',
2869 'InitVar',
2870 ' InitVar ',
2871 'dataclasses.InitVar[int]',
2872 'dataclasses.InitVar[str]',
2873 ' dataclasses.InitVar[str]',
2874 'dataclasses .InitVar[str]',
2875 'dataclasses. InitVar[str]',
2876 'dataclasses.InitVar [str]',
2877 'dataclasses.InitVar [ str]',
2878
2879 # Not syntactically valid, but these will
2880 # be treated as InitVars.
2881 'dataclasses.InitVar.[int]',
2882 'dataclasses.InitVar+',
2883 ):
2884 with self.subTest(typestr=typestr):
2885 @dataclass
2886 class C:
2887 x: typestr
2888
2889 # x is an InitVar, so doesn't create a member.
2890 with self.assertRaisesRegex(AttributeError,
2891 "object has no attribute 'x'"):
2892 C(1).x
2893
2894 def test_isnt_initvar(self):
2895 for typestr in ('IV',
2896 'dc.InitVar',
2897 'xdataclasses.xInitVar',
2898 'typing.xInitVar[int]',
2899 ):
2900 with self.subTest(typestr=typestr):
2901 @dataclass
2902 class C:
2903 x: typestr
2904
2905 # x is not an InitVar, so there will be a member x.
2906 self.assertEqual(C(10).x, 10)
2907
2908 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002909 from test import dataclass_module_1
2910 from test import dataclass_module_1_str
2911 from test import dataclass_module_2
2912 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002913
2914 for m in (dataclass_module_1, dataclass_module_1_str,
2915 dataclass_module_2, dataclass_module_2_str,
2916 ):
2917 with self.subTest(m=m):
2918 # There's a difference in how the ClassVars are
2919 # interpreted when using string annotations or
2920 # not. See the imported modules for details.
2921 if m.USING_STRINGS:
2922 c = m.CV(10)
2923 else:
2924 c = m.CV()
2925 self.assertEqual(c.cv0, 20)
2926
2927
2928 # There's a difference in how the InitVars are
2929 # interpreted when using string annotations or
2930 # not. See the imported modules for details.
2931 c = m.IV(0, 1, 2, 3, 4)
2932
2933 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2934 with self.subTest(field_name=field_name):
2935 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2936 # Since field_name is an InitVar, it's
2937 # not an instance field.
2938 getattr(c, field_name)
2939
2940 if m.USING_STRINGS:
2941 # iv4 is interpreted as a normal field.
2942 self.assertIn('not_iv4', c.__dict__)
2943 self.assertEqual(c.not_iv4, 4)
2944 else:
2945 # iv4 is interpreted as an InitVar, so it
2946 # won't exist on the instance.
2947 self.assertNotIn('not_iv4', c.__dict__)
2948
Yury Selivanovd219cc42019-12-09 09:54:20 -05002949 def test_text_annotations(self):
2950 from test import dataclass_textanno
2951
2952 self.assertEqual(
2953 get_type_hints(dataclass_textanno.Bar),
2954 {'foo': dataclass_textanno.Foo})
2955 self.assertEqual(
2956 get_type_hints(dataclass_textanno.Bar.__init__),
2957 {'foo': dataclass_textanno.Foo,
2958 'return': type(None)})
2959
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002960
Eric V. Smith4e812962018-05-16 11:31:29 -04002961class TestMakeDataclass(unittest.TestCase):
2962 def test_simple(self):
2963 C = make_dataclass('C',
2964 [('x', int),
2965 ('y', int, field(default=5))],
2966 namespace={'add_one': lambda self: self.x + 1})
2967 c = C(10)
2968 self.assertEqual((c.x, c.y), (10, 5))
2969 self.assertEqual(c.add_one(), 11)
2970
2971
2972 def test_no_mutate_namespace(self):
2973 # Make sure a provided namespace isn't mutated.
2974 ns = {}
2975 C = make_dataclass('C',
2976 [('x', int),
2977 ('y', int, field(default=5))],
2978 namespace=ns)
2979 self.assertEqual(ns, {})
2980
2981 def test_base(self):
2982 class Base1:
2983 pass
2984 class Base2:
2985 pass
2986 C = make_dataclass('C',
2987 [('x', int)],
2988 bases=(Base1, Base2))
2989 c = C(2)
2990 self.assertIsInstance(c, C)
2991 self.assertIsInstance(c, Base1)
2992 self.assertIsInstance(c, Base2)
2993
2994 def test_base_dataclass(self):
2995 @dataclass
2996 class Base1:
2997 x: int
2998 class Base2:
2999 pass
3000 C = make_dataclass('C',
3001 [('y', int)],
3002 bases=(Base1, Base2))
3003 with self.assertRaisesRegex(TypeError, 'required positional'):
3004 c = C(2)
3005 c = C(1, 2)
3006 self.assertIsInstance(c, C)
3007 self.assertIsInstance(c, Base1)
3008 self.assertIsInstance(c, Base2)
3009
3010 self.assertEqual((c.x, c.y), (1, 2))
3011
3012 def test_init_var(self):
3013 def post_init(self, y):
3014 self.x *= y
3015
3016 C = make_dataclass('C',
3017 [('x', int),
3018 ('y', InitVar[int]),
3019 ],
3020 namespace={'__post_init__': post_init},
3021 )
3022 c = C(2, 3)
3023 self.assertEqual(vars(c), {'x': 6})
3024 self.assertEqual(len(fields(c)), 1)
3025
3026 def test_class_var(self):
3027 C = make_dataclass('C',
3028 [('x', int),
3029 ('y', ClassVar[int], 10),
3030 ('z', ClassVar[int], field(default=20)),
3031 ])
3032 c = C(1)
3033 self.assertEqual(vars(c), {'x': 1})
3034 self.assertEqual(len(fields(c)), 1)
3035 self.assertEqual(C.y, 10)
3036 self.assertEqual(C.z, 20)
3037
3038 def test_other_params(self):
3039 C = make_dataclass('C',
3040 [('x', int),
3041 ('y', ClassVar[int], 10),
3042 ('z', ClassVar[int], field(default=20)),
3043 ],
3044 init=False)
3045 # Make sure we have a repr, but no init.
3046 self.assertNotIn('__init__', vars(C))
3047 self.assertIn('__repr__', vars(C))
3048
3049 # Make sure random other params don't work.
3050 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3051 C = make_dataclass('C',
3052 [],
3053 xxinit=False)
3054
3055 def test_no_types(self):
3056 C = make_dataclass('Point', ['x', 'y', 'z'])
3057 c = C(1, 2, 3)
3058 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3059 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3060 'y': 'typing.Any',
3061 'z': 'typing.Any'})
3062
3063 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3064 c = C(1, 2, 3)
3065 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3066 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3067 'y': int,
3068 'z': 'typing.Any'})
3069
3070 def test_invalid_type_specification(self):
3071 for bad_field in [(),
3072 (1, 2, 3, 4),
3073 ]:
3074 with self.subTest(bad_field=bad_field):
3075 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3076 make_dataclass('C', ['a', bad_field])
3077
3078 # And test for things with no len().
3079 for bad_field in [float,
3080 lambda x:x,
3081 ]:
3082 with self.subTest(bad_field=bad_field):
3083 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3084 make_dataclass('C', ['a', bad_field])
3085
3086 def test_duplicate_field_names(self):
3087 for field in ['a', 'ab']:
3088 with self.subTest(field=field):
3089 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3090 make_dataclass('C', [field, 'a', field])
3091
3092 def test_keyword_field_names(self):
3093 for field in ['for', 'async', 'await', 'as']:
3094 with self.subTest(field=field):
3095 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3096 make_dataclass('C', ['a', field])
3097 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3098 make_dataclass('C', [field])
3099 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3100 make_dataclass('C', [field, 'a'])
3101
3102 def test_non_identifier_field_names(self):
3103 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3104 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003105 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003106 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003107 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003108 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003109 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003110 make_dataclass('C', [field, 'a'])
3111
3112 def test_underscore_field_names(self):
3113 # Unlike namedtuple, it's okay if dataclass field names have
3114 # an underscore.
3115 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3116
3117 def test_funny_class_names_names(self):
3118 # No reason to prevent weird class names, since
3119 # types.new_class allows them.
3120 for classname in ['()', 'x,y', '*', '2@3', '']:
3121 with self.subTest(classname=classname):
3122 C = make_dataclass(classname, ['a', 'b'])
3123 self.assertEqual(C.__name__, classname)
3124
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003125class TestReplace(unittest.TestCase):
3126 def test(self):
3127 @dataclass(frozen=True)
3128 class C:
3129 x: int
3130 y: int
3131
3132 c = C(1, 2)
3133 c1 = replace(c, x=3)
3134 self.assertEqual(c1.x, 3)
3135 self.assertEqual(c1.y, 2)
3136
3137 def test_frozen(self):
3138 @dataclass(frozen=True)
3139 class C:
3140 x: int
3141 y: int
3142 z: int = field(init=False, default=10)
3143 t: int = field(init=False, default=100)
3144
3145 c = C(1, 2)
3146 c1 = replace(c, x=3)
3147 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3148 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3149
3150
3151 with self.assertRaisesRegex(ValueError, 'init=False'):
3152 replace(c, x=3, z=20, t=50)
3153 with self.assertRaisesRegex(ValueError, 'init=False'):
3154 replace(c, z=20)
3155 replace(c, x=3, z=20, t=50)
3156
3157 # Make sure the result is still frozen.
3158 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3159 c1.x = 3
3160
3161 # Make sure we can't replace an attribute that doesn't exist,
3162 # if we're also replacing one that does exist. Test this
3163 # here, because setting attributes on frozen instances is
3164 # handled slightly differently from non-frozen ones.
3165 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3166 "keyword argument 'a'"):
3167 c1 = replace(c, x=20, a=5)
3168
3169 def test_invalid_field_name(self):
3170 @dataclass(frozen=True)
3171 class C:
3172 x: int
3173 y: int
3174
3175 c = C(1, 2)
3176 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3177 "keyword argument 'z'"):
3178 c1 = replace(c, z=3)
3179
3180 def test_invalid_object(self):
3181 @dataclass(frozen=True)
3182 class C:
3183 x: int
3184 y: int
3185
3186 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3187 replace(C, x=3)
3188
3189 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3190 replace(0, x=3)
3191
3192 def test_no_init(self):
3193 @dataclass
3194 class C:
3195 x: int
3196 y: int = field(init=False, default=10)
3197
3198 c = C(1)
3199 c.y = 20
3200
3201 # Make sure y gets the default value.
3202 c1 = replace(c, x=5)
3203 self.assertEqual((c1.x, c1.y), (5, 10))
3204
3205 # Trying to replace y is an error.
3206 with self.assertRaisesRegex(ValueError, 'init=False'):
3207 replace(c, x=2, y=30)
3208
3209 with self.assertRaisesRegex(ValueError, 'init=False'):
3210 replace(c, y=30)
3211
3212 def test_classvar(self):
3213 @dataclass
3214 class C:
3215 x: int
3216 y: ClassVar[int] = 1000
3217
3218 c = C(1)
3219 d = C(2)
3220
3221 self.assertIs(c.y, d.y)
3222 self.assertEqual(c.y, 1000)
3223
3224 # Trying to replace y is an error: can't replace ClassVars.
3225 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3226 "unexpected keyword argument 'y'"):
3227 replace(c, y=30)
3228
3229 replace(c, x=5)
3230
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003231 def test_initvar_is_specified(self):
3232 @dataclass
3233 class C:
3234 x: int
3235 y: InitVar[int]
3236
3237 def __post_init__(self, y):
3238 self.x *= y
3239
3240 c = C(1, 10)
3241 self.assertEqual(c.x, 10)
3242 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3243 "specified with replace()"):
3244 replace(c, x=3)
3245 c = replace(c, x=3, y=5)
3246 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303247
3248 def test_recursive_repr(self):
3249 @dataclass
3250 class C:
3251 f: "C"
3252
3253 c = C(None)
3254 c.f = c
3255 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3256
3257 def test_recursive_repr_two_attrs(self):
3258 @dataclass
3259 class C:
3260 f: "C"
3261 g: "C"
3262
3263 c = C(None, None)
3264 c.f = c
3265 c.g = c
3266 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3267 ".<locals>.C(f=..., g=...)")
3268
3269 def test_recursive_repr_indirection(self):
3270 @dataclass
3271 class C:
3272 f: "D"
3273
3274 @dataclass
3275 class D:
3276 f: "C"
3277
3278 c = C(None)
3279 d = D(None)
3280 c.f = d
3281 d.f = c
3282 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3283 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3284 ".<locals>.D(f=...))")
3285
3286 def test_recursive_repr_indirection_two(self):
3287 @dataclass
3288 class C:
3289 f: "D"
3290
3291 @dataclass
3292 class D:
3293 f: "E"
3294
3295 @dataclass
3296 class E:
3297 f: "C"
3298
3299 c = C(None)
3300 d = D(None)
3301 e = E(None)
3302 c.f = d
3303 d.f = e
3304 e.f = c
3305 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3306 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3307 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3308 ".<locals>.E(f=...)))")
3309
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303310 def test_recursive_repr_misc_attrs(self):
3311 @dataclass
3312 class C:
3313 f: "C"
3314 g: int
3315
3316 c = C(None, 1)
3317 c.f = c
3318 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3319 ".<locals>.C(f=..., g=1)")
3320
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003321 ## def test_initvar(self):
3322 ## @dataclass
3323 ## class C:
3324 ## x: int
3325 ## y: InitVar[int]
3326
3327 ## c = C(1, 10)
3328 ## d = C(2, 20)
3329
3330 ## # In our case, replacing an InitVar is a no-op
3331 ## self.assertEqual(c, replace(c, y=5))
3332
3333 ## replace(c, x=5)
3334
Eric V. Smith4e812962018-05-16 11:31:29 -04003335
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003336if __name__ == '__main__':
3337 unittest.main()