blob: f35f466125d1c1c956a5ba22e819d8830468bead [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011import unittest
12from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010013from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Yury Selivanovd219cc42019-12-09 09:54:20 -050014from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050016from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050017
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040018import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
19import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
20
Eric V. Smithf0db54a2017-12-04 16:58:55 -050021# Just any custom exception we can catch.
22class CustomError(Exception): pass
23
24class TestCase(unittest.TestCase):
25 def test_no_fields(self):
26 @dataclass
27 class C:
28 pass
29
30 o = C()
31 self.assertEqual(len(fields(C)), 0)
32
Eric V. Smith56970b82018-03-22 16:28:48 -040033 def test_no_fields_but_member_variable(self):
34 @dataclass
35 class C:
36 i = 0
37
38 o = C()
39 self.assertEqual(len(fields(C)), 0)
40
Eric V. Smithf0db54a2017-12-04 16:58:55 -050041 def test_one_field_no_default(self):
42 @dataclass
43 class C:
44 x: int
45
46 o = C(42)
47 self.assertEqual(o.x, 42)
48
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053049 def test_field_default_default_factory_error(self):
50 msg = "cannot specify both default and default_factory"
51 with self.assertRaisesRegex(ValueError, msg):
52 @dataclass
53 class C:
54 x: int = field(default=1, default_factory=int)
55
56 def test_field_repr(self):
57 int_field = field(default=1, init=True, repr=False)
58 int_field.name = "id"
59 repr_output = repr(int_field)
60 expected_output = "Field(name='id',type=None," \
61 f"default=1,default_factory={MISSING!r}," \
62 "init=True,repr=False,hash=None," \
63 "compare=True,metadata=mappingproxy({})," \
64 "_field_type=None)"
65
66 self.assertEqual(repr_output, expected_output)
67
Eric V. Smithf0db54a2017-12-04 16:58:55 -050068 def test_named_init_params(self):
69 @dataclass
70 class C:
71 x: int
72
73 o = C(x=32)
74 self.assertEqual(o.x, 32)
75
76 def test_two_fields_one_default(self):
77 @dataclass
78 class C:
79 x: int
80 y: int = 0
81
82 o = C(3)
83 self.assertEqual((o.x, o.y), (3, 0))
84
85 # Non-defaults following defaults.
86 with self.assertRaisesRegex(TypeError,
87 "non-default argument 'y' follows "
88 "default argument"):
89 @dataclass
90 class C:
91 x: int = 0
92 y: int
93
94 # A derived class adds a non-default field after a default one.
95 with self.assertRaisesRegex(TypeError,
96 "non-default argument 'y' follows "
97 "default argument"):
98 @dataclass
99 class B:
100 x: int = 0
101
102 @dataclass
103 class C(B):
104 y: int
105
106 # Override a base class field and add a default to
107 # a field which didn't use to have a default.
108 with self.assertRaisesRegex(TypeError,
109 "non-default argument 'y' follows "
110 "default argument"):
111 @dataclass
112 class B:
113 x: int
114 y: int
115
116 @dataclass
117 class C(B):
118 x: int = 0
119
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500120 def test_overwrite_hash(self):
121 # Test that declaring this class isn't an error. It should
122 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500123 @dataclass(frozen=True)
124 class C:
125 x: int
126 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500127 return 301
128 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500129
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500130 # Test that declaring this class isn't an error. It should
131 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500132 @dataclass(frozen=True)
133 class C:
134 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500135 def __eq__(self, other):
136 return False
137 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500138
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500139 # But this one should generate an exception, because with
140 # unsafe_hash=True, it's an error to have a __hash__ defined.
141 with self.assertRaisesRegex(TypeError,
142 'Cannot overwrite attribute __hash__'):
143 @dataclass(unsafe_hash=True)
144 class C:
145 def __hash__(self):
146 pass
147
148 # Creating this class should not generate an exception,
149 # because even though __hash__ exists before @dataclass is
150 # called, (due to __eq__ being defined), since it's None
151 # that's okay.
152 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500153 class C:
154 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500155 def __eq__(self):
156 pass
157 # The generated hash function works as we'd expect.
158 self.assertEqual(hash(C(10)), hash((10,)))
159
160 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400161 # __hash__ exists and is not None, which it would be if it
162 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500163 with self.assertRaisesRegex(TypeError,
164 'Cannot overwrite attribute __hash__'):
165 @dataclass(unsafe_hash=True)
166 class C:
167 x: int
168 def __eq__(self):
169 pass
170 def __hash__(self):
171 pass
172
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500173 def test_overwrite_fields_in_derived_class(self):
174 # Note that x from C1 replaces x in Base, but the order remains
175 # the same as defined in Base.
176 @dataclass
177 class Base:
178 x: Any = 15.0
179 y: int = 0
180
181 @dataclass
182 class C1(Base):
183 z: int = 10
184 x: int = 15
185
186 o = Base()
187 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
188
189 o = C1()
190 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
191
192 o = C1(x=5)
193 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
194
195 def test_field_named_self(self):
196 @dataclass
197 class C:
198 self: str
199 c=C('foo')
200 self.assertEqual(c.self, 'foo')
201
202 # Make sure the first parameter is not named 'self'.
203 sig = inspect.signature(C.__init__)
204 first = next(iter(sig.parameters))
205 self.assertNotEqual('self', first)
206
207 # But we do use 'self' if no field named self.
208 @dataclass
209 class C:
210 selfx: str
211
212 # Make sure the first parameter is named 'self'.
213 sig = inspect.signature(C.__init__)
214 first = next(iter(sig.parameters))
215 self.assertEqual('self', first)
216
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300217 def test_field_named_object(self):
218 @dataclass
219 class C:
220 object: str
221 c = C('foo')
222 self.assertEqual(c.object, 'foo')
223
224 def test_field_named_object_frozen(self):
225 @dataclass(frozen=True)
226 class C:
227 object: str
228 c = C('foo')
229 self.assertEqual(c.object, 'foo')
230
231 def test_field_named_like_builtin(self):
232 # Attribute names can shadow built-in names
233 # since code generation is used.
234 # Ensure that this is not happening.
235 exclusions = {'None', 'True', 'False'}
236 builtins_names = sorted(
237 b for b in builtins.__dict__.keys()
238 if not b.startswith('__') and b not in exclusions
239 )
240 attributes = [(name, str) for name in builtins_names]
241 C = make_dataclass('C', attributes)
242
243 c = C(*[name for name in builtins_names])
244
245 for name in builtins_names:
246 self.assertEqual(getattr(c, name), name)
247
248 def test_field_named_like_builtin_frozen(self):
249 # Attribute names can shadow built-in names
250 # since code generation is used.
251 # Ensure that this is not happening
252 # for frozen data classes.
253 exclusions = {'None', 'True', 'False'}
254 builtins_names = sorted(
255 b for b in builtins.__dict__.keys()
256 if not b.startswith('__') and b not in exclusions
257 )
258 attributes = [(name, str) for name in builtins_names]
259 C = make_dataclass('C', attributes, frozen=True)
260
261 c = C(*[name for name in builtins_names])
262
263 for name in builtins_names:
264 self.assertEqual(getattr(c, name), name)
265
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500266 def test_0_field_compare(self):
267 # Ensure that order=False is the default.
268 @dataclass
269 class C0:
270 pass
271
272 @dataclass(order=False)
273 class C1:
274 pass
275
276 for cls in [C0, C1]:
277 with self.subTest(cls=cls):
278 self.assertEqual(cls(), cls())
279 for idx, fn in enumerate([lambda a, b: a < b,
280 lambda a, b: a <= b,
281 lambda a, b: a > b,
282 lambda a, b: a >= b]):
283 with self.subTest(idx=idx):
284 with self.assertRaisesRegex(TypeError,
285 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
286 fn(cls(), cls())
287
288 @dataclass(order=True)
289 class C:
290 pass
291 self.assertLessEqual(C(), C())
292 self.assertGreaterEqual(C(), C())
293
294 def test_1_field_compare(self):
295 # Ensure that order=False is the default.
296 @dataclass
297 class C0:
298 x: int
299
300 @dataclass(order=False)
301 class C1:
302 x: int
303
304 for cls in [C0, C1]:
305 with self.subTest(cls=cls):
306 self.assertEqual(cls(1), cls(1))
307 self.assertNotEqual(cls(0), cls(1))
308 for idx, fn in enumerate([lambda a, b: a < b,
309 lambda a, b: a <= b,
310 lambda a, b: a > b,
311 lambda a, b: a >= b]):
312 with self.subTest(idx=idx):
313 with self.assertRaisesRegex(TypeError,
314 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
315 fn(cls(0), cls(0))
316
317 @dataclass(order=True)
318 class C:
319 x: int
320 self.assertLess(C(0), C(1))
321 self.assertLessEqual(C(0), C(1))
322 self.assertLessEqual(C(1), C(1))
323 self.assertGreater(C(1), C(0))
324 self.assertGreaterEqual(C(1), C(0))
325 self.assertGreaterEqual(C(1), C(1))
326
327 def test_simple_compare(self):
328 # Ensure that order=False is the default.
329 @dataclass
330 class C0:
331 x: int
332 y: int
333
334 @dataclass(order=False)
335 class C1:
336 x: int
337 y: int
338
339 for cls in [C0, C1]:
340 with self.subTest(cls=cls):
341 self.assertEqual(cls(0, 0), cls(0, 0))
342 self.assertEqual(cls(1, 2), cls(1, 2))
343 self.assertNotEqual(cls(1, 0), cls(0, 0))
344 self.assertNotEqual(cls(1, 0), cls(1, 1))
345 for idx, fn in enumerate([lambda a, b: a < b,
346 lambda a, b: a <= b,
347 lambda a, b: a > b,
348 lambda a, b: a >= b]):
349 with self.subTest(idx=idx):
350 with self.assertRaisesRegex(TypeError,
351 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
352 fn(cls(0, 0), cls(0, 0))
353
354 @dataclass(order=True)
355 class C:
356 x: int
357 y: int
358
359 for idx, fn in enumerate([lambda a, b: a == b,
360 lambda a, b: a <= b,
361 lambda a, b: a >= b]):
362 with self.subTest(idx=idx):
363 self.assertTrue(fn(C(0, 0), C(0, 0)))
364
365 for idx, fn in enumerate([lambda a, b: a < b,
366 lambda a, b: a <= b,
367 lambda a, b: a != b]):
368 with self.subTest(idx=idx):
369 self.assertTrue(fn(C(0, 0), C(0, 1)))
370 self.assertTrue(fn(C(0, 1), C(1, 0)))
371 self.assertTrue(fn(C(1, 0), C(1, 1)))
372
373 for idx, fn in enumerate([lambda a, b: a > b,
374 lambda a, b: a >= b,
375 lambda a, b: a != b]):
376 with self.subTest(idx=idx):
377 self.assertTrue(fn(C(0, 1), C(0, 0)))
378 self.assertTrue(fn(C(1, 0), C(0, 1)))
379 self.assertTrue(fn(C(1, 1), C(1, 0)))
380
381 def test_compare_subclasses(self):
382 # Comparisons fail for subclasses, even if no fields
383 # are added.
384 @dataclass
385 class B:
386 i: int
387
388 @dataclass
389 class C(B):
390 pass
391
392 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
393 (lambda a, b: a != b, True)]):
394 with self.subTest(idx=idx):
395 self.assertEqual(fn(B(0), C(0)), expected)
396
397 for idx, fn in enumerate([lambda a, b: a < b,
398 lambda a, b: a <= b,
399 lambda a, b: a > b,
400 lambda a, b: a >= b]):
401 with self.subTest(idx=idx):
402 with self.assertRaisesRegex(TypeError,
403 "not supported between instances of 'B' and 'C'"):
404 fn(B(0), C(0))
405
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500407 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500408 for (eq, order, result ) in [
409 (False, False, 'neither'),
410 (False, True, 'exception'),
411 (True, False, 'eq_only'),
412 (True, True, 'both'),
413 ]:
414 with self.subTest(eq=eq, order=order):
415 if result == 'exception':
416 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
417 @dataclass(eq=eq, order=order)
418 class C:
419 pass
420 else:
421 @dataclass(eq=eq, order=order)
422 class C:
423 pass
424
425 if result == 'neither':
426 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500427 self.assertNotIn('__lt__', C.__dict__)
428 self.assertNotIn('__le__', C.__dict__)
429 self.assertNotIn('__gt__', C.__dict__)
430 self.assertNotIn('__ge__', C.__dict__)
431 elif result == 'both':
432 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500433 self.assertIn('__lt__', C.__dict__)
434 self.assertIn('__le__', C.__dict__)
435 self.assertIn('__gt__', C.__dict__)
436 self.assertIn('__ge__', C.__dict__)
437 elif result == 'eq_only':
438 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500439 self.assertNotIn('__lt__', C.__dict__)
440 self.assertNotIn('__le__', C.__dict__)
441 self.assertNotIn('__gt__', C.__dict__)
442 self.assertNotIn('__ge__', C.__dict__)
443 else:
444 assert False, f'unknown result {result!r}'
445
446 def test_field_no_default(self):
447 @dataclass
448 class C:
449 x: int = field()
450
451 self.assertEqual(C(5).x, 5)
452
453 with self.assertRaisesRegex(TypeError,
454 r"__init__\(\) missing 1 required "
455 "positional argument: 'x'"):
456 C()
457
458 def test_field_default(self):
459 default = object()
460 @dataclass
461 class C:
462 x: object = field(default=default)
463
464 self.assertIs(C.x, default)
465 c = C(10)
466 self.assertEqual(c.x, 10)
467
468 # If we delete the instance attribute, we should then see the
469 # class attribute.
470 del c.x
471 self.assertIs(c.x, default)
472
473 self.assertIs(C().x, default)
474
475 def test_not_in_repr(self):
476 @dataclass
477 class C:
478 x: int = field(repr=False)
479 with self.assertRaises(TypeError):
480 C()
481 c = C(10)
482 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
483
484 @dataclass
485 class C:
486 x: int = field(repr=False)
487 y: int
488 c = C(10, 20)
489 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
490
491 def test_not_in_compare(self):
492 @dataclass
493 class C:
494 x: int = 0
495 y: int = field(compare=False, default=4)
496
497 self.assertEqual(C(), C(0, 20))
498 self.assertEqual(C(1, 10), C(1, 20))
499 self.assertNotEqual(C(3), C(4, 10))
500 self.assertNotEqual(C(3, 10), C(4, 10))
501
502 def test_hash_field_rules(self):
503 # Test all 6 cases of:
504 # hash=True/False/None
505 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500506 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500507 (True, False, 'field' ),
508 (True, True, 'field' ),
509 (False, False, 'absent'),
510 (False, True, 'absent'),
511 (None, False, 'absent'),
512 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500513 ]:
514 with self.subTest(hash=hash_, compare=compare):
515 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500516 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500517 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500518
519 if result == 'field':
520 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500521 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500522 elif result == 'absent':
523 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500524 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500525 else:
526 assert False, f'unknown result {result!r}'
527
528 def test_init_false_no_default(self):
529 # If init=False and no default value, then the field won't be
530 # present in the instance.
531 @dataclass
532 class C:
533 x: int = field(init=False)
534
535 self.assertNotIn('x', C().__dict__)
536
537 @dataclass
538 class C:
539 x: int
540 y: int = 0
541 z: int = field(init=False)
542 t: int = 10
543
544 self.assertNotIn('z', C(0).__dict__)
545 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
546
547 def test_class_marker(self):
548 @dataclass
549 class C:
550 x: int
551 y: str = field(init=False, default=None)
552 z: str = field(repr=False)
553
554 the_fields = fields(C)
555 # the_fields is a tuple of 3 items, each value
556 # is in __annotations__.
557 self.assertIsInstance(the_fields, tuple)
558 for f in the_fields:
559 self.assertIs(type(f), Field)
560 self.assertIn(f.name, C.__annotations__)
561
562 self.assertEqual(len(the_fields), 3)
563
564 self.assertEqual(the_fields[0].name, 'x')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100565 self.assertEqual(the_fields[0].type, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500566 self.assertFalse(hasattr(C, 'x'))
567 self.assertTrue (the_fields[0].init)
568 self.assertTrue (the_fields[0].repr)
569 self.assertEqual(the_fields[1].name, 'y')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100570 self.assertEqual(the_fields[1].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500571 self.assertIsNone(getattr(C, 'y'))
572 self.assertFalse(the_fields[1].init)
573 self.assertTrue (the_fields[1].repr)
574 self.assertEqual(the_fields[2].name, 'z')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100575 self.assertEqual(the_fields[2].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500576 self.assertFalse(hasattr(C, 'z'))
577 self.assertTrue (the_fields[2].init)
578 self.assertFalse(the_fields[2].repr)
579
580 def test_field_order(self):
581 @dataclass
582 class B:
583 a: str = 'B:a'
584 b: str = 'B:b'
585 c: str = 'B:c'
586
587 @dataclass
588 class C(B):
589 b: str = 'C:b'
590
591 self.assertEqual([(f.name, f.default) for f in fields(C)],
592 [('a', 'B:a'),
593 ('b', 'C:b'),
594 ('c', 'B:c')])
595
596 @dataclass
597 class D(B):
598 c: str = 'D:c'
599
600 self.assertEqual([(f.name, f.default) for f in fields(D)],
601 [('a', 'B:a'),
602 ('b', 'B:b'),
603 ('c', 'D:c')])
604
605 @dataclass
606 class E(D):
607 a: str = 'E:a'
608 d: str = 'E:d'
609
610 self.assertEqual([(f.name, f.default) for f in fields(E)],
611 [('a', 'E:a'),
612 ('b', 'B:b'),
613 ('c', 'D:c'),
614 ('d', 'E:d')])
615
616 def test_class_attrs(self):
617 # We only have a class attribute if a default value is
618 # specified, either directly or via a field with a default.
619 default = object()
620 @dataclass
621 class C:
622 x: int
623 y: int = field(repr=False)
624 z: object = default
625 t: int = field(default=100)
626
627 self.assertFalse(hasattr(C, 'x'))
628 self.assertFalse(hasattr(C, 'y'))
629 self.assertIs (C.z, default)
630 self.assertEqual(C.t, 100)
631
632 def test_disallowed_mutable_defaults(self):
633 # For the known types, don't allow mutable default values.
634 for typ, empty, non_empty in [(list, [], [1]),
635 (dict, {}, {0:1}),
636 (set, set(), set([1])),
637 ]:
638 with self.subTest(typ=typ):
639 # Can't use a zero-length value.
640 with self.assertRaisesRegex(ValueError,
641 f'mutable default {typ} for field '
642 'x is not allowed'):
643 @dataclass
644 class Point:
645 x: typ = empty
646
647
648 # Nor a non-zero-length value
649 with self.assertRaisesRegex(ValueError,
650 f'mutable default {typ} for field '
651 'y is not allowed'):
652 @dataclass
653 class Point:
654 y: typ = non_empty
655
656 # Check subtypes also fail.
657 class Subclass(typ): pass
658
659 with self.assertRaisesRegex(ValueError,
660 f"mutable default .*Subclass'>"
661 ' for field z is not allowed'
662 ):
663 @dataclass
664 class Point:
665 z: typ = Subclass()
666
667 # Because this is a ClassVar, it can be mutable.
668 @dataclass
669 class C:
670 z: ClassVar[typ] = typ()
671
672 # Because this is a ClassVar, it can be mutable.
673 @dataclass
674 class C:
675 x: ClassVar[typ] = Subclass()
676
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500677 def test_deliberately_mutable_defaults(self):
678 # If a mutable default isn't in the known list of
679 # (list, dict, set), then it's okay.
680 class Mutable:
681 def __init__(self):
682 self.l = []
683
684 @dataclass
685 class C:
686 x: Mutable
687
688 # These 2 instances will share this value of x.
689 lst = Mutable()
690 o1 = C(lst)
691 o2 = C(lst)
692 self.assertEqual(o1, o2)
693 o1.x.l.extend([1, 2])
694 self.assertEqual(o1, o2)
695 self.assertEqual(o1.x.l, [1, 2])
696 self.assertIs(o1.x, o2.x)
697
698 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400699 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500700 @dataclass()
701 class C:
702 x: int
703
704 self.assertEqual(C(42).x, 42)
705
706 def test_not_tuple(self):
707 # Make sure we can't be compared to a tuple.
708 @dataclass
709 class Point:
710 x: int
711 y: int
712 self.assertNotEqual(Point(1, 2), (1, 2))
713
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400714 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500715 @dataclass
716 class C:
717 x: int
718 y: int
719 self.assertNotEqual(Point(1, 3), C(1, 3))
720
Windson yangbe372d72019-04-23 02:45:34 +0800721 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500722 # Test that some of the problems with namedtuple don't happen
723 # here.
724 @dataclass
725 class Point3D:
726 x: int
727 y: int
728 z: int
729
730 @dataclass
731 class Date:
732 year: int
733 month: int
734 day: int
735
736 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
737 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
738
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400739 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200740 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500741 x, y, z = Point3D(4, 5, 6)
742
Eric V. Smith7c99e932018-01-28 19:18:55 -0500743 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500744 # equal.
745 @dataclass
746 class Point3Dv1:
747 x: int = 0
748 y: int = 0
749 z: int = 0
750 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
751
752 def test_function_annotations(self):
753 # Some dummy class and instance to use as a default.
754 class F:
755 pass
756 f = F()
757
758 def validate_class(cls):
759 # First, check __annotations__, even though they're not
760 # function annotations.
Pablo Galindob0544ba2021-04-21 12:41:19 +0100761 self.assertEqual(cls.__annotations__['i'], int)
762 self.assertEqual(cls.__annotations__['j'], str)
763 self.assertEqual(cls.__annotations__['k'], F)
764 self.assertEqual(cls.__annotations__['l'], float)
765 self.assertEqual(cls.__annotations__['z'], complex)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500766
767 # Verify __init__.
768
769 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400770 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertIs(signature.return_annotation, None)
772
773 # Check each parameter.
774 params = iter(signature.parameters.values())
775 param = next(params)
776 # This is testing an internal name, and probably shouldn't be tested.
777 self.assertEqual(param.name, 'self')
778 param = next(params)
779 self.assertEqual(param.name, 'i')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100780 self.assertIs (param.annotation, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500781 self.assertEqual(param.default, inspect.Parameter.empty)
782 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
783 param = next(params)
784 self.assertEqual(param.name, 'j')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100785 self.assertIs (param.annotation, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500786 self.assertEqual(param.default, inspect.Parameter.empty)
787 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
788 param = next(params)
789 self.assertEqual(param.name, 'k')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100790 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400791 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
793 param = next(params)
794 self.assertEqual(param.name, 'l')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100795 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400796 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500797 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
798 self.assertRaises(StopIteration, next, params)
799
800
801 @dataclass
802 class C:
803 i: int
804 j: str
805 k: F = f
806 l: float=field(default=None)
807 z: complex=field(default=3+4j, init=False)
808
809 validate_class(C)
810
811 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500812 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500813 class C:
814 i: int
815 j: str
816 k: F = f
817 l: float=field(default=None)
818 z: complex=field(default=3+4j, init=False)
819
820 validate_class(C)
821
Eric V. Smith03220fd2017-12-29 13:59:58 -0500822 def test_missing_default(self):
823 # Test that MISSING works the same as a default not being
824 # specified.
825 @dataclass
826 class C:
827 x: int=field(default=MISSING)
828 with self.assertRaisesRegex(TypeError,
829 r'__init__\(\) missing 1 required '
830 'positional argument'):
831 C()
832 self.assertNotIn('x', C.__dict__)
833
834 @dataclass
835 class D:
836 x: int
837 with self.assertRaisesRegex(TypeError,
838 r'__init__\(\) missing 1 required '
839 'positional argument'):
840 D()
841 self.assertNotIn('x', D.__dict__)
842
843 def test_missing_default_factory(self):
844 # Test that MISSING works the same as a default factory not
845 # being specified (which is really the same as a default not
846 # being specified, too).
847 @dataclass
848 class C:
849 x: int=field(default_factory=MISSING)
850 with self.assertRaisesRegex(TypeError,
851 r'__init__\(\) missing 1 required '
852 'positional argument'):
853 C()
854 self.assertNotIn('x', C.__dict__)
855
856 @dataclass
857 class D:
858 x: int=field(default=MISSING, default_factory=MISSING)
859 with self.assertRaisesRegex(TypeError,
860 r'__init__\(\) missing 1 required '
861 'positional argument'):
862 D()
863 self.assertNotIn('x', D.__dict__)
864
865 def test_missing_repr(self):
866 self.assertIn('MISSING_TYPE object', repr(MISSING))
867
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500868 def test_dont_include_other_annotations(self):
869 @dataclass
870 class C:
871 i: int
872 def foo(self) -> int:
873 return 4
874 @property
875 def bar(self) -> int:
876 return 5
877 self.assertEqual(list(C.__annotations__), ['i'])
878 self.assertEqual(C(10).foo(), 4)
879 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400880 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500881
882 def test_post_init(self):
883 # Just make sure it gets called
884 @dataclass
885 class C:
886 def __post_init__(self):
887 raise CustomError()
888 with self.assertRaises(CustomError):
889 C()
890
891 @dataclass
892 class C:
893 i: int = 10
894 def __post_init__(self):
895 if self.i == 10:
896 raise CustomError()
897 with self.assertRaises(CustomError):
898 C()
899 # post-init gets called, but doesn't raise. This is just
900 # checking that self is used correctly.
901 C(5)
902
903 # If there's not an __init__, then post-init won't get called.
904 @dataclass(init=False)
905 class C:
906 def __post_init__(self):
907 raise CustomError()
908 # Creating the class won't raise
909 C()
910
911 @dataclass
912 class C:
913 x: int = 0
914 def __post_init__(self):
915 self.x *= 2
916 self.assertEqual(C().x, 0)
917 self.assertEqual(C(2).x, 4)
918
Mike53f7a7c2017-12-14 14:04:53 +0300919 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500920 # attributes.
921 @dataclass(frozen=True)
922 class C:
923 x: int = 0
924 def __post_init__(self):
925 self.x *= 2
926 with self.assertRaises(FrozenInstanceError):
927 C()
928
929 def test_post_init_super(self):
930 # Make sure super() post-init isn't called by default.
931 class B:
932 def __post_init__(self):
933 raise CustomError()
934
935 @dataclass
936 class C(B):
937 def __post_init__(self):
938 self.x = 5
939
940 self.assertEqual(C().x, 5)
941
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400942 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500943 @dataclass
944 class C(B):
945 def __post_init__(self):
946 super().__post_init__()
947
948 with self.assertRaises(CustomError):
949 C()
950
951 # Make sure post-init is called, even if not defined in our
952 # class.
953 @dataclass
954 class C(B):
955 pass
956
957 with self.assertRaises(CustomError):
958 C()
959
960 def test_post_init_staticmethod(self):
961 flag = False
962 @dataclass
963 class C:
964 x: int
965 y: int
966 @staticmethod
967 def __post_init__():
968 nonlocal flag
969 flag = True
970
971 self.assertFalse(flag)
972 c = C(3, 4)
973 self.assertEqual((c.x, c.y), (3, 4))
974 self.assertTrue(flag)
975
976 def test_post_init_classmethod(self):
977 @dataclass
978 class C:
979 flag = False
980 x: int
981 y: int
982 @classmethod
983 def __post_init__(cls):
984 cls.flag = True
985
986 self.assertFalse(C.flag)
987 c = C(3, 4)
988 self.assertEqual((c.x, c.y), (3, 4))
989 self.assertTrue(C.flag)
990
991 def test_class_var(self):
992 # Make sure ClassVars are ignored in __init__, __repr__, etc.
993 @dataclass
994 class C:
995 x: int
996 y: int = 10
997 z: ClassVar[int] = 1000
998 w: ClassVar[int] = 2000
999 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001000 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001001
1002 c = C(5)
1003 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001004 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001005 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001006 self.assertEqual(c.z, 1000)
1007 self.assertEqual(c.w, 2000)
1008 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001009 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001010 C.z += 1
1011 self.assertEqual(c.z, 1001)
1012 c = C(20)
1013 self.assertEqual((c.x, c.y), (20, 10))
1014 self.assertEqual(c.z, 1001)
1015 self.assertEqual(c.w, 2000)
1016 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001017 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001018
1019 def test_class_var_no_default(self):
1020 # If a ClassVar has no default value, it should not be set on the class.
1021 @dataclass
1022 class C:
1023 x: ClassVar[int]
1024
1025 self.assertNotIn('x', C.__dict__)
1026
1027 def test_class_var_default_factory(self):
1028 # It makes no sense for a ClassVar to have a default factory. When
1029 # would it be called? Call it yourself, since it's class-wide.
1030 with self.assertRaisesRegex(TypeError,
1031 'cannot have a default factory'):
1032 @dataclass
1033 class C:
1034 x: ClassVar[int] = field(default_factory=int)
1035
1036 self.assertNotIn('x', C.__dict__)
1037
1038 def test_class_var_with_default(self):
1039 # If a ClassVar has a default value, it should be set on the class.
1040 @dataclass
1041 class C:
1042 x: ClassVar[int] = 10
1043 self.assertEqual(C.x, 10)
1044
1045 @dataclass
1046 class C:
1047 x: ClassVar[int] = field(default=10)
1048 self.assertEqual(C.x, 10)
1049
1050 def test_class_var_frozen(self):
1051 # Make sure ClassVars work even if we're frozen.
1052 @dataclass(frozen=True)
1053 class C:
1054 x: int
1055 y: int = 10
1056 z: ClassVar[int] = 1000
1057 w: ClassVar[int] = 2000
1058 t: ClassVar[int] = 3000
1059
1060 c = C(5)
1061 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1062 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1063 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1064 self.assertEqual(c.z, 1000)
1065 self.assertEqual(c.w, 2000)
1066 self.assertEqual(c.t, 3000)
1067 # We can still modify the ClassVar, it's only instances that are
1068 # frozen.
1069 C.z += 1
1070 self.assertEqual(c.z, 1001)
1071 c = C(20)
1072 self.assertEqual((c.x, c.y), (20, 10))
1073 self.assertEqual(c.z, 1001)
1074 self.assertEqual(c.w, 2000)
1075 self.assertEqual(c.t, 3000)
1076
1077 def test_init_var_no_default(self):
1078 # If an InitVar has no default value, it should not be set on the class.
1079 @dataclass
1080 class C:
1081 x: InitVar[int]
1082
1083 self.assertNotIn('x', C.__dict__)
1084
1085 def test_init_var_default_factory(self):
1086 # It makes no sense for an InitVar to have a default factory. When
1087 # would it be called? Call it yourself, since it's class-wide.
1088 with self.assertRaisesRegex(TypeError,
1089 'cannot have a default factory'):
1090 @dataclass
1091 class C:
1092 x: InitVar[int] = field(default_factory=int)
1093
1094 self.assertNotIn('x', C.__dict__)
1095
1096 def test_init_var_with_default(self):
1097 # If an InitVar has a default value, it should be set on the class.
1098 @dataclass
1099 class C:
1100 x: InitVar[int] = 10
1101 self.assertEqual(C.x, 10)
1102
1103 @dataclass
1104 class C:
1105 x: InitVar[int] = field(default=10)
1106 self.assertEqual(C.x, 10)
1107
1108 def test_init_var(self):
1109 @dataclass
1110 class C:
1111 x: int = None
1112 init_param: InitVar[int] = None
1113
1114 def __post_init__(self, init_param):
1115 if self.x is None:
1116 self.x = init_param*2
1117
1118 c = C(init_param=10)
1119 self.assertEqual(c.x, 20)
1120
Augusto Hack01ee12b2019-06-02 23:14:48 -03001121 def test_init_var_preserve_type(self):
1122 self.assertEqual(InitVar[int].type, int)
1123
1124 # Make sure the repr is correct.
1125 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001126 self.assertEqual(repr(InitVar[List[int]]),
1127 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001128
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001129 def test_init_var_inheritance(self):
1130 # Note that this deliberately tests that a dataclass need not
1131 # have a __post_init__ function if it has an InitVar field.
1132 # It could just be used in a derived class, as shown here.
1133 @dataclass
1134 class Base:
1135 x: int
1136 init_base: InitVar[int]
1137
1138 # We can instantiate by passing the InitVar, even though
1139 # it's not used.
1140 b = Base(0, 10)
1141 self.assertEqual(vars(b), {'x': 0})
1142
1143 @dataclass
1144 class C(Base):
1145 y: int
1146 init_derived: InitVar[int]
1147
1148 def __post_init__(self, init_base, init_derived):
1149 self.x = self.x + init_base
1150 self.y = self.y + init_derived
1151
1152 c = C(10, 11, 50, 51)
1153 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1154
1155 def test_default_factory(self):
1156 # Test a factory that returns a new list.
1157 @dataclass
1158 class C:
1159 x: int
1160 y: list = field(default_factory=list)
1161
1162 c0 = C(3)
1163 c1 = C(3)
1164 self.assertEqual(c0.x, 3)
1165 self.assertEqual(c0.y, [])
1166 self.assertEqual(c0, c1)
1167 self.assertIsNot(c0.y, c1.y)
1168 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1169
1170 # Test a factory that returns a shared list.
1171 l = []
1172 @dataclass
1173 class C:
1174 x: int
1175 y: list = field(default_factory=lambda: l)
1176
1177 c0 = C(3)
1178 c1 = C(3)
1179 self.assertEqual(c0.x, 3)
1180 self.assertEqual(c0.y, [])
1181 self.assertEqual(c0, c1)
1182 self.assertIs(c0.y, c1.y)
1183 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1184
1185 # Test various other field flags.
1186 # repr
1187 @dataclass
1188 class C:
1189 x: list = field(default_factory=list, repr=False)
1190 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1191 self.assertEqual(C().x, [])
1192
1193 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001194 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001195 class C:
1196 x: list = field(default_factory=list, hash=False)
1197 self.assertEqual(astuple(C()), ([],))
1198 self.assertEqual(hash(C()), hash(()))
1199
1200 # init (see also test_default_factory_with_no_init)
1201 @dataclass
1202 class C:
1203 x: list = field(default_factory=list, init=False)
1204 self.assertEqual(astuple(C()), ([],))
1205
1206 # compare
1207 @dataclass
1208 class C:
1209 x: list = field(default_factory=list, compare=False)
1210 self.assertEqual(C(), C([1]))
1211
1212 def test_default_factory_with_no_init(self):
1213 # We need a factory with a side effect.
1214 factory = Mock()
1215
1216 @dataclass
1217 class C:
1218 x: list = field(default_factory=factory, init=False)
1219
1220 # Make sure the default factory is called for each new instance.
1221 C().x
1222 self.assertEqual(factory.call_count, 1)
1223 C().x
1224 self.assertEqual(factory.call_count, 2)
1225
1226 def test_default_factory_not_called_if_value_given(self):
1227 # We need a factory that we can test if it's been called.
1228 factory = Mock()
1229
1230 @dataclass
1231 class C:
1232 x: int = field(default_factory=factory)
1233
1234 # Make sure that if a field has a default factory function,
1235 # it's not called if a value is specified.
1236 C().x
1237 self.assertEqual(factory.call_count, 1)
1238 self.assertEqual(C(10).x, 10)
1239 self.assertEqual(factory.call_count, 1)
1240 C().x
1241 self.assertEqual(factory.call_count, 2)
1242
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001243 def test_default_factory_derived(self):
1244 # See bpo-32896.
1245 @dataclass
1246 class Foo:
1247 x: dict = field(default_factory=dict)
1248
1249 @dataclass
1250 class Bar(Foo):
1251 y: int = 1
1252
1253 self.assertEqual(Foo().x, {})
1254 self.assertEqual(Bar().x, {})
1255 self.assertEqual(Bar().y, 1)
1256
1257 @dataclass
1258 class Baz(Foo):
1259 pass
1260 self.assertEqual(Baz().x, {})
1261
1262 def test_intermediate_non_dataclass(self):
1263 # Test that an intermediate class that defines
1264 # annotations does not define fields.
1265
1266 @dataclass
1267 class A:
1268 x: int
1269
1270 class B(A):
1271 y: int
1272
1273 @dataclass
1274 class C(B):
1275 z: int
1276
1277 c = C(1, 3)
1278 self.assertEqual((c.x, c.z), (1, 3))
1279
1280 # .y was not initialized.
1281 with self.assertRaisesRegex(AttributeError,
1282 'object has no attribute'):
1283 c.y
1284
1285 # And if we again derive a non-dataclass, no fields are added.
1286 class D(C):
1287 t: int
1288 d = D(4, 5)
1289 self.assertEqual((d.x, d.z), (4, 5))
1290
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001291 def test_classvar_default_factory(self):
1292 # It's an error for a ClassVar to have a factory function.
1293 with self.assertRaisesRegex(TypeError,
1294 'cannot have a default factory'):
1295 @dataclass
1296 class C:
1297 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001298
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001299 def test_is_dataclass(self):
1300 class NotDataClass:
1301 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001302
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001303 self.assertFalse(is_dataclass(0))
1304 self.assertFalse(is_dataclass(int))
1305 self.assertFalse(is_dataclass(NotDataClass))
1306 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001307
1308 @dataclass
1309 class C:
1310 x: int
1311
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001312 @dataclass
1313 class D:
1314 d: C
1315 e: int
1316
1317 c = C(10)
1318 d = D(c, 4)
1319
1320 self.assertTrue(is_dataclass(C))
1321 self.assertTrue(is_dataclass(c))
1322 self.assertFalse(is_dataclass(c.x))
1323 self.assertTrue(is_dataclass(d.d))
1324 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001325
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001326 def test_is_dataclass_when_getattr_always_returns(self):
1327 # See bpo-37868.
1328 class A:
1329 def __getattr__(self, key):
1330 return 0
1331 self.assertFalse(is_dataclass(A))
1332 a = A()
1333
1334 # Also test for an instance attribute.
1335 class B:
1336 pass
1337 b = B()
1338 b.__dataclass_fields__ = []
1339
1340 for obj in a, b:
1341 with self.subTest(obj=obj):
1342 self.assertFalse(is_dataclass(obj))
1343
1344 # Indirect tests for _is_dataclass_instance().
1345 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1346 asdict(obj)
1347 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1348 astuple(obj)
1349 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1350 replace(obj, x=0)
1351
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001352 def test_helper_fields_with_class_instance(self):
1353 # Check that we can call fields() on either a class or instance,
1354 # and get back the same thing.
1355 @dataclass
1356 class C:
1357 x: int
1358 y: float
1359
1360 self.assertEqual(fields(C), fields(C(0, 0.0)))
1361
1362 def test_helper_fields_exception(self):
1363 # Check that TypeError is raised if not passed a dataclass or
1364 # instance.
1365 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1366 fields(0)
1367
1368 class C: pass
1369 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1370 fields(C)
1371 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1372 fields(C())
1373
1374 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001375 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001376 @dataclass
1377 class C:
1378 x: int
1379 y: int
1380 c = C(1, 2)
1381
1382 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1383 self.assertEqual(asdict(c), asdict(c))
1384 self.assertIsNot(asdict(c), asdict(c))
1385 c.x = 42
1386 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1387 self.assertIs(type(asdict(c)), dict)
1388
1389 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001390 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001391 @dataclass
1392 class C:
1393 x: int
1394 y: int
1395 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1396 asdict(C)
1397 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1398 asdict(int)
1399
1400 def test_helper_asdict_copy_values(self):
1401 @dataclass
1402 class C:
1403 x: int
1404 y: List[int] = field(default_factory=list)
1405 initial = []
1406 c = C(1, initial)
1407 d = asdict(c)
1408 self.assertEqual(d['y'], initial)
1409 self.assertIsNot(d['y'], initial)
1410 c = C(1)
1411 d = asdict(c)
1412 d['y'].append(1)
1413 self.assertEqual(c.y, [])
1414
1415 def test_helper_asdict_nested(self):
1416 @dataclass
1417 class UserId:
1418 token: int
1419 group: int
1420 @dataclass
1421 class User:
1422 name: str
1423 id: UserId
1424 u = User('Joe', UserId(123, 1))
1425 d = asdict(u)
1426 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1427 self.assertIsNot(asdict(u), asdict(u))
1428 u.id.group = 2
1429 self.assertEqual(asdict(u), {'name': 'Joe',
1430 'id': {'token': 123, 'group': 2}})
1431
1432 def test_helper_asdict_builtin_containers(self):
1433 @dataclass
1434 class User:
1435 name: str
1436 id: int
1437 @dataclass
1438 class GroupList:
1439 id: int
1440 users: List[User]
1441 @dataclass
1442 class GroupTuple:
1443 id: int
1444 users: Tuple[User, ...]
1445 @dataclass
1446 class GroupDict:
1447 id: int
1448 users: Dict[str, User]
1449 a = User('Alice', 1)
1450 b = User('Bob', 2)
1451 gl = GroupList(0, [a, b])
1452 gt = GroupTuple(0, (a, b))
1453 gd = GroupDict(0, {'first': a, 'second': b})
1454 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1455 {'name': 'Bob', 'id': 2}]})
1456 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1457 {'name': 'Bob', 'id': 2})})
1458 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1459 'second': {'name': 'Bob', 'id': 2}}})
1460
Windson yangbe372d72019-04-23 02:45:34 +08001461 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001462 @dataclass
1463 class Child:
1464 d: object
1465
1466 @dataclass
1467 class Parent:
1468 child: Child
1469
1470 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1471 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1472
1473 def test_helper_asdict_factory(self):
1474 @dataclass
1475 class C:
1476 x: int
1477 y: int
1478 c = C(1, 2)
1479 d = asdict(c, dict_factory=OrderedDict)
1480 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1481 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1482 c.x = 42
1483 d = asdict(c, dict_factory=OrderedDict)
1484 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1485 self.assertIs(type(d), OrderedDict)
1486
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001487 def test_helper_asdict_namedtuple(self):
1488 T = namedtuple('T', 'a b c')
1489 @dataclass
1490 class C:
1491 x: str
1492 y: T
1493 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1494
1495 d = asdict(c)
1496 self.assertEqual(d, {'x': 'outer',
1497 'y': T(1,
1498 {'x': 'inner',
1499 'y': T(11, 12, 13)},
1500 2),
1501 }
1502 )
1503
1504 # Now with a dict_factory. OrderedDict is convenient, but
1505 # since it compares to dicts, we also need to have separate
1506 # assertIs tests.
1507 d = asdict(c, dict_factory=OrderedDict)
1508 self.assertEqual(d, {'x': 'outer',
1509 'y': T(1,
1510 {'x': 'inner',
1511 'y': T(11, 12, 13)},
1512 2),
1513 }
1514 )
1515
penguindustin96466302019-05-06 14:57:17 -04001516 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001517 self.assertIs(type(d), OrderedDict)
1518 self.assertIs(type(d['y'][1]), OrderedDict)
1519
1520 def test_helper_asdict_namedtuple_key(self):
1521 # Ensure that a field that contains a dict which has a
1522 # namedtuple as a key works with asdict().
1523
1524 @dataclass
1525 class C:
1526 f: dict
1527 T = namedtuple('T', 'a')
1528
1529 c = C({T('an a'): 0})
1530
1531 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1532
1533 def test_helper_asdict_namedtuple_derived(self):
1534 class T(namedtuple('Tbase', 'a')):
1535 def my_a(self):
1536 return self.a
1537
1538 @dataclass
1539 class C:
1540 f: T
1541
1542 t = T(6)
1543 c = C(t)
1544
1545 d = asdict(c)
1546 self.assertEqual(d, {'f': T(a=6)})
1547 # Make sure that t has been copied, not used directly.
1548 self.assertIsNot(d['f'], t)
1549 self.assertEqual(d['f'].my_a(), 6)
1550
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001551 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001552 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001553 @dataclass
1554 class C:
1555 x: int
1556 y: int = 0
1557 c = C(1)
1558
1559 self.assertEqual(astuple(c), (1, 0))
1560 self.assertEqual(astuple(c), astuple(c))
1561 self.assertIsNot(astuple(c), astuple(c))
1562 c.y = 42
1563 self.assertEqual(astuple(c), (1, 42))
1564 self.assertIs(type(astuple(c)), tuple)
1565
1566 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001567 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001568 @dataclass
1569 class C:
1570 x: int
1571 y: int
1572 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1573 astuple(C)
1574 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1575 astuple(int)
1576
1577 def test_helper_astuple_copy_values(self):
1578 @dataclass
1579 class C:
1580 x: int
1581 y: List[int] = field(default_factory=list)
1582 initial = []
1583 c = C(1, initial)
1584 t = astuple(c)
1585 self.assertEqual(t[1], initial)
1586 self.assertIsNot(t[1], initial)
1587 c = C(1)
1588 t = astuple(c)
1589 t[1].append(1)
1590 self.assertEqual(c.y, [])
1591
1592 def test_helper_astuple_nested(self):
1593 @dataclass
1594 class UserId:
1595 token: int
1596 group: int
1597 @dataclass
1598 class User:
1599 name: str
1600 id: UserId
1601 u = User('Joe', UserId(123, 1))
1602 t = astuple(u)
1603 self.assertEqual(t, ('Joe', (123, 1)))
1604 self.assertIsNot(astuple(u), astuple(u))
1605 u.id.group = 2
1606 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1607
1608 def test_helper_astuple_builtin_containers(self):
1609 @dataclass
1610 class User:
1611 name: str
1612 id: int
1613 @dataclass
1614 class GroupList:
1615 id: int
1616 users: List[User]
1617 @dataclass
1618 class GroupTuple:
1619 id: int
1620 users: Tuple[User, ...]
1621 @dataclass
1622 class GroupDict:
1623 id: int
1624 users: Dict[str, User]
1625 a = User('Alice', 1)
1626 b = User('Bob', 2)
1627 gl = GroupList(0, [a, b])
1628 gt = GroupTuple(0, (a, b))
1629 gd = GroupDict(0, {'first': a, 'second': b})
1630 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1631 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1632 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1633
Windson yangbe372d72019-04-23 02:45:34 +08001634 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001635 @dataclass
1636 class Child:
1637 d: object
1638
1639 @dataclass
1640 class Parent:
1641 child: Child
1642
1643 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1644 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1645
1646 def test_helper_astuple_factory(self):
1647 @dataclass
1648 class C:
1649 x: int
1650 y: int
1651 NT = namedtuple('NT', 'x y')
1652 def nt(lst):
1653 return NT(*lst)
1654 c = C(1, 2)
1655 t = astuple(c, tuple_factory=nt)
1656 self.assertEqual(t, NT(1, 2))
1657 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1658 c.x = 42
1659 t = astuple(c, tuple_factory=nt)
1660 self.assertEqual(t, NT(42, 2))
1661 self.assertIs(type(t), NT)
1662
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001663 def test_helper_astuple_namedtuple(self):
1664 T = namedtuple('T', 'a b c')
1665 @dataclass
1666 class C:
1667 x: str
1668 y: T
1669 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1670
1671 t = astuple(c)
1672 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1673
1674 # Now, using a tuple_factory. list is convenient here.
1675 t = astuple(c, tuple_factory=list)
1676 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1677
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001678 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001679 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001680 }
1681
1682 # Create the class.
1683 cls = type('C', (), cls_dict)
1684
1685 # Make it a dataclass.
1686 cls1 = dataclass(cls)
1687
1688 self.assertEqual(cls1, cls)
1689 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1690
1691 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001692 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001693 'y': field(default=5),
1694 }
1695
1696 # Create the class.
1697 cls = type('C', (), cls_dict)
1698
1699 # Make it a dataclass.
1700 cls1 = dataclass(cls)
1701
1702 self.assertEqual(cls1, cls)
1703 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1704
1705 def test_init_in_order(self):
1706 @dataclass
1707 class C:
1708 a: int
1709 b: int = field()
1710 c: list = field(default_factory=list, init=False)
1711 d: list = field(default_factory=list)
1712 e: int = field(default=4, init=False)
1713 f: int = 4
1714
1715 calls = []
1716 def setattr(self, name, value):
1717 calls.append((name, value))
1718
1719 C.__setattr__ = setattr
1720 c = C(0, 1)
1721 self.assertEqual(('a', 0), calls[0])
1722 self.assertEqual(('b', 1), calls[1])
1723 self.assertEqual(('c', []), calls[2])
1724 self.assertEqual(('d', []), calls[3])
1725 self.assertNotIn(('e', 4), calls)
1726 self.assertEqual(('f', 4), calls[4])
1727
1728 def test_items_in_dicts(self):
1729 @dataclass
1730 class C:
1731 a: int
1732 b: list = field(default_factory=list, init=False)
1733 c: list = field(default_factory=list)
1734 d: int = field(default=4, init=False)
1735 e: int = 0
1736
1737 c = C(0)
1738 # Class dict
1739 self.assertNotIn('a', C.__dict__)
1740 self.assertNotIn('b', C.__dict__)
1741 self.assertNotIn('c', C.__dict__)
1742 self.assertIn('d', C.__dict__)
1743 self.assertEqual(C.d, 4)
1744 self.assertIn('e', C.__dict__)
1745 self.assertEqual(C.e, 0)
1746 # Instance dict
1747 self.assertIn('a', c.__dict__)
1748 self.assertEqual(c.a, 0)
1749 self.assertIn('b', c.__dict__)
1750 self.assertEqual(c.b, [])
1751 self.assertIn('c', c.__dict__)
1752 self.assertEqual(c.c, [])
1753 self.assertNotIn('d', c.__dict__)
1754 self.assertIn('e', c.__dict__)
1755 self.assertEqual(c.e, 0)
1756
1757 def test_alternate_classmethod_constructor(self):
1758 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001759 # alternate constructor. This is mostly an example to show
1760 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001761 @dataclass
1762 class C:
1763 x: int
1764 @classmethod
1765 def from_file(cls, filename):
1766 # In a real example, create a new instance
1767 # and populate 'x' from contents of a file.
1768 value_in_file = 20
1769 return cls(value_in_file)
1770
1771 self.assertEqual(C.from_file('filename').x, 20)
1772
1773 def test_field_metadata_default(self):
1774 # Make sure the default metadata is read-only and of
1775 # zero length.
1776 @dataclass
1777 class C:
1778 i: int
1779
1780 self.assertFalse(fields(C)[0].metadata)
1781 self.assertEqual(len(fields(C)[0].metadata), 0)
1782 with self.assertRaisesRegex(TypeError,
1783 'does not support item assignment'):
1784 fields(C)[0].metadata['test'] = 3
1785
1786 def test_field_metadata_mapping(self):
1787 # Make sure only a mapping can be passed as metadata
1788 # zero length.
1789 with self.assertRaises(TypeError):
1790 @dataclass
1791 class C:
1792 i: int = field(metadata=0)
1793
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001794 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001795 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001796 @dataclass
1797 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001798 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001799 self.assertFalse(fields(C)[0].metadata)
1800 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001801 # Update should work (see bpo-35960).
1802 d['foo'] = 1
1803 self.assertEqual(len(fields(C)[0].metadata), 1)
1804 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001805 with self.assertRaisesRegex(TypeError,
1806 'does not support item assignment'):
1807 fields(C)[0].metadata['test'] = 3
1808
1809 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001810 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001811 @dataclass
1812 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001813 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001814 self.assertEqual(len(fields(C)[0].metadata), 3)
1815 self.assertEqual(fields(C)[0].metadata['test'], 10)
1816 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1817 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001818 # Update should work.
1819 d['foo'] = 1
1820 self.assertEqual(len(fields(C)[0].metadata), 4)
1821 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001822 with self.assertRaises(KeyError):
1823 # Non-existent key.
1824 fields(C)[0].metadata['baz']
1825 with self.assertRaisesRegex(TypeError,
1826 'does not support item assignment'):
1827 fields(C)[0].metadata['test'] = 3
1828
1829 def test_field_metadata_custom_mapping(self):
1830 # Try a custom mapping.
1831 class SimpleNameSpace:
1832 def __init__(self, **kw):
1833 self.__dict__.update(kw)
1834
1835 def __getitem__(self, item):
1836 if item == 'xyzzy':
1837 return 'plugh'
1838 return getattr(self, item)
1839
1840 def __len__(self):
1841 return self.__dict__.__len__()
1842
1843 @dataclass
1844 class C:
1845 i: int = field(metadata=SimpleNameSpace(a=10))
1846
1847 self.assertEqual(len(fields(C)[0].metadata), 1)
1848 self.assertEqual(fields(C)[0].metadata['a'], 10)
1849 with self.assertRaises(AttributeError):
1850 fields(C)[0].metadata['b']
1851 # Make sure we're still talking to our custom mapping.
1852 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1853
1854 def test_generic_dataclasses(self):
1855 T = TypeVar('T')
1856
1857 @dataclass
1858 class LabeledBox(Generic[T]):
1859 content: T
1860 label: str = '<unknown>'
1861
1862 box = LabeledBox(42)
1863 self.assertEqual(box.content, 42)
1864 self.assertEqual(box.label, '<unknown>')
1865
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001866 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001867 Alias = List[LabeledBox[int]]
1868
1869 def test_generic_extending(self):
1870 S = TypeVar('S')
1871 T = TypeVar('T')
1872
1873 @dataclass
1874 class Base(Generic[T, S]):
1875 x: T
1876 y: S
1877
1878 @dataclass
1879 class DataDerived(Base[int, T]):
1880 new_field: str
1881 Alias = DataDerived[str]
1882 c = Alias(0, 'test1', 'test2')
1883 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1884
1885 class NonDataDerived(Base[int, T]):
1886 def new_method(self):
1887 return self.y
1888 Alias = NonDataDerived[float]
1889 c = Alias(10, 1.0)
1890 self.assertEqual(c.new_method(), 1.0)
1891
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001892 def test_generic_dynamic(self):
1893 T = TypeVar('T')
1894
1895 @dataclass
1896 class Parent(Generic[T]):
1897 x: T
1898 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1899 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1900 self.assertIs(Child[int](1, 2).z, None)
1901 self.assertEqual(Child[int](1, 2, 3).z, 3)
1902 self.assertEqual(Child[int](1, 2, 3).other, 42)
1903 # Check that type aliases work correctly.
1904 Alias = Child[T]
1905 self.assertEqual(Alias[int](1, 2).x, 1)
1906 # Check MRO resolution.
1907 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1908
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001909 def test_dataclassses_pickleable(self):
1910 global P, Q, R
1911 @dataclass
1912 class P:
1913 x: int
1914 y: int = 0
1915 @dataclass
1916 class Q:
1917 x: int
1918 y: int = field(default=0, init=False)
1919 @dataclass
1920 class R:
1921 x: int
1922 y: List[int] = field(default_factory=list)
1923 q = Q(1)
1924 q.y = 2
1925 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1926 for sample in samples:
1927 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1928 with self.subTest(sample=sample, proto=proto):
1929 new_sample = pickle.loads(pickle.dumps(sample, proto))
1930 self.assertEqual(sample.x, new_sample.x)
1931 self.assertEqual(sample.y, new_sample.y)
1932 self.assertIsNot(sample, new_sample)
1933 new_sample.x = 42
1934 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1935 self.assertEqual(new_sample.x, another_new_sample.x)
1936 self.assertEqual(sample.y, another_new_sample.y)
1937
Batuhan Taskayac7437e22020-10-21 16:49:22 +03001938 def test_dataclasses_qualnames(self):
1939 @dataclass(order=True, unsafe_hash=True, frozen=True)
1940 class A:
1941 x: int
1942 y: int
1943
1944 self.assertEqual(A.__init__.__name__, "__init__")
1945 for function in (
1946 '__eq__',
1947 '__lt__',
1948 '__le__',
1949 '__gt__',
1950 '__ge__',
1951 '__hash__',
1952 '__init__',
1953 '__repr__',
1954 '__setattr__',
1955 '__delattr__',
1956 ):
1957 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
1958
1959 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
1960 A()
1961
Eric V. Smithea8fc522018-01-27 19:07:40 -05001962
Eric V. Smith56970b82018-03-22 16:28:48 -04001963class TestFieldNoAnnotation(unittest.TestCase):
1964 def test_field_without_annotation(self):
1965 with self.assertRaisesRegex(TypeError,
1966 "'f' is a field but has no type annotation"):
1967 @dataclass
1968 class C:
1969 f = field()
1970
1971 def test_field_without_annotation_but_annotation_in_base(self):
1972 @dataclass
1973 class B:
1974 f: int
1975
1976 with self.assertRaisesRegex(TypeError,
1977 "'f' is a field but has no type annotation"):
1978 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001979 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001980 @dataclass
1981 class C(B):
1982 f = field()
1983
1984 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1985 # Same test, but with the base class not a dataclass.
1986 class B:
1987 f: int
1988
1989 with self.assertRaisesRegex(TypeError,
1990 "'f' is a field but has no type annotation"):
1991 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001992 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001993 @dataclass
1994 class C(B):
1995 f = field()
1996
1997
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001998class TestDocString(unittest.TestCase):
1999 def assertDocStrEqual(self, a, b):
2000 # Because 3.6 and 3.7 differ in how inspect.signature work
2001 # (see bpo #32108), for the time being just compare them with
2002 # whitespace stripped.
2003 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2004
2005 def test_existing_docstring_not_overridden(self):
2006 @dataclass
2007 class C:
2008 """Lorem ipsum"""
2009 x: int
2010
2011 self.assertEqual(C.__doc__, "Lorem ipsum")
2012
2013 def test_docstring_no_fields(self):
2014 @dataclass
2015 class C:
2016 pass
2017
2018 self.assertDocStrEqual(C.__doc__, "C()")
2019
2020 def test_docstring_one_field(self):
2021 @dataclass
2022 class C:
2023 x: int
2024
2025 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2026
2027 def test_docstring_two_fields(self):
2028 @dataclass
2029 class C:
2030 x: int
2031 y: int
2032
2033 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2034
2035 def test_docstring_three_fields(self):
2036 @dataclass
2037 class C:
2038 x: int
2039 y: int
2040 z: str
2041
2042 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2043
2044 def test_docstring_one_field_with_default(self):
2045 @dataclass
2046 class C:
2047 x: int = 3
2048
2049 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2050
2051 def test_docstring_one_field_with_default_none(self):
2052 @dataclass
2053 class C:
2054 x: Union[int, type(None)] = None
2055
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002056 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002057
2058 def test_docstring_list_field(self):
2059 @dataclass
2060 class C:
2061 x: List[int]
2062
2063 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2064
2065 def test_docstring_list_field_with_default_factory(self):
2066 @dataclass
2067 class C:
2068 x: List[int] = field(default_factory=list)
2069
2070 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2071
2072 def test_docstring_deque_field(self):
2073 @dataclass
2074 class C:
2075 x: deque
2076
2077 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2078
2079 def test_docstring_deque_field_with_default_factory(self):
2080 @dataclass
2081 class C:
2082 x: deque = field(default_factory=deque)
2083
2084 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2085
2086
Eric V. Smithea8fc522018-01-27 19:07:40 -05002087class TestInit(unittest.TestCase):
2088 def test_base_has_init(self):
2089 class B:
2090 def __init__(self):
2091 self.z = 100
2092 pass
2093
2094 # Make sure that declaring this class doesn't raise an error.
2095 # The issue is that we can't override __init__ in our class,
2096 # but it should be okay to add __init__ to us if our base has
2097 # an __init__.
2098 @dataclass
2099 class C(B):
2100 x: int = 0
2101 c = C(10)
2102 self.assertEqual(c.x, 10)
2103 self.assertNotIn('z', vars(c))
2104
2105 # Make sure that if we don't add an init, the base __init__
2106 # gets called.
2107 @dataclass(init=False)
2108 class C(B):
2109 x: int = 10
2110 c = C()
2111 self.assertEqual(c.x, 10)
2112 self.assertEqual(c.z, 100)
2113
2114 def test_no_init(self):
2115 dataclass(init=False)
2116 class C:
2117 i: int = 0
2118 self.assertEqual(C().i, 0)
2119
2120 dataclass(init=False)
2121 class C:
2122 i: int = 2
2123 def __init__(self):
2124 self.i = 3
2125 self.assertEqual(C().i, 3)
2126
2127 def test_overwriting_init(self):
2128 # If the class has __init__, use it no matter the value of
2129 # init=.
2130
2131 @dataclass
2132 class C:
2133 x: int
2134 def __init__(self, x):
2135 self.x = 2 * x
2136 self.assertEqual(C(3).x, 6)
2137
2138 @dataclass(init=True)
2139 class C:
2140 x: int
2141 def __init__(self, x):
2142 self.x = 2 * x
2143 self.assertEqual(C(4).x, 8)
2144
2145 @dataclass(init=False)
2146 class C:
2147 x: int
2148 def __init__(self, x):
2149 self.x = 2 * x
2150 self.assertEqual(C(5).x, 10)
2151
2152
2153class TestRepr(unittest.TestCase):
2154 def test_repr(self):
2155 @dataclass
2156 class B:
2157 x: int
2158
2159 @dataclass
2160 class C(B):
2161 y: int = 10
2162
2163 o = C(4)
2164 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2165
2166 @dataclass
2167 class D(C):
2168 x: int = 20
2169 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2170
2171 @dataclass
2172 class C:
2173 @dataclass
2174 class D:
2175 i: int
2176 @dataclass
2177 class E:
2178 pass
2179 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2180 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2181
2182 def test_no_repr(self):
2183 # Test a class with no __repr__ and repr=False.
2184 @dataclass(repr=False)
2185 class C:
2186 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002187 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002188 repr(C(3)))
2189
2190 # Test a class with a __repr__ and repr=False.
2191 @dataclass(repr=False)
2192 class C:
2193 x: int
2194 def __repr__(self):
2195 return 'C-class'
2196 self.assertEqual(repr(C(3)), 'C-class')
2197
2198 def test_overwriting_repr(self):
2199 # If the class has __repr__, use it no matter the value of
2200 # repr=.
2201
2202 @dataclass
2203 class C:
2204 x: int
2205 def __repr__(self):
2206 return 'x'
2207 self.assertEqual(repr(C(0)), 'x')
2208
2209 @dataclass(repr=True)
2210 class C:
2211 x: int
2212 def __repr__(self):
2213 return 'x'
2214 self.assertEqual(repr(C(0)), 'x')
2215
2216 @dataclass(repr=False)
2217 class C:
2218 x: int
2219 def __repr__(self):
2220 return 'x'
2221 self.assertEqual(repr(C(0)), 'x')
2222
2223
Eric V. Smithea8fc522018-01-27 19:07:40 -05002224class TestEq(unittest.TestCase):
2225 def test_no_eq(self):
2226 # Test a class with no __eq__ and eq=False.
2227 @dataclass(eq=False)
2228 class C:
2229 x: int
2230 self.assertNotEqual(C(0), C(0))
2231 c = C(3)
2232 self.assertEqual(c, c)
2233
2234 # Test a class with an __eq__ and eq=False.
2235 @dataclass(eq=False)
2236 class C:
2237 x: int
2238 def __eq__(self, other):
2239 return other == 10
2240 self.assertEqual(C(3), 10)
2241
2242 def test_overwriting_eq(self):
2243 # If the class has __eq__, use it no matter the value of
2244 # eq=.
2245
2246 @dataclass
2247 class C:
2248 x: int
2249 def __eq__(self, other):
2250 return other == 3
2251 self.assertEqual(C(1), 3)
2252 self.assertNotEqual(C(1), 1)
2253
2254 @dataclass(eq=True)
2255 class C:
2256 x: int
2257 def __eq__(self, other):
2258 return other == 4
2259 self.assertEqual(C(1), 4)
2260 self.assertNotEqual(C(1), 1)
2261
2262 @dataclass(eq=False)
2263 class C:
2264 x: int
2265 def __eq__(self, other):
2266 return other == 5
2267 self.assertEqual(C(1), 5)
2268 self.assertNotEqual(C(1), 1)
2269
2270
2271class TestOrdering(unittest.TestCase):
2272 def test_functools_total_ordering(self):
2273 # Test that functools.total_ordering works with this class.
2274 @total_ordering
2275 @dataclass
2276 class C:
2277 x: int
2278 def __lt__(self, other):
2279 # Perform the test "backward", just to make
2280 # sure this is being called.
2281 return self.x >= other
2282
2283 self.assertLess(C(0), -1)
2284 self.assertLessEqual(C(0), -1)
2285 self.assertGreater(C(0), 1)
2286 self.assertGreaterEqual(C(0), 1)
2287
2288 def test_no_order(self):
2289 # Test that no ordering functions are added by default.
2290 @dataclass(order=False)
2291 class C:
2292 x: int
2293 # Make sure no order methods are added.
2294 self.assertNotIn('__le__', C.__dict__)
2295 self.assertNotIn('__lt__', C.__dict__)
2296 self.assertNotIn('__ge__', C.__dict__)
2297 self.assertNotIn('__gt__', C.__dict__)
2298
2299 # Test that __lt__ is still called
2300 @dataclass(order=False)
2301 class C:
2302 x: int
2303 def __lt__(self, other):
2304 return False
2305 # Make sure other methods aren't added.
2306 self.assertNotIn('__le__', C.__dict__)
2307 self.assertNotIn('__ge__', C.__dict__)
2308 self.assertNotIn('__gt__', C.__dict__)
2309
2310 def test_overwriting_order(self):
2311 with self.assertRaisesRegex(TypeError,
2312 'Cannot overwrite attribute __lt__'
2313 '.*using functools.total_ordering'):
2314 @dataclass(order=True)
2315 class C:
2316 x: int
2317 def __lt__(self):
2318 pass
2319
2320 with self.assertRaisesRegex(TypeError,
2321 'Cannot overwrite attribute __le__'
2322 '.*using functools.total_ordering'):
2323 @dataclass(order=True)
2324 class C:
2325 x: int
2326 def __le__(self):
2327 pass
2328
2329 with self.assertRaisesRegex(TypeError,
2330 'Cannot overwrite attribute __gt__'
2331 '.*using functools.total_ordering'):
2332 @dataclass(order=True)
2333 class C:
2334 x: int
2335 def __gt__(self):
2336 pass
2337
2338 with self.assertRaisesRegex(TypeError,
2339 'Cannot overwrite attribute __ge__'
2340 '.*using functools.total_ordering'):
2341 @dataclass(order=True)
2342 class C:
2343 x: int
2344 def __ge__(self):
2345 pass
2346
2347class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002348 def test_unsafe_hash(self):
2349 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002350 class C:
2351 x: int
2352 y: str
2353 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2354
Eric V. Smithea8fc522018-01-27 19:07:40 -05002355 def test_hash_rules(self):
2356 def non_bool(value):
2357 # Map to something else that's True, but not a bool.
2358 if value is None:
2359 return None
2360 if value:
2361 return (3,)
2362 return 0
2363
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002364 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2365 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2366 frozen=frozen):
2367 if result != 'exception':
2368 if with_hash:
2369 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2370 class C:
2371 def __hash__(self):
2372 return 0
2373 else:
2374 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2375 class C:
2376 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002377
2378 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002379 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002380 # __hash__ contains the function we generated.
2381 self.assertIn('__hash__', C.__dict__)
2382 self.assertIsNotNone(C.__dict__['__hash__'])
2383
Eric V. Smithea8fc522018-01-27 19:07:40 -05002384 elif result == '':
2385 # __hash__ is not present in our class.
2386 if not with_hash:
2387 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002388
Eric V. Smithea8fc522018-01-27 19:07:40 -05002389 elif result == 'none':
2390 # __hash__ is set to None.
2391 self.assertIn('__hash__', C.__dict__)
2392 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002393
2394 elif result == 'exception':
2395 # Creating the class should cause an exception.
2396 # This only happens with with_hash==True.
2397 assert(with_hash)
2398 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2399 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2400 class C:
2401 def __hash__(self):
2402 return 0
2403
Eric V. Smithea8fc522018-01-27 19:07:40 -05002404 else:
2405 assert False, f'unknown result {result!r}'
2406
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002407 # There are 8 cases of:
2408 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002409 # eq=True/False
2410 # frozen=True/False
2411 # And for each of these, a different result if
2412 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002413 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2414 (False, False, False, '', ''),
2415 (False, False, True, '', ''),
2416 (False, True, False, 'none', ''),
2417 (False, True, True, 'fn', ''),
2418 (True, False, False, 'fn', 'exception'),
2419 (True, False, True, 'fn', 'exception'),
2420 (True, True, False, 'fn', 'exception'),
2421 (True, True, True, 'fn', 'exception'),
2422 ], 1):
2423 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2424 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002425
2426 # Test non-bool truth values, too. This is just to
2427 # make sure the data-driven table in the decorator
2428 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002429 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2430 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002431
2432
2433 def test_eq_only(self):
2434 # If a class defines __eq__, __hash__ is automatically added
2435 # and set to None. This is normal Python behavior, not
2436 # related to dataclasses. Make sure we don't interfere with
2437 # that (see bpo=32546).
2438
2439 @dataclass
2440 class C:
2441 i: int
2442 def __eq__(self, other):
2443 return self.i == other.i
2444 self.assertEqual(C(1), C(1))
2445 self.assertNotEqual(C(1), C(4))
2446
2447 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002448 # unsafe_hash=True.
2449 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002450 class C:
2451 i: int
2452 def __eq__(self, other):
2453 return self.i == other.i
2454 self.assertEqual(C(1), C(1.0))
2455 self.assertEqual(hash(C(1)), hash(C(1.0)))
2456
2457 # And check that the classes __eq__ is being used, despite
2458 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002459 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002460 class C:
2461 i: int
2462 def __eq__(self, other):
2463 return self.i == 3 and self.i == other.i
2464 self.assertEqual(C(3), C(3))
2465 self.assertNotEqual(C(1), C(1))
2466 self.assertEqual(hash(C(1)), hash(C(1.0)))
2467
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002468 def test_0_field_hash(self):
2469 @dataclass(frozen=True)
2470 class C:
2471 pass
2472 self.assertEqual(hash(C()), hash(()))
2473
2474 @dataclass(unsafe_hash=True)
2475 class C:
2476 pass
2477 self.assertEqual(hash(C()), hash(()))
2478
2479 def test_1_field_hash(self):
2480 @dataclass(frozen=True)
2481 class C:
2482 x: int
2483 self.assertEqual(hash(C(4)), hash((4,)))
2484 self.assertEqual(hash(C(42)), hash((42,)))
2485
2486 @dataclass(unsafe_hash=True)
2487 class C:
2488 x: int
2489 self.assertEqual(hash(C(4)), hash((4,)))
2490 self.assertEqual(hash(C(42)), hash((42,)))
2491
Eric V. Smith718070d2018-02-23 13:01:31 -05002492 def test_hash_no_args(self):
2493 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002494 # make sure that if the @dataclass parameter name is changed
2495 # or the non-default hashing behavior changes, the default
2496 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002497
2498 class Base:
2499 def __hash__(self):
2500 return 301
2501
2502 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002503 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002504 for frozen, eq, base, expected in [
2505 (None, None, object, 'unhashable'),
2506 (None, None, Base, 'unhashable'),
2507 (None, False, object, 'object'),
2508 (None, False, Base, 'base'),
2509 (None, True, object, 'unhashable'),
2510 (None, True, Base, 'unhashable'),
2511 (False, None, object, 'unhashable'),
2512 (False, None, Base, 'unhashable'),
2513 (False, False, object, 'object'),
2514 (False, False, Base, 'base'),
2515 (False, True, object, 'unhashable'),
2516 (False, True, Base, 'unhashable'),
2517 (True, None, object, 'tuple'),
2518 (True, None, Base, 'tuple'),
2519 (True, False, object, 'object'),
2520 (True, False, Base, 'base'),
2521 (True, True, object, 'tuple'),
2522 (True, True, Base, 'tuple'),
2523 ]:
2524
2525 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2526 # First, create the class.
2527 if frozen is None and eq is None:
2528 @dataclass
2529 class C(base):
2530 i: int
2531 elif frozen is None:
2532 @dataclass(eq=eq)
2533 class C(base):
2534 i: int
2535 elif eq is None:
2536 @dataclass(frozen=frozen)
2537 class C(base):
2538 i: int
2539 else:
2540 @dataclass(frozen=frozen, eq=eq)
2541 class C(base):
2542 i: int
2543
2544 # Now, make sure it hashes as expected.
2545 if expected == 'unhashable':
2546 c = C(10)
2547 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2548 hash(c)
2549
2550 elif expected == 'base':
2551 self.assertEqual(hash(C(10)), 301)
2552
2553 elif expected == 'object':
2554 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002555 # hash isn't based on id(), so calling hash()
2556 # won't tell us much. So, just check the
2557 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002558 self.assertIs(C.__hash__, object.__hash__)
2559
2560 elif expected == 'tuple':
2561 self.assertEqual(hash(C(42)), hash((42,)))
2562
2563 else:
2564 assert False, f'unknown value for expected={expected!r}'
2565
Eric V. Smithea8fc522018-01-27 19:07:40 -05002566
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002567class TestFrozen(unittest.TestCase):
2568 def test_frozen(self):
2569 @dataclass(frozen=True)
2570 class C:
2571 i: int
2572
2573 c = C(10)
2574 self.assertEqual(c.i, 10)
2575 with self.assertRaises(FrozenInstanceError):
2576 c.i = 5
2577 self.assertEqual(c.i, 10)
2578
2579 def test_inherit(self):
2580 @dataclass(frozen=True)
2581 class C:
2582 i: int
2583
2584 @dataclass(frozen=True)
2585 class D(C):
2586 j: int
2587
2588 d = D(0, 10)
2589 with self.assertRaises(FrozenInstanceError):
2590 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002591 with self.assertRaises(FrozenInstanceError):
2592 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002593 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002594 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002595
Iurii Kemaev376ffc62021-04-06 06:14:01 +01002596 def test_inherit_nonfrozen_from_empty_frozen(self):
2597 @dataclass(frozen=True)
2598 class C:
2599 pass
2600
2601 with self.assertRaisesRegex(TypeError,
2602 'cannot inherit non-frozen dataclass from a frozen one'):
2603 @dataclass
2604 class D(C):
2605 j: int
2606
2607 def test_inherit_nonfrozen_from_empty(self):
2608 @dataclass
2609 class C:
2610 pass
2611
2612 @dataclass
2613 class D(C):
2614 j: int
2615
2616 d = D(3)
2617 self.assertEqual(d.j, 3)
2618 self.assertIsInstance(d, C)
2619
Eric V. Smithf199bc62018-03-18 20:40:34 -04002620 # Test both ways: with an intermediate normal (non-dataclass)
2621 # class and without an intermediate class.
2622 def test_inherit_nonfrozen_from_frozen(self):
2623 for intermediate_class in [True, False]:
2624 with self.subTest(intermediate_class=intermediate_class):
2625 @dataclass(frozen=True)
2626 class C:
2627 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002628
Eric V. Smithf199bc62018-03-18 20:40:34 -04002629 if intermediate_class:
2630 class I(C): pass
2631 else:
2632 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002633
Eric V. Smithf199bc62018-03-18 20:40:34 -04002634 with self.assertRaisesRegex(TypeError,
2635 'cannot inherit non-frozen dataclass from a frozen one'):
2636 @dataclass
2637 class D(I):
2638 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002639
Eric V. Smithf199bc62018-03-18 20:40:34 -04002640 def test_inherit_frozen_from_nonfrozen(self):
2641 for intermediate_class in [True, False]:
2642 with self.subTest(intermediate_class=intermediate_class):
2643 @dataclass
2644 class C:
2645 i: int
2646
2647 if intermediate_class:
2648 class I(C): pass
2649 else:
2650 I = C
2651
2652 with self.assertRaisesRegex(TypeError,
2653 'cannot inherit frozen dataclass from a non-frozen one'):
2654 @dataclass(frozen=True)
2655 class D(I):
2656 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002657
2658 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002659 for intermediate_class in [True, False]:
2660 with self.subTest(intermediate_class=intermediate_class):
2661 class C:
2662 pass
2663
2664 if intermediate_class:
2665 class I(C): pass
2666 else:
2667 I = C
2668
2669 @dataclass(frozen=True)
2670 class D(I):
2671 i: int
2672
2673 d = D(10)
2674 with self.assertRaises(FrozenInstanceError):
2675 d.i = 5
2676
2677 def test_non_frozen_normal_derived(self):
2678 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002679
2680 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002681 class D:
2682 x: int
2683 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002684
Eric V. Smithf199bc62018-03-18 20:40:34 -04002685 class S(D):
2686 pass
2687
2688 s = S(3)
2689 self.assertEqual(s.x, 3)
2690 self.assertEqual(s.y, 10)
2691 s.cached = True
2692
2693 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002694 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002695 s.x = 5
2696 with self.assertRaises(FrozenInstanceError):
2697 s.y = 5
2698 self.assertEqual(s.x, 3)
2699 self.assertEqual(s.y, 10)
2700 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002701
Eric V. Smith74940912018-04-05 06:50:18 -04002702 def test_overwriting_frozen(self):
2703 # frozen uses __setattr__ and __delattr__.
2704 with self.assertRaisesRegex(TypeError,
2705 'Cannot overwrite attribute __setattr__'):
2706 @dataclass(frozen=True)
2707 class C:
2708 x: int
2709 def __setattr__(self):
2710 pass
2711
2712 with self.assertRaisesRegex(TypeError,
2713 'Cannot overwrite attribute __delattr__'):
2714 @dataclass(frozen=True)
2715 class C:
2716 x: int
2717 def __delattr__(self):
2718 pass
2719
2720 @dataclass(frozen=False)
2721 class C:
2722 x: int
2723 def __setattr__(self, name, value):
2724 self.__dict__['x'] = value * 2
2725 self.assertEqual(C(10).x, 20)
2726
2727 def test_frozen_hash(self):
2728 @dataclass(frozen=True)
2729 class C:
2730 x: Any
2731
2732 # If x is immutable, we can compute the hash. No exception is
2733 # raised.
2734 hash(C(3))
2735
2736 # If x is mutable, computing the hash is an error.
2737 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2738 hash(C({}))
2739
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002740
Eric V. Smith7389fd92018-03-19 21:07:51 -04002741class TestSlots(unittest.TestCase):
2742 def test_simple(self):
2743 @dataclass
2744 class C:
2745 __slots__ = ('x',)
2746 x: Any
2747
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002748 # There was a bug where a variable in a slot was assumed to
2749 # also have a default value (of type
2750 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002751 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002752 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002753 C()
2754
2755 # We can create an instance, and assign to x.
2756 c = C(10)
2757 self.assertEqual(c.x, 10)
2758 c.x = 5
2759 self.assertEqual(c.x, 5)
2760
2761 # We can't assign to anything else.
2762 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2763 c.y = 5
2764
2765 def test_derived_added_field(self):
2766 # See bpo-33100.
2767 @dataclass
2768 class Base:
2769 __slots__ = ('x',)
2770 x: Any
2771
2772 @dataclass
2773 class Derived(Base):
2774 x: int
2775 y: int
2776
2777 d = Derived(1, 2)
2778 self.assertEqual((d.x, d.y), (1, 2))
2779
2780 # We can add a new field to the derived instance.
2781 d.z = 10
2782
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002783class TestDescriptors(unittest.TestCase):
2784 def test_set_name(self):
2785 # See bpo-33141.
2786
2787 # Create a descriptor.
2788 class D:
2789 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002790 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002791 def __get__(self, instance, owner):
2792 if instance is not None:
2793 return 1
2794 return self
2795
2796 # This is the case of just normal descriptor behavior, no
2797 # dataclass code is involved in initializing the descriptor.
2798 @dataclass
2799 class C:
2800 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002801 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002802
2803 # Now test with a default value and init=False, which is the
2804 # only time this is really meaningful. If not using
2805 # init=False, then the descriptor will be overwritten, anyway.
2806 @dataclass
2807 class C:
2808 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002809 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002810 self.assertEqual(C().c, 1)
2811
2812 def test_non_descriptor(self):
2813 # PEP 487 says __set_name__ should work on non-descriptors.
2814 # Create a descriptor.
2815
2816 class D:
2817 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002818 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002819
2820 @dataclass
2821 class C:
2822 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002823 self.assertEqual(C.c.name, 'cx')
2824
2825 def test_lookup_on_instance(self):
2826 # See bpo-33175.
2827 class D:
2828 pass
2829
2830 d = D()
2831 # Create an attribute on the instance, not type.
2832 d.__set_name__ = Mock()
2833
2834 # Make sure d.__set_name__ is not called.
2835 @dataclass
2836 class C:
2837 i: int=field(default=d, init=False)
2838
2839 self.assertEqual(d.__set_name__.call_count, 0)
2840
2841 def test_lookup_on_class(self):
2842 # See bpo-33175.
2843 class D:
2844 pass
2845 D.__set_name__ = Mock()
2846
2847 # Make sure D.__set_name__ is called.
2848 @dataclass
2849 class C:
2850 i: int=field(default=D(), init=False)
2851
2852 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002853
Eric V. Smith7389fd92018-03-19 21:07:51 -04002854
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002855class TestStringAnnotations(unittest.TestCase):
2856 def test_classvar(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002857 # Some expressions recognized as ClassVar really aren't. But
2858 # if you're using string annotations, it's not an exact
2859 # science.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002860 # These tests assume that both "import typing" and "from
2861 # typing import *" have been run in this file.
2862 for typestr in ('ClassVar[int]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002863 'ClassVar [int]',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002864 ' ClassVar [int]',
2865 'ClassVar',
2866 ' ClassVar ',
2867 'typing.ClassVar[int]',
2868 'typing.ClassVar[str]',
2869 ' typing.ClassVar[str]',
2870 'typing .ClassVar[str]',
2871 'typing. ClassVar[str]',
2872 'typing.ClassVar [str]',
2873 'typing.ClassVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01002874
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002875 # Not syntactically valid, but these will
Pablo Galindob0544ba2021-04-21 12:41:19 +01002876 # be treated as ClassVars.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002877 'typing.ClassVar.[int]',
2878 'typing.ClassVar+',
2879 ):
2880 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002881 @dataclass
2882 class C:
2883 x: typestr
2884
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002885 # x is a ClassVar, so C() takes no args.
2886 C()
2887
2888 # And it won't appear in the class's dict because it doesn't
2889 # have a default.
2890 self.assertNotIn('x', C.__dict__)
2891
2892 def test_isnt_classvar(self):
2893 for typestr in ('CV',
2894 't.ClassVar',
2895 't.ClassVar[int]',
2896 'typing..ClassVar[int]',
2897 'Classvar',
2898 'Classvar[int]',
2899 'typing.ClassVarx[int]',
2900 'typong.ClassVar[int]',
2901 'dataclasses.ClassVar[int]',
2902 'typingxClassVar[str]',
2903 ):
2904 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002905 @dataclass
2906 class C:
2907 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002908
2909 # x is not a ClassVar, so C() takes one arg.
2910 self.assertEqual(C(10).x, 10)
2911
2912 def test_initvar(self):
2913 # These tests assume that both "import dataclasses" and "from
2914 # dataclasses import *" have been run in this file.
2915 for typestr in ('InitVar[int]',
2916 'InitVar [int]'
2917 ' InitVar [int]',
2918 'InitVar',
2919 ' InitVar ',
2920 'dataclasses.InitVar[int]',
2921 'dataclasses.InitVar[str]',
2922 ' dataclasses.InitVar[str]',
2923 'dataclasses .InitVar[str]',
2924 'dataclasses. InitVar[str]',
2925 'dataclasses.InitVar [str]',
2926 'dataclasses.InitVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01002927
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002928 # Not syntactically valid, but these will
2929 # be treated as InitVars.
2930 'dataclasses.InitVar.[int]',
2931 'dataclasses.InitVar+',
2932 ):
2933 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002934 @dataclass
2935 class C:
2936 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002937
2938 # x is an InitVar, so doesn't create a member.
2939 with self.assertRaisesRegex(AttributeError,
2940 "object has no attribute 'x'"):
2941 C(1).x
2942
2943 def test_isnt_initvar(self):
2944 for typestr in ('IV',
2945 'dc.InitVar',
2946 'xdataclasses.xInitVar',
2947 'typing.xInitVar[int]',
2948 ):
2949 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002950 @dataclass
2951 class C:
2952 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002953
2954 # x is not an InitVar, so there will be a member x.
2955 self.assertEqual(C(10).x, 10)
2956
2957 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002958 from test import dataclass_module_1
Pablo Galindob0544ba2021-04-21 12:41:19 +01002959 from test import dataclass_module_1_str
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002960 from test import dataclass_module_2
Pablo Galindob0544ba2021-04-21 12:41:19 +01002961 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002962
Pablo Galindob0544ba2021-04-21 12:41:19 +01002963 for m in (dataclass_module_1, dataclass_module_1_str,
2964 dataclass_module_2, dataclass_module_2_str,
2965 ):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002966 with self.subTest(m=m):
2967 # There's a difference in how the ClassVars are
2968 # interpreted when using string annotations or
2969 # not. See the imported modules for details.
Pablo Galindob0544ba2021-04-21 12:41:19 +01002970 if m.USING_STRINGS:
2971 c = m.CV(10)
2972 else:
2973 c = m.CV()
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002974 self.assertEqual(c.cv0, 20)
2975
2976
2977 # There's a difference in how the InitVars are
2978 # interpreted when using string annotations or
2979 # not. See the imported modules for details.
2980 c = m.IV(0, 1, 2, 3, 4)
2981
2982 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2983 with self.subTest(field_name=field_name):
2984 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2985 # Since field_name is an InitVar, it's
2986 # not an instance field.
2987 getattr(c, field_name)
2988
Pablo Galindob0544ba2021-04-21 12:41:19 +01002989 if m.USING_STRINGS:
2990 # iv4 is interpreted as a normal field.
2991 self.assertIn('not_iv4', c.__dict__)
2992 self.assertEqual(c.not_iv4, 4)
2993 else:
2994 # iv4 is interpreted as an InitVar, so it
2995 # won't exist on the instance.
2996 self.assertNotIn('not_iv4', c.__dict__)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002997
Yury Selivanovd219cc42019-12-09 09:54:20 -05002998 def test_text_annotations(self):
2999 from test import dataclass_textanno
3000
3001 self.assertEqual(
3002 get_type_hints(dataclass_textanno.Bar),
3003 {'foo': dataclass_textanno.Foo})
3004 self.assertEqual(
3005 get_type_hints(dataclass_textanno.Bar.__init__),
3006 {'foo': dataclass_textanno.Foo,
3007 'return': type(None)})
3008
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003009
Eric V. Smith4e812962018-05-16 11:31:29 -04003010class TestMakeDataclass(unittest.TestCase):
3011 def test_simple(self):
3012 C = make_dataclass('C',
3013 [('x', int),
3014 ('y', int, field(default=5))],
3015 namespace={'add_one': lambda self: self.x + 1})
3016 c = C(10)
3017 self.assertEqual((c.x, c.y), (10, 5))
3018 self.assertEqual(c.add_one(), 11)
3019
3020
3021 def test_no_mutate_namespace(self):
3022 # Make sure a provided namespace isn't mutated.
3023 ns = {}
3024 C = make_dataclass('C',
3025 [('x', int),
3026 ('y', int, field(default=5))],
3027 namespace=ns)
3028 self.assertEqual(ns, {})
3029
3030 def test_base(self):
3031 class Base1:
3032 pass
3033 class Base2:
3034 pass
3035 C = make_dataclass('C',
3036 [('x', int)],
3037 bases=(Base1, Base2))
3038 c = C(2)
3039 self.assertIsInstance(c, C)
3040 self.assertIsInstance(c, Base1)
3041 self.assertIsInstance(c, Base2)
3042
3043 def test_base_dataclass(self):
3044 @dataclass
3045 class Base1:
3046 x: int
3047 class Base2:
3048 pass
3049 C = make_dataclass('C',
3050 [('y', int)],
3051 bases=(Base1, Base2))
3052 with self.assertRaisesRegex(TypeError, 'required positional'):
3053 c = C(2)
3054 c = C(1, 2)
3055 self.assertIsInstance(c, C)
3056 self.assertIsInstance(c, Base1)
3057 self.assertIsInstance(c, Base2)
3058
3059 self.assertEqual((c.x, c.y), (1, 2))
3060
3061 def test_init_var(self):
3062 def post_init(self, y):
3063 self.x *= y
3064
3065 C = make_dataclass('C',
3066 [('x', int),
3067 ('y', InitVar[int]),
3068 ],
3069 namespace={'__post_init__': post_init},
3070 )
3071 c = C(2, 3)
3072 self.assertEqual(vars(c), {'x': 6})
3073 self.assertEqual(len(fields(c)), 1)
3074
3075 def test_class_var(self):
3076 C = make_dataclass('C',
3077 [('x', int),
3078 ('y', ClassVar[int], 10),
3079 ('z', ClassVar[int], field(default=20)),
3080 ])
3081 c = C(1)
3082 self.assertEqual(vars(c), {'x': 1})
3083 self.assertEqual(len(fields(c)), 1)
3084 self.assertEqual(C.y, 10)
3085 self.assertEqual(C.z, 20)
3086
3087 def test_other_params(self):
3088 C = make_dataclass('C',
3089 [('x', int),
3090 ('y', ClassVar[int], 10),
3091 ('z', ClassVar[int], field(default=20)),
3092 ],
3093 init=False)
3094 # Make sure we have a repr, but no init.
3095 self.assertNotIn('__init__', vars(C))
3096 self.assertIn('__repr__', vars(C))
3097
3098 # Make sure random other params don't work.
3099 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3100 C = make_dataclass('C',
3101 [],
3102 xxinit=False)
3103
3104 def test_no_types(self):
3105 C = make_dataclass('Point', ['x', 'y', 'z'])
3106 c = C(1, 2, 3)
3107 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3108 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3109 'y': 'typing.Any',
3110 'z': 'typing.Any'})
3111
3112 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3113 c = C(1, 2, 3)
3114 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3115 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3116 'y': int,
3117 'z': 'typing.Any'})
3118
3119 def test_invalid_type_specification(self):
3120 for bad_field in [(),
3121 (1, 2, 3, 4),
3122 ]:
3123 with self.subTest(bad_field=bad_field):
3124 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3125 make_dataclass('C', ['a', bad_field])
3126
3127 # And test for things with no len().
3128 for bad_field in [float,
3129 lambda x:x,
3130 ]:
3131 with self.subTest(bad_field=bad_field):
3132 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3133 make_dataclass('C', ['a', bad_field])
3134
3135 def test_duplicate_field_names(self):
3136 for field in ['a', 'ab']:
3137 with self.subTest(field=field):
3138 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3139 make_dataclass('C', [field, 'a', field])
3140
3141 def test_keyword_field_names(self):
3142 for field in ['for', 'async', 'await', 'as']:
3143 with self.subTest(field=field):
3144 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3145 make_dataclass('C', ['a', field])
3146 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3147 make_dataclass('C', [field])
3148 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3149 make_dataclass('C', [field, 'a'])
3150
3151 def test_non_identifier_field_names(self):
3152 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3153 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003154 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003155 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003156 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003157 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003158 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003159 make_dataclass('C', [field, 'a'])
3160
3161 def test_underscore_field_names(self):
3162 # Unlike namedtuple, it's okay if dataclass field names have
3163 # an underscore.
3164 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3165
3166 def test_funny_class_names_names(self):
3167 # No reason to prevent weird class names, since
3168 # types.new_class allows them.
3169 for classname in ['()', 'x,y', '*', '2@3', '']:
3170 with self.subTest(classname=classname):
3171 C = make_dataclass(classname, ['a', 'b'])
3172 self.assertEqual(C.__name__, classname)
3173
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003174class TestReplace(unittest.TestCase):
3175 def test(self):
3176 @dataclass(frozen=True)
3177 class C:
3178 x: int
3179 y: int
3180
3181 c = C(1, 2)
3182 c1 = replace(c, x=3)
3183 self.assertEqual(c1.x, 3)
3184 self.assertEqual(c1.y, 2)
3185
3186 def test_frozen(self):
3187 @dataclass(frozen=True)
3188 class C:
3189 x: int
3190 y: int
3191 z: int = field(init=False, default=10)
3192 t: int = field(init=False, default=100)
3193
3194 c = C(1, 2)
3195 c1 = replace(c, x=3)
3196 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3197 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3198
3199
3200 with self.assertRaisesRegex(ValueError, 'init=False'):
3201 replace(c, x=3, z=20, t=50)
3202 with self.assertRaisesRegex(ValueError, 'init=False'):
3203 replace(c, z=20)
3204 replace(c, x=3, z=20, t=50)
3205
3206 # Make sure the result is still frozen.
3207 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3208 c1.x = 3
3209
3210 # Make sure we can't replace an attribute that doesn't exist,
3211 # if we're also replacing one that does exist. Test this
3212 # here, because setting attributes on frozen instances is
3213 # handled slightly differently from non-frozen ones.
3214 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3215 "keyword argument 'a'"):
3216 c1 = replace(c, x=20, a=5)
3217
3218 def test_invalid_field_name(self):
3219 @dataclass(frozen=True)
3220 class C:
3221 x: int
3222 y: int
3223
3224 c = C(1, 2)
3225 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3226 "keyword argument 'z'"):
3227 c1 = replace(c, z=3)
3228
3229 def test_invalid_object(self):
3230 @dataclass(frozen=True)
3231 class C:
3232 x: int
3233 y: int
3234
3235 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3236 replace(C, x=3)
3237
3238 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3239 replace(0, x=3)
3240
3241 def test_no_init(self):
3242 @dataclass
3243 class C:
3244 x: int
3245 y: int = field(init=False, default=10)
3246
3247 c = C(1)
3248 c.y = 20
3249
3250 # Make sure y gets the default value.
3251 c1 = replace(c, x=5)
3252 self.assertEqual((c1.x, c1.y), (5, 10))
3253
3254 # Trying to replace y is an error.
3255 with self.assertRaisesRegex(ValueError, 'init=False'):
3256 replace(c, x=2, y=30)
3257
3258 with self.assertRaisesRegex(ValueError, 'init=False'):
3259 replace(c, y=30)
3260
3261 def test_classvar(self):
3262 @dataclass
3263 class C:
3264 x: int
3265 y: ClassVar[int] = 1000
3266
3267 c = C(1)
3268 d = C(2)
3269
3270 self.assertIs(c.y, d.y)
3271 self.assertEqual(c.y, 1000)
3272
3273 # Trying to replace y is an error: can't replace ClassVars.
3274 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3275 "unexpected keyword argument 'y'"):
3276 replace(c, y=30)
3277
3278 replace(c, x=5)
3279
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003280 def test_initvar_is_specified(self):
3281 @dataclass
3282 class C:
3283 x: int
3284 y: InitVar[int]
3285
3286 def __post_init__(self, y):
3287 self.x *= y
3288
3289 c = C(1, 10)
3290 self.assertEqual(c.x, 10)
3291 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3292 "specified with replace()"):
3293 replace(c, x=3)
3294 c = replace(c, x=3, y=5)
3295 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303296
Zackery Spytz75220672021-04-05 13:41:01 -06003297 def test_initvar_with_default_value(self):
3298 @dataclass
3299 class C:
3300 x: int
3301 y: InitVar[int] = None
3302 z: InitVar[int] = 42
3303
3304 def __post_init__(self, y, z):
3305 if y is not None:
3306 self.x += y
3307 if z is not None:
3308 self.x += z
3309
3310 c = C(x=1, y=10, z=1)
3311 self.assertEqual(replace(c), C(x=12))
3312 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
3313 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
3314
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303315 def test_recursive_repr(self):
3316 @dataclass
3317 class C:
3318 f: "C"
3319
3320 c = C(None)
3321 c.f = c
3322 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3323
3324 def test_recursive_repr_two_attrs(self):
3325 @dataclass
3326 class C:
3327 f: "C"
3328 g: "C"
3329
3330 c = C(None, None)
3331 c.f = c
3332 c.g = c
3333 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3334 ".<locals>.C(f=..., g=...)")
3335
3336 def test_recursive_repr_indirection(self):
3337 @dataclass
3338 class C:
3339 f: "D"
3340
3341 @dataclass
3342 class D:
3343 f: "C"
3344
3345 c = C(None)
3346 d = D(None)
3347 c.f = d
3348 d.f = c
3349 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3350 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3351 ".<locals>.D(f=...))")
3352
3353 def test_recursive_repr_indirection_two(self):
3354 @dataclass
3355 class C:
3356 f: "D"
3357
3358 @dataclass
3359 class D:
3360 f: "E"
3361
3362 @dataclass
3363 class E:
3364 f: "C"
3365
3366 c = C(None)
3367 d = D(None)
3368 e = E(None)
3369 c.f = d
3370 d.f = e
3371 e.f = c
3372 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3373 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3374 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3375 ".<locals>.E(f=...)))")
3376
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303377 def test_recursive_repr_misc_attrs(self):
3378 @dataclass
3379 class C:
3380 f: "C"
3381 g: int
3382
3383 c = C(None, 1)
3384 c.f = c
3385 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3386 ".<locals>.C(f=..., g=1)")
3387
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003388 ## def test_initvar(self):
3389 ## @dataclass
3390 ## class C:
3391 ## x: int
3392 ## y: InitVar[int]
3393
3394 ## c = C(1, 10)
3395 ## d = C(2, 20)
3396
3397 ## # In our case, replacing an InitVar is a no-op
3398 ## self.assertEqual(c, replace(c, y=5))
3399
3400 ## replace(c, x=5)
3401
Ben Avrahamibef7d292020-10-06 20:40:50 +03003402class TestAbstract(unittest.TestCase):
3403 def test_abc_implementation(self):
3404 class Ordered(abc.ABC):
3405 @abc.abstractmethod
3406 def __lt__(self, other):
3407 pass
3408
3409 @abc.abstractmethod
3410 def __le__(self, other):
3411 pass
3412
3413 @dataclass(order=True)
3414 class Date(Ordered):
3415 year: int
3416 month: 'Month'
3417 day: 'int'
3418
3419 self.assertFalse(inspect.isabstract(Date))
3420 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3421
3422 def test_maintain_abc(self):
3423 class A(abc.ABC):
3424 @abc.abstractmethod
3425 def foo(self):
3426 pass
3427
3428 @dataclass
3429 class Date(A):
3430 year: int
3431 month: 'Month'
3432 day: 'int'
3433
3434 self.assertTrue(inspect.isabstract(Date))
3435 msg = 'class Date with abstract method foo'
3436 self.assertRaisesRegex(TypeError, msg, Date)
3437
Eric V. Smith4e812962018-05-16 11:31:29 -04003438
Brandt Bucher145bf262021-02-26 14:51:55 -08003439class TestMatchArgs(unittest.TestCase):
3440 def test_match_args(self):
3441 @dataclass
3442 class C:
3443 a: int
3444 self.assertEqual(C(42).__match_args__, ('a',))
3445
3446 def test_explicit_match_args(self):
Brandt Bucherf84d5a12021-04-05 19:17:08 -07003447 ma = ()
Brandt Bucher145bf262021-02-26 14:51:55 -08003448 @dataclass
3449 class C:
3450 a: int
3451 __match_args__ = ma
3452 self.assertIs(C(42).__match_args__, ma)
3453
Brandt Bucherd92c59f2021-04-08 12:54:34 -07003454 def test_bpo_43764(self):
3455 @dataclass(repr=False, eq=False, init=False)
3456 class X:
3457 a: int
3458 b: int
3459 c: int
3460 self.assertEqual(X.__match_args__, ("a", "b", "c"))
3461
Eric V. Smith750f4842021-04-10 21:28:42 -04003462 def test_match_args_argument(self):
3463 @dataclass(match_args=False)
3464 class X:
3465 a: int
3466 self.assertNotIn('__match_args__', X.__dict__)
3467
3468 @dataclass(match_args=False)
3469 class Y:
3470 a: int
3471 __match_args__ = ('b',)
3472 self.assertEqual(Y.__match_args__, ('b',))
3473
3474 @dataclass(match_args=False)
3475 class Z(Y):
3476 z: int
3477 self.assertEqual(Z.__match_args__, ('b',))
3478
3479 # Ensure parent dataclass __match_args__ is seen, if child class
3480 # specifies match_args=False.
3481 @dataclass
3482 class A:
3483 a: int
3484 z: int
3485 @dataclass(match_args=False)
3486 class B(A):
3487 b: int
3488 self.assertEqual(B.__match_args__, ('a', 'z'))
3489
3490 def test_make_dataclasses(self):
3491 C = make_dataclass('C', [('x', int), ('y', int)])
3492 self.assertEqual(C.__match_args__, ('x', 'y'))
3493
3494 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
3495 self.assertEqual(C.__match_args__, ('x', 'y'))
3496
3497 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
3498 self.assertNotIn('__match__args__', C.__dict__)
3499
3500 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
3501 self.assertEqual(C.__match_args__, ('z',))
3502
Brandt Bucher145bf262021-02-26 14:51:55 -08003503
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003504if __name__ == '__main__':
3505 unittest.main()