blob: bdcb4a2cfd1a07c8d6abf9c33b6b77bd11e25779 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011import unittest
12from unittest.mock import Mock
Miss Islington (bot)79e9f5a2021-09-02 23:26:53 -070013from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
Yury Selivanovd219cc42019-12-09 09:54:20 -050014from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050016from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050017
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040018import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
19import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
20
Eric V. Smithf0db54a2017-12-04 16:58:55 -050021# Just any custom exception we can catch.
22class CustomError(Exception): pass
23
24class TestCase(unittest.TestCase):
25 def test_no_fields(self):
26 @dataclass
27 class C:
28 pass
29
30 o = C()
31 self.assertEqual(len(fields(C)), 0)
32
Eric V. Smith56970b82018-03-22 16:28:48 -040033 def test_no_fields_but_member_variable(self):
34 @dataclass
35 class C:
36 i = 0
37
38 o = C()
39 self.assertEqual(len(fields(C)), 0)
40
Eric V. Smithf0db54a2017-12-04 16:58:55 -050041 def test_one_field_no_default(self):
42 @dataclass
43 class C:
44 x: int
45
46 o = C(42)
47 self.assertEqual(o.x, 42)
48
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053049 def test_field_default_default_factory_error(self):
50 msg = "cannot specify both default and default_factory"
51 with self.assertRaisesRegex(ValueError, msg):
52 @dataclass
53 class C:
54 x: int = field(default=1, default_factory=int)
55
56 def test_field_repr(self):
57 int_field = field(default=1, init=True, repr=False)
58 int_field.name = "id"
59 repr_output = repr(int_field)
60 expected_output = "Field(name='id',type=None," \
61 f"default=1,default_factory={MISSING!r}," \
62 "init=True,repr=False,hash=None," \
63 "compare=True,metadata=mappingproxy({})," \
Eric V. Smithc0280532021-04-25 20:42:39 -040064 f"kw_only={MISSING!r}," \
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053065 "_field_type=None)"
66
67 self.assertEqual(repr_output, expected_output)
68
Eric V. Smithf0db54a2017-12-04 16:58:55 -050069 def test_named_init_params(self):
70 @dataclass
71 class C:
72 x: int
73
74 o = C(x=32)
75 self.assertEqual(o.x, 32)
76
77 def test_two_fields_one_default(self):
78 @dataclass
79 class C:
80 x: int
81 y: int = 0
82
83 o = C(3)
84 self.assertEqual((o.x, o.y), (3, 0))
85
86 # Non-defaults following defaults.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class C:
92 x: int = 0
93 y: int
94
95 # A derived class adds a non-default field after a default one.
96 with self.assertRaisesRegex(TypeError,
97 "non-default argument 'y' follows "
98 "default argument"):
99 @dataclass
100 class B:
101 x: int = 0
102
103 @dataclass
104 class C(B):
105 y: int
106
107 # Override a base class field and add a default to
108 # a field which didn't use to have a default.
109 with self.assertRaisesRegex(TypeError,
110 "non-default argument 'y' follows "
111 "default argument"):
112 @dataclass
113 class B:
114 x: int
115 y: int
116
117 @dataclass
118 class C(B):
119 x: int = 0
120
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500121 def test_overwrite_hash(self):
122 # Test that declaring this class isn't an error. It should
123 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500124 @dataclass(frozen=True)
125 class C:
126 x: int
127 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500128 return 301
129 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500130
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500131 # Test that declaring this class isn't an error. It should
132 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500133 @dataclass(frozen=True)
134 class C:
135 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500136 def __eq__(self, other):
137 return False
138 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500140 # But this one should generate an exception, because with
141 # unsafe_hash=True, it's an error to have a __hash__ defined.
142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 def __hash__(self):
147 pass
148
149 # Creating this class should not generate an exception,
150 # because even though __hash__ exists before @dataclass is
151 # called, (due to __eq__ being defined), since it's None
152 # that's okay.
153 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500154 class C:
155 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500156 def __eq__(self):
157 pass
158 # The generated hash function works as we'd expect.
159 self.assertEqual(hash(C(10)), hash((10,)))
160
161 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400162 # __hash__ exists and is not None, which it would be if it
163 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500164 with self.assertRaisesRegex(TypeError,
165 'Cannot overwrite attribute __hash__'):
166 @dataclass(unsafe_hash=True)
167 class C:
168 x: int
169 def __eq__(self):
170 pass
171 def __hash__(self):
172 pass
173
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500174 def test_overwrite_fields_in_derived_class(self):
175 # Note that x from C1 replaces x in Base, but the order remains
176 # the same as defined in Base.
177 @dataclass
178 class Base:
179 x: Any = 15.0
180 y: int = 0
181
182 @dataclass
183 class C1(Base):
184 z: int = 10
185 x: int = 15
186
187 o = Base()
188 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
189
190 o = C1()
191 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
192
193 o = C1(x=5)
194 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
195
196 def test_field_named_self(self):
197 @dataclass
198 class C:
199 self: str
200 c=C('foo')
201 self.assertEqual(c.self, 'foo')
202
203 # Make sure the first parameter is not named 'self'.
204 sig = inspect.signature(C.__init__)
205 first = next(iter(sig.parameters))
206 self.assertNotEqual('self', first)
207
208 # But we do use 'self' if no field named self.
209 @dataclass
210 class C:
211 selfx: str
212
213 # Make sure the first parameter is named 'self'.
214 sig = inspect.signature(C.__init__)
215 first = next(iter(sig.parameters))
216 self.assertEqual('self', first)
217
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300218 def test_field_named_object(self):
219 @dataclass
220 class C:
221 object: str
222 c = C('foo')
223 self.assertEqual(c.object, 'foo')
224
225 def test_field_named_object_frozen(self):
226 @dataclass(frozen=True)
227 class C:
228 object: str
229 c = C('foo')
230 self.assertEqual(c.object, 'foo')
231
232 def test_field_named_like_builtin(self):
233 # Attribute names can shadow built-in names
234 # since code generation is used.
235 # Ensure that this is not happening.
236 exclusions = {'None', 'True', 'False'}
237 builtins_names = sorted(
238 b for b in builtins.__dict__.keys()
239 if not b.startswith('__') and b not in exclusions
240 )
241 attributes = [(name, str) for name in builtins_names]
242 C = make_dataclass('C', attributes)
243
244 c = C(*[name for name in builtins_names])
245
246 for name in builtins_names:
247 self.assertEqual(getattr(c, name), name)
248
249 def test_field_named_like_builtin_frozen(self):
250 # Attribute names can shadow built-in names
251 # since code generation is used.
252 # Ensure that this is not happening
253 # for frozen data classes.
254 exclusions = {'None', 'True', 'False'}
255 builtins_names = sorted(
256 b for b in builtins.__dict__.keys()
257 if not b.startswith('__') and b not in exclusions
258 )
259 attributes = [(name, str) for name in builtins_names]
260 C = make_dataclass('C', attributes, frozen=True)
261
262 c = C(*[name for name in builtins_names])
263
264 for name in builtins_names:
265 self.assertEqual(getattr(c, name), name)
266
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500267 def test_0_field_compare(self):
268 # Ensure that order=False is the default.
269 @dataclass
270 class C0:
271 pass
272
273 @dataclass(order=False)
274 class C1:
275 pass
276
277 for cls in [C0, C1]:
278 with self.subTest(cls=cls):
279 self.assertEqual(cls(), cls())
280 for idx, fn in enumerate([lambda a, b: a < b,
281 lambda a, b: a <= b,
282 lambda a, b: a > b,
283 lambda a, b: a >= b]):
284 with self.subTest(idx=idx):
285 with self.assertRaisesRegex(TypeError,
286 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
287 fn(cls(), cls())
288
289 @dataclass(order=True)
290 class C:
291 pass
292 self.assertLessEqual(C(), C())
293 self.assertGreaterEqual(C(), C())
294
295 def test_1_field_compare(self):
296 # Ensure that order=False is the default.
297 @dataclass
298 class C0:
299 x: int
300
301 @dataclass(order=False)
302 class C1:
303 x: int
304
305 for cls in [C0, C1]:
306 with self.subTest(cls=cls):
307 self.assertEqual(cls(1), cls(1))
308 self.assertNotEqual(cls(0), cls(1))
309 for idx, fn in enumerate([lambda a, b: a < b,
310 lambda a, b: a <= b,
311 lambda a, b: a > b,
312 lambda a, b: a >= b]):
313 with self.subTest(idx=idx):
314 with self.assertRaisesRegex(TypeError,
315 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
316 fn(cls(0), cls(0))
317
318 @dataclass(order=True)
319 class C:
320 x: int
321 self.assertLess(C(0), C(1))
322 self.assertLessEqual(C(0), C(1))
323 self.assertLessEqual(C(1), C(1))
324 self.assertGreater(C(1), C(0))
325 self.assertGreaterEqual(C(1), C(0))
326 self.assertGreaterEqual(C(1), C(1))
327
328 def test_simple_compare(self):
329 # Ensure that order=False is the default.
330 @dataclass
331 class C0:
332 x: int
333 y: int
334
335 @dataclass(order=False)
336 class C1:
337 x: int
338 y: int
339
340 for cls in [C0, C1]:
341 with self.subTest(cls=cls):
342 self.assertEqual(cls(0, 0), cls(0, 0))
343 self.assertEqual(cls(1, 2), cls(1, 2))
344 self.assertNotEqual(cls(1, 0), cls(0, 0))
345 self.assertNotEqual(cls(1, 0), cls(1, 1))
346 for idx, fn in enumerate([lambda a, b: a < b,
347 lambda a, b: a <= b,
348 lambda a, b: a > b,
349 lambda a, b: a >= b]):
350 with self.subTest(idx=idx):
351 with self.assertRaisesRegex(TypeError,
352 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
353 fn(cls(0, 0), cls(0, 0))
354
355 @dataclass(order=True)
356 class C:
357 x: int
358 y: int
359
360 for idx, fn in enumerate([lambda a, b: a == b,
361 lambda a, b: a <= b,
362 lambda a, b: a >= b]):
363 with self.subTest(idx=idx):
364 self.assertTrue(fn(C(0, 0), C(0, 0)))
365
366 for idx, fn in enumerate([lambda a, b: a < b,
367 lambda a, b: a <= b,
368 lambda a, b: a != b]):
369 with self.subTest(idx=idx):
370 self.assertTrue(fn(C(0, 0), C(0, 1)))
371 self.assertTrue(fn(C(0, 1), C(1, 0)))
372 self.assertTrue(fn(C(1, 0), C(1, 1)))
373
374 for idx, fn in enumerate([lambda a, b: a > b,
375 lambda a, b: a >= b,
376 lambda a, b: a != b]):
377 with self.subTest(idx=idx):
378 self.assertTrue(fn(C(0, 1), C(0, 0)))
379 self.assertTrue(fn(C(1, 0), C(0, 1)))
380 self.assertTrue(fn(C(1, 1), C(1, 0)))
381
382 def test_compare_subclasses(self):
383 # Comparisons fail for subclasses, even if no fields
384 # are added.
385 @dataclass
386 class B:
387 i: int
388
389 @dataclass
390 class C(B):
391 pass
392
393 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
394 (lambda a, b: a != b, True)]):
395 with self.subTest(idx=idx):
396 self.assertEqual(fn(B(0), C(0)), expected)
397
398 for idx, fn in enumerate([lambda a, b: a < b,
399 lambda a, b: a <= b,
400 lambda a, b: a > b,
401 lambda a, b: a >= b]):
402 with self.subTest(idx=idx):
403 with self.assertRaisesRegex(TypeError,
404 "not supported between instances of 'B' and 'C'"):
405 fn(B(0), C(0))
406
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500407 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500408 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500409 for (eq, order, result ) in [
410 (False, False, 'neither'),
411 (False, True, 'exception'),
412 (True, False, 'eq_only'),
413 (True, True, 'both'),
414 ]:
415 with self.subTest(eq=eq, order=order):
416 if result == 'exception':
417 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
418 @dataclass(eq=eq, order=order)
419 class C:
420 pass
421 else:
422 @dataclass(eq=eq, order=order)
423 class C:
424 pass
425
426 if result == 'neither':
427 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500428 self.assertNotIn('__lt__', C.__dict__)
429 self.assertNotIn('__le__', C.__dict__)
430 self.assertNotIn('__gt__', C.__dict__)
431 self.assertNotIn('__ge__', C.__dict__)
432 elif result == 'both':
433 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 self.assertIn('__lt__', C.__dict__)
435 self.assertIn('__le__', C.__dict__)
436 self.assertIn('__gt__', C.__dict__)
437 self.assertIn('__ge__', C.__dict__)
438 elif result == 'eq_only':
439 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 self.assertNotIn('__lt__', C.__dict__)
441 self.assertNotIn('__le__', C.__dict__)
442 self.assertNotIn('__gt__', C.__dict__)
443 self.assertNotIn('__ge__', C.__dict__)
444 else:
445 assert False, f'unknown result {result!r}'
446
447 def test_field_no_default(self):
448 @dataclass
449 class C:
450 x: int = field()
451
452 self.assertEqual(C(5).x, 5)
453
454 with self.assertRaisesRegex(TypeError,
455 r"__init__\(\) missing 1 required "
456 "positional argument: 'x'"):
457 C()
458
459 def test_field_default(self):
460 default = object()
461 @dataclass
462 class C:
463 x: object = field(default=default)
464
465 self.assertIs(C.x, default)
466 c = C(10)
467 self.assertEqual(c.x, 10)
468
469 # If we delete the instance attribute, we should then see the
470 # class attribute.
471 del c.x
472 self.assertIs(c.x, default)
473
474 self.assertIs(C().x, default)
475
476 def test_not_in_repr(self):
477 @dataclass
478 class C:
479 x: int = field(repr=False)
480 with self.assertRaises(TypeError):
481 C()
482 c = C(10)
483 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
484
485 @dataclass
486 class C:
487 x: int = field(repr=False)
488 y: int
489 c = C(10, 20)
490 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
491
492 def test_not_in_compare(self):
493 @dataclass
494 class C:
495 x: int = 0
496 y: int = field(compare=False, default=4)
497
498 self.assertEqual(C(), C(0, 20))
499 self.assertEqual(C(1, 10), C(1, 20))
500 self.assertNotEqual(C(3), C(4, 10))
501 self.assertNotEqual(C(3, 10), C(4, 10))
502
503 def test_hash_field_rules(self):
504 # Test all 6 cases of:
505 # hash=True/False/None
506 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500507 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500508 (True, False, 'field' ),
509 (True, True, 'field' ),
510 (False, False, 'absent'),
511 (False, True, 'absent'),
512 (None, False, 'absent'),
513 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500514 ]:
515 with self.subTest(hash=hash_, compare=compare):
516 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500517 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500518 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500519
520 if result == 'field':
521 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500522 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500523 elif result == 'absent':
524 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500525 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500526 else:
527 assert False, f'unknown result {result!r}'
528
529 def test_init_false_no_default(self):
530 # If init=False and no default value, then the field won't be
531 # present in the instance.
532 @dataclass
533 class C:
534 x: int = field(init=False)
535
536 self.assertNotIn('x', C().__dict__)
537
538 @dataclass
539 class C:
540 x: int
541 y: int = 0
542 z: int = field(init=False)
543 t: int = 10
544
545 self.assertNotIn('z', C(0).__dict__)
546 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
547
548 def test_class_marker(self):
549 @dataclass
550 class C:
551 x: int
552 y: str = field(init=False, default=None)
553 z: str = field(repr=False)
554
555 the_fields = fields(C)
556 # the_fields is a tuple of 3 items, each value
557 # is in __annotations__.
558 self.assertIsInstance(the_fields, tuple)
559 for f in the_fields:
560 self.assertIs(type(f), Field)
561 self.assertIn(f.name, C.__annotations__)
562
563 self.assertEqual(len(the_fields), 3)
564
565 self.assertEqual(the_fields[0].name, 'x')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100566 self.assertEqual(the_fields[0].type, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500567 self.assertFalse(hasattr(C, 'x'))
568 self.assertTrue (the_fields[0].init)
569 self.assertTrue (the_fields[0].repr)
570 self.assertEqual(the_fields[1].name, 'y')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100571 self.assertEqual(the_fields[1].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500572 self.assertIsNone(getattr(C, 'y'))
573 self.assertFalse(the_fields[1].init)
574 self.assertTrue (the_fields[1].repr)
575 self.assertEqual(the_fields[2].name, 'z')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100576 self.assertEqual(the_fields[2].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500577 self.assertFalse(hasattr(C, 'z'))
578 self.assertTrue (the_fields[2].init)
579 self.assertFalse(the_fields[2].repr)
580
581 def test_field_order(self):
582 @dataclass
583 class B:
584 a: str = 'B:a'
585 b: str = 'B:b'
586 c: str = 'B:c'
587
588 @dataclass
589 class C(B):
590 b: str = 'C:b'
591
592 self.assertEqual([(f.name, f.default) for f in fields(C)],
593 [('a', 'B:a'),
594 ('b', 'C:b'),
595 ('c', 'B:c')])
596
597 @dataclass
598 class D(B):
599 c: str = 'D:c'
600
601 self.assertEqual([(f.name, f.default) for f in fields(D)],
602 [('a', 'B:a'),
603 ('b', 'B:b'),
604 ('c', 'D:c')])
605
606 @dataclass
607 class E(D):
608 a: str = 'E:a'
609 d: str = 'E:d'
610
611 self.assertEqual([(f.name, f.default) for f in fields(E)],
612 [('a', 'E:a'),
613 ('b', 'B:b'),
614 ('c', 'D:c'),
615 ('d', 'E:d')])
616
617 def test_class_attrs(self):
618 # We only have a class attribute if a default value is
619 # specified, either directly or via a field with a default.
620 default = object()
621 @dataclass
622 class C:
623 x: int
624 y: int = field(repr=False)
625 z: object = default
626 t: int = field(default=100)
627
628 self.assertFalse(hasattr(C, 'x'))
629 self.assertFalse(hasattr(C, 'y'))
630 self.assertIs (C.z, default)
631 self.assertEqual(C.t, 100)
632
633 def test_disallowed_mutable_defaults(self):
634 # For the known types, don't allow mutable default values.
635 for typ, empty, non_empty in [(list, [], [1]),
636 (dict, {}, {0:1}),
637 (set, set(), set([1])),
638 ]:
639 with self.subTest(typ=typ):
640 # Can't use a zero-length value.
641 with self.assertRaisesRegex(ValueError,
642 f'mutable default {typ} for field '
643 'x is not allowed'):
644 @dataclass
645 class Point:
646 x: typ = empty
647
648
649 # Nor a non-zero-length value
650 with self.assertRaisesRegex(ValueError,
651 f'mutable default {typ} for field '
652 'y is not allowed'):
653 @dataclass
654 class Point:
655 y: typ = non_empty
656
657 # Check subtypes also fail.
658 class Subclass(typ): pass
659
660 with self.assertRaisesRegex(ValueError,
661 f"mutable default .*Subclass'>"
662 ' for field z is not allowed'
663 ):
664 @dataclass
665 class Point:
666 z: typ = Subclass()
667
668 # Because this is a ClassVar, it can be mutable.
669 @dataclass
670 class C:
671 z: ClassVar[typ] = typ()
672
673 # Because this is a ClassVar, it can be mutable.
674 @dataclass
675 class C:
676 x: ClassVar[typ] = Subclass()
677
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500678 def test_deliberately_mutable_defaults(self):
679 # If a mutable default isn't in the known list of
680 # (list, dict, set), then it's okay.
681 class Mutable:
682 def __init__(self):
683 self.l = []
684
685 @dataclass
686 class C:
687 x: Mutable
688
689 # These 2 instances will share this value of x.
690 lst = Mutable()
691 o1 = C(lst)
692 o2 = C(lst)
693 self.assertEqual(o1, o2)
694 o1.x.l.extend([1, 2])
695 self.assertEqual(o1, o2)
696 self.assertEqual(o1.x.l, [1, 2])
697 self.assertIs(o1.x, o2.x)
698
699 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400700 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 @dataclass()
702 class C:
703 x: int
704
705 self.assertEqual(C(42).x, 42)
706
707 def test_not_tuple(self):
708 # Make sure we can't be compared to a tuple.
709 @dataclass
710 class Point:
711 x: int
712 y: int
713 self.assertNotEqual(Point(1, 2), (1, 2))
714
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400715 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716 @dataclass
717 class C:
718 x: int
719 y: int
720 self.assertNotEqual(Point(1, 3), C(1, 3))
721
Windson yangbe372d72019-04-23 02:45:34 +0800722 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # Test that some of the problems with namedtuple don't happen
724 # here.
725 @dataclass
726 class Point3D:
727 x: int
728 y: int
729 z: int
730
731 @dataclass
732 class Date:
733 year: int
734 month: int
735 day: int
736
737 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
738 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
739
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400740 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200741 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 x, y, z = Point3D(4, 5, 6)
743
Eric V. Smith7c99e932018-01-28 19:18:55 -0500744 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500745 # equal.
746 @dataclass
747 class Point3Dv1:
748 x: int = 0
749 y: int = 0
750 z: int = 0
751 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
752
753 def test_function_annotations(self):
754 # Some dummy class and instance to use as a default.
755 class F:
756 pass
757 f = F()
758
759 def validate_class(cls):
760 # First, check __annotations__, even though they're not
761 # function annotations.
Pablo Galindob0544ba2021-04-21 12:41:19 +0100762 self.assertEqual(cls.__annotations__['i'], int)
763 self.assertEqual(cls.__annotations__['j'], str)
764 self.assertEqual(cls.__annotations__['k'], F)
765 self.assertEqual(cls.__annotations__['l'], float)
766 self.assertEqual(cls.__annotations__['z'], complex)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500767
768 # Verify __init__.
769
770 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400771 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500772 self.assertIs(signature.return_annotation, None)
773
774 # Check each parameter.
775 params = iter(signature.parameters.values())
776 param = next(params)
777 # This is testing an internal name, and probably shouldn't be tested.
778 self.assertEqual(param.name, 'self')
779 param = next(params)
780 self.assertEqual(param.name, 'i')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100781 self.assertIs (param.annotation, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500782 self.assertEqual(param.default, inspect.Parameter.empty)
783 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
784 param = next(params)
785 self.assertEqual(param.name, 'j')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100786 self.assertIs (param.annotation, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500787 self.assertEqual(param.default, inspect.Parameter.empty)
788 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
789 param = next(params)
790 self.assertEqual(param.name, 'k')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100791 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400792 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500793 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
794 param = next(params)
795 self.assertEqual(param.name, 'l')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100796 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400797 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500798 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
799 self.assertRaises(StopIteration, next, params)
800
801
802 @dataclass
803 class C:
804 i: int
805 j: str
806 k: F = f
807 l: float=field(default=None)
808 z: complex=field(default=3+4j, init=False)
809
810 validate_class(C)
811
812 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500813 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500814 class C:
815 i: int
816 j: str
817 k: F = f
818 l: float=field(default=None)
819 z: complex=field(default=3+4j, init=False)
820
821 validate_class(C)
822
Eric V. Smith03220fd2017-12-29 13:59:58 -0500823 def test_missing_default(self):
824 # Test that MISSING works the same as a default not being
825 # specified.
826 @dataclass
827 class C:
828 x: int=field(default=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_default_factory(self):
845 # Test that MISSING works the same as a default factory not
846 # being specified (which is really the same as a default not
847 # being specified, too).
848 @dataclass
849 class C:
850 x: int=field(default_factory=MISSING)
851 with self.assertRaisesRegex(TypeError,
852 r'__init__\(\) missing 1 required '
853 'positional argument'):
854 C()
855 self.assertNotIn('x', C.__dict__)
856
857 @dataclass
858 class D:
859 x: int=field(default=MISSING, default_factory=MISSING)
860 with self.assertRaisesRegex(TypeError,
861 r'__init__\(\) missing 1 required '
862 'positional argument'):
863 D()
864 self.assertNotIn('x', D.__dict__)
865
866 def test_missing_repr(self):
867 self.assertIn('MISSING_TYPE object', repr(MISSING))
868
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500869 def test_dont_include_other_annotations(self):
870 @dataclass
871 class C:
872 i: int
873 def foo(self) -> int:
874 return 4
875 @property
876 def bar(self) -> int:
877 return 5
878 self.assertEqual(list(C.__annotations__), ['i'])
879 self.assertEqual(C(10).foo(), 4)
880 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400881 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500882
883 def test_post_init(self):
884 # Just make sure it gets called
885 @dataclass
886 class C:
887 def __post_init__(self):
888 raise CustomError()
889 with self.assertRaises(CustomError):
890 C()
891
892 @dataclass
893 class C:
894 i: int = 10
895 def __post_init__(self):
896 if self.i == 10:
897 raise CustomError()
898 with self.assertRaises(CustomError):
899 C()
900 # post-init gets called, but doesn't raise. This is just
901 # checking that self is used correctly.
902 C(5)
903
904 # If there's not an __init__, then post-init won't get called.
905 @dataclass(init=False)
906 class C:
907 def __post_init__(self):
908 raise CustomError()
909 # Creating the class won't raise
910 C()
911
912 @dataclass
913 class C:
914 x: int = 0
915 def __post_init__(self):
916 self.x *= 2
917 self.assertEqual(C().x, 0)
918 self.assertEqual(C(2).x, 4)
919
Mike53f7a7c2017-12-14 14:04:53 +0300920 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500921 # attributes.
922 @dataclass(frozen=True)
923 class C:
924 x: int = 0
925 def __post_init__(self):
926 self.x *= 2
927 with self.assertRaises(FrozenInstanceError):
928 C()
929
930 def test_post_init_super(self):
931 # Make sure super() post-init isn't called by default.
932 class B:
933 def __post_init__(self):
934 raise CustomError()
935
936 @dataclass
937 class C(B):
938 def __post_init__(self):
939 self.x = 5
940
941 self.assertEqual(C().x, 5)
942
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400943 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500944 @dataclass
945 class C(B):
946 def __post_init__(self):
947 super().__post_init__()
948
949 with self.assertRaises(CustomError):
950 C()
951
952 # Make sure post-init is called, even if not defined in our
953 # class.
954 @dataclass
955 class C(B):
956 pass
957
958 with self.assertRaises(CustomError):
959 C()
960
961 def test_post_init_staticmethod(self):
962 flag = False
963 @dataclass
964 class C:
965 x: int
966 y: int
967 @staticmethod
968 def __post_init__():
969 nonlocal flag
970 flag = True
971
972 self.assertFalse(flag)
973 c = C(3, 4)
974 self.assertEqual((c.x, c.y), (3, 4))
975 self.assertTrue(flag)
976
977 def test_post_init_classmethod(self):
978 @dataclass
979 class C:
980 flag = False
981 x: int
982 y: int
983 @classmethod
984 def __post_init__(cls):
985 cls.flag = True
986
987 self.assertFalse(C.flag)
988 c = C(3, 4)
989 self.assertEqual((c.x, c.y), (3, 4))
990 self.assertTrue(C.flag)
991
992 def test_class_var(self):
993 # Make sure ClassVars are ignored in __init__, __repr__, etc.
994 @dataclass
995 class C:
996 x: int
997 y: int = 10
998 z: ClassVar[int] = 1000
999 w: ClassVar[int] = 2000
1000 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001001 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001002
1003 c = C(5)
1004 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001005 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001006 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001007 self.assertEqual(c.z, 1000)
1008 self.assertEqual(c.w, 2000)
1009 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001010 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001011 C.z += 1
1012 self.assertEqual(c.z, 1001)
1013 c = C(20)
1014 self.assertEqual((c.x, c.y), (20, 10))
1015 self.assertEqual(c.z, 1001)
1016 self.assertEqual(c.w, 2000)
1017 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001018 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001019
1020 def test_class_var_no_default(self):
1021 # If a ClassVar has no default value, it should not be set on the class.
1022 @dataclass
1023 class C:
1024 x: ClassVar[int]
1025
1026 self.assertNotIn('x', C.__dict__)
1027
1028 def test_class_var_default_factory(self):
1029 # It makes no sense for a ClassVar to have a default factory. When
1030 # would it be called? Call it yourself, since it's class-wide.
1031 with self.assertRaisesRegex(TypeError,
1032 'cannot have a default factory'):
1033 @dataclass
1034 class C:
1035 x: ClassVar[int] = field(default_factory=int)
1036
1037 self.assertNotIn('x', C.__dict__)
1038
1039 def test_class_var_with_default(self):
1040 # If a ClassVar has a default value, it should be set on the class.
1041 @dataclass
1042 class C:
1043 x: ClassVar[int] = 10
1044 self.assertEqual(C.x, 10)
1045
1046 @dataclass
1047 class C:
1048 x: ClassVar[int] = field(default=10)
1049 self.assertEqual(C.x, 10)
1050
1051 def test_class_var_frozen(self):
1052 # Make sure ClassVars work even if we're frozen.
1053 @dataclass(frozen=True)
1054 class C:
1055 x: int
1056 y: int = 10
1057 z: ClassVar[int] = 1000
1058 w: ClassVar[int] = 2000
1059 t: ClassVar[int] = 3000
1060
1061 c = C(5)
1062 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1063 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1064 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1065 self.assertEqual(c.z, 1000)
1066 self.assertEqual(c.w, 2000)
1067 self.assertEqual(c.t, 3000)
1068 # We can still modify the ClassVar, it's only instances that are
1069 # frozen.
1070 C.z += 1
1071 self.assertEqual(c.z, 1001)
1072 c = C(20)
1073 self.assertEqual((c.x, c.y), (20, 10))
1074 self.assertEqual(c.z, 1001)
1075 self.assertEqual(c.w, 2000)
1076 self.assertEqual(c.t, 3000)
1077
1078 def test_init_var_no_default(self):
1079 # If an InitVar has no default value, it should not be set on the class.
1080 @dataclass
1081 class C:
1082 x: InitVar[int]
1083
1084 self.assertNotIn('x', C.__dict__)
1085
1086 def test_init_var_default_factory(self):
1087 # It makes no sense for an InitVar to have a default factory. When
1088 # would it be called? Call it yourself, since it's class-wide.
1089 with self.assertRaisesRegex(TypeError,
1090 'cannot have a default factory'):
1091 @dataclass
1092 class C:
1093 x: InitVar[int] = field(default_factory=int)
1094
1095 self.assertNotIn('x', C.__dict__)
1096
1097 def test_init_var_with_default(self):
1098 # If an InitVar has a default value, it should be set on the class.
1099 @dataclass
1100 class C:
1101 x: InitVar[int] = 10
1102 self.assertEqual(C.x, 10)
1103
1104 @dataclass
1105 class C:
1106 x: InitVar[int] = field(default=10)
1107 self.assertEqual(C.x, 10)
1108
1109 def test_init_var(self):
1110 @dataclass
1111 class C:
1112 x: int = None
1113 init_param: InitVar[int] = None
1114
1115 def __post_init__(self, init_param):
1116 if self.x is None:
1117 self.x = init_param*2
1118
1119 c = C(init_param=10)
1120 self.assertEqual(c.x, 20)
1121
Augusto Hack01ee12b2019-06-02 23:14:48 -03001122 def test_init_var_preserve_type(self):
1123 self.assertEqual(InitVar[int].type, int)
1124
1125 # Make sure the repr is correct.
1126 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001127 self.assertEqual(repr(InitVar[List[int]]),
1128 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001129
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001130 def test_init_var_inheritance(self):
1131 # Note that this deliberately tests that a dataclass need not
1132 # have a __post_init__ function if it has an InitVar field.
1133 # It could just be used in a derived class, as shown here.
1134 @dataclass
1135 class Base:
1136 x: int
1137 init_base: InitVar[int]
1138
1139 # We can instantiate by passing the InitVar, even though
1140 # it's not used.
1141 b = Base(0, 10)
1142 self.assertEqual(vars(b), {'x': 0})
1143
1144 @dataclass
1145 class C(Base):
1146 y: int
1147 init_derived: InitVar[int]
1148
1149 def __post_init__(self, init_base, init_derived):
1150 self.x = self.x + init_base
1151 self.y = self.y + init_derived
1152
1153 c = C(10, 11, 50, 51)
1154 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1155
1156 def test_default_factory(self):
1157 # Test a factory that returns a new list.
1158 @dataclass
1159 class C:
1160 x: int
1161 y: list = field(default_factory=list)
1162
1163 c0 = C(3)
1164 c1 = C(3)
1165 self.assertEqual(c0.x, 3)
1166 self.assertEqual(c0.y, [])
1167 self.assertEqual(c0, c1)
1168 self.assertIsNot(c0.y, c1.y)
1169 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1170
1171 # Test a factory that returns a shared list.
1172 l = []
1173 @dataclass
1174 class C:
1175 x: int
1176 y: list = field(default_factory=lambda: l)
1177
1178 c0 = C(3)
1179 c1 = C(3)
1180 self.assertEqual(c0.x, 3)
1181 self.assertEqual(c0.y, [])
1182 self.assertEqual(c0, c1)
1183 self.assertIs(c0.y, c1.y)
1184 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1185
1186 # Test various other field flags.
1187 # repr
1188 @dataclass
1189 class C:
1190 x: list = field(default_factory=list, repr=False)
1191 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1192 self.assertEqual(C().x, [])
1193
1194 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001195 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001196 class C:
1197 x: list = field(default_factory=list, hash=False)
1198 self.assertEqual(astuple(C()), ([],))
1199 self.assertEqual(hash(C()), hash(()))
1200
1201 # init (see also test_default_factory_with_no_init)
1202 @dataclass
1203 class C:
1204 x: list = field(default_factory=list, init=False)
1205 self.assertEqual(astuple(C()), ([],))
1206
1207 # compare
1208 @dataclass
1209 class C:
1210 x: list = field(default_factory=list, compare=False)
1211 self.assertEqual(C(), C([1]))
1212
1213 def test_default_factory_with_no_init(self):
1214 # We need a factory with a side effect.
1215 factory = Mock()
1216
1217 @dataclass
1218 class C:
1219 x: list = field(default_factory=factory, init=False)
1220
1221 # Make sure the default factory is called for each new instance.
1222 C().x
1223 self.assertEqual(factory.call_count, 1)
1224 C().x
1225 self.assertEqual(factory.call_count, 2)
1226
1227 def test_default_factory_not_called_if_value_given(self):
1228 # We need a factory that we can test if it's been called.
1229 factory = Mock()
1230
1231 @dataclass
1232 class C:
1233 x: int = field(default_factory=factory)
1234
1235 # Make sure that if a field has a default factory function,
1236 # it's not called if a value is specified.
1237 C().x
1238 self.assertEqual(factory.call_count, 1)
1239 self.assertEqual(C(10).x, 10)
1240 self.assertEqual(factory.call_count, 1)
1241 C().x
1242 self.assertEqual(factory.call_count, 2)
1243
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001244 def test_default_factory_derived(self):
1245 # See bpo-32896.
1246 @dataclass
1247 class Foo:
1248 x: dict = field(default_factory=dict)
1249
1250 @dataclass
1251 class Bar(Foo):
1252 y: int = 1
1253
1254 self.assertEqual(Foo().x, {})
1255 self.assertEqual(Bar().x, {})
1256 self.assertEqual(Bar().y, 1)
1257
1258 @dataclass
1259 class Baz(Foo):
1260 pass
1261 self.assertEqual(Baz().x, {})
1262
1263 def test_intermediate_non_dataclass(self):
1264 # Test that an intermediate class that defines
1265 # annotations does not define fields.
1266
1267 @dataclass
1268 class A:
1269 x: int
1270
1271 class B(A):
1272 y: int
1273
1274 @dataclass
1275 class C(B):
1276 z: int
1277
1278 c = C(1, 3)
1279 self.assertEqual((c.x, c.z), (1, 3))
1280
1281 # .y was not initialized.
1282 with self.assertRaisesRegex(AttributeError,
1283 'object has no attribute'):
1284 c.y
1285
1286 # And if we again derive a non-dataclass, no fields are added.
1287 class D(C):
1288 t: int
1289 d = D(4, 5)
1290 self.assertEqual((d.x, d.z), (4, 5))
1291
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001292 def test_classvar_default_factory(self):
1293 # It's an error for a ClassVar to have a factory function.
1294 with self.assertRaisesRegex(TypeError,
1295 'cannot have a default factory'):
1296 @dataclass
1297 class C:
1298 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001299
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001300 def test_is_dataclass(self):
1301 class NotDataClass:
1302 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001303
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001304 self.assertFalse(is_dataclass(0))
1305 self.assertFalse(is_dataclass(int))
1306 self.assertFalse(is_dataclass(NotDataClass))
1307 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001308
1309 @dataclass
1310 class C:
1311 x: int
1312
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001313 @dataclass
1314 class D:
1315 d: C
1316 e: int
1317
1318 c = C(10)
1319 d = D(c, 4)
1320
1321 self.assertTrue(is_dataclass(C))
1322 self.assertTrue(is_dataclass(c))
1323 self.assertFalse(is_dataclass(c.x))
1324 self.assertTrue(is_dataclass(d.d))
1325 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001326
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001327 def test_is_dataclass_when_getattr_always_returns(self):
1328 # See bpo-37868.
1329 class A:
1330 def __getattr__(self, key):
1331 return 0
1332 self.assertFalse(is_dataclass(A))
1333 a = A()
1334
1335 # Also test for an instance attribute.
1336 class B:
1337 pass
1338 b = B()
1339 b.__dataclass_fields__ = []
1340
1341 for obj in a, b:
1342 with self.subTest(obj=obj):
1343 self.assertFalse(is_dataclass(obj))
1344
1345 # Indirect tests for _is_dataclass_instance().
1346 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1347 asdict(obj)
1348 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1349 astuple(obj)
1350 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1351 replace(obj, x=0)
1352
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001353 def test_helper_fields_with_class_instance(self):
1354 # Check that we can call fields() on either a class or instance,
1355 # and get back the same thing.
1356 @dataclass
1357 class C:
1358 x: int
1359 y: float
1360
1361 self.assertEqual(fields(C), fields(C(0, 0.0)))
1362
1363 def test_helper_fields_exception(self):
1364 # Check that TypeError is raised if not passed a dataclass or
1365 # instance.
1366 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1367 fields(0)
1368
1369 class C: pass
1370 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1371 fields(C)
1372 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1373 fields(C())
1374
1375 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001376 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001377 @dataclass
1378 class C:
1379 x: int
1380 y: int
1381 c = C(1, 2)
1382
1383 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1384 self.assertEqual(asdict(c), asdict(c))
1385 self.assertIsNot(asdict(c), asdict(c))
1386 c.x = 42
1387 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1388 self.assertIs(type(asdict(c)), dict)
1389
1390 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001391 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001392 @dataclass
1393 class C:
1394 x: int
1395 y: int
1396 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1397 asdict(C)
1398 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1399 asdict(int)
1400
1401 def test_helper_asdict_copy_values(self):
1402 @dataclass
1403 class C:
1404 x: int
1405 y: List[int] = field(default_factory=list)
1406 initial = []
1407 c = C(1, initial)
1408 d = asdict(c)
1409 self.assertEqual(d['y'], initial)
1410 self.assertIsNot(d['y'], initial)
1411 c = C(1)
1412 d = asdict(c)
1413 d['y'].append(1)
1414 self.assertEqual(c.y, [])
1415
1416 def test_helper_asdict_nested(self):
1417 @dataclass
1418 class UserId:
1419 token: int
1420 group: int
1421 @dataclass
1422 class User:
1423 name: str
1424 id: UserId
1425 u = User('Joe', UserId(123, 1))
1426 d = asdict(u)
1427 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1428 self.assertIsNot(asdict(u), asdict(u))
1429 u.id.group = 2
1430 self.assertEqual(asdict(u), {'name': 'Joe',
1431 'id': {'token': 123, 'group': 2}})
1432
1433 def test_helper_asdict_builtin_containers(self):
1434 @dataclass
1435 class User:
1436 name: str
1437 id: int
1438 @dataclass
1439 class GroupList:
1440 id: int
1441 users: List[User]
1442 @dataclass
1443 class GroupTuple:
1444 id: int
1445 users: Tuple[User, ...]
1446 @dataclass
1447 class GroupDict:
1448 id: int
1449 users: Dict[str, User]
1450 a = User('Alice', 1)
1451 b = User('Bob', 2)
1452 gl = GroupList(0, [a, b])
1453 gt = GroupTuple(0, (a, b))
1454 gd = GroupDict(0, {'first': a, 'second': b})
1455 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1456 {'name': 'Bob', 'id': 2}]})
1457 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1458 {'name': 'Bob', 'id': 2})})
1459 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1460 'second': {'name': 'Bob', 'id': 2}}})
1461
Windson yangbe372d72019-04-23 02:45:34 +08001462 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001463 @dataclass
1464 class Child:
1465 d: object
1466
1467 @dataclass
1468 class Parent:
1469 child: Child
1470
1471 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1472 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1473
1474 def test_helper_asdict_factory(self):
1475 @dataclass
1476 class C:
1477 x: int
1478 y: int
1479 c = C(1, 2)
1480 d = asdict(c, dict_factory=OrderedDict)
1481 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1482 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1483 c.x = 42
1484 d = asdict(c, dict_factory=OrderedDict)
1485 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1486 self.assertIs(type(d), OrderedDict)
1487
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001488 def test_helper_asdict_namedtuple(self):
1489 T = namedtuple('T', 'a b c')
1490 @dataclass
1491 class C:
1492 x: str
1493 y: T
1494 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1495
1496 d = asdict(c)
1497 self.assertEqual(d, {'x': 'outer',
1498 'y': T(1,
1499 {'x': 'inner',
1500 'y': T(11, 12, 13)},
1501 2),
1502 }
1503 )
1504
1505 # Now with a dict_factory. OrderedDict is convenient, but
1506 # since it compares to dicts, we also need to have separate
1507 # assertIs tests.
1508 d = asdict(c, dict_factory=OrderedDict)
1509 self.assertEqual(d, {'x': 'outer',
1510 'y': T(1,
1511 {'x': 'inner',
1512 'y': T(11, 12, 13)},
1513 2),
1514 }
1515 )
1516
penguindustin96466302019-05-06 14:57:17 -04001517 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001518 self.assertIs(type(d), OrderedDict)
1519 self.assertIs(type(d['y'][1]), OrderedDict)
1520
1521 def test_helper_asdict_namedtuple_key(self):
1522 # Ensure that a field that contains a dict which has a
1523 # namedtuple as a key works with asdict().
1524
1525 @dataclass
1526 class C:
1527 f: dict
1528 T = namedtuple('T', 'a')
1529
1530 c = C({T('an a'): 0})
1531
1532 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1533
1534 def test_helper_asdict_namedtuple_derived(self):
1535 class T(namedtuple('Tbase', 'a')):
1536 def my_a(self):
1537 return self.a
1538
1539 @dataclass
1540 class C:
1541 f: T
1542
1543 t = T(6)
1544 c = C(t)
1545
1546 d = asdict(c)
1547 self.assertEqual(d, {'f': T(a=6)})
1548 # Make sure that t has been copied, not used directly.
1549 self.assertIsNot(d['f'], t)
1550 self.assertEqual(d['f'].my_a(), 6)
1551
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001552 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001553 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001554 @dataclass
1555 class C:
1556 x: int
1557 y: int = 0
1558 c = C(1)
1559
1560 self.assertEqual(astuple(c), (1, 0))
1561 self.assertEqual(astuple(c), astuple(c))
1562 self.assertIsNot(astuple(c), astuple(c))
1563 c.y = 42
1564 self.assertEqual(astuple(c), (1, 42))
1565 self.assertIs(type(astuple(c)), tuple)
1566
1567 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001568 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001569 @dataclass
1570 class C:
1571 x: int
1572 y: int
1573 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1574 astuple(C)
1575 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1576 astuple(int)
1577
1578 def test_helper_astuple_copy_values(self):
1579 @dataclass
1580 class C:
1581 x: int
1582 y: List[int] = field(default_factory=list)
1583 initial = []
1584 c = C(1, initial)
1585 t = astuple(c)
1586 self.assertEqual(t[1], initial)
1587 self.assertIsNot(t[1], initial)
1588 c = C(1)
1589 t = astuple(c)
1590 t[1].append(1)
1591 self.assertEqual(c.y, [])
1592
1593 def test_helper_astuple_nested(self):
1594 @dataclass
1595 class UserId:
1596 token: int
1597 group: int
1598 @dataclass
1599 class User:
1600 name: str
1601 id: UserId
1602 u = User('Joe', UserId(123, 1))
1603 t = astuple(u)
1604 self.assertEqual(t, ('Joe', (123, 1)))
1605 self.assertIsNot(astuple(u), astuple(u))
1606 u.id.group = 2
1607 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1608
1609 def test_helper_astuple_builtin_containers(self):
1610 @dataclass
1611 class User:
1612 name: str
1613 id: int
1614 @dataclass
1615 class GroupList:
1616 id: int
1617 users: List[User]
1618 @dataclass
1619 class GroupTuple:
1620 id: int
1621 users: Tuple[User, ...]
1622 @dataclass
1623 class GroupDict:
1624 id: int
1625 users: Dict[str, User]
1626 a = User('Alice', 1)
1627 b = User('Bob', 2)
1628 gl = GroupList(0, [a, b])
1629 gt = GroupTuple(0, (a, b))
1630 gd = GroupDict(0, {'first': a, 'second': b})
1631 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1632 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1633 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1634
Windson yangbe372d72019-04-23 02:45:34 +08001635 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001636 @dataclass
1637 class Child:
1638 d: object
1639
1640 @dataclass
1641 class Parent:
1642 child: Child
1643
1644 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1645 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1646
1647 def test_helper_astuple_factory(self):
1648 @dataclass
1649 class C:
1650 x: int
1651 y: int
1652 NT = namedtuple('NT', 'x y')
1653 def nt(lst):
1654 return NT(*lst)
1655 c = C(1, 2)
1656 t = astuple(c, tuple_factory=nt)
1657 self.assertEqual(t, NT(1, 2))
1658 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1659 c.x = 42
1660 t = astuple(c, tuple_factory=nt)
1661 self.assertEqual(t, NT(42, 2))
1662 self.assertIs(type(t), NT)
1663
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001664 def test_helper_astuple_namedtuple(self):
1665 T = namedtuple('T', 'a b c')
1666 @dataclass
1667 class C:
1668 x: str
1669 y: T
1670 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1671
1672 t = astuple(c)
1673 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1674
1675 # Now, using a tuple_factory. list is convenient here.
1676 t = astuple(c, tuple_factory=list)
1677 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1678
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001679 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001680 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001681 }
1682
1683 # Create the class.
1684 cls = type('C', (), cls_dict)
1685
1686 # Make it a dataclass.
1687 cls1 = dataclass(cls)
1688
1689 self.assertEqual(cls1, cls)
1690 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1691
1692 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001693 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001694 'y': field(default=5),
1695 }
1696
1697 # Create the class.
1698 cls = type('C', (), cls_dict)
1699
1700 # Make it a dataclass.
1701 cls1 = dataclass(cls)
1702
1703 self.assertEqual(cls1, cls)
1704 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1705
1706 def test_init_in_order(self):
1707 @dataclass
1708 class C:
1709 a: int
1710 b: int = field()
1711 c: list = field(default_factory=list, init=False)
1712 d: list = field(default_factory=list)
1713 e: int = field(default=4, init=False)
1714 f: int = 4
1715
1716 calls = []
1717 def setattr(self, name, value):
1718 calls.append((name, value))
1719
1720 C.__setattr__ = setattr
1721 c = C(0, 1)
1722 self.assertEqual(('a', 0), calls[0])
1723 self.assertEqual(('b', 1), calls[1])
1724 self.assertEqual(('c', []), calls[2])
1725 self.assertEqual(('d', []), calls[3])
1726 self.assertNotIn(('e', 4), calls)
1727 self.assertEqual(('f', 4), calls[4])
1728
1729 def test_items_in_dicts(self):
1730 @dataclass
1731 class C:
1732 a: int
1733 b: list = field(default_factory=list, init=False)
1734 c: list = field(default_factory=list)
1735 d: int = field(default=4, init=False)
1736 e: int = 0
1737
1738 c = C(0)
1739 # Class dict
1740 self.assertNotIn('a', C.__dict__)
1741 self.assertNotIn('b', C.__dict__)
1742 self.assertNotIn('c', C.__dict__)
1743 self.assertIn('d', C.__dict__)
1744 self.assertEqual(C.d, 4)
1745 self.assertIn('e', C.__dict__)
1746 self.assertEqual(C.e, 0)
1747 # Instance dict
1748 self.assertIn('a', c.__dict__)
1749 self.assertEqual(c.a, 0)
1750 self.assertIn('b', c.__dict__)
1751 self.assertEqual(c.b, [])
1752 self.assertIn('c', c.__dict__)
1753 self.assertEqual(c.c, [])
1754 self.assertNotIn('d', c.__dict__)
1755 self.assertIn('e', c.__dict__)
1756 self.assertEqual(c.e, 0)
1757
1758 def test_alternate_classmethod_constructor(self):
1759 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001760 # alternate constructor. This is mostly an example to show
1761 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001762 @dataclass
1763 class C:
1764 x: int
1765 @classmethod
1766 def from_file(cls, filename):
1767 # In a real example, create a new instance
1768 # and populate 'x' from contents of a file.
1769 value_in_file = 20
1770 return cls(value_in_file)
1771
1772 self.assertEqual(C.from_file('filename').x, 20)
1773
1774 def test_field_metadata_default(self):
1775 # Make sure the default metadata is read-only and of
1776 # zero length.
1777 @dataclass
1778 class C:
1779 i: int
1780
1781 self.assertFalse(fields(C)[0].metadata)
1782 self.assertEqual(len(fields(C)[0].metadata), 0)
1783 with self.assertRaisesRegex(TypeError,
1784 'does not support item assignment'):
1785 fields(C)[0].metadata['test'] = 3
1786
1787 def test_field_metadata_mapping(self):
1788 # Make sure only a mapping can be passed as metadata
1789 # zero length.
1790 with self.assertRaises(TypeError):
1791 @dataclass
1792 class C:
1793 i: int = field(metadata=0)
1794
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001795 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001796 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001797 @dataclass
1798 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001799 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001800 self.assertFalse(fields(C)[0].metadata)
1801 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001802 # Update should work (see bpo-35960).
1803 d['foo'] = 1
1804 self.assertEqual(len(fields(C)[0].metadata), 1)
1805 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001806 with self.assertRaisesRegex(TypeError,
1807 'does not support item assignment'):
1808 fields(C)[0].metadata['test'] = 3
1809
1810 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001811 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001812 @dataclass
1813 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001814 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001815 self.assertEqual(len(fields(C)[0].metadata), 3)
1816 self.assertEqual(fields(C)[0].metadata['test'], 10)
1817 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1818 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001819 # Update should work.
1820 d['foo'] = 1
1821 self.assertEqual(len(fields(C)[0].metadata), 4)
1822 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001823 with self.assertRaises(KeyError):
1824 # Non-existent key.
1825 fields(C)[0].metadata['baz']
1826 with self.assertRaisesRegex(TypeError,
1827 'does not support item assignment'):
1828 fields(C)[0].metadata['test'] = 3
1829
1830 def test_field_metadata_custom_mapping(self):
1831 # Try a custom mapping.
1832 class SimpleNameSpace:
1833 def __init__(self, **kw):
1834 self.__dict__.update(kw)
1835
1836 def __getitem__(self, item):
1837 if item == 'xyzzy':
1838 return 'plugh'
1839 return getattr(self, item)
1840
1841 def __len__(self):
1842 return self.__dict__.__len__()
1843
1844 @dataclass
1845 class C:
1846 i: int = field(metadata=SimpleNameSpace(a=10))
1847
1848 self.assertEqual(len(fields(C)[0].metadata), 1)
1849 self.assertEqual(fields(C)[0].metadata['a'], 10)
1850 with self.assertRaises(AttributeError):
1851 fields(C)[0].metadata['b']
1852 # Make sure we're still talking to our custom mapping.
1853 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1854
1855 def test_generic_dataclasses(self):
1856 T = TypeVar('T')
1857
1858 @dataclass
1859 class LabeledBox(Generic[T]):
1860 content: T
1861 label: str = '<unknown>'
1862
1863 box = LabeledBox(42)
1864 self.assertEqual(box.content, 42)
1865 self.assertEqual(box.label, '<unknown>')
1866
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001867 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001868 Alias = List[LabeledBox[int]]
1869
1870 def test_generic_extending(self):
1871 S = TypeVar('S')
1872 T = TypeVar('T')
1873
1874 @dataclass
1875 class Base(Generic[T, S]):
1876 x: T
1877 y: S
1878
1879 @dataclass
1880 class DataDerived(Base[int, T]):
1881 new_field: str
1882 Alias = DataDerived[str]
1883 c = Alias(0, 'test1', 'test2')
1884 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1885
1886 class NonDataDerived(Base[int, T]):
1887 def new_method(self):
1888 return self.y
1889 Alias = NonDataDerived[float]
1890 c = Alias(10, 1.0)
1891 self.assertEqual(c.new_method(), 1.0)
1892
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001893 def test_generic_dynamic(self):
1894 T = TypeVar('T')
1895
1896 @dataclass
1897 class Parent(Generic[T]):
1898 x: T
1899 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1900 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1901 self.assertIs(Child[int](1, 2).z, None)
1902 self.assertEqual(Child[int](1, 2, 3).z, 3)
1903 self.assertEqual(Child[int](1, 2, 3).other, 42)
1904 # Check that type aliases work correctly.
1905 Alias = Child[T]
1906 self.assertEqual(Alias[int](1, 2).x, 1)
1907 # Check MRO resolution.
1908 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1909
Miss Islington (bot)e086bfe2021-10-09 12:50:45 -07001910 def test_dataclasses_pickleable(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001911 global P, Q, R
1912 @dataclass
1913 class P:
1914 x: int
1915 y: int = 0
1916 @dataclass
1917 class Q:
1918 x: int
1919 y: int = field(default=0, init=False)
1920 @dataclass
1921 class R:
1922 x: int
1923 y: List[int] = field(default_factory=list)
1924 q = Q(1)
1925 q.y = 2
1926 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1927 for sample in samples:
1928 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1929 with self.subTest(sample=sample, proto=proto):
1930 new_sample = pickle.loads(pickle.dumps(sample, proto))
1931 self.assertEqual(sample.x, new_sample.x)
1932 self.assertEqual(sample.y, new_sample.y)
1933 self.assertIsNot(sample, new_sample)
1934 new_sample.x = 42
1935 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1936 self.assertEqual(new_sample.x, another_new_sample.x)
1937 self.assertEqual(sample.y, another_new_sample.y)
1938
Batuhan Taskayac7437e22020-10-21 16:49:22 +03001939 def test_dataclasses_qualnames(self):
1940 @dataclass(order=True, unsafe_hash=True, frozen=True)
1941 class A:
1942 x: int
1943 y: int
1944
1945 self.assertEqual(A.__init__.__name__, "__init__")
1946 for function in (
1947 '__eq__',
1948 '__lt__',
1949 '__le__',
1950 '__gt__',
1951 '__ge__',
1952 '__hash__',
1953 '__init__',
1954 '__repr__',
1955 '__setattr__',
1956 '__delattr__',
1957 ):
1958 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
1959
1960 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
1961 A()
1962
Eric V. Smithea8fc522018-01-27 19:07:40 -05001963
Eric V. Smith56970b82018-03-22 16:28:48 -04001964class TestFieldNoAnnotation(unittest.TestCase):
1965 def test_field_without_annotation(self):
1966 with self.assertRaisesRegex(TypeError,
1967 "'f' is a field but has no type annotation"):
1968 @dataclass
1969 class C:
1970 f = field()
1971
1972 def test_field_without_annotation_but_annotation_in_base(self):
1973 @dataclass
1974 class B:
1975 f: int
1976
1977 with self.assertRaisesRegex(TypeError,
1978 "'f' is a field but has no type annotation"):
1979 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001980 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001981 @dataclass
1982 class C(B):
1983 f = field()
1984
1985 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1986 # Same test, but with the base class not a dataclass.
1987 class B:
1988 f: int
1989
1990 with self.assertRaisesRegex(TypeError,
1991 "'f' is a field but has no type annotation"):
1992 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001993 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001994 @dataclass
1995 class C(B):
1996 f = field()
1997
1998
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001999class TestDocString(unittest.TestCase):
2000 def assertDocStrEqual(self, a, b):
2001 # Because 3.6 and 3.7 differ in how inspect.signature work
2002 # (see bpo #32108), for the time being just compare them with
2003 # whitespace stripped.
2004 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2005
2006 def test_existing_docstring_not_overridden(self):
2007 @dataclass
2008 class C:
2009 """Lorem ipsum"""
2010 x: int
2011
2012 self.assertEqual(C.__doc__, "Lorem ipsum")
2013
2014 def test_docstring_no_fields(self):
2015 @dataclass
2016 class C:
2017 pass
2018
2019 self.assertDocStrEqual(C.__doc__, "C()")
2020
2021 def test_docstring_one_field(self):
2022 @dataclass
2023 class C:
2024 x: int
2025
2026 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2027
2028 def test_docstring_two_fields(self):
2029 @dataclass
2030 class C:
2031 x: int
2032 y: int
2033
2034 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2035
2036 def test_docstring_three_fields(self):
2037 @dataclass
2038 class C:
2039 x: int
2040 y: int
2041 z: str
2042
2043 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2044
2045 def test_docstring_one_field_with_default(self):
2046 @dataclass
2047 class C:
2048 x: int = 3
2049
2050 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2051
2052 def test_docstring_one_field_with_default_none(self):
2053 @dataclass
2054 class C:
2055 x: Union[int, type(None)] = None
2056
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002057 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002058
2059 def test_docstring_list_field(self):
2060 @dataclass
2061 class C:
2062 x: List[int]
2063
2064 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2065
2066 def test_docstring_list_field_with_default_factory(self):
2067 @dataclass
2068 class C:
2069 x: List[int] = field(default_factory=list)
2070
2071 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2072
2073 def test_docstring_deque_field(self):
2074 @dataclass
2075 class C:
2076 x: deque
2077
2078 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2079
2080 def test_docstring_deque_field_with_default_factory(self):
2081 @dataclass
2082 class C:
2083 x: deque = field(default_factory=deque)
2084
2085 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2086
2087
Eric V. Smithea8fc522018-01-27 19:07:40 -05002088class TestInit(unittest.TestCase):
2089 def test_base_has_init(self):
2090 class B:
2091 def __init__(self):
2092 self.z = 100
2093 pass
2094
2095 # Make sure that declaring this class doesn't raise an error.
2096 # The issue is that we can't override __init__ in our class,
2097 # but it should be okay to add __init__ to us if our base has
2098 # an __init__.
2099 @dataclass
2100 class C(B):
2101 x: int = 0
2102 c = C(10)
2103 self.assertEqual(c.x, 10)
2104 self.assertNotIn('z', vars(c))
2105
2106 # Make sure that if we don't add an init, the base __init__
2107 # gets called.
2108 @dataclass(init=False)
2109 class C(B):
2110 x: int = 10
2111 c = C()
2112 self.assertEqual(c.x, 10)
2113 self.assertEqual(c.z, 100)
2114
2115 def test_no_init(self):
2116 dataclass(init=False)
2117 class C:
2118 i: int = 0
2119 self.assertEqual(C().i, 0)
2120
2121 dataclass(init=False)
2122 class C:
2123 i: int = 2
2124 def __init__(self):
2125 self.i = 3
2126 self.assertEqual(C().i, 3)
2127
2128 def test_overwriting_init(self):
2129 # If the class has __init__, use it no matter the value of
2130 # init=.
2131
2132 @dataclass
2133 class C:
2134 x: int
2135 def __init__(self, x):
2136 self.x = 2 * x
2137 self.assertEqual(C(3).x, 6)
2138
2139 @dataclass(init=True)
2140 class C:
2141 x: int
2142 def __init__(self, x):
2143 self.x = 2 * x
2144 self.assertEqual(C(4).x, 8)
2145
2146 @dataclass(init=False)
2147 class C:
2148 x: int
2149 def __init__(self, x):
2150 self.x = 2 * x
2151 self.assertEqual(C(5).x, 10)
2152
Miss Islington (bot)79e9f5a2021-09-02 23:26:53 -07002153 def test_inherit_from_protocol(self):
2154 # Dataclasses inheriting from protocol should preserve their own `__init__`.
2155 # See bpo-45081.
2156
2157 class P(Protocol):
2158 a: int
2159
2160 @dataclass
2161 class C(P):
2162 a: int
2163
2164 self.assertEqual(C(5).a, 5)
2165
2166 @dataclass
2167 class D(P):
2168 def __init__(self, a):
2169 self.a = a * 2
2170
2171 self.assertEqual(D(5).a, 10)
2172
Eric V. Smithea8fc522018-01-27 19:07:40 -05002173
2174class TestRepr(unittest.TestCase):
2175 def test_repr(self):
2176 @dataclass
2177 class B:
2178 x: int
2179
2180 @dataclass
2181 class C(B):
2182 y: int = 10
2183
2184 o = C(4)
2185 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2186
2187 @dataclass
2188 class D(C):
2189 x: int = 20
2190 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2191
2192 @dataclass
2193 class C:
2194 @dataclass
2195 class D:
2196 i: int
2197 @dataclass
2198 class E:
2199 pass
2200 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2201 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2202
2203 def test_no_repr(self):
2204 # Test a class with no __repr__ and repr=False.
2205 @dataclass(repr=False)
2206 class C:
2207 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002208 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002209 repr(C(3)))
2210
2211 # Test a class with a __repr__ and repr=False.
2212 @dataclass(repr=False)
2213 class C:
2214 x: int
2215 def __repr__(self):
2216 return 'C-class'
2217 self.assertEqual(repr(C(3)), 'C-class')
2218
2219 def test_overwriting_repr(self):
2220 # If the class has __repr__, use it no matter the value of
2221 # repr=.
2222
2223 @dataclass
2224 class C:
2225 x: int
2226 def __repr__(self):
2227 return 'x'
2228 self.assertEqual(repr(C(0)), 'x')
2229
2230 @dataclass(repr=True)
2231 class C:
2232 x: int
2233 def __repr__(self):
2234 return 'x'
2235 self.assertEqual(repr(C(0)), 'x')
2236
2237 @dataclass(repr=False)
2238 class C:
2239 x: int
2240 def __repr__(self):
2241 return 'x'
2242 self.assertEqual(repr(C(0)), 'x')
2243
2244
Eric V. Smithea8fc522018-01-27 19:07:40 -05002245class TestEq(unittest.TestCase):
2246 def test_no_eq(self):
2247 # Test a class with no __eq__ and eq=False.
2248 @dataclass(eq=False)
2249 class C:
2250 x: int
2251 self.assertNotEqual(C(0), C(0))
2252 c = C(3)
2253 self.assertEqual(c, c)
2254
2255 # Test a class with an __eq__ and eq=False.
2256 @dataclass(eq=False)
2257 class C:
2258 x: int
2259 def __eq__(self, other):
2260 return other == 10
2261 self.assertEqual(C(3), 10)
2262
2263 def test_overwriting_eq(self):
2264 # If the class has __eq__, use it no matter the value of
2265 # eq=.
2266
2267 @dataclass
2268 class C:
2269 x: int
2270 def __eq__(self, other):
2271 return other == 3
2272 self.assertEqual(C(1), 3)
2273 self.assertNotEqual(C(1), 1)
2274
2275 @dataclass(eq=True)
2276 class C:
2277 x: int
2278 def __eq__(self, other):
2279 return other == 4
2280 self.assertEqual(C(1), 4)
2281 self.assertNotEqual(C(1), 1)
2282
2283 @dataclass(eq=False)
2284 class C:
2285 x: int
2286 def __eq__(self, other):
2287 return other == 5
2288 self.assertEqual(C(1), 5)
2289 self.assertNotEqual(C(1), 1)
2290
2291
2292class TestOrdering(unittest.TestCase):
2293 def test_functools_total_ordering(self):
2294 # Test that functools.total_ordering works with this class.
2295 @total_ordering
2296 @dataclass
2297 class C:
2298 x: int
2299 def __lt__(self, other):
2300 # Perform the test "backward", just to make
2301 # sure this is being called.
2302 return self.x >= other
2303
2304 self.assertLess(C(0), -1)
2305 self.assertLessEqual(C(0), -1)
2306 self.assertGreater(C(0), 1)
2307 self.assertGreaterEqual(C(0), 1)
2308
2309 def test_no_order(self):
2310 # Test that no ordering functions are added by default.
2311 @dataclass(order=False)
2312 class C:
2313 x: int
2314 # Make sure no order methods are added.
2315 self.assertNotIn('__le__', C.__dict__)
2316 self.assertNotIn('__lt__', C.__dict__)
2317 self.assertNotIn('__ge__', C.__dict__)
2318 self.assertNotIn('__gt__', C.__dict__)
2319
2320 # Test that __lt__ is still called
2321 @dataclass(order=False)
2322 class C:
2323 x: int
2324 def __lt__(self, other):
2325 return False
2326 # Make sure other methods aren't added.
2327 self.assertNotIn('__le__', C.__dict__)
2328 self.assertNotIn('__ge__', C.__dict__)
2329 self.assertNotIn('__gt__', C.__dict__)
2330
2331 def test_overwriting_order(self):
2332 with self.assertRaisesRegex(TypeError,
2333 'Cannot overwrite attribute __lt__'
2334 '.*using functools.total_ordering'):
2335 @dataclass(order=True)
2336 class C:
2337 x: int
2338 def __lt__(self):
2339 pass
2340
2341 with self.assertRaisesRegex(TypeError,
2342 'Cannot overwrite attribute __le__'
2343 '.*using functools.total_ordering'):
2344 @dataclass(order=True)
2345 class C:
2346 x: int
2347 def __le__(self):
2348 pass
2349
2350 with self.assertRaisesRegex(TypeError,
2351 'Cannot overwrite attribute __gt__'
2352 '.*using functools.total_ordering'):
2353 @dataclass(order=True)
2354 class C:
2355 x: int
2356 def __gt__(self):
2357 pass
2358
2359 with self.assertRaisesRegex(TypeError,
2360 'Cannot overwrite attribute __ge__'
2361 '.*using functools.total_ordering'):
2362 @dataclass(order=True)
2363 class C:
2364 x: int
2365 def __ge__(self):
2366 pass
2367
2368class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002369 def test_unsafe_hash(self):
2370 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002371 class C:
2372 x: int
2373 y: str
2374 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2375
Eric V. Smithea8fc522018-01-27 19:07:40 -05002376 def test_hash_rules(self):
2377 def non_bool(value):
2378 # Map to something else that's True, but not a bool.
2379 if value is None:
2380 return None
2381 if value:
2382 return (3,)
2383 return 0
2384
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002385 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2386 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2387 frozen=frozen):
2388 if result != 'exception':
2389 if with_hash:
2390 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2391 class C:
2392 def __hash__(self):
2393 return 0
2394 else:
2395 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2396 class C:
2397 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002398
2399 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002400 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002401 # __hash__ contains the function we generated.
2402 self.assertIn('__hash__', C.__dict__)
2403 self.assertIsNotNone(C.__dict__['__hash__'])
2404
Eric V. Smithea8fc522018-01-27 19:07:40 -05002405 elif result == '':
2406 # __hash__ is not present in our class.
2407 if not with_hash:
2408 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002409
Eric V. Smithea8fc522018-01-27 19:07:40 -05002410 elif result == 'none':
2411 # __hash__ is set to None.
2412 self.assertIn('__hash__', C.__dict__)
2413 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002414
2415 elif result == 'exception':
2416 # Creating the class should cause an exception.
2417 # This only happens with with_hash==True.
2418 assert(with_hash)
2419 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2420 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2421 class C:
2422 def __hash__(self):
2423 return 0
2424
Eric V. Smithea8fc522018-01-27 19:07:40 -05002425 else:
2426 assert False, f'unknown result {result!r}'
2427
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002428 # There are 8 cases of:
2429 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002430 # eq=True/False
2431 # frozen=True/False
2432 # And for each of these, a different result if
2433 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002434 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2435 (False, False, False, '', ''),
2436 (False, False, True, '', ''),
2437 (False, True, False, 'none', ''),
2438 (False, True, True, 'fn', ''),
2439 (True, False, False, 'fn', 'exception'),
2440 (True, False, True, 'fn', 'exception'),
2441 (True, True, False, 'fn', 'exception'),
2442 (True, True, True, 'fn', 'exception'),
2443 ], 1):
2444 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2445 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002446
2447 # Test non-bool truth values, too. This is just to
2448 # make sure the data-driven table in the decorator
2449 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002450 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2451 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002452
2453
2454 def test_eq_only(self):
2455 # If a class defines __eq__, __hash__ is automatically added
2456 # and set to None. This is normal Python behavior, not
2457 # related to dataclasses. Make sure we don't interfere with
2458 # that (see bpo=32546).
2459
2460 @dataclass
2461 class C:
2462 i: int
2463 def __eq__(self, other):
2464 return self.i == other.i
2465 self.assertEqual(C(1), C(1))
2466 self.assertNotEqual(C(1), C(4))
2467
2468 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002469 # unsafe_hash=True.
2470 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002471 class C:
2472 i: int
2473 def __eq__(self, other):
2474 return self.i == other.i
2475 self.assertEqual(C(1), C(1.0))
2476 self.assertEqual(hash(C(1)), hash(C(1.0)))
2477
2478 # And check that the classes __eq__ is being used, despite
2479 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002480 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002481 class C:
2482 i: int
2483 def __eq__(self, other):
2484 return self.i == 3 and self.i == other.i
2485 self.assertEqual(C(3), C(3))
2486 self.assertNotEqual(C(1), C(1))
2487 self.assertEqual(hash(C(1)), hash(C(1.0)))
2488
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002489 def test_0_field_hash(self):
2490 @dataclass(frozen=True)
2491 class C:
2492 pass
2493 self.assertEqual(hash(C()), hash(()))
2494
2495 @dataclass(unsafe_hash=True)
2496 class C:
2497 pass
2498 self.assertEqual(hash(C()), hash(()))
2499
2500 def test_1_field_hash(self):
2501 @dataclass(frozen=True)
2502 class C:
2503 x: int
2504 self.assertEqual(hash(C(4)), hash((4,)))
2505 self.assertEqual(hash(C(42)), hash((42,)))
2506
2507 @dataclass(unsafe_hash=True)
2508 class C:
2509 x: int
2510 self.assertEqual(hash(C(4)), hash((4,)))
2511 self.assertEqual(hash(C(42)), hash((42,)))
2512
Eric V. Smith718070d2018-02-23 13:01:31 -05002513 def test_hash_no_args(self):
2514 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002515 # make sure that if the @dataclass parameter name is changed
2516 # or the non-default hashing behavior changes, the default
2517 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002518
2519 class Base:
2520 def __hash__(self):
2521 return 301
2522
2523 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002524 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002525 for frozen, eq, base, expected in [
2526 (None, None, object, 'unhashable'),
2527 (None, None, Base, 'unhashable'),
2528 (None, False, object, 'object'),
2529 (None, False, Base, 'base'),
2530 (None, True, object, 'unhashable'),
2531 (None, True, Base, 'unhashable'),
2532 (False, None, object, 'unhashable'),
2533 (False, None, Base, 'unhashable'),
2534 (False, False, object, 'object'),
2535 (False, False, Base, 'base'),
2536 (False, True, object, 'unhashable'),
2537 (False, True, Base, 'unhashable'),
2538 (True, None, object, 'tuple'),
2539 (True, None, Base, 'tuple'),
2540 (True, False, object, 'object'),
2541 (True, False, Base, 'base'),
2542 (True, True, object, 'tuple'),
2543 (True, True, Base, 'tuple'),
2544 ]:
2545
2546 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2547 # First, create the class.
2548 if frozen is None and eq is None:
2549 @dataclass
2550 class C(base):
2551 i: int
2552 elif frozen is None:
2553 @dataclass(eq=eq)
2554 class C(base):
2555 i: int
2556 elif eq is None:
2557 @dataclass(frozen=frozen)
2558 class C(base):
2559 i: int
2560 else:
2561 @dataclass(frozen=frozen, eq=eq)
2562 class C(base):
2563 i: int
2564
2565 # Now, make sure it hashes as expected.
2566 if expected == 'unhashable':
2567 c = C(10)
2568 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2569 hash(c)
2570
2571 elif expected == 'base':
2572 self.assertEqual(hash(C(10)), 301)
2573
2574 elif expected == 'object':
2575 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002576 # hash isn't based on id(), so calling hash()
2577 # won't tell us much. So, just check the
2578 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002579 self.assertIs(C.__hash__, object.__hash__)
2580
2581 elif expected == 'tuple':
2582 self.assertEqual(hash(C(42)), hash((42,)))
2583
2584 else:
2585 assert False, f'unknown value for expected={expected!r}'
2586
Eric V. Smithea8fc522018-01-27 19:07:40 -05002587
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002588class TestFrozen(unittest.TestCase):
2589 def test_frozen(self):
2590 @dataclass(frozen=True)
2591 class C:
2592 i: int
2593
2594 c = C(10)
2595 self.assertEqual(c.i, 10)
2596 with self.assertRaises(FrozenInstanceError):
2597 c.i = 5
2598 self.assertEqual(c.i, 10)
2599
2600 def test_inherit(self):
2601 @dataclass(frozen=True)
2602 class C:
2603 i: int
2604
2605 @dataclass(frozen=True)
2606 class D(C):
2607 j: int
2608
2609 d = D(0, 10)
2610 with self.assertRaises(FrozenInstanceError):
2611 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002612 with self.assertRaises(FrozenInstanceError):
2613 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002614 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002615 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002616
Iurii Kemaev376ffc62021-04-06 06:14:01 +01002617 def test_inherit_nonfrozen_from_empty_frozen(self):
2618 @dataclass(frozen=True)
2619 class C:
2620 pass
2621
2622 with self.assertRaisesRegex(TypeError,
2623 'cannot inherit non-frozen dataclass from a frozen one'):
2624 @dataclass
2625 class D(C):
2626 j: int
2627
2628 def test_inherit_nonfrozen_from_empty(self):
2629 @dataclass
2630 class C:
2631 pass
2632
2633 @dataclass
2634 class D(C):
2635 j: int
2636
2637 d = D(3)
2638 self.assertEqual(d.j, 3)
2639 self.assertIsInstance(d, C)
2640
Eric V. Smithf199bc62018-03-18 20:40:34 -04002641 # Test both ways: with an intermediate normal (non-dataclass)
2642 # class and without an intermediate class.
2643 def test_inherit_nonfrozen_from_frozen(self):
2644 for intermediate_class in [True, False]:
2645 with self.subTest(intermediate_class=intermediate_class):
2646 @dataclass(frozen=True)
2647 class C:
2648 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002649
Eric V. Smithf199bc62018-03-18 20:40:34 -04002650 if intermediate_class:
2651 class I(C): pass
2652 else:
2653 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002654
Eric V. Smithf199bc62018-03-18 20:40:34 -04002655 with self.assertRaisesRegex(TypeError,
2656 'cannot inherit non-frozen dataclass from a frozen one'):
2657 @dataclass
2658 class D(I):
2659 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002660
Eric V. Smithf199bc62018-03-18 20:40:34 -04002661 def test_inherit_frozen_from_nonfrozen(self):
2662 for intermediate_class in [True, False]:
2663 with self.subTest(intermediate_class=intermediate_class):
2664 @dataclass
2665 class C:
2666 i: int
2667
2668 if intermediate_class:
2669 class I(C): pass
2670 else:
2671 I = C
2672
2673 with self.assertRaisesRegex(TypeError,
2674 'cannot inherit frozen dataclass from a non-frozen one'):
2675 @dataclass(frozen=True)
2676 class D(I):
2677 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002678
2679 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002680 for intermediate_class in [True, False]:
2681 with self.subTest(intermediate_class=intermediate_class):
2682 class C:
2683 pass
2684
2685 if intermediate_class:
2686 class I(C): pass
2687 else:
2688 I = C
2689
2690 @dataclass(frozen=True)
2691 class D(I):
2692 i: int
2693
2694 d = D(10)
2695 with self.assertRaises(FrozenInstanceError):
2696 d.i = 5
2697
2698 def test_non_frozen_normal_derived(self):
2699 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002700
2701 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002702 class D:
2703 x: int
2704 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002705
Eric V. Smithf199bc62018-03-18 20:40:34 -04002706 class S(D):
2707 pass
2708
2709 s = S(3)
2710 self.assertEqual(s.x, 3)
2711 self.assertEqual(s.y, 10)
2712 s.cached = True
2713
2714 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002715 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002716 s.x = 5
2717 with self.assertRaises(FrozenInstanceError):
2718 s.y = 5
2719 self.assertEqual(s.x, 3)
2720 self.assertEqual(s.y, 10)
2721 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002722
Eric V. Smith74940912018-04-05 06:50:18 -04002723 def test_overwriting_frozen(self):
2724 # frozen uses __setattr__ and __delattr__.
2725 with self.assertRaisesRegex(TypeError,
2726 'Cannot overwrite attribute __setattr__'):
2727 @dataclass(frozen=True)
2728 class C:
2729 x: int
2730 def __setattr__(self):
2731 pass
2732
2733 with self.assertRaisesRegex(TypeError,
2734 'Cannot overwrite attribute __delattr__'):
2735 @dataclass(frozen=True)
2736 class C:
2737 x: int
2738 def __delattr__(self):
2739 pass
2740
2741 @dataclass(frozen=False)
2742 class C:
2743 x: int
2744 def __setattr__(self, name, value):
2745 self.__dict__['x'] = value * 2
2746 self.assertEqual(C(10).x, 20)
2747
2748 def test_frozen_hash(self):
2749 @dataclass(frozen=True)
2750 class C:
2751 x: Any
2752
2753 # If x is immutable, we can compute the hash. No exception is
2754 # raised.
2755 hash(C(3))
2756
2757 # If x is mutable, computing the hash is an error.
2758 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2759 hash(C({}))
2760
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002761
Eric V. Smith7389fd92018-03-19 21:07:51 -04002762class TestSlots(unittest.TestCase):
2763 def test_simple(self):
2764 @dataclass
2765 class C:
2766 __slots__ = ('x',)
2767 x: Any
2768
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002769 # There was a bug where a variable in a slot was assumed to
2770 # also have a default value (of type
2771 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002772 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002773 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002774 C()
2775
2776 # We can create an instance, and assign to x.
2777 c = C(10)
2778 self.assertEqual(c.x, 10)
2779 c.x = 5
2780 self.assertEqual(c.x, 5)
2781
2782 # We can't assign to anything else.
2783 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2784 c.y = 5
2785
2786 def test_derived_added_field(self):
2787 # See bpo-33100.
2788 @dataclass
2789 class Base:
2790 __slots__ = ('x',)
2791 x: Any
2792
2793 @dataclass
2794 class Derived(Base):
2795 x: int
2796 y: int
2797
2798 d = Derived(1, 2)
2799 self.assertEqual((d.x, d.y), (1, 2))
2800
2801 # We can add a new field to the derived instance.
2802 d.z = 10
2803
Yurii Karabasc2419912021-05-01 05:14:30 +03002804 def test_generated_slots(self):
2805 @dataclass(slots=True)
2806 class C:
2807 x: int
2808 y: int
2809
2810 c = C(1, 2)
2811 self.assertEqual((c.x, c.y), (1, 2))
2812
2813 c.x = 3
2814 c.y = 4
2815 self.assertEqual((c.x, c.y), (3, 4))
2816
2817 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
2818 c.z = 5
2819
2820 def test_add_slots_when_slots_exists(self):
2821 with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
2822 @dataclass(slots=True)
2823 class C:
2824 __slots__ = ('x',)
2825 x: int
2826
2827 def test_generated_slots_value(self):
2828 @dataclass(slots=True)
2829 class Base:
2830 x: int
2831
2832 self.assertEqual(Base.__slots__, ('x',))
2833
2834 @dataclass(slots=True)
2835 class Delivered(Base):
2836 y: int
2837
2838 self.assertEqual(Delivered.__slots__, ('x', 'y'))
2839
2840 @dataclass
2841 class AnotherDelivered(Base):
2842 z: int
2843
2844 self.assertTrue('__slots__' not in AnotherDelivered.__dict__)
2845
2846 def test_returns_new_class(self):
2847 class A:
2848 x: int
2849
2850 B = dataclass(A, slots=True)
2851 self.assertIsNot(A, B)
2852
2853 self.assertFalse(hasattr(A, "__slots__"))
2854 self.assertTrue(hasattr(B, "__slots__"))
2855
Eric V. Smith823fbf42021-05-01 13:27:30 -04002856 # Can't be local to test_frozen_pickle.
2857 @dataclass(frozen=True, slots=True)
2858 class FrozenSlotsClass:
2859 foo: str
2860 bar: int
2861
2862 def test_frozen_pickle(self):
2863 # bpo-43999
2864
2865 assert self.FrozenSlotsClass.__slots__ == ("foo", "bar")
2866 p = pickle.dumps(self.FrozenSlotsClass("a", 1))
2867 assert pickle.loads(p) == self.FrozenSlotsClass("a", 1)
2868
Yurii Karabasc2419912021-05-01 05:14:30 +03002869
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002870class TestDescriptors(unittest.TestCase):
2871 def test_set_name(self):
2872 # See bpo-33141.
2873
2874 # Create a descriptor.
2875 class D:
2876 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002877 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002878 def __get__(self, instance, owner):
2879 if instance is not None:
2880 return 1
2881 return self
2882
2883 # This is the case of just normal descriptor behavior, no
2884 # dataclass code is involved in initializing the descriptor.
2885 @dataclass
2886 class C:
2887 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002888 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002889
2890 # Now test with a default value and init=False, which is the
2891 # only time this is really meaningful. If not using
2892 # init=False, then the descriptor will be overwritten, anyway.
2893 @dataclass
2894 class C:
2895 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002896 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002897 self.assertEqual(C().c, 1)
2898
2899 def test_non_descriptor(self):
2900 # PEP 487 says __set_name__ should work on non-descriptors.
2901 # Create a descriptor.
2902
2903 class D:
2904 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002905 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002906
2907 @dataclass
2908 class C:
2909 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002910 self.assertEqual(C.c.name, 'cx')
2911
2912 def test_lookup_on_instance(self):
2913 # See bpo-33175.
2914 class D:
2915 pass
2916
2917 d = D()
2918 # Create an attribute on the instance, not type.
2919 d.__set_name__ = Mock()
2920
2921 # Make sure d.__set_name__ is not called.
2922 @dataclass
2923 class C:
2924 i: int=field(default=d, init=False)
2925
2926 self.assertEqual(d.__set_name__.call_count, 0)
2927
2928 def test_lookup_on_class(self):
2929 # See bpo-33175.
2930 class D:
2931 pass
2932 D.__set_name__ = Mock()
2933
2934 # Make sure D.__set_name__ is called.
2935 @dataclass
2936 class C:
2937 i: int=field(default=D(), init=False)
2938
2939 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002940
Eric V. Smith7389fd92018-03-19 21:07:51 -04002941
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002942class TestStringAnnotations(unittest.TestCase):
2943 def test_classvar(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002944 # Some expressions recognized as ClassVar really aren't. But
2945 # if you're using string annotations, it's not an exact
2946 # science.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002947 # These tests assume that both "import typing" and "from
2948 # typing import *" have been run in this file.
2949 for typestr in ('ClassVar[int]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002950 'ClassVar [int]',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002951 ' ClassVar [int]',
2952 'ClassVar',
2953 ' ClassVar ',
2954 'typing.ClassVar[int]',
2955 'typing.ClassVar[str]',
2956 ' typing.ClassVar[str]',
2957 'typing .ClassVar[str]',
2958 'typing. ClassVar[str]',
2959 'typing.ClassVar [str]',
2960 'typing.ClassVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01002961
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002962 # Not syntactically valid, but these will
Pablo Galindob0544ba2021-04-21 12:41:19 +01002963 # be treated as ClassVars.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002964 'typing.ClassVar.[int]',
2965 'typing.ClassVar+',
2966 ):
2967 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002968 @dataclass
2969 class C:
2970 x: typestr
2971
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002972 # x is a ClassVar, so C() takes no args.
2973 C()
2974
2975 # And it won't appear in the class's dict because it doesn't
2976 # have a default.
2977 self.assertNotIn('x', C.__dict__)
2978
2979 def test_isnt_classvar(self):
2980 for typestr in ('CV',
2981 't.ClassVar',
2982 't.ClassVar[int]',
2983 'typing..ClassVar[int]',
2984 'Classvar',
2985 'Classvar[int]',
2986 'typing.ClassVarx[int]',
2987 'typong.ClassVar[int]',
2988 'dataclasses.ClassVar[int]',
2989 'typingxClassVar[str]',
2990 ):
2991 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002992 @dataclass
2993 class C:
2994 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002995
2996 # x is not a ClassVar, so C() takes one arg.
2997 self.assertEqual(C(10).x, 10)
2998
2999 def test_initvar(self):
3000 # These tests assume that both "import dataclasses" and "from
3001 # dataclasses import *" have been run in this file.
3002 for typestr in ('InitVar[int]',
3003 'InitVar [int]'
3004 ' InitVar [int]',
3005 'InitVar',
3006 ' InitVar ',
3007 'dataclasses.InitVar[int]',
3008 'dataclasses.InitVar[str]',
3009 ' dataclasses.InitVar[str]',
3010 'dataclasses .InitVar[str]',
3011 'dataclasses. InitVar[str]',
3012 'dataclasses.InitVar [str]',
3013 'dataclasses.InitVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01003014
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003015 # Not syntactically valid, but these will
3016 # be treated as InitVars.
3017 'dataclasses.InitVar.[int]',
3018 'dataclasses.InitVar+',
3019 ):
3020 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003021 @dataclass
3022 class C:
3023 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003024
3025 # x is an InitVar, so doesn't create a member.
3026 with self.assertRaisesRegex(AttributeError,
3027 "object has no attribute 'x'"):
3028 C(1).x
3029
3030 def test_isnt_initvar(self):
3031 for typestr in ('IV',
3032 'dc.InitVar',
3033 'xdataclasses.xInitVar',
3034 'typing.xInitVar[int]',
3035 ):
3036 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003037 @dataclass
3038 class C:
3039 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003040
3041 # x is not an InitVar, so there will be a member x.
3042 self.assertEqual(C(10).x, 10)
3043
3044 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03003045 from test import dataclass_module_1
Pablo Galindob0544ba2021-04-21 12:41:19 +01003046 from test import dataclass_module_1_str
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03003047 from test import dataclass_module_2
Pablo Galindob0544ba2021-04-21 12:41:19 +01003048 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003049
Pablo Galindob0544ba2021-04-21 12:41:19 +01003050 for m in (dataclass_module_1, dataclass_module_1_str,
3051 dataclass_module_2, dataclass_module_2_str,
3052 ):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003053 with self.subTest(m=m):
3054 # There's a difference in how the ClassVars are
3055 # interpreted when using string annotations or
3056 # not. See the imported modules for details.
Pablo Galindob0544ba2021-04-21 12:41:19 +01003057 if m.USING_STRINGS:
3058 c = m.CV(10)
3059 else:
3060 c = m.CV()
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003061 self.assertEqual(c.cv0, 20)
3062
3063
3064 # There's a difference in how the InitVars are
3065 # interpreted when using string annotations or
3066 # not. See the imported modules for details.
3067 c = m.IV(0, 1, 2, 3, 4)
3068
3069 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
3070 with self.subTest(field_name=field_name):
3071 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
3072 # Since field_name is an InitVar, it's
3073 # not an instance field.
3074 getattr(c, field_name)
3075
Pablo Galindob0544ba2021-04-21 12:41:19 +01003076 if m.USING_STRINGS:
3077 # iv4 is interpreted as a normal field.
3078 self.assertIn('not_iv4', c.__dict__)
3079 self.assertEqual(c.not_iv4, 4)
3080 else:
3081 # iv4 is interpreted as an InitVar, so it
3082 # won't exist on the instance.
3083 self.assertNotIn('not_iv4', c.__dict__)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003084
Yury Selivanovd219cc42019-12-09 09:54:20 -05003085 def test_text_annotations(self):
3086 from test import dataclass_textanno
3087
3088 self.assertEqual(
3089 get_type_hints(dataclass_textanno.Bar),
3090 {'foo': dataclass_textanno.Foo})
3091 self.assertEqual(
3092 get_type_hints(dataclass_textanno.Bar.__init__),
3093 {'foo': dataclass_textanno.Foo,
3094 'return': type(None)})
3095
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003096
Eric V. Smith4e812962018-05-16 11:31:29 -04003097class TestMakeDataclass(unittest.TestCase):
3098 def test_simple(self):
3099 C = make_dataclass('C',
3100 [('x', int),
3101 ('y', int, field(default=5))],
3102 namespace={'add_one': lambda self: self.x + 1})
3103 c = C(10)
3104 self.assertEqual((c.x, c.y), (10, 5))
3105 self.assertEqual(c.add_one(), 11)
3106
3107
3108 def test_no_mutate_namespace(self):
3109 # Make sure a provided namespace isn't mutated.
3110 ns = {}
3111 C = make_dataclass('C',
3112 [('x', int),
3113 ('y', int, field(default=5))],
3114 namespace=ns)
3115 self.assertEqual(ns, {})
3116
3117 def test_base(self):
3118 class Base1:
3119 pass
3120 class Base2:
3121 pass
3122 C = make_dataclass('C',
3123 [('x', int)],
3124 bases=(Base1, Base2))
3125 c = C(2)
3126 self.assertIsInstance(c, C)
3127 self.assertIsInstance(c, Base1)
3128 self.assertIsInstance(c, Base2)
3129
3130 def test_base_dataclass(self):
3131 @dataclass
3132 class Base1:
3133 x: int
3134 class Base2:
3135 pass
3136 C = make_dataclass('C',
3137 [('y', int)],
3138 bases=(Base1, Base2))
3139 with self.assertRaisesRegex(TypeError, 'required positional'):
3140 c = C(2)
3141 c = C(1, 2)
3142 self.assertIsInstance(c, C)
3143 self.assertIsInstance(c, Base1)
3144 self.assertIsInstance(c, Base2)
3145
3146 self.assertEqual((c.x, c.y), (1, 2))
3147
3148 def test_init_var(self):
3149 def post_init(self, y):
3150 self.x *= y
3151
3152 C = make_dataclass('C',
3153 [('x', int),
3154 ('y', InitVar[int]),
3155 ],
3156 namespace={'__post_init__': post_init},
3157 )
3158 c = C(2, 3)
3159 self.assertEqual(vars(c), {'x': 6})
3160 self.assertEqual(len(fields(c)), 1)
3161
3162 def test_class_var(self):
3163 C = make_dataclass('C',
3164 [('x', int),
3165 ('y', ClassVar[int], 10),
3166 ('z', ClassVar[int], field(default=20)),
3167 ])
3168 c = C(1)
3169 self.assertEqual(vars(c), {'x': 1})
3170 self.assertEqual(len(fields(c)), 1)
3171 self.assertEqual(C.y, 10)
3172 self.assertEqual(C.z, 20)
3173
3174 def test_other_params(self):
3175 C = make_dataclass('C',
3176 [('x', int),
3177 ('y', ClassVar[int], 10),
3178 ('z', ClassVar[int], field(default=20)),
3179 ],
3180 init=False)
3181 # Make sure we have a repr, but no init.
3182 self.assertNotIn('__init__', vars(C))
3183 self.assertIn('__repr__', vars(C))
3184
3185 # Make sure random other params don't work.
3186 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3187 C = make_dataclass('C',
3188 [],
3189 xxinit=False)
3190
3191 def test_no_types(self):
3192 C = make_dataclass('Point', ['x', 'y', 'z'])
3193 c = C(1, 2, 3)
3194 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3195 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3196 'y': 'typing.Any',
3197 'z': 'typing.Any'})
3198
3199 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3200 c = C(1, 2, 3)
3201 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3202 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3203 'y': int,
3204 'z': 'typing.Any'})
3205
3206 def test_invalid_type_specification(self):
3207 for bad_field in [(),
3208 (1, 2, 3, 4),
3209 ]:
3210 with self.subTest(bad_field=bad_field):
3211 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3212 make_dataclass('C', ['a', bad_field])
3213
3214 # And test for things with no len().
3215 for bad_field in [float,
3216 lambda x:x,
3217 ]:
3218 with self.subTest(bad_field=bad_field):
3219 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3220 make_dataclass('C', ['a', bad_field])
3221
3222 def test_duplicate_field_names(self):
3223 for field in ['a', 'ab']:
3224 with self.subTest(field=field):
3225 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3226 make_dataclass('C', [field, 'a', field])
3227
3228 def test_keyword_field_names(self):
3229 for field in ['for', 'async', 'await', 'as']:
3230 with self.subTest(field=field):
3231 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3232 make_dataclass('C', ['a', field])
3233 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3234 make_dataclass('C', [field])
3235 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3236 make_dataclass('C', [field, 'a'])
3237
3238 def test_non_identifier_field_names(self):
3239 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3240 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003241 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003242 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003243 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003244 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003245 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003246 make_dataclass('C', [field, 'a'])
3247
3248 def test_underscore_field_names(self):
3249 # Unlike namedtuple, it's okay if dataclass field names have
3250 # an underscore.
3251 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3252
3253 def test_funny_class_names_names(self):
3254 # No reason to prevent weird class names, since
3255 # types.new_class allows them.
3256 for classname in ['()', 'x,y', '*', '2@3', '']:
3257 with self.subTest(classname=classname):
3258 C = make_dataclass(classname, ['a', 'b'])
3259 self.assertEqual(C.__name__, classname)
3260
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003261class TestReplace(unittest.TestCase):
3262 def test(self):
3263 @dataclass(frozen=True)
3264 class C:
3265 x: int
3266 y: int
3267
3268 c = C(1, 2)
3269 c1 = replace(c, x=3)
3270 self.assertEqual(c1.x, 3)
3271 self.assertEqual(c1.y, 2)
3272
3273 def test_frozen(self):
3274 @dataclass(frozen=True)
3275 class C:
3276 x: int
3277 y: int
3278 z: int = field(init=False, default=10)
3279 t: int = field(init=False, default=100)
3280
3281 c = C(1, 2)
3282 c1 = replace(c, x=3)
3283 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3284 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3285
3286
3287 with self.assertRaisesRegex(ValueError, 'init=False'):
3288 replace(c, x=3, z=20, t=50)
3289 with self.assertRaisesRegex(ValueError, 'init=False'):
3290 replace(c, z=20)
3291 replace(c, x=3, z=20, t=50)
3292
3293 # Make sure the result is still frozen.
3294 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3295 c1.x = 3
3296
3297 # Make sure we can't replace an attribute that doesn't exist,
3298 # if we're also replacing one that does exist. Test this
3299 # here, because setting attributes on frozen instances is
3300 # handled slightly differently from non-frozen ones.
3301 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3302 "keyword argument 'a'"):
3303 c1 = replace(c, x=20, a=5)
3304
3305 def test_invalid_field_name(self):
3306 @dataclass(frozen=True)
3307 class C:
3308 x: int
3309 y: int
3310
3311 c = C(1, 2)
3312 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3313 "keyword argument 'z'"):
3314 c1 = replace(c, z=3)
3315
3316 def test_invalid_object(self):
3317 @dataclass(frozen=True)
3318 class C:
3319 x: int
3320 y: int
3321
3322 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3323 replace(C, x=3)
3324
3325 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3326 replace(0, x=3)
3327
3328 def test_no_init(self):
3329 @dataclass
3330 class C:
3331 x: int
3332 y: int = field(init=False, default=10)
3333
3334 c = C(1)
3335 c.y = 20
3336
3337 # Make sure y gets the default value.
3338 c1 = replace(c, x=5)
3339 self.assertEqual((c1.x, c1.y), (5, 10))
3340
3341 # Trying to replace y is an error.
3342 with self.assertRaisesRegex(ValueError, 'init=False'):
3343 replace(c, x=2, y=30)
3344
3345 with self.assertRaisesRegex(ValueError, 'init=False'):
3346 replace(c, y=30)
3347
3348 def test_classvar(self):
3349 @dataclass
3350 class C:
3351 x: int
3352 y: ClassVar[int] = 1000
3353
3354 c = C(1)
3355 d = C(2)
3356
3357 self.assertIs(c.y, d.y)
3358 self.assertEqual(c.y, 1000)
3359
3360 # Trying to replace y is an error: can't replace ClassVars.
3361 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3362 "unexpected keyword argument 'y'"):
3363 replace(c, y=30)
3364
3365 replace(c, x=5)
3366
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003367 def test_initvar_is_specified(self):
3368 @dataclass
3369 class C:
3370 x: int
3371 y: InitVar[int]
3372
3373 def __post_init__(self, y):
3374 self.x *= y
3375
3376 c = C(1, 10)
3377 self.assertEqual(c.x, 10)
3378 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3379 "specified with replace()"):
3380 replace(c, x=3)
3381 c = replace(c, x=3, y=5)
3382 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303383
Zackery Spytz75220672021-04-05 13:41:01 -06003384 def test_initvar_with_default_value(self):
3385 @dataclass
3386 class C:
3387 x: int
3388 y: InitVar[int] = None
3389 z: InitVar[int] = 42
3390
3391 def __post_init__(self, y, z):
3392 if y is not None:
3393 self.x += y
3394 if z is not None:
3395 self.x += z
3396
3397 c = C(x=1, y=10, z=1)
3398 self.assertEqual(replace(c), C(x=12))
3399 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
3400 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
3401
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303402 def test_recursive_repr(self):
3403 @dataclass
3404 class C:
3405 f: "C"
3406
3407 c = C(None)
3408 c.f = c
3409 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3410
3411 def test_recursive_repr_two_attrs(self):
3412 @dataclass
3413 class C:
3414 f: "C"
3415 g: "C"
3416
3417 c = C(None, None)
3418 c.f = c
3419 c.g = c
3420 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3421 ".<locals>.C(f=..., g=...)")
3422
3423 def test_recursive_repr_indirection(self):
3424 @dataclass
3425 class C:
3426 f: "D"
3427
3428 @dataclass
3429 class D:
3430 f: "C"
3431
3432 c = C(None)
3433 d = D(None)
3434 c.f = d
3435 d.f = c
3436 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3437 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3438 ".<locals>.D(f=...))")
3439
3440 def test_recursive_repr_indirection_two(self):
3441 @dataclass
3442 class C:
3443 f: "D"
3444
3445 @dataclass
3446 class D:
3447 f: "E"
3448
3449 @dataclass
3450 class E:
3451 f: "C"
3452
3453 c = C(None)
3454 d = D(None)
3455 e = E(None)
3456 c.f = d
3457 d.f = e
3458 e.f = c
3459 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3460 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3461 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3462 ".<locals>.E(f=...)))")
3463
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303464 def test_recursive_repr_misc_attrs(self):
3465 @dataclass
3466 class C:
3467 f: "C"
3468 g: int
3469
3470 c = C(None, 1)
3471 c.f = c
3472 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3473 ".<locals>.C(f=..., g=1)")
3474
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003475 ## def test_initvar(self):
3476 ## @dataclass
3477 ## class C:
3478 ## x: int
3479 ## y: InitVar[int]
3480
3481 ## c = C(1, 10)
3482 ## d = C(2, 20)
3483
3484 ## # In our case, replacing an InitVar is a no-op
3485 ## self.assertEqual(c, replace(c, y=5))
3486
3487 ## replace(c, x=5)
3488
Ben Avrahamibef7d292020-10-06 20:40:50 +03003489class TestAbstract(unittest.TestCase):
3490 def test_abc_implementation(self):
3491 class Ordered(abc.ABC):
3492 @abc.abstractmethod
3493 def __lt__(self, other):
3494 pass
3495
3496 @abc.abstractmethod
3497 def __le__(self, other):
3498 pass
3499
3500 @dataclass(order=True)
3501 class Date(Ordered):
3502 year: int
3503 month: 'Month'
3504 day: 'int'
3505
3506 self.assertFalse(inspect.isabstract(Date))
3507 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3508
3509 def test_maintain_abc(self):
3510 class A(abc.ABC):
3511 @abc.abstractmethod
3512 def foo(self):
3513 pass
3514
3515 @dataclass
3516 class Date(A):
3517 year: int
3518 month: 'Month'
3519 day: 'int'
3520
3521 self.assertTrue(inspect.isabstract(Date))
3522 msg = 'class Date with abstract method foo'
3523 self.assertRaisesRegex(TypeError, msg, Date)
3524
Eric V. Smith4e812962018-05-16 11:31:29 -04003525
Brandt Bucher145bf262021-02-26 14:51:55 -08003526class TestMatchArgs(unittest.TestCase):
3527 def test_match_args(self):
3528 @dataclass
3529 class C:
3530 a: int
3531 self.assertEqual(C(42).__match_args__, ('a',))
3532
3533 def test_explicit_match_args(self):
Brandt Bucherf84d5a12021-04-05 19:17:08 -07003534 ma = ()
Brandt Bucher145bf262021-02-26 14:51:55 -08003535 @dataclass
3536 class C:
3537 a: int
3538 __match_args__ = ma
3539 self.assertIs(C(42).__match_args__, ma)
3540
Brandt Bucherd92c59f2021-04-08 12:54:34 -07003541 def test_bpo_43764(self):
3542 @dataclass(repr=False, eq=False, init=False)
3543 class X:
3544 a: int
3545 b: int
3546 c: int
3547 self.assertEqual(X.__match_args__, ("a", "b", "c"))
3548
Eric V. Smith750f4842021-04-10 21:28:42 -04003549 def test_match_args_argument(self):
3550 @dataclass(match_args=False)
3551 class X:
3552 a: int
3553 self.assertNotIn('__match_args__', X.__dict__)
3554
3555 @dataclass(match_args=False)
3556 class Y:
3557 a: int
3558 __match_args__ = ('b',)
3559 self.assertEqual(Y.__match_args__, ('b',))
3560
3561 @dataclass(match_args=False)
3562 class Z(Y):
3563 z: int
3564 self.assertEqual(Z.__match_args__, ('b',))
3565
3566 # Ensure parent dataclass __match_args__ is seen, if child class
3567 # specifies match_args=False.
3568 @dataclass
3569 class A:
3570 a: int
3571 z: int
3572 @dataclass(match_args=False)
3573 class B(A):
3574 b: int
3575 self.assertEqual(B.__match_args__, ('a', 'z'))
3576
3577 def test_make_dataclasses(self):
3578 C = make_dataclass('C', [('x', int), ('y', int)])
3579 self.assertEqual(C.__match_args__, ('x', 'y'))
3580
3581 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
3582 self.assertEqual(C.__match_args__, ('x', 'y'))
3583
3584 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
3585 self.assertNotIn('__match__args__', C.__dict__)
3586
3587 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
3588 self.assertEqual(C.__match_args__, ('z',))
3589
Brandt Bucher145bf262021-02-26 14:51:55 -08003590
Eric V. Smith94549ee2021-04-26 13:14:28 -04003591class TestKeywordArgs(unittest.TestCase):
Eric V. Smithc0280532021-04-25 20:42:39 -04003592 def test_no_classvar_kwarg(self):
3593 msg = 'field a is a ClassVar but specifies kw_only'
3594 with self.assertRaisesRegex(TypeError, msg):
3595 @dataclass
3596 class A:
3597 a: ClassVar[int] = field(kw_only=True)
3598
3599 with self.assertRaisesRegex(TypeError, msg):
3600 @dataclass
3601 class A:
3602 a: ClassVar[int] = field(kw_only=False)
3603
3604 with self.assertRaisesRegex(TypeError, msg):
3605 @dataclass(kw_only=True)
3606 class A:
3607 a: ClassVar[int] = field(kw_only=False)
3608
3609 def test_field_marked_as_kwonly(self):
3610 #######################
3611 # Using dataclass(kw_only=True)
3612 @dataclass(kw_only=True)
3613 class A:
3614 a: int
3615 self.assertTrue(fields(A)[0].kw_only)
3616
3617 @dataclass(kw_only=True)
3618 class A:
3619 a: int = field(kw_only=True)
3620 self.assertTrue(fields(A)[0].kw_only)
3621
3622 @dataclass(kw_only=True)
3623 class A:
3624 a: int = field(kw_only=False)
3625 self.assertFalse(fields(A)[0].kw_only)
3626
3627 #######################
3628 # Using dataclass(kw_only=False)
3629 @dataclass(kw_only=False)
3630 class A:
3631 a: int
3632 self.assertFalse(fields(A)[0].kw_only)
3633
3634 @dataclass(kw_only=False)
3635 class A:
3636 a: int = field(kw_only=True)
3637 self.assertTrue(fields(A)[0].kw_only)
3638
3639 @dataclass(kw_only=False)
3640 class A:
3641 a: int = field(kw_only=False)
3642 self.assertFalse(fields(A)[0].kw_only)
3643
3644 #######################
3645 # Not specifying dataclass(kw_only)
3646 @dataclass
3647 class A:
3648 a: int
3649 self.assertFalse(fields(A)[0].kw_only)
3650
3651 @dataclass
3652 class A:
3653 a: int = field(kw_only=True)
3654 self.assertTrue(fields(A)[0].kw_only)
3655
3656 @dataclass
3657 class A:
3658 a: int = field(kw_only=False)
3659 self.assertFalse(fields(A)[0].kw_only)
3660
3661 def test_match_args(self):
3662 # kw fields don't show up in __match_args__.
3663 @dataclass(kw_only=True)
3664 class C:
3665 a: int
3666 self.assertEqual(C(a=42).__match_args__, ())
3667
3668 @dataclass
3669 class C:
3670 a: int
3671 b: int = field(kw_only=True)
3672 self.assertEqual(C(42, b=10).__match_args__, ('a',))
3673
3674 def test_KW_ONLY(self):
3675 @dataclass
3676 class A:
3677 a: int
3678 _: KW_ONLY
3679 b: int
3680 c: int
3681 A(3, c=5, b=4)
3682 msg = "takes 2 positional arguments but 4 were given"
3683 with self.assertRaisesRegex(TypeError, msg):
3684 A(3, 4, 5)
3685
3686
3687 @dataclass(kw_only=True)
3688 class B:
3689 a: int
3690 _: KW_ONLY
3691 b: int
3692 c: int
3693 B(a=3, b=4, c=5)
3694 msg = "takes 1 positional argument but 4 were given"
3695 with self.assertRaisesRegex(TypeError, msg):
3696 B(3, 4, 5)
3697
Christian Clausscfca4a62021-10-07 17:49:47 +02003698 # Explicitly make a field that follows KW_ONLY be non-keyword-only.
Eric V. Smithc0280532021-04-25 20:42:39 -04003699 @dataclass
3700 class C:
3701 a: int
3702 _: KW_ONLY
3703 b: int
3704 c: int = field(kw_only=False)
3705 c = C(1, 2, b=3)
3706 self.assertEqual(c.a, 1)
3707 self.assertEqual(c.b, 3)
3708 self.assertEqual(c.c, 2)
3709 c = C(1, b=3, c=2)
3710 self.assertEqual(c.a, 1)
3711 self.assertEqual(c.b, 3)
3712 self.assertEqual(c.c, 2)
3713 c = C(1, b=3, c=2)
3714 self.assertEqual(c.a, 1)
3715 self.assertEqual(c.b, 3)
3716 self.assertEqual(c.c, 2)
3717 c = C(c=2, b=3, a=1)
3718 self.assertEqual(c.a, 1)
3719 self.assertEqual(c.b, 3)
3720 self.assertEqual(c.c, 2)
3721
Eric V. Smith99ad7422021-05-03 03:24:53 -04003722 def test_KW_ONLY_as_string(self):
3723 @dataclass
3724 class A:
3725 a: int
3726 _: 'dataclasses.KW_ONLY'
3727 b: int
3728 c: int
3729 A(3, c=5, b=4)
3730 msg = "takes 2 positional arguments but 4 were given"
3731 with self.assertRaisesRegex(TypeError, msg):
3732 A(3, 4, 5)
3733
3734 def test_KW_ONLY_twice(self):
3735 msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
3736
3737 with self.assertRaisesRegex(TypeError, msg):
3738 @dataclass
3739 class A:
3740 a: int
3741 X: KW_ONLY
3742 Y: KW_ONLY
3743 b: int
3744 c: int
3745
3746 with self.assertRaisesRegex(TypeError, msg):
3747 @dataclass
3748 class A:
3749 a: int
3750 X: KW_ONLY
3751 b: int
3752 Y: KW_ONLY
3753 c: int
3754
3755 with self.assertRaisesRegex(TypeError, msg):
3756 @dataclass
3757 class A:
3758 a: int
3759 X: KW_ONLY
3760 b: int
3761 c: int
3762 Y: KW_ONLY
3763
3764 # But this usage is okay, since it's not using KW_ONLY.
3765 @dataclass
3766 class A:
3767 a: int
3768 _: KW_ONLY
3769 b: int
3770 c: int = field(kw_only=True)
3771
3772 # And if inheriting, it's okay.
3773 @dataclass
3774 class A:
3775 a: int
3776 _: KW_ONLY
3777 b: int
3778 c: int
3779 @dataclass
3780 class B(A):
3781 _: KW_ONLY
3782 d: int
3783
3784 # Make sure the error is raised in a derived class.
3785 with self.assertRaisesRegex(TypeError, msg):
3786 @dataclass
3787 class A:
3788 a: int
3789 _: KW_ONLY
3790 b: int
3791 c: int
3792 @dataclass
3793 class B(A):
3794 X: KW_ONLY
3795 d: int
3796 Y: KW_ONLY
3797
3798
Eric V. Smithc0280532021-04-25 20:42:39 -04003799 def test_post_init(self):
3800 @dataclass
3801 class A:
3802 a: int
3803 _: KW_ONLY
3804 b: InitVar[int]
3805 c: int
3806 d: InitVar[int]
3807 def __post_init__(self, b, d):
3808 raise CustomError(f'{b=} {d=}')
3809 with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
3810 A(1, c=2, b=3, d=4)
3811
3812 @dataclass
3813 class B:
3814 a: int
3815 _: KW_ONLY
3816 b: InitVar[int]
3817 c: int
3818 d: InitVar[int]
3819 def __post_init__(self, b, d):
3820 self.a = b
3821 self.c = d
3822 b = B(1, c=2, b=3, d=4)
3823 self.assertEqual(asdict(b), {'a': 3, 'c': 4})
3824
Eric V. Smith94549ee2021-04-26 13:14:28 -04003825 def test_defaults(self):
3826 # For kwargs, make sure we can have defaults after non-defaults.
3827 @dataclass
3828 class A:
3829 a: int = 0
3830 _: KW_ONLY
3831 b: int
3832 c: int = 1
3833 d: int
3834
3835 a = A(d=4, b=3)
3836 self.assertEqual(a.a, 0)
3837 self.assertEqual(a.b, 3)
3838 self.assertEqual(a.c, 1)
3839 self.assertEqual(a.d, 4)
3840
3841 # Make sure we still check for non-kwarg non-defaults not following
3842 # defaults.
3843 err_regex = "non-default argument 'z' follows default argument"
3844 with self.assertRaisesRegex(TypeError, err_regex):
3845 @dataclass
3846 class A:
3847 a: int = 0
3848 z: int
3849 _: KW_ONLY
3850 b: int
3851 c: int = 1
3852 d: int
Eric V. Smithc0280532021-04-25 20:42:39 -04003853
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003854if __name__ == '__main__':
3855 unittest.main()