blob: ef5009ab116776bdb0cab909894c4eb05574230f [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Miss Islington (bot)abceb662021-12-05 13:04:29 -080011import types
Eric V. Smithf0db54a2017-12-04 16:58:55 -050012import unittest
13from unittest.mock import Mock
Miss Islington (bot)79e9f5a2021-09-02 23:26:53 -070014from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
Yury Selivanovd219cc42019-12-09 09:54:20 -050015from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050016from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050017from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050018
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040019import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
20import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
21
Eric V. Smithf0db54a2017-12-04 16:58:55 -050022# Just any custom exception we can catch.
23class CustomError(Exception): pass
24
25class TestCase(unittest.TestCase):
26 def test_no_fields(self):
27 @dataclass
28 class C:
29 pass
30
31 o = C()
32 self.assertEqual(len(fields(C)), 0)
33
Eric V. Smith56970b82018-03-22 16:28:48 -040034 def test_no_fields_but_member_variable(self):
35 @dataclass
36 class C:
37 i = 0
38
39 o = C()
40 self.assertEqual(len(fields(C)), 0)
41
Eric V. Smithf0db54a2017-12-04 16:58:55 -050042 def test_one_field_no_default(self):
43 @dataclass
44 class C:
45 x: int
46
47 o = C(42)
48 self.assertEqual(o.x, 42)
49
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053050 def test_field_default_default_factory_error(self):
51 msg = "cannot specify both default and default_factory"
52 with self.assertRaisesRegex(ValueError, msg):
53 @dataclass
54 class C:
55 x: int = field(default=1, default_factory=int)
56
57 def test_field_repr(self):
58 int_field = field(default=1, init=True, repr=False)
59 int_field.name = "id"
60 repr_output = repr(int_field)
61 expected_output = "Field(name='id',type=None," \
62 f"default=1,default_factory={MISSING!r}," \
63 "init=True,repr=False,hash=None," \
64 "compare=True,metadata=mappingproxy({})," \
Eric V. Smithc0280532021-04-25 20:42:39 -040065 f"kw_only={MISSING!r}," \
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053066 "_field_type=None)"
67
68 self.assertEqual(repr_output, expected_output)
69
Eric V. Smithf0db54a2017-12-04 16:58:55 -050070 def test_named_init_params(self):
71 @dataclass
72 class C:
73 x: int
74
75 o = C(x=32)
76 self.assertEqual(o.x, 32)
77
78 def test_two_fields_one_default(self):
79 @dataclass
80 class C:
81 x: int
82 y: int = 0
83
84 o = C(3)
85 self.assertEqual((o.x, o.y), (3, 0))
86
87 # Non-defaults following defaults.
88 with self.assertRaisesRegex(TypeError,
89 "non-default argument 'y' follows "
90 "default argument"):
91 @dataclass
92 class C:
93 x: int = 0
94 y: int
95
96 # A derived class adds a non-default field after a default one.
97 with self.assertRaisesRegex(TypeError,
98 "non-default argument 'y' follows "
99 "default argument"):
100 @dataclass
101 class B:
102 x: int = 0
103
104 @dataclass
105 class C(B):
106 y: int
107
108 # Override a base class field and add a default to
109 # a field which didn't use to have a default.
110 with self.assertRaisesRegex(TypeError,
111 "non-default argument 'y' follows "
112 "default argument"):
113 @dataclass
114 class B:
115 x: int
116 y: int
117
118 @dataclass
119 class C(B):
120 x: int = 0
121
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500122 def test_overwrite_hash(self):
123 # Test that declaring this class isn't an error. It should
124 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500125 @dataclass(frozen=True)
126 class C:
127 x: int
128 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500129 return 301
130 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500131
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500132 # Test that declaring this class isn't an error. It should
133 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500134 @dataclass(frozen=True)
135 class C:
136 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500137 def __eq__(self, other):
138 return False
139 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500140
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500141 # But this one should generate an exception, because with
142 # unsafe_hash=True, it's an error to have a __hash__ defined.
143 with self.assertRaisesRegex(TypeError,
144 'Cannot overwrite attribute __hash__'):
145 @dataclass(unsafe_hash=True)
146 class C:
147 def __hash__(self):
148 pass
149
150 # Creating this class should not generate an exception,
151 # because even though __hash__ exists before @dataclass is
152 # called, (due to __eq__ being defined), since it's None
153 # that's okay.
154 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500155 class C:
156 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500157 def __eq__(self):
158 pass
159 # The generated hash function works as we'd expect.
160 self.assertEqual(hash(C(10)), hash((10,)))
161
162 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400163 # __hash__ exists and is not None, which it would be if it
164 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500165 with self.assertRaisesRegex(TypeError,
166 'Cannot overwrite attribute __hash__'):
167 @dataclass(unsafe_hash=True)
168 class C:
169 x: int
170 def __eq__(self):
171 pass
172 def __hash__(self):
173 pass
174
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500175 def test_overwrite_fields_in_derived_class(self):
176 # Note that x from C1 replaces x in Base, but the order remains
177 # the same as defined in Base.
178 @dataclass
179 class Base:
180 x: Any = 15.0
181 y: int = 0
182
183 @dataclass
184 class C1(Base):
185 z: int = 10
186 x: int = 15
187
188 o = Base()
189 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
190
191 o = C1()
192 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
193
194 o = C1(x=5)
195 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
196
197 def test_field_named_self(self):
198 @dataclass
199 class C:
200 self: str
201 c=C('foo')
202 self.assertEqual(c.self, 'foo')
203
204 # Make sure the first parameter is not named 'self'.
205 sig = inspect.signature(C.__init__)
206 first = next(iter(sig.parameters))
207 self.assertNotEqual('self', first)
208
209 # But we do use 'self' if no field named self.
210 @dataclass
211 class C:
212 selfx: str
213
214 # Make sure the first parameter is named 'self'.
215 sig = inspect.signature(C.__init__)
216 first = next(iter(sig.parameters))
217 self.assertEqual('self', first)
218
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300219 def test_field_named_object(self):
220 @dataclass
221 class C:
222 object: str
223 c = C('foo')
224 self.assertEqual(c.object, 'foo')
225
226 def test_field_named_object_frozen(self):
227 @dataclass(frozen=True)
228 class C:
229 object: str
230 c = C('foo')
231 self.assertEqual(c.object, 'foo')
232
233 def test_field_named_like_builtin(self):
234 # Attribute names can shadow built-in names
235 # since code generation is used.
236 # Ensure that this is not happening.
237 exclusions = {'None', 'True', 'False'}
238 builtins_names = sorted(
239 b for b in builtins.__dict__.keys()
240 if not b.startswith('__') and b not in exclusions
241 )
242 attributes = [(name, str) for name in builtins_names]
243 C = make_dataclass('C', attributes)
244
245 c = C(*[name for name in builtins_names])
246
247 for name in builtins_names:
248 self.assertEqual(getattr(c, name), name)
249
250 def test_field_named_like_builtin_frozen(self):
251 # Attribute names can shadow built-in names
252 # since code generation is used.
253 # Ensure that this is not happening
254 # for frozen data classes.
255 exclusions = {'None', 'True', 'False'}
256 builtins_names = sorted(
257 b for b in builtins.__dict__.keys()
258 if not b.startswith('__') and b not in exclusions
259 )
260 attributes = [(name, str) for name in builtins_names]
261 C = make_dataclass('C', attributes, frozen=True)
262
263 c = C(*[name for name in builtins_names])
264
265 for name in builtins_names:
266 self.assertEqual(getattr(c, name), name)
267
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500268 def test_0_field_compare(self):
269 # Ensure that order=False is the default.
270 @dataclass
271 class C0:
272 pass
273
274 @dataclass(order=False)
275 class C1:
276 pass
277
278 for cls in [C0, C1]:
279 with self.subTest(cls=cls):
280 self.assertEqual(cls(), cls())
281 for idx, fn in enumerate([lambda a, b: a < b,
282 lambda a, b: a <= b,
283 lambda a, b: a > b,
284 lambda a, b: a >= b]):
285 with self.subTest(idx=idx):
286 with self.assertRaisesRegex(TypeError,
287 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
288 fn(cls(), cls())
289
290 @dataclass(order=True)
291 class C:
292 pass
293 self.assertLessEqual(C(), C())
294 self.assertGreaterEqual(C(), C())
295
296 def test_1_field_compare(self):
297 # Ensure that order=False is the default.
298 @dataclass
299 class C0:
300 x: int
301
302 @dataclass(order=False)
303 class C1:
304 x: int
305
306 for cls in [C0, C1]:
307 with self.subTest(cls=cls):
308 self.assertEqual(cls(1), cls(1))
309 self.assertNotEqual(cls(0), cls(1))
310 for idx, fn in enumerate([lambda a, b: a < b,
311 lambda a, b: a <= b,
312 lambda a, b: a > b,
313 lambda a, b: a >= b]):
314 with self.subTest(idx=idx):
315 with self.assertRaisesRegex(TypeError,
316 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
317 fn(cls(0), cls(0))
318
319 @dataclass(order=True)
320 class C:
321 x: int
322 self.assertLess(C(0), C(1))
323 self.assertLessEqual(C(0), C(1))
324 self.assertLessEqual(C(1), C(1))
325 self.assertGreater(C(1), C(0))
326 self.assertGreaterEqual(C(1), C(0))
327 self.assertGreaterEqual(C(1), C(1))
328
329 def test_simple_compare(self):
330 # Ensure that order=False is the default.
331 @dataclass
332 class C0:
333 x: int
334 y: int
335
336 @dataclass(order=False)
337 class C1:
338 x: int
339 y: int
340
341 for cls in [C0, C1]:
342 with self.subTest(cls=cls):
343 self.assertEqual(cls(0, 0), cls(0, 0))
344 self.assertEqual(cls(1, 2), cls(1, 2))
345 self.assertNotEqual(cls(1, 0), cls(0, 0))
346 self.assertNotEqual(cls(1, 0), cls(1, 1))
347 for idx, fn in enumerate([lambda a, b: a < b,
348 lambda a, b: a <= b,
349 lambda a, b: a > b,
350 lambda a, b: a >= b]):
351 with self.subTest(idx=idx):
352 with self.assertRaisesRegex(TypeError,
353 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
354 fn(cls(0, 0), cls(0, 0))
355
356 @dataclass(order=True)
357 class C:
358 x: int
359 y: int
360
361 for idx, fn in enumerate([lambda a, b: a == b,
362 lambda a, b: a <= b,
363 lambda a, b: a >= b]):
364 with self.subTest(idx=idx):
365 self.assertTrue(fn(C(0, 0), C(0, 0)))
366
367 for idx, fn in enumerate([lambda a, b: a < b,
368 lambda a, b: a <= b,
369 lambda a, b: a != b]):
370 with self.subTest(idx=idx):
371 self.assertTrue(fn(C(0, 0), C(0, 1)))
372 self.assertTrue(fn(C(0, 1), C(1, 0)))
373 self.assertTrue(fn(C(1, 0), C(1, 1)))
374
375 for idx, fn in enumerate([lambda a, b: a > b,
376 lambda a, b: a >= b,
377 lambda a, b: a != b]):
378 with self.subTest(idx=idx):
379 self.assertTrue(fn(C(0, 1), C(0, 0)))
380 self.assertTrue(fn(C(1, 0), C(0, 1)))
381 self.assertTrue(fn(C(1, 1), C(1, 0)))
382
383 def test_compare_subclasses(self):
384 # Comparisons fail for subclasses, even if no fields
385 # are added.
386 @dataclass
387 class B:
388 i: int
389
390 @dataclass
391 class C(B):
392 pass
393
394 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
395 (lambda a, b: a != b, True)]):
396 with self.subTest(idx=idx):
397 self.assertEqual(fn(B(0), C(0)), expected)
398
399 for idx, fn in enumerate([lambda a, b: a < b,
400 lambda a, b: a <= b,
401 lambda a, b: a > b,
402 lambda a, b: a >= b]):
403 with self.subTest(idx=idx):
404 with self.assertRaisesRegex(TypeError,
405 "not supported between instances of 'B' and 'C'"):
406 fn(B(0), C(0))
407
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500408 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500409 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500410 for (eq, order, result ) in [
411 (False, False, 'neither'),
412 (False, True, 'exception'),
413 (True, False, 'eq_only'),
414 (True, True, 'both'),
415 ]:
416 with self.subTest(eq=eq, order=order):
417 if result == 'exception':
418 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
419 @dataclass(eq=eq, order=order)
420 class C:
421 pass
422 else:
423 @dataclass(eq=eq, order=order)
424 class C:
425 pass
426
427 if result == 'neither':
428 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500429 self.assertNotIn('__lt__', C.__dict__)
430 self.assertNotIn('__le__', C.__dict__)
431 self.assertNotIn('__gt__', C.__dict__)
432 self.assertNotIn('__ge__', C.__dict__)
433 elif result == 'both':
434 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500435 self.assertIn('__lt__', C.__dict__)
436 self.assertIn('__le__', C.__dict__)
437 self.assertIn('__gt__', C.__dict__)
438 self.assertIn('__ge__', C.__dict__)
439 elif result == 'eq_only':
440 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500441 self.assertNotIn('__lt__', C.__dict__)
442 self.assertNotIn('__le__', C.__dict__)
443 self.assertNotIn('__gt__', C.__dict__)
444 self.assertNotIn('__ge__', C.__dict__)
445 else:
446 assert False, f'unknown result {result!r}'
447
448 def test_field_no_default(self):
449 @dataclass
450 class C:
451 x: int = field()
452
453 self.assertEqual(C(5).x, 5)
454
455 with self.assertRaisesRegex(TypeError,
456 r"__init__\(\) missing 1 required "
457 "positional argument: 'x'"):
458 C()
459
460 def test_field_default(self):
461 default = object()
462 @dataclass
463 class C:
464 x: object = field(default=default)
465
466 self.assertIs(C.x, default)
467 c = C(10)
468 self.assertEqual(c.x, 10)
469
470 # If we delete the instance attribute, we should then see the
471 # class attribute.
472 del c.x
473 self.assertIs(c.x, default)
474
475 self.assertIs(C().x, default)
476
477 def test_not_in_repr(self):
478 @dataclass
479 class C:
480 x: int = field(repr=False)
481 with self.assertRaises(TypeError):
482 C()
483 c = C(10)
484 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
485
486 @dataclass
487 class C:
488 x: int = field(repr=False)
489 y: int
490 c = C(10, 20)
491 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
492
493 def test_not_in_compare(self):
494 @dataclass
495 class C:
496 x: int = 0
497 y: int = field(compare=False, default=4)
498
499 self.assertEqual(C(), C(0, 20))
500 self.assertEqual(C(1, 10), C(1, 20))
501 self.assertNotEqual(C(3), C(4, 10))
502 self.assertNotEqual(C(3, 10), C(4, 10))
503
504 def test_hash_field_rules(self):
505 # Test all 6 cases of:
506 # hash=True/False/None
507 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500508 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500509 (True, False, 'field' ),
510 (True, True, 'field' ),
511 (False, False, 'absent'),
512 (False, True, 'absent'),
513 (None, False, 'absent'),
514 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500515 ]:
516 with self.subTest(hash=hash_, compare=compare):
517 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500518 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500519 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500520
521 if result == 'field':
522 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500523 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500524 elif result == 'absent':
525 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500526 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500527 else:
528 assert False, f'unknown result {result!r}'
529
530 def test_init_false_no_default(self):
531 # If init=False and no default value, then the field won't be
532 # present in the instance.
533 @dataclass
534 class C:
535 x: int = field(init=False)
536
537 self.assertNotIn('x', C().__dict__)
538
539 @dataclass
540 class C:
541 x: int
542 y: int = 0
543 z: int = field(init=False)
544 t: int = 10
545
546 self.assertNotIn('z', C(0).__dict__)
547 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
548
549 def test_class_marker(self):
550 @dataclass
551 class C:
552 x: int
553 y: str = field(init=False, default=None)
554 z: str = field(repr=False)
555
556 the_fields = fields(C)
557 # the_fields is a tuple of 3 items, each value
558 # is in __annotations__.
559 self.assertIsInstance(the_fields, tuple)
560 for f in the_fields:
561 self.assertIs(type(f), Field)
562 self.assertIn(f.name, C.__annotations__)
563
564 self.assertEqual(len(the_fields), 3)
565
566 self.assertEqual(the_fields[0].name, 'x')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100567 self.assertEqual(the_fields[0].type, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500568 self.assertFalse(hasattr(C, 'x'))
569 self.assertTrue (the_fields[0].init)
570 self.assertTrue (the_fields[0].repr)
571 self.assertEqual(the_fields[1].name, 'y')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100572 self.assertEqual(the_fields[1].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500573 self.assertIsNone(getattr(C, 'y'))
574 self.assertFalse(the_fields[1].init)
575 self.assertTrue (the_fields[1].repr)
576 self.assertEqual(the_fields[2].name, 'z')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100577 self.assertEqual(the_fields[2].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500578 self.assertFalse(hasattr(C, 'z'))
579 self.assertTrue (the_fields[2].init)
580 self.assertFalse(the_fields[2].repr)
581
582 def test_field_order(self):
583 @dataclass
584 class B:
585 a: str = 'B:a'
586 b: str = 'B:b'
587 c: str = 'B:c'
588
589 @dataclass
590 class C(B):
591 b: str = 'C:b'
592
593 self.assertEqual([(f.name, f.default) for f in fields(C)],
594 [('a', 'B:a'),
595 ('b', 'C:b'),
596 ('c', 'B:c')])
597
598 @dataclass
599 class D(B):
600 c: str = 'D:c'
601
602 self.assertEqual([(f.name, f.default) for f in fields(D)],
603 [('a', 'B:a'),
604 ('b', 'B:b'),
605 ('c', 'D:c')])
606
607 @dataclass
608 class E(D):
609 a: str = 'E:a'
610 d: str = 'E:d'
611
612 self.assertEqual([(f.name, f.default) for f in fields(E)],
613 [('a', 'E:a'),
614 ('b', 'B:b'),
615 ('c', 'D:c'),
616 ('d', 'E:d')])
617
618 def test_class_attrs(self):
619 # We only have a class attribute if a default value is
620 # specified, either directly or via a field with a default.
621 default = object()
622 @dataclass
623 class C:
624 x: int
625 y: int = field(repr=False)
626 z: object = default
627 t: int = field(default=100)
628
629 self.assertFalse(hasattr(C, 'x'))
630 self.assertFalse(hasattr(C, 'y'))
631 self.assertIs (C.z, default)
632 self.assertEqual(C.t, 100)
633
634 def test_disallowed_mutable_defaults(self):
635 # For the known types, don't allow mutable default values.
636 for typ, empty, non_empty in [(list, [], [1]),
637 (dict, {}, {0:1}),
638 (set, set(), set([1])),
639 ]:
640 with self.subTest(typ=typ):
641 # Can't use a zero-length value.
642 with self.assertRaisesRegex(ValueError,
643 f'mutable default {typ} for field '
644 'x is not allowed'):
645 @dataclass
646 class Point:
647 x: typ = empty
648
649
650 # Nor a non-zero-length value
651 with self.assertRaisesRegex(ValueError,
652 f'mutable default {typ} for field '
653 'y is not allowed'):
654 @dataclass
655 class Point:
656 y: typ = non_empty
657
658 # Check subtypes also fail.
659 class Subclass(typ): pass
660
661 with self.assertRaisesRegex(ValueError,
662 f"mutable default .*Subclass'>"
663 ' for field z is not allowed'
664 ):
665 @dataclass
666 class Point:
667 z: typ = Subclass()
668
669 # Because this is a ClassVar, it can be mutable.
670 @dataclass
671 class C:
672 z: ClassVar[typ] = typ()
673
674 # Because this is a ClassVar, it can be mutable.
675 @dataclass
676 class C:
677 x: ClassVar[typ] = Subclass()
678
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500679 def test_deliberately_mutable_defaults(self):
680 # If a mutable default isn't in the known list of
681 # (list, dict, set), then it's okay.
682 class Mutable:
683 def __init__(self):
684 self.l = []
685
686 @dataclass
687 class C:
688 x: Mutable
689
690 # These 2 instances will share this value of x.
691 lst = Mutable()
692 o1 = C(lst)
693 o2 = C(lst)
694 self.assertEqual(o1, o2)
695 o1.x.l.extend([1, 2])
696 self.assertEqual(o1, o2)
697 self.assertEqual(o1.x.l, [1, 2])
698 self.assertIs(o1.x, o2.x)
699
700 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400701 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500702 @dataclass()
703 class C:
704 x: int
705
706 self.assertEqual(C(42).x, 42)
707
708 def test_not_tuple(self):
709 # Make sure we can't be compared to a tuple.
710 @dataclass
711 class Point:
712 x: int
713 y: int
714 self.assertNotEqual(Point(1, 2), (1, 2))
715
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400716 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500717 @dataclass
718 class C:
719 x: int
720 y: int
721 self.assertNotEqual(Point(1, 3), C(1, 3))
722
Windson yangbe372d72019-04-23 02:45:34 +0800723 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500724 # Test that some of the problems with namedtuple don't happen
725 # here.
726 @dataclass
727 class Point3D:
728 x: int
729 y: int
730 z: int
731
732 @dataclass
733 class Date:
734 year: int
735 month: int
736 day: int
737
738 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
739 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
740
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400741 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200742 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500743 x, y, z = Point3D(4, 5, 6)
744
Eric V. Smith7c99e932018-01-28 19:18:55 -0500745 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500746 # equal.
747 @dataclass
748 class Point3Dv1:
749 x: int = 0
750 y: int = 0
751 z: int = 0
752 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
753
754 def test_function_annotations(self):
755 # Some dummy class and instance to use as a default.
756 class F:
757 pass
758 f = F()
759
760 def validate_class(cls):
761 # First, check __annotations__, even though they're not
762 # function annotations.
Pablo Galindob0544ba2021-04-21 12:41:19 +0100763 self.assertEqual(cls.__annotations__['i'], int)
764 self.assertEqual(cls.__annotations__['j'], str)
765 self.assertEqual(cls.__annotations__['k'], F)
766 self.assertEqual(cls.__annotations__['l'], float)
767 self.assertEqual(cls.__annotations__['z'], complex)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500768
769 # Verify __init__.
770
771 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400772 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500773 self.assertIs(signature.return_annotation, None)
774
775 # Check each parameter.
776 params = iter(signature.parameters.values())
777 param = next(params)
778 # This is testing an internal name, and probably shouldn't be tested.
779 self.assertEqual(param.name, 'self')
780 param = next(params)
781 self.assertEqual(param.name, 'i')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100782 self.assertIs (param.annotation, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500783 self.assertEqual(param.default, inspect.Parameter.empty)
784 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
785 param = next(params)
786 self.assertEqual(param.name, 'j')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100787 self.assertIs (param.annotation, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500788 self.assertEqual(param.default, inspect.Parameter.empty)
789 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
790 param = next(params)
791 self.assertEqual(param.name, 'k')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100792 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400793 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500794 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
795 param = next(params)
796 self.assertEqual(param.name, 'l')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100797 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400798 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500799 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
800 self.assertRaises(StopIteration, next, params)
801
802
803 @dataclass
804 class C:
805 i: int
806 j: str
807 k: F = f
808 l: float=field(default=None)
809 z: complex=field(default=3+4j, init=False)
810
811 validate_class(C)
812
813 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500814 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500815 class C:
816 i: int
817 j: str
818 k: F = f
819 l: float=field(default=None)
820 z: complex=field(default=3+4j, init=False)
821
822 validate_class(C)
823
Eric V. Smith03220fd2017-12-29 13:59:58 -0500824 def test_missing_default(self):
825 # Test that MISSING works the same as a default not being
826 # specified.
827 @dataclass
828 class C:
829 x: int=field(default=MISSING)
830 with self.assertRaisesRegex(TypeError,
831 r'__init__\(\) missing 1 required '
832 'positional argument'):
833 C()
834 self.assertNotIn('x', C.__dict__)
835
836 @dataclass
837 class D:
838 x: int
839 with self.assertRaisesRegex(TypeError,
840 r'__init__\(\) missing 1 required '
841 'positional argument'):
842 D()
843 self.assertNotIn('x', D.__dict__)
844
845 def test_missing_default_factory(self):
846 # Test that MISSING works the same as a default factory not
847 # being specified (which is really the same as a default not
848 # being specified, too).
849 @dataclass
850 class C:
851 x: int=field(default_factory=MISSING)
852 with self.assertRaisesRegex(TypeError,
853 r'__init__\(\) missing 1 required '
854 'positional argument'):
855 C()
856 self.assertNotIn('x', C.__dict__)
857
858 @dataclass
859 class D:
860 x: int=field(default=MISSING, default_factory=MISSING)
861 with self.assertRaisesRegex(TypeError,
862 r'__init__\(\) missing 1 required '
863 'positional argument'):
864 D()
865 self.assertNotIn('x', D.__dict__)
866
867 def test_missing_repr(self):
868 self.assertIn('MISSING_TYPE object', repr(MISSING))
869
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500870 def test_dont_include_other_annotations(self):
871 @dataclass
872 class C:
873 i: int
874 def foo(self) -> int:
875 return 4
876 @property
877 def bar(self) -> int:
878 return 5
879 self.assertEqual(list(C.__annotations__), ['i'])
880 self.assertEqual(C(10).foo(), 4)
881 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400882 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500883
884 def test_post_init(self):
885 # Just make sure it gets called
886 @dataclass
887 class C:
888 def __post_init__(self):
889 raise CustomError()
890 with self.assertRaises(CustomError):
891 C()
892
893 @dataclass
894 class C:
895 i: int = 10
896 def __post_init__(self):
897 if self.i == 10:
898 raise CustomError()
899 with self.assertRaises(CustomError):
900 C()
901 # post-init gets called, but doesn't raise. This is just
902 # checking that self is used correctly.
903 C(5)
904
905 # If there's not an __init__, then post-init won't get called.
906 @dataclass(init=False)
907 class C:
908 def __post_init__(self):
909 raise CustomError()
910 # Creating the class won't raise
911 C()
912
913 @dataclass
914 class C:
915 x: int = 0
916 def __post_init__(self):
917 self.x *= 2
918 self.assertEqual(C().x, 0)
919 self.assertEqual(C(2).x, 4)
920
Mike53f7a7c2017-12-14 14:04:53 +0300921 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500922 # attributes.
923 @dataclass(frozen=True)
924 class C:
925 x: int = 0
926 def __post_init__(self):
927 self.x *= 2
928 with self.assertRaises(FrozenInstanceError):
929 C()
930
931 def test_post_init_super(self):
932 # Make sure super() post-init isn't called by default.
933 class B:
934 def __post_init__(self):
935 raise CustomError()
936
937 @dataclass
938 class C(B):
939 def __post_init__(self):
940 self.x = 5
941
942 self.assertEqual(C().x, 5)
943
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400944 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500945 @dataclass
946 class C(B):
947 def __post_init__(self):
948 super().__post_init__()
949
950 with self.assertRaises(CustomError):
951 C()
952
953 # Make sure post-init is called, even if not defined in our
954 # class.
955 @dataclass
956 class C(B):
957 pass
958
959 with self.assertRaises(CustomError):
960 C()
961
962 def test_post_init_staticmethod(self):
963 flag = False
964 @dataclass
965 class C:
966 x: int
967 y: int
968 @staticmethod
969 def __post_init__():
970 nonlocal flag
971 flag = True
972
973 self.assertFalse(flag)
974 c = C(3, 4)
975 self.assertEqual((c.x, c.y), (3, 4))
976 self.assertTrue(flag)
977
978 def test_post_init_classmethod(self):
979 @dataclass
980 class C:
981 flag = False
982 x: int
983 y: int
984 @classmethod
985 def __post_init__(cls):
986 cls.flag = True
987
988 self.assertFalse(C.flag)
989 c = C(3, 4)
990 self.assertEqual((c.x, c.y), (3, 4))
991 self.assertTrue(C.flag)
992
993 def test_class_var(self):
994 # Make sure ClassVars are ignored in __init__, __repr__, etc.
995 @dataclass
996 class C:
997 x: int
998 y: int = 10
999 z: ClassVar[int] = 1000
1000 w: ClassVar[int] = 2000
1001 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001002 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001003
1004 c = C(5)
1005 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001006 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001007 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001008 self.assertEqual(c.z, 1000)
1009 self.assertEqual(c.w, 2000)
1010 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001011 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001012 C.z += 1
1013 self.assertEqual(c.z, 1001)
1014 c = C(20)
1015 self.assertEqual((c.x, c.y), (20, 10))
1016 self.assertEqual(c.z, 1001)
1017 self.assertEqual(c.w, 2000)
1018 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001019 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001020
1021 def test_class_var_no_default(self):
1022 # If a ClassVar has no default value, it should not be set on the class.
1023 @dataclass
1024 class C:
1025 x: ClassVar[int]
1026
1027 self.assertNotIn('x', C.__dict__)
1028
1029 def test_class_var_default_factory(self):
1030 # It makes no sense for a ClassVar to have a default factory. When
1031 # would it be called? Call it yourself, since it's class-wide.
1032 with self.assertRaisesRegex(TypeError,
1033 'cannot have a default factory'):
1034 @dataclass
1035 class C:
1036 x: ClassVar[int] = field(default_factory=int)
1037
1038 self.assertNotIn('x', C.__dict__)
1039
1040 def test_class_var_with_default(self):
1041 # If a ClassVar has a default value, it should be set on the class.
1042 @dataclass
1043 class C:
1044 x: ClassVar[int] = 10
1045 self.assertEqual(C.x, 10)
1046
1047 @dataclass
1048 class C:
1049 x: ClassVar[int] = field(default=10)
1050 self.assertEqual(C.x, 10)
1051
1052 def test_class_var_frozen(self):
1053 # Make sure ClassVars work even if we're frozen.
1054 @dataclass(frozen=True)
1055 class C:
1056 x: int
1057 y: int = 10
1058 z: ClassVar[int] = 1000
1059 w: ClassVar[int] = 2000
1060 t: ClassVar[int] = 3000
1061
1062 c = C(5)
1063 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1064 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1065 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1066 self.assertEqual(c.z, 1000)
1067 self.assertEqual(c.w, 2000)
1068 self.assertEqual(c.t, 3000)
1069 # We can still modify the ClassVar, it's only instances that are
1070 # frozen.
1071 C.z += 1
1072 self.assertEqual(c.z, 1001)
1073 c = C(20)
1074 self.assertEqual((c.x, c.y), (20, 10))
1075 self.assertEqual(c.z, 1001)
1076 self.assertEqual(c.w, 2000)
1077 self.assertEqual(c.t, 3000)
1078
1079 def test_init_var_no_default(self):
1080 # If an InitVar has no default value, it should not be set on the class.
1081 @dataclass
1082 class C:
1083 x: InitVar[int]
1084
1085 self.assertNotIn('x', C.__dict__)
1086
1087 def test_init_var_default_factory(self):
1088 # It makes no sense for an InitVar to have a default factory. When
1089 # would it be called? Call it yourself, since it's class-wide.
1090 with self.assertRaisesRegex(TypeError,
1091 'cannot have a default factory'):
1092 @dataclass
1093 class C:
1094 x: InitVar[int] = field(default_factory=int)
1095
1096 self.assertNotIn('x', C.__dict__)
1097
1098 def test_init_var_with_default(self):
1099 # If an InitVar has a default value, it should be set on the class.
1100 @dataclass
1101 class C:
1102 x: InitVar[int] = 10
1103 self.assertEqual(C.x, 10)
1104
1105 @dataclass
1106 class C:
1107 x: InitVar[int] = field(default=10)
1108 self.assertEqual(C.x, 10)
1109
1110 def test_init_var(self):
1111 @dataclass
1112 class C:
1113 x: int = None
1114 init_param: InitVar[int] = None
1115
1116 def __post_init__(self, init_param):
1117 if self.x is None:
1118 self.x = init_param*2
1119
1120 c = C(init_param=10)
1121 self.assertEqual(c.x, 20)
1122
Augusto Hack01ee12b2019-06-02 23:14:48 -03001123 def test_init_var_preserve_type(self):
1124 self.assertEqual(InitVar[int].type, int)
1125
1126 # Make sure the repr is correct.
1127 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001128 self.assertEqual(repr(InitVar[List[int]]),
1129 'dataclasses.InitVar[typing.List[int]]')
Miss Islington (bot)f1dd5ed2021-12-05 13:02:47 -08001130 self.assertEqual(repr(InitVar[list[int]]),
1131 'dataclasses.InitVar[list[int]]')
1132 self.assertEqual(repr(InitVar[int|str]),
1133 'dataclasses.InitVar[int | str]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001134
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001135 def test_init_var_inheritance(self):
1136 # Note that this deliberately tests that a dataclass need not
1137 # have a __post_init__ function if it has an InitVar field.
1138 # It could just be used in a derived class, as shown here.
1139 @dataclass
1140 class Base:
1141 x: int
1142 init_base: InitVar[int]
1143
1144 # We can instantiate by passing the InitVar, even though
1145 # it's not used.
1146 b = Base(0, 10)
1147 self.assertEqual(vars(b), {'x': 0})
1148
1149 @dataclass
1150 class C(Base):
1151 y: int
1152 init_derived: InitVar[int]
1153
1154 def __post_init__(self, init_base, init_derived):
1155 self.x = self.x + init_base
1156 self.y = self.y + init_derived
1157
1158 c = C(10, 11, 50, 51)
1159 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1160
1161 def test_default_factory(self):
1162 # Test a factory that returns a new list.
1163 @dataclass
1164 class C:
1165 x: int
1166 y: list = field(default_factory=list)
1167
1168 c0 = C(3)
1169 c1 = C(3)
1170 self.assertEqual(c0.x, 3)
1171 self.assertEqual(c0.y, [])
1172 self.assertEqual(c0, c1)
1173 self.assertIsNot(c0.y, c1.y)
1174 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1175
1176 # Test a factory that returns a shared list.
1177 l = []
1178 @dataclass
1179 class C:
1180 x: int
1181 y: list = field(default_factory=lambda: l)
1182
1183 c0 = C(3)
1184 c1 = C(3)
1185 self.assertEqual(c0.x, 3)
1186 self.assertEqual(c0.y, [])
1187 self.assertEqual(c0, c1)
1188 self.assertIs(c0.y, c1.y)
1189 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1190
1191 # Test various other field flags.
1192 # repr
1193 @dataclass
1194 class C:
1195 x: list = field(default_factory=list, repr=False)
1196 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1197 self.assertEqual(C().x, [])
1198
1199 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001200 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001201 class C:
1202 x: list = field(default_factory=list, hash=False)
1203 self.assertEqual(astuple(C()), ([],))
1204 self.assertEqual(hash(C()), hash(()))
1205
1206 # init (see also test_default_factory_with_no_init)
1207 @dataclass
1208 class C:
1209 x: list = field(default_factory=list, init=False)
1210 self.assertEqual(astuple(C()), ([],))
1211
1212 # compare
1213 @dataclass
1214 class C:
1215 x: list = field(default_factory=list, compare=False)
1216 self.assertEqual(C(), C([1]))
1217
1218 def test_default_factory_with_no_init(self):
1219 # We need a factory with a side effect.
1220 factory = Mock()
1221
1222 @dataclass
1223 class C:
1224 x: list = field(default_factory=factory, init=False)
1225
1226 # Make sure the default factory is called for each new instance.
1227 C().x
1228 self.assertEqual(factory.call_count, 1)
1229 C().x
1230 self.assertEqual(factory.call_count, 2)
1231
1232 def test_default_factory_not_called_if_value_given(self):
1233 # We need a factory that we can test if it's been called.
1234 factory = Mock()
1235
1236 @dataclass
1237 class C:
1238 x: int = field(default_factory=factory)
1239
1240 # Make sure that if a field has a default factory function,
1241 # it's not called if a value is specified.
1242 C().x
1243 self.assertEqual(factory.call_count, 1)
1244 self.assertEqual(C(10).x, 10)
1245 self.assertEqual(factory.call_count, 1)
1246 C().x
1247 self.assertEqual(factory.call_count, 2)
1248
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001249 def test_default_factory_derived(self):
1250 # See bpo-32896.
1251 @dataclass
1252 class Foo:
1253 x: dict = field(default_factory=dict)
1254
1255 @dataclass
1256 class Bar(Foo):
1257 y: int = 1
1258
1259 self.assertEqual(Foo().x, {})
1260 self.assertEqual(Bar().x, {})
1261 self.assertEqual(Bar().y, 1)
1262
1263 @dataclass
1264 class Baz(Foo):
1265 pass
1266 self.assertEqual(Baz().x, {})
1267
1268 def test_intermediate_non_dataclass(self):
1269 # Test that an intermediate class that defines
1270 # annotations does not define fields.
1271
1272 @dataclass
1273 class A:
1274 x: int
1275
1276 class B(A):
1277 y: int
1278
1279 @dataclass
1280 class C(B):
1281 z: int
1282
1283 c = C(1, 3)
1284 self.assertEqual((c.x, c.z), (1, 3))
1285
1286 # .y was not initialized.
1287 with self.assertRaisesRegex(AttributeError,
1288 'object has no attribute'):
1289 c.y
1290
1291 # And if we again derive a non-dataclass, no fields are added.
1292 class D(C):
1293 t: int
1294 d = D(4, 5)
1295 self.assertEqual((d.x, d.z), (4, 5))
1296
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001297 def test_classvar_default_factory(self):
1298 # It's an error for a ClassVar to have a factory function.
1299 with self.assertRaisesRegex(TypeError,
1300 'cannot have a default factory'):
1301 @dataclass
1302 class C:
1303 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001304
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001305 def test_is_dataclass(self):
1306 class NotDataClass:
1307 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001308
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001309 self.assertFalse(is_dataclass(0))
1310 self.assertFalse(is_dataclass(int))
1311 self.assertFalse(is_dataclass(NotDataClass))
1312 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001313
1314 @dataclass
1315 class C:
1316 x: int
1317
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001318 @dataclass
1319 class D:
1320 d: C
1321 e: int
1322
1323 c = C(10)
1324 d = D(c, 4)
1325
1326 self.assertTrue(is_dataclass(C))
1327 self.assertTrue(is_dataclass(c))
1328 self.assertFalse(is_dataclass(c.x))
1329 self.assertTrue(is_dataclass(d.d))
1330 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001331
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001332 def test_is_dataclass_when_getattr_always_returns(self):
1333 # See bpo-37868.
1334 class A:
1335 def __getattr__(self, key):
1336 return 0
1337 self.assertFalse(is_dataclass(A))
1338 a = A()
1339
1340 # Also test for an instance attribute.
1341 class B:
1342 pass
1343 b = B()
1344 b.__dataclass_fields__ = []
1345
1346 for obj in a, b:
1347 with self.subTest(obj=obj):
1348 self.assertFalse(is_dataclass(obj))
1349
1350 # Indirect tests for _is_dataclass_instance().
1351 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1352 asdict(obj)
1353 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1354 astuple(obj)
1355 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1356 replace(obj, x=0)
1357
Miss Islington (bot)abceb662021-12-05 13:04:29 -08001358 def test_is_dataclass_genericalias(self):
1359 @dataclass
1360 class A(types.GenericAlias):
1361 origin: type
1362 args: type
1363 self.assertTrue(is_dataclass(A))
1364 a = A(list, int)
1365 self.assertTrue(is_dataclass(type(a)))
1366 self.assertTrue(is_dataclass(a))
1367
1368
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001369 def test_helper_fields_with_class_instance(self):
1370 # Check that we can call fields() on either a class or instance,
1371 # and get back the same thing.
1372 @dataclass
1373 class C:
1374 x: int
1375 y: float
1376
1377 self.assertEqual(fields(C), fields(C(0, 0.0)))
1378
1379 def test_helper_fields_exception(self):
1380 # Check that TypeError is raised if not passed a dataclass or
1381 # instance.
1382 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1383 fields(0)
1384
1385 class C: pass
1386 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1387 fields(C)
1388 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1389 fields(C())
1390
1391 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001392 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001393 @dataclass
1394 class C:
1395 x: int
1396 y: int
1397 c = C(1, 2)
1398
1399 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1400 self.assertEqual(asdict(c), asdict(c))
1401 self.assertIsNot(asdict(c), asdict(c))
1402 c.x = 42
1403 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1404 self.assertIs(type(asdict(c)), dict)
1405
1406 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001407 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001408 @dataclass
1409 class C:
1410 x: int
1411 y: int
1412 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1413 asdict(C)
1414 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1415 asdict(int)
1416
1417 def test_helper_asdict_copy_values(self):
1418 @dataclass
1419 class C:
1420 x: int
1421 y: List[int] = field(default_factory=list)
1422 initial = []
1423 c = C(1, initial)
1424 d = asdict(c)
1425 self.assertEqual(d['y'], initial)
1426 self.assertIsNot(d['y'], initial)
1427 c = C(1)
1428 d = asdict(c)
1429 d['y'].append(1)
1430 self.assertEqual(c.y, [])
1431
1432 def test_helper_asdict_nested(self):
1433 @dataclass
1434 class UserId:
1435 token: int
1436 group: int
1437 @dataclass
1438 class User:
1439 name: str
1440 id: UserId
1441 u = User('Joe', UserId(123, 1))
1442 d = asdict(u)
1443 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1444 self.assertIsNot(asdict(u), asdict(u))
1445 u.id.group = 2
1446 self.assertEqual(asdict(u), {'name': 'Joe',
1447 'id': {'token': 123, 'group': 2}})
1448
1449 def test_helper_asdict_builtin_containers(self):
1450 @dataclass
1451 class User:
1452 name: str
1453 id: int
1454 @dataclass
1455 class GroupList:
1456 id: int
1457 users: List[User]
1458 @dataclass
1459 class GroupTuple:
1460 id: int
1461 users: Tuple[User, ...]
1462 @dataclass
1463 class GroupDict:
1464 id: int
1465 users: Dict[str, User]
1466 a = User('Alice', 1)
1467 b = User('Bob', 2)
1468 gl = GroupList(0, [a, b])
1469 gt = GroupTuple(0, (a, b))
1470 gd = GroupDict(0, {'first': a, 'second': b})
1471 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1472 {'name': 'Bob', 'id': 2}]})
1473 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1474 {'name': 'Bob', 'id': 2})})
1475 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1476 'second': {'name': 'Bob', 'id': 2}}})
1477
Windson yangbe372d72019-04-23 02:45:34 +08001478 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001479 @dataclass
1480 class Child:
1481 d: object
1482
1483 @dataclass
1484 class Parent:
1485 child: Child
1486
1487 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1488 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1489
1490 def test_helper_asdict_factory(self):
1491 @dataclass
1492 class C:
1493 x: int
1494 y: int
1495 c = C(1, 2)
1496 d = asdict(c, dict_factory=OrderedDict)
1497 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1498 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1499 c.x = 42
1500 d = asdict(c, dict_factory=OrderedDict)
1501 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1502 self.assertIs(type(d), OrderedDict)
1503
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001504 def test_helper_asdict_namedtuple(self):
1505 T = namedtuple('T', 'a b c')
1506 @dataclass
1507 class C:
1508 x: str
1509 y: T
1510 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1511
1512 d = asdict(c)
1513 self.assertEqual(d, {'x': 'outer',
1514 'y': T(1,
1515 {'x': 'inner',
1516 'y': T(11, 12, 13)},
1517 2),
1518 }
1519 )
1520
1521 # Now with a dict_factory. OrderedDict is convenient, but
1522 # since it compares to dicts, we also need to have separate
1523 # assertIs tests.
1524 d = asdict(c, dict_factory=OrderedDict)
1525 self.assertEqual(d, {'x': 'outer',
1526 'y': T(1,
1527 {'x': 'inner',
1528 'y': T(11, 12, 13)},
1529 2),
1530 }
1531 )
1532
penguindustin96466302019-05-06 14:57:17 -04001533 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001534 self.assertIs(type(d), OrderedDict)
1535 self.assertIs(type(d['y'][1]), OrderedDict)
1536
1537 def test_helper_asdict_namedtuple_key(self):
1538 # Ensure that a field that contains a dict which has a
1539 # namedtuple as a key works with asdict().
1540
1541 @dataclass
1542 class C:
1543 f: dict
1544 T = namedtuple('T', 'a')
1545
1546 c = C({T('an a'): 0})
1547
1548 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1549
1550 def test_helper_asdict_namedtuple_derived(self):
1551 class T(namedtuple('Tbase', 'a')):
1552 def my_a(self):
1553 return self.a
1554
1555 @dataclass
1556 class C:
1557 f: T
1558
1559 t = T(6)
1560 c = C(t)
1561
1562 d = asdict(c)
1563 self.assertEqual(d, {'f': T(a=6)})
1564 # Make sure that t has been copied, not used directly.
1565 self.assertIsNot(d['f'], t)
1566 self.assertEqual(d['f'].my_a(), 6)
1567
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001568 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001569 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001570 @dataclass
1571 class C:
1572 x: int
1573 y: int = 0
1574 c = C(1)
1575
1576 self.assertEqual(astuple(c), (1, 0))
1577 self.assertEqual(astuple(c), astuple(c))
1578 self.assertIsNot(astuple(c), astuple(c))
1579 c.y = 42
1580 self.assertEqual(astuple(c), (1, 42))
1581 self.assertIs(type(astuple(c)), tuple)
1582
1583 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001584 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001585 @dataclass
1586 class C:
1587 x: int
1588 y: int
1589 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1590 astuple(C)
1591 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1592 astuple(int)
1593
1594 def test_helper_astuple_copy_values(self):
1595 @dataclass
1596 class C:
1597 x: int
1598 y: List[int] = field(default_factory=list)
1599 initial = []
1600 c = C(1, initial)
1601 t = astuple(c)
1602 self.assertEqual(t[1], initial)
1603 self.assertIsNot(t[1], initial)
1604 c = C(1)
1605 t = astuple(c)
1606 t[1].append(1)
1607 self.assertEqual(c.y, [])
1608
1609 def test_helper_astuple_nested(self):
1610 @dataclass
1611 class UserId:
1612 token: int
1613 group: int
1614 @dataclass
1615 class User:
1616 name: str
1617 id: UserId
1618 u = User('Joe', UserId(123, 1))
1619 t = astuple(u)
1620 self.assertEqual(t, ('Joe', (123, 1)))
1621 self.assertIsNot(astuple(u), astuple(u))
1622 u.id.group = 2
1623 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1624
1625 def test_helper_astuple_builtin_containers(self):
1626 @dataclass
1627 class User:
1628 name: str
1629 id: int
1630 @dataclass
1631 class GroupList:
1632 id: int
1633 users: List[User]
1634 @dataclass
1635 class GroupTuple:
1636 id: int
1637 users: Tuple[User, ...]
1638 @dataclass
1639 class GroupDict:
1640 id: int
1641 users: Dict[str, User]
1642 a = User('Alice', 1)
1643 b = User('Bob', 2)
1644 gl = GroupList(0, [a, b])
1645 gt = GroupTuple(0, (a, b))
1646 gd = GroupDict(0, {'first': a, 'second': b})
1647 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1648 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1649 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1650
Windson yangbe372d72019-04-23 02:45:34 +08001651 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001652 @dataclass
1653 class Child:
1654 d: object
1655
1656 @dataclass
1657 class Parent:
1658 child: Child
1659
1660 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1661 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1662
1663 def test_helper_astuple_factory(self):
1664 @dataclass
1665 class C:
1666 x: int
1667 y: int
1668 NT = namedtuple('NT', 'x y')
1669 def nt(lst):
1670 return NT(*lst)
1671 c = C(1, 2)
1672 t = astuple(c, tuple_factory=nt)
1673 self.assertEqual(t, NT(1, 2))
1674 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1675 c.x = 42
1676 t = astuple(c, tuple_factory=nt)
1677 self.assertEqual(t, NT(42, 2))
1678 self.assertIs(type(t), NT)
1679
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001680 def test_helper_astuple_namedtuple(self):
1681 T = namedtuple('T', 'a b c')
1682 @dataclass
1683 class C:
1684 x: str
1685 y: T
1686 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1687
1688 t = astuple(c)
1689 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1690
1691 # Now, using a tuple_factory. list is convenient here.
1692 t = astuple(c, tuple_factory=list)
1693 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1694
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001695 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001696 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001697 }
1698
1699 # Create the class.
1700 cls = type('C', (), cls_dict)
1701
1702 # Make it a dataclass.
1703 cls1 = dataclass(cls)
1704
1705 self.assertEqual(cls1, cls)
1706 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1707
1708 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001709 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001710 'y': field(default=5),
1711 }
1712
1713 # Create the class.
1714 cls = type('C', (), cls_dict)
1715
1716 # Make it a dataclass.
1717 cls1 = dataclass(cls)
1718
1719 self.assertEqual(cls1, cls)
1720 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1721
1722 def test_init_in_order(self):
1723 @dataclass
1724 class C:
1725 a: int
1726 b: int = field()
1727 c: list = field(default_factory=list, init=False)
1728 d: list = field(default_factory=list)
1729 e: int = field(default=4, init=False)
1730 f: int = 4
1731
1732 calls = []
1733 def setattr(self, name, value):
1734 calls.append((name, value))
1735
1736 C.__setattr__ = setattr
1737 c = C(0, 1)
1738 self.assertEqual(('a', 0), calls[0])
1739 self.assertEqual(('b', 1), calls[1])
1740 self.assertEqual(('c', []), calls[2])
1741 self.assertEqual(('d', []), calls[3])
1742 self.assertNotIn(('e', 4), calls)
1743 self.assertEqual(('f', 4), calls[4])
1744
1745 def test_items_in_dicts(self):
1746 @dataclass
1747 class C:
1748 a: int
1749 b: list = field(default_factory=list, init=False)
1750 c: list = field(default_factory=list)
1751 d: int = field(default=4, init=False)
1752 e: int = 0
1753
1754 c = C(0)
1755 # Class dict
1756 self.assertNotIn('a', C.__dict__)
1757 self.assertNotIn('b', C.__dict__)
1758 self.assertNotIn('c', C.__dict__)
1759 self.assertIn('d', C.__dict__)
1760 self.assertEqual(C.d, 4)
1761 self.assertIn('e', C.__dict__)
1762 self.assertEqual(C.e, 0)
1763 # Instance dict
1764 self.assertIn('a', c.__dict__)
1765 self.assertEqual(c.a, 0)
1766 self.assertIn('b', c.__dict__)
1767 self.assertEqual(c.b, [])
1768 self.assertIn('c', c.__dict__)
1769 self.assertEqual(c.c, [])
1770 self.assertNotIn('d', c.__dict__)
1771 self.assertIn('e', c.__dict__)
1772 self.assertEqual(c.e, 0)
1773
1774 def test_alternate_classmethod_constructor(self):
1775 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001776 # alternate constructor. This is mostly an example to show
1777 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001778 @dataclass
1779 class C:
1780 x: int
1781 @classmethod
1782 def from_file(cls, filename):
1783 # In a real example, create a new instance
1784 # and populate 'x' from contents of a file.
1785 value_in_file = 20
1786 return cls(value_in_file)
1787
1788 self.assertEqual(C.from_file('filename').x, 20)
1789
1790 def test_field_metadata_default(self):
1791 # Make sure the default metadata is read-only and of
1792 # zero length.
1793 @dataclass
1794 class C:
1795 i: int
1796
1797 self.assertFalse(fields(C)[0].metadata)
1798 self.assertEqual(len(fields(C)[0].metadata), 0)
1799 with self.assertRaisesRegex(TypeError,
1800 'does not support item assignment'):
1801 fields(C)[0].metadata['test'] = 3
1802
1803 def test_field_metadata_mapping(self):
1804 # Make sure only a mapping can be passed as metadata
1805 # zero length.
1806 with self.assertRaises(TypeError):
1807 @dataclass
1808 class C:
1809 i: int = field(metadata=0)
1810
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001811 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001812 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001813 @dataclass
1814 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001815 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001816 self.assertFalse(fields(C)[0].metadata)
1817 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001818 # Update should work (see bpo-35960).
1819 d['foo'] = 1
1820 self.assertEqual(len(fields(C)[0].metadata), 1)
1821 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001822 with self.assertRaisesRegex(TypeError,
1823 'does not support item assignment'):
1824 fields(C)[0].metadata['test'] = 3
1825
1826 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001827 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001828 @dataclass
1829 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001830 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001831 self.assertEqual(len(fields(C)[0].metadata), 3)
1832 self.assertEqual(fields(C)[0].metadata['test'], 10)
1833 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1834 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001835 # Update should work.
1836 d['foo'] = 1
1837 self.assertEqual(len(fields(C)[0].metadata), 4)
1838 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001839 with self.assertRaises(KeyError):
1840 # Non-existent key.
1841 fields(C)[0].metadata['baz']
1842 with self.assertRaisesRegex(TypeError,
1843 'does not support item assignment'):
1844 fields(C)[0].metadata['test'] = 3
1845
1846 def test_field_metadata_custom_mapping(self):
1847 # Try a custom mapping.
1848 class SimpleNameSpace:
1849 def __init__(self, **kw):
1850 self.__dict__.update(kw)
1851
1852 def __getitem__(self, item):
1853 if item == 'xyzzy':
1854 return 'plugh'
1855 return getattr(self, item)
1856
1857 def __len__(self):
1858 return self.__dict__.__len__()
1859
1860 @dataclass
1861 class C:
1862 i: int = field(metadata=SimpleNameSpace(a=10))
1863
1864 self.assertEqual(len(fields(C)[0].metadata), 1)
1865 self.assertEqual(fields(C)[0].metadata['a'], 10)
1866 with self.assertRaises(AttributeError):
1867 fields(C)[0].metadata['b']
1868 # Make sure we're still talking to our custom mapping.
1869 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1870
1871 def test_generic_dataclasses(self):
1872 T = TypeVar('T')
1873
1874 @dataclass
1875 class LabeledBox(Generic[T]):
1876 content: T
1877 label: str = '<unknown>'
1878
1879 box = LabeledBox(42)
1880 self.assertEqual(box.content, 42)
1881 self.assertEqual(box.label, '<unknown>')
1882
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001883 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001884 Alias = List[LabeledBox[int]]
1885
1886 def test_generic_extending(self):
1887 S = TypeVar('S')
1888 T = TypeVar('T')
1889
1890 @dataclass
1891 class Base(Generic[T, S]):
1892 x: T
1893 y: S
1894
1895 @dataclass
1896 class DataDerived(Base[int, T]):
1897 new_field: str
1898 Alias = DataDerived[str]
1899 c = Alias(0, 'test1', 'test2')
1900 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1901
1902 class NonDataDerived(Base[int, T]):
1903 def new_method(self):
1904 return self.y
1905 Alias = NonDataDerived[float]
1906 c = Alias(10, 1.0)
1907 self.assertEqual(c.new_method(), 1.0)
1908
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001909 def test_generic_dynamic(self):
1910 T = TypeVar('T')
1911
1912 @dataclass
1913 class Parent(Generic[T]):
1914 x: T
1915 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1916 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1917 self.assertIs(Child[int](1, 2).z, None)
1918 self.assertEqual(Child[int](1, 2, 3).z, 3)
1919 self.assertEqual(Child[int](1, 2, 3).other, 42)
1920 # Check that type aliases work correctly.
1921 Alias = Child[T]
1922 self.assertEqual(Alias[int](1, 2).x, 1)
1923 # Check MRO resolution.
1924 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1925
Miss Islington (bot)e086bfe2021-10-09 12:50:45 -07001926 def test_dataclasses_pickleable(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001927 global P, Q, R
1928 @dataclass
1929 class P:
1930 x: int
1931 y: int = 0
1932 @dataclass
1933 class Q:
1934 x: int
1935 y: int = field(default=0, init=False)
1936 @dataclass
1937 class R:
1938 x: int
1939 y: List[int] = field(default_factory=list)
1940 q = Q(1)
1941 q.y = 2
1942 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1943 for sample in samples:
1944 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1945 with self.subTest(sample=sample, proto=proto):
1946 new_sample = pickle.loads(pickle.dumps(sample, proto))
1947 self.assertEqual(sample.x, new_sample.x)
1948 self.assertEqual(sample.y, new_sample.y)
1949 self.assertIsNot(sample, new_sample)
1950 new_sample.x = 42
1951 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1952 self.assertEqual(new_sample.x, another_new_sample.x)
1953 self.assertEqual(sample.y, another_new_sample.y)
1954
Batuhan Taskayac7437e22020-10-21 16:49:22 +03001955 def test_dataclasses_qualnames(self):
1956 @dataclass(order=True, unsafe_hash=True, frozen=True)
1957 class A:
1958 x: int
1959 y: int
1960
1961 self.assertEqual(A.__init__.__name__, "__init__")
1962 for function in (
1963 '__eq__',
1964 '__lt__',
1965 '__le__',
1966 '__gt__',
1967 '__ge__',
1968 '__hash__',
1969 '__init__',
1970 '__repr__',
1971 '__setattr__',
1972 '__delattr__',
1973 ):
1974 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
1975
1976 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
1977 A()
1978
Eric V. Smithea8fc522018-01-27 19:07:40 -05001979
Eric V. Smith56970b82018-03-22 16:28:48 -04001980class TestFieldNoAnnotation(unittest.TestCase):
1981 def test_field_without_annotation(self):
1982 with self.assertRaisesRegex(TypeError,
1983 "'f' is a field but has no type annotation"):
1984 @dataclass
1985 class C:
1986 f = field()
1987
1988 def test_field_without_annotation_but_annotation_in_base(self):
1989 @dataclass
1990 class B:
1991 f: int
1992
1993 with self.assertRaisesRegex(TypeError,
1994 "'f' is a field but has no type annotation"):
1995 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001996 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001997 @dataclass
1998 class C(B):
1999 f = field()
2000
2001 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
2002 # Same test, but with the base class not a dataclass.
2003 class B:
2004 f: int
2005
2006 with self.assertRaisesRegex(TypeError,
2007 "'f' is a field but has no type annotation"):
2008 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002009 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04002010 @dataclass
2011 class C(B):
2012 f = field()
2013
2014
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002015class TestDocString(unittest.TestCase):
2016 def assertDocStrEqual(self, a, b):
2017 # Because 3.6 and 3.7 differ in how inspect.signature work
2018 # (see bpo #32108), for the time being just compare them with
2019 # whitespace stripped.
2020 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2021
2022 def test_existing_docstring_not_overridden(self):
2023 @dataclass
2024 class C:
2025 """Lorem ipsum"""
2026 x: int
2027
2028 self.assertEqual(C.__doc__, "Lorem ipsum")
2029
2030 def test_docstring_no_fields(self):
2031 @dataclass
2032 class C:
2033 pass
2034
2035 self.assertDocStrEqual(C.__doc__, "C()")
2036
2037 def test_docstring_one_field(self):
2038 @dataclass
2039 class C:
2040 x: int
2041
2042 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2043
2044 def test_docstring_two_fields(self):
2045 @dataclass
2046 class C:
2047 x: int
2048 y: int
2049
2050 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2051
2052 def test_docstring_three_fields(self):
2053 @dataclass
2054 class C:
2055 x: int
2056 y: int
2057 z: str
2058
2059 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2060
2061 def test_docstring_one_field_with_default(self):
2062 @dataclass
2063 class C:
2064 x: int = 3
2065
2066 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2067
2068 def test_docstring_one_field_with_default_none(self):
2069 @dataclass
2070 class C:
2071 x: Union[int, type(None)] = None
2072
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002073 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002074
2075 def test_docstring_list_field(self):
2076 @dataclass
2077 class C:
2078 x: List[int]
2079
2080 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2081
2082 def test_docstring_list_field_with_default_factory(self):
2083 @dataclass
2084 class C:
2085 x: List[int] = field(default_factory=list)
2086
2087 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2088
2089 def test_docstring_deque_field(self):
2090 @dataclass
2091 class C:
2092 x: deque
2093
2094 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2095
2096 def test_docstring_deque_field_with_default_factory(self):
2097 @dataclass
2098 class C:
2099 x: deque = field(default_factory=deque)
2100
2101 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2102
2103
Eric V. Smithea8fc522018-01-27 19:07:40 -05002104class TestInit(unittest.TestCase):
2105 def test_base_has_init(self):
2106 class B:
2107 def __init__(self):
2108 self.z = 100
2109 pass
2110
2111 # Make sure that declaring this class doesn't raise an error.
2112 # The issue is that we can't override __init__ in our class,
2113 # but it should be okay to add __init__ to us if our base has
2114 # an __init__.
2115 @dataclass
2116 class C(B):
2117 x: int = 0
2118 c = C(10)
2119 self.assertEqual(c.x, 10)
2120 self.assertNotIn('z', vars(c))
2121
2122 # Make sure that if we don't add an init, the base __init__
2123 # gets called.
2124 @dataclass(init=False)
2125 class C(B):
2126 x: int = 10
2127 c = C()
2128 self.assertEqual(c.x, 10)
2129 self.assertEqual(c.z, 100)
2130
2131 def test_no_init(self):
2132 dataclass(init=False)
2133 class C:
2134 i: int = 0
2135 self.assertEqual(C().i, 0)
2136
2137 dataclass(init=False)
2138 class C:
2139 i: int = 2
2140 def __init__(self):
2141 self.i = 3
2142 self.assertEqual(C().i, 3)
2143
2144 def test_overwriting_init(self):
2145 # If the class has __init__, use it no matter the value of
2146 # init=.
2147
2148 @dataclass
2149 class C:
2150 x: int
2151 def __init__(self, x):
2152 self.x = 2 * x
2153 self.assertEqual(C(3).x, 6)
2154
2155 @dataclass(init=True)
2156 class C:
2157 x: int
2158 def __init__(self, x):
2159 self.x = 2 * x
2160 self.assertEqual(C(4).x, 8)
2161
2162 @dataclass(init=False)
2163 class C:
2164 x: int
2165 def __init__(self, x):
2166 self.x = 2 * x
2167 self.assertEqual(C(5).x, 10)
2168
Miss Islington (bot)79e9f5a2021-09-02 23:26:53 -07002169 def test_inherit_from_protocol(self):
2170 # Dataclasses inheriting from protocol should preserve their own `__init__`.
2171 # See bpo-45081.
2172
2173 class P(Protocol):
2174 a: int
2175
2176 @dataclass
2177 class C(P):
2178 a: int
2179
2180 self.assertEqual(C(5).a, 5)
2181
2182 @dataclass
2183 class D(P):
2184 def __init__(self, a):
2185 self.a = a * 2
2186
2187 self.assertEqual(D(5).a, 10)
2188
Eric V. Smithea8fc522018-01-27 19:07:40 -05002189
2190class TestRepr(unittest.TestCase):
2191 def test_repr(self):
2192 @dataclass
2193 class B:
2194 x: int
2195
2196 @dataclass
2197 class C(B):
2198 y: int = 10
2199
2200 o = C(4)
2201 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2202
2203 @dataclass
2204 class D(C):
2205 x: int = 20
2206 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2207
2208 @dataclass
2209 class C:
2210 @dataclass
2211 class D:
2212 i: int
2213 @dataclass
2214 class E:
2215 pass
2216 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2217 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2218
2219 def test_no_repr(self):
2220 # Test a class with no __repr__ and repr=False.
2221 @dataclass(repr=False)
2222 class C:
2223 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002224 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002225 repr(C(3)))
2226
2227 # Test a class with a __repr__ and repr=False.
2228 @dataclass(repr=False)
2229 class C:
2230 x: int
2231 def __repr__(self):
2232 return 'C-class'
2233 self.assertEqual(repr(C(3)), 'C-class')
2234
2235 def test_overwriting_repr(self):
2236 # If the class has __repr__, use it no matter the value of
2237 # repr=.
2238
2239 @dataclass
2240 class C:
2241 x: int
2242 def __repr__(self):
2243 return 'x'
2244 self.assertEqual(repr(C(0)), 'x')
2245
2246 @dataclass(repr=True)
2247 class C:
2248 x: int
2249 def __repr__(self):
2250 return 'x'
2251 self.assertEqual(repr(C(0)), 'x')
2252
2253 @dataclass(repr=False)
2254 class C:
2255 x: int
2256 def __repr__(self):
2257 return 'x'
2258 self.assertEqual(repr(C(0)), 'x')
2259
2260
Eric V. Smithea8fc522018-01-27 19:07:40 -05002261class TestEq(unittest.TestCase):
2262 def test_no_eq(self):
2263 # Test a class with no __eq__ and eq=False.
2264 @dataclass(eq=False)
2265 class C:
2266 x: int
2267 self.assertNotEqual(C(0), C(0))
2268 c = C(3)
2269 self.assertEqual(c, c)
2270
2271 # Test a class with an __eq__ and eq=False.
2272 @dataclass(eq=False)
2273 class C:
2274 x: int
2275 def __eq__(self, other):
2276 return other == 10
2277 self.assertEqual(C(3), 10)
2278
2279 def test_overwriting_eq(self):
2280 # If the class has __eq__, use it no matter the value of
2281 # eq=.
2282
2283 @dataclass
2284 class C:
2285 x: int
2286 def __eq__(self, other):
2287 return other == 3
2288 self.assertEqual(C(1), 3)
2289 self.assertNotEqual(C(1), 1)
2290
2291 @dataclass(eq=True)
2292 class C:
2293 x: int
2294 def __eq__(self, other):
2295 return other == 4
2296 self.assertEqual(C(1), 4)
2297 self.assertNotEqual(C(1), 1)
2298
2299 @dataclass(eq=False)
2300 class C:
2301 x: int
2302 def __eq__(self, other):
2303 return other == 5
2304 self.assertEqual(C(1), 5)
2305 self.assertNotEqual(C(1), 1)
2306
2307
2308class TestOrdering(unittest.TestCase):
2309 def test_functools_total_ordering(self):
2310 # Test that functools.total_ordering works with this class.
2311 @total_ordering
2312 @dataclass
2313 class C:
2314 x: int
2315 def __lt__(self, other):
2316 # Perform the test "backward", just to make
2317 # sure this is being called.
2318 return self.x >= other
2319
2320 self.assertLess(C(0), -1)
2321 self.assertLessEqual(C(0), -1)
2322 self.assertGreater(C(0), 1)
2323 self.assertGreaterEqual(C(0), 1)
2324
2325 def test_no_order(self):
2326 # Test that no ordering functions are added by default.
2327 @dataclass(order=False)
2328 class C:
2329 x: int
2330 # Make sure no order methods are added.
2331 self.assertNotIn('__le__', C.__dict__)
2332 self.assertNotIn('__lt__', C.__dict__)
2333 self.assertNotIn('__ge__', C.__dict__)
2334 self.assertNotIn('__gt__', C.__dict__)
2335
2336 # Test that __lt__ is still called
2337 @dataclass(order=False)
2338 class C:
2339 x: int
2340 def __lt__(self, other):
2341 return False
2342 # Make sure other methods aren't added.
2343 self.assertNotIn('__le__', C.__dict__)
2344 self.assertNotIn('__ge__', C.__dict__)
2345 self.assertNotIn('__gt__', C.__dict__)
2346
2347 def test_overwriting_order(self):
2348 with self.assertRaisesRegex(TypeError,
2349 'Cannot overwrite attribute __lt__'
2350 '.*using functools.total_ordering'):
2351 @dataclass(order=True)
2352 class C:
2353 x: int
2354 def __lt__(self):
2355 pass
2356
2357 with self.assertRaisesRegex(TypeError,
2358 'Cannot overwrite attribute __le__'
2359 '.*using functools.total_ordering'):
2360 @dataclass(order=True)
2361 class C:
2362 x: int
2363 def __le__(self):
2364 pass
2365
2366 with self.assertRaisesRegex(TypeError,
2367 'Cannot overwrite attribute __gt__'
2368 '.*using functools.total_ordering'):
2369 @dataclass(order=True)
2370 class C:
2371 x: int
2372 def __gt__(self):
2373 pass
2374
2375 with self.assertRaisesRegex(TypeError,
2376 'Cannot overwrite attribute __ge__'
2377 '.*using functools.total_ordering'):
2378 @dataclass(order=True)
2379 class C:
2380 x: int
2381 def __ge__(self):
2382 pass
2383
2384class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002385 def test_unsafe_hash(self):
2386 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002387 class C:
2388 x: int
2389 y: str
2390 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2391
Eric V. Smithea8fc522018-01-27 19:07:40 -05002392 def test_hash_rules(self):
2393 def non_bool(value):
2394 # Map to something else that's True, but not a bool.
2395 if value is None:
2396 return None
2397 if value:
2398 return (3,)
2399 return 0
2400
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002401 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2402 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2403 frozen=frozen):
2404 if result != 'exception':
2405 if with_hash:
2406 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2407 class C:
2408 def __hash__(self):
2409 return 0
2410 else:
2411 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2412 class C:
2413 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002414
2415 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002416 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002417 # __hash__ contains the function we generated.
2418 self.assertIn('__hash__', C.__dict__)
2419 self.assertIsNotNone(C.__dict__['__hash__'])
2420
Eric V. Smithea8fc522018-01-27 19:07:40 -05002421 elif result == '':
2422 # __hash__ is not present in our class.
2423 if not with_hash:
2424 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002425
Eric V. Smithea8fc522018-01-27 19:07:40 -05002426 elif result == 'none':
2427 # __hash__ is set to None.
2428 self.assertIn('__hash__', C.__dict__)
2429 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002430
2431 elif result == 'exception':
2432 # Creating the class should cause an exception.
2433 # This only happens with with_hash==True.
2434 assert(with_hash)
2435 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2436 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2437 class C:
2438 def __hash__(self):
2439 return 0
2440
Eric V. Smithea8fc522018-01-27 19:07:40 -05002441 else:
2442 assert False, f'unknown result {result!r}'
2443
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002444 # There are 8 cases of:
2445 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002446 # eq=True/False
2447 # frozen=True/False
2448 # And for each of these, a different result if
2449 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002450 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2451 (False, False, False, '', ''),
2452 (False, False, True, '', ''),
2453 (False, True, False, 'none', ''),
2454 (False, True, True, 'fn', ''),
2455 (True, False, False, 'fn', 'exception'),
2456 (True, False, True, 'fn', 'exception'),
2457 (True, True, False, 'fn', 'exception'),
2458 (True, True, True, 'fn', 'exception'),
2459 ], 1):
2460 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2461 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002462
2463 # Test non-bool truth values, too. This is just to
2464 # make sure the data-driven table in the decorator
2465 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002466 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2467 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002468
2469
2470 def test_eq_only(self):
2471 # If a class defines __eq__, __hash__ is automatically added
2472 # and set to None. This is normal Python behavior, not
2473 # related to dataclasses. Make sure we don't interfere with
2474 # that (see bpo=32546).
2475
2476 @dataclass
2477 class C:
2478 i: int
2479 def __eq__(self, other):
2480 return self.i == other.i
2481 self.assertEqual(C(1), C(1))
2482 self.assertNotEqual(C(1), C(4))
2483
2484 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002485 # unsafe_hash=True.
2486 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002487 class C:
2488 i: int
2489 def __eq__(self, other):
2490 return self.i == other.i
2491 self.assertEqual(C(1), C(1.0))
2492 self.assertEqual(hash(C(1)), hash(C(1.0)))
2493
2494 # And check that the classes __eq__ is being used, despite
2495 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002496 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002497 class C:
2498 i: int
2499 def __eq__(self, other):
2500 return self.i == 3 and self.i == other.i
2501 self.assertEqual(C(3), C(3))
2502 self.assertNotEqual(C(1), C(1))
2503 self.assertEqual(hash(C(1)), hash(C(1.0)))
2504
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002505 def test_0_field_hash(self):
2506 @dataclass(frozen=True)
2507 class C:
2508 pass
2509 self.assertEqual(hash(C()), hash(()))
2510
2511 @dataclass(unsafe_hash=True)
2512 class C:
2513 pass
2514 self.assertEqual(hash(C()), hash(()))
2515
2516 def test_1_field_hash(self):
2517 @dataclass(frozen=True)
2518 class C:
2519 x: int
2520 self.assertEqual(hash(C(4)), hash((4,)))
2521 self.assertEqual(hash(C(42)), hash((42,)))
2522
2523 @dataclass(unsafe_hash=True)
2524 class C:
2525 x: int
2526 self.assertEqual(hash(C(4)), hash((4,)))
2527 self.assertEqual(hash(C(42)), hash((42,)))
2528
Eric V. Smith718070d2018-02-23 13:01:31 -05002529 def test_hash_no_args(self):
2530 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002531 # make sure that if the @dataclass parameter name is changed
2532 # or the non-default hashing behavior changes, the default
2533 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002534
2535 class Base:
2536 def __hash__(self):
2537 return 301
2538
2539 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002540 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002541 for frozen, eq, base, expected in [
2542 (None, None, object, 'unhashable'),
2543 (None, None, Base, 'unhashable'),
2544 (None, False, object, 'object'),
2545 (None, False, Base, 'base'),
2546 (None, True, object, 'unhashable'),
2547 (None, True, Base, 'unhashable'),
2548 (False, None, object, 'unhashable'),
2549 (False, None, Base, 'unhashable'),
2550 (False, False, object, 'object'),
2551 (False, False, Base, 'base'),
2552 (False, True, object, 'unhashable'),
2553 (False, True, Base, 'unhashable'),
2554 (True, None, object, 'tuple'),
2555 (True, None, Base, 'tuple'),
2556 (True, False, object, 'object'),
2557 (True, False, Base, 'base'),
2558 (True, True, object, 'tuple'),
2559 (True, True, Base, 'tuple'),
2560 ]:
2561
2562 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2563 # First, create the class.
2564 if frozen is None and eq is None:
2565 @dataclass
2566 class C(base):
2567 i: int
2568 elif frozen is None:
2569 @dataclass(eq=eq)
2570 class C(base):
2571 i: int
2572 elif eq is None:
2573 @dataclass(frozen=frozen)
2574 class C(base):
2575 i: int
2576 else:
2577 @dataclass(frozen=frozen, eq=eq)
2578 class C(base):
2579 i: int
2580
2581 # Now, make sure it hashes as expected.
2582 if expected == 'unhashable':
2583 c = C(10)
2584 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2585 hash(c)
2586
2587 elif expected == 'base':
2588 self.assertEqual(hash(C(10)), 301)
2589
2590 elif expected == 'object':
2591 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002592 # hash isn't based on id(), so calling hash()
2593 # won't tell us much. So, just check the
2594 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002595 self.assertIs(C.__hash__, object.__hash__)
2596
2597 elif expected == 'tuple':
2598 self.assertEqual(hash(C(42)), hash((42,)))
2599
2600 else:
2601 assert False, f'unknown value for expected={expected!r}'
2602
Eric V. Smithea8fc522018-01-27 19:07:40 -05002603
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002604class TestFrozen(unittest.TestCase):
2605 def test_frozen(self):
2606 @dataclass(frozen=True)
2607 class C:
2608 i: int
2609
2610 c = C(10)
2611 self.assertEqual(c.i, 10)
2612 with self.assertRaises(FrozenInstanceError):
2613 c.i = 5
2614 self.assertEqual(c.i, 10)
2615
2616 def test_inherit(self):
2617 @dataclass(frozen=True)
2618 class C:
2619 i: int
2620
2621 @dataclass(frozen=True)
2622 class D(C):
2623 j: int
2624
2625 d = D(0, 10)
2626 with self.assertRaises(FrozenInstanceError):
2627 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002628 with self.assertRaises(FrozenInstanceError):
2629 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002630 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002631 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002632
Iurii Kemaev376ffc62021-04-06 06:14:01 +01002633 def test_inherit_nonfrozen_from_empty_frozen(self):
2634 @dataclass(frozen=True)
2635 class C:
2636 pass
2637
2638 with self.assertRaisesRegex(TypeError,
2639 'cannot inherit non-frozen dataclass from a frozen one'):
2640 @dataclass
2641 class D(C):
2642 j: int
2643
2644 def test_inherit_nonfrozen_from_empty(self):
2645 @dataclass
2646 class C:
2647 pass
2648
2649 @dataclass
2650 class D(C):
2651 j: int
2652
2653 d = D(3)
2654 self.assertEqual(d.j, 3)
2655 self.assertIsInstance(d, C)
2656
Eric V. Smithf199bc62018-03-18 20:40:34 -04002657 # Test both ways: with an intermediate normal (non-dataclass)
2658 # class and without an intermediate class.
2659 def test_inherit_nonfrozen_from_frozen(self):
2660 for intermediate_class in [True, False]:
2661 with self.subTest(intermediate_class=intermediate_class):
2662 @dataclass(frozen=True)
2663 class C:
2664 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002665
Eric V. Smithf199bc62018-03-18 20:40:34 -04002666 if intermediate_class:
2667 class I(C): pass
2668 else:
2669 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002670
Eric V. Smithf199bc62018-03-18 20:40:34 -04002671 with self.assertRaisesRegex(TypeError,
2672 'cannot inherit non-frozen dataclass from a frozen one'):
2673 @dataclass
2674 class D(I):
2675 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002676
Eric V. Smithf199bc62018-03-18 20:40:34 -04002677 def test_inherit_frozen_from_nonfrozen(self):
2678 for intermediate_class in [True, False]:
2679 with self.subTest(intermediate_class=intermediate_class):
2680 @dataclass
2681 class C:
2682 i: int
2683
2684 if intermediate_class:
2685 class I(C): pass
2686 else:
2687 I = C
2688
2689 with self.assertRaisesRegex(TypeError,
2690 'cannot inherit frozen dataclass from a non-frozen one'):
2691 @dataclass(frozen=True)
2692 class D(I):
2693 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002694
2695 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002696 for intermediate_class in [True, False]:
2697 with self.subTest(intermediate_class=intermediate_class):
2698 class C:
2699 pass
2700
2701 if intermediate_class:
2702 class I(C): pass
2703 else:
2704 I = C
2705
2706 @dataclass(frozen=True)
2707 class D(I):
2708 i: int
2709
2710 d = D(10)
2711 with self.assertRaises(FrozenInstanceError):
2712 d.i = 5
2713
2714 def test_non_frozen_normal_derived(self):
2715 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002716
2717 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002718 class D:
2719 x: int
2720 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002721
Eric V. Smithf199bc62018-03-18 20:40:34 -04002722 class S(D):
2723 pass
2724
2725 s = S(3)
2726 self.assertEqual(s.x, 3)
2727 self.assertEqual(s.y, 10)
2728 s.cached = True
2729
2730 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002731 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002732 s.x = 5
2733 with self.assertRaises(FrozenInstanceError):
2734 s.y = 5
2735 self.assertEqual(s.x, 3)
2736 self.assertEqual(s.y, 10)
2737 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002738
Eric V. Smith74940912018-04-05 06:50:18 -04002739 def test_overwriting_frozen(self):
2740 # frozen uses __setattr__ and __delattr__.
2741 with self.assertRaisesRegex(TypeError,
2742 'Cannot overwrite attribute __setattr__'):
2743 @dataclass(frozen=True)
2744 class C:
2745 x: int
2746 def __setattr__(self):
2747 pass
2748
2749 with self.assertRaisesRegex(TypeError,
2750 'Cannot overwrite attribute __delattr__'):
2751 @dataclass(frozen=True)
2752 class C:
2753 x: int
2754 def __delattr__(self):
2755 pass
2756
2757 @dataclass(frozen=False)
2758 class C:
2759 x: int
2760 def __setattr__(self, name, value):
2761 self.__dict__['x'] = value * 2
2762 self.assertEqual(C(10).x, 20)
2763
2764 def test_frozen_hash(self):
2765 @dataclass(frozen=True)
2766 class C:
2767 x: Any
2768
2769 # If x is immutable, we can compute the hash. No exception is
2770 # raised.
2771 hash(C(3))
2772
2773 # If x is mutable, computing the hash is an error.
2774 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2775 hash(C({}))
2776
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002777
Eric V. Smith7389fd92018-03-19 21:07:51 -04002778class TestSlots(unittest.TestCase):
2779 def test_simple(self):
2780 @dataclass
2781 class C:
2782 __slots__ = ('x',)
2783 x: Any
2784
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002785 # There was a bug where a variable in a slot was assumed to
2786 # also have a default value (of type
2787 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002788 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002789 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002790 C()
2791
2792 # We can create an instance, and assign to x.
2793 c = C(10)
2794 self.assertEqual(c.x, 10)
2795 c.x = 5
2796 self.assertEqual(c.x, 5)
2797
2798 # We can't assign to anything else.
2799 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2800 c.y = 5
2801
2802 def test_derived_added_field(self):
2803 # See bpo-33100.
2804 @dataclass
2805 class Base:
2806 __slots__ = ('x',)
2807 x: Any
2808
2809 @dataclass
2810 class Derived(Base):
2811 x: int
2812 y: int
2813
2814 d = Derived(1, 2)
2815 self.assertEqual((d.x, d.y), (1, 2))
2816
2817 # We can add a new field to the derived instance.
2818 d.z = 10
2819
Yurii Karabasc2419912021-05-01 05:14:30 +03002820 def test_generated_slots(self):
2821 @dataclass(slots=True)
2822 class C:
2823 x: int
2824 y: int
2825
2826 c = C(1, 2)
2827 self.assertEqual((c.x, c.y), (1, 2))
2828
2829 c.x = 3
2830 c.y = 4
2831 self.assertEqual((c.x, c.y), (3, 4))
2832
2833 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
2834 c.z = 5
2835
2836 def test_add_slots_when_slots_exists(self):
2837 with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
2838 @dataclass(slots=True)
2839 class C:
2840 __slots__ = ('x',)
2841 x: int
2842
2843 def test_generated_slots_value(self):
2844 @dataclass(slots=True)
2845 class Base:
2846 x: int
2847
2848 self.assertEqual(Base.__slots__, ('x',))
2849
2850 @dataclass(slots=True)
2851 class Delivered(Base):
2852 y: int
2853
2854 self.assertEqual(Delivered.__slots__, ('x', 'y'))
2855
2856 @dataclass
2857 class AnotherDelivered(Base):
2858 z: int
2859
2860 self.assertTrue('__slots__' not in AnotherDelivered.__dict__)
2861
2862 def test_returns_new_class(self):
2863 class A:
2864 x: int
2865
2866 B = dataclass(A, slots=True)
2867 self.assertIsNot(A, B)
2868
2869 self.assertFalse(hasattr(A, "__slots__"))
2870 self.assertTrue(hasattr(B, "__slots__"))
2871
Eric V. Smith823fbf42021-05-01 13:27:30 -04002872 # Can't be local to test_frozen_pickle.
2873 @dataclass(frozen=True, slots=True)
2874 class FrozenSlotsClass:
2875 foo: str
2876 bar: int
2877
Miss Islington (bot)36971fd2021-10-24 06:29:37 -07002878 @dataclass(frozen=True)
2879 class FrozenWithoutSlotsClass:
2880 foo: str
2881 bar: int
2882
Eric V. Smith823fbf42021-05-01 13:27:30 -04002883 def test_frozen_pickle(self):
2884 # bpo-43999
2885
Miss Islington (bot)36971fd2021-10-24 06:29:37 -07002886 self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar"))
2887 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2888 with self.subTest(proto=proto):
2889 obj = self.FrozenSlotsClass("a", 1)
2890 p = pickle.loads(pickle.dumps(obj, protocol=proto))
2891 self.assertIsNot(obj, p)
2892 self.assertEqual(obj, p)
Eric V. Smith823fbf42021-05-01 13:27:30 -04002893
Miss Islington (bot)36971fd2021-10-24 06:29:37 -07002894 obj = self.FrozenWithoutSlotsClass("a", 1)
2895 p = pickle.loads(pickle.dumps(obj, protocol=proto))
2896 self.assertIsNot(obj, p)
2897 self.assertEqual(obj, p)
Yurii Karabasc2419912021-05-01 05:14:30 +03002898
Miss Islington (bot)10343bd2021-11-22 05:47:41 -08002899 def test_slots_with_default_no_init(self):
2900 # Originally reported in bpo-44649.
2901 @dataclass(slots=True)
2902 class A:
2903 a: str
2904 b: str = field(default='b', init=False)
2905
2906 obj = A("a")
2907 self.assertEqual(obj.a, 'a')
2908 self.assertEqual(obj.b, 'b')
2909
2910 def test_slots_with_default_factory_no_init(self):
2911 # Originally reported in bpo-44649.
2912 @dataclass(slots=True)
2913 class A:
2914 a: str
2915 b: str = field(default_factory=lambda:'b', init=False)
2916
2917 obj = A("a")
2918 self.assertEqual(obj.a, 'a')
2919 self.assertEqual(obj.b, 'b')
2920
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002921class TestDescriptors(unittest.TestCase):
2922 def test_set_name(self):
2923 # See bpo-33141.
2924
2925 # Create a descriptor.
2926 class D:
2927 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002928 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002929 def __get__(self, instance, owner):
2930 if instance is not None:
2931 return 1
2932 return self
2933
2934 # This is the case of just normal descriptor behavior, no
2935 # dataclass code is involved in initializing the descriptor.
2936 @dataclass
2937 class C:
2938 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002939 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002940
2941 # Now test with a default value and init=False, which is the
2942 # only time this is really meaningful. If not using
2943 # init=False, then the descriptor will be overwritten, anyway.
2944 @dataclass
2945 class C:
2946 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002947 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002948 self.assertEqual(C().c, 1)
2949
2950 def test_non_descriptor(self):
2951 # PEP 487 says __set_name__ should work on non-descriptors.
2952 # Create a descriptor.
2953
2954 class D:
2955 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002956 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002957
2958 @dataclass
2959 class C:
2960 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002961 self.assertEqual(C.c.name, 'cx')
2962
2963 def test_lookup_on_instance(self):
2964 # See bpo-33175.
2965 class D:
2966 pass
2967
2968 d = D()
2969 # Create an attribute on the instance, not type.
2970 d.__set_name__ = Mock()
2971
2972 # Make sure d.__set_name__ is not called.
2973 @dataclass
2974 class C:
2975 i: int=field(default=d, init=False)
2976
2977 self.assertEqual(d.__set_name__.call_count, 0)
2978
2979 def test_lookup_on_class(self):
2980 # See bpo-33175.
2981 class D:
2982 pass
2983 D.__set_name__ = Mock()
2984
2985 # Make sure D.__set_name__ is called.
2986 @dataclass
2987 class C:
2988 i: int=field(default=D(), init=False)
2989
2990 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002991
Eric V. Smith7389fd92018-03-19 21:07:51 -04002992
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002993class TestStringAnnotations(unittest.TestCase):
2994 def test_classvar(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002995 # Some expressions recognized as ClassVar really aren't. But
2996 # if you're using string annotations, it's not an exact
2997 # science.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002998 # These tests assume that both "import typing" and "from
2999 # typing import *" have been run in this file.
3000 for typestr in ('ClassVar[int]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03003001 'ClassVar [int]',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003002 ' ClassVar [int]',
3003 'ClassVar',
3004 ' ClassVar ',
3005 'typing.ClassVar[int]',
3006 'typing.ClassVar[str]',
3007 ' typing.ClassVar[str]',
3008 'typing .ClassVar[str]',
3009 'typing. ClassVar[str]',
3010 'typing.ClassVar [str]',
3011 'typing.ClassVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01003012
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003013 # Not syntactically valid, but these will
Pablo Galindob0544ba2021-04-21 12:41:19 +01003014 # be treated as ClassVars.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003015 'typing.ClassVar.[int]',
3016 'typing.ClassVar+',
3017 ):
3018 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003019 @dataclass
3020 class C:
3021 x: typestr
3022
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003023 # x is a ClassVar, so C() takes no args.
3024 C()
3025
3026 # And it won't appear in the class's dict because it doesn't
3027 # have a default.
3028 self.assertNotIn('x', C.__dict__)
3029
3030 def test_isnt_classvar(self):
3031 for typestr in ('CV',
3032 't.ClassVar',
3033 't.ClassVar[int]',
3034 'typing..ClassVar[int]',
3035 'Classvar',
3036 'Classvar[int]',
3037 'typing.ClassVarx[int]',
3038 'typong.ClassVar[int]',
3039 'dataclasses.ClassVar[int]',
3040 'typingxClassVar[str]',
3041 ):
3042 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003043 @dataclass
3044 class C:
3045 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003046
3047 # x is not a ClassVar, so C() takes one arg.
3048 self.assertEqual(C(10).x, 10)
3049
3050 def test_initvar(self):
3051 # These tests assume that both "import dataclasses" and "from
3052 # dataclasses import *" have been run in this file.
3053 for typestr in ('InitVar[int]',
3054 'InitVar [int]'
3055 ' InitVar [int]',
3056 'InitVar',
3057 ' InitVar ',
3058 'dataclasses.InitVar[int]',
3059 'dataclasses.InitVar[str]',
3060 ' dataclasses.InitVar[str]',
3061 'dataclasses .InitVar[str]',
3062 'dataclasses. InitVar[str]',
3063 'dataclasses.InitVar [str]',
3064 'dataclasses.InitVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01003065
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003066 # Not syntactically valid, but these will
3067 # be treated as InitVars.
3068 'dataclasses.InitVar.[int]',
3069 'dataclasses.InitVar+',
3070 ):
3071 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003072 @dataclass
3073 class C:
3074 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003075
3076 # x is an InitVar, so doesn't create a member.
3077 with self.assertRaisesRegex(AttributeError,
3078 "object has no attribute 'x'"):
3079 C(1).x
3080
3081 def test_isnt_initvar(self):
3082 for typestr in ('IV',
3083 'dc.InitVar',
3084 'xdataclasses.xInitVar',
3085 'typing.xInitVar[int]',
3086 ):
3087 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003088 @dataclass
3089 class C:
3090 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003091
3092 # x is not an InitVar, so there will be a member x.
3093 self.assertEqual(C(10).x, 10)
3094
3095 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03003096 from test import dataclass_module_1
Pablo Galindob0544ba2021-04-21 12:41:19 +01003097 from test import dataclass_module_1_str
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03003098 from test import dataclass_module_2
Pablo Galindob0544ba2021-04-21 12:41:19 +01003099 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003100
Pablo Galindob0544ba2021-04-21 12:41:19 +01003101 for m in (dataclass_module_1, dataclass_module_1_str,
3102 dataclass_module_2, dataclass_module_2_str,
3103 ):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003104 with self.subTest(m=m):
3105 # There's a difference in how the ClassVars are
3106 # interpreted when using string annotations or
3107 # not. See the imported modules for details.
Pablo Galindob0544ba2021-04-21 12:41:19 +01003108 if m.USING_STRINGS:
3109 c = m.CV(10)
3110 else:
3111 c = m.CV()
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003112 self.assertEqual(c.cv0, 20)
3113
3114
3115 # There's a difference in how the InitVars are
3116 # interpreted when using string annotations or
3117 # not. See the imported modules for details.
3118 c = m.IV(0, 1, 2, 3, 4)
3119
3120 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
3121 with self.subTest(field_name=field_name):
3122 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
3123 # Since field_name is an InitVar, it's
3124 # not an instance field.
3125 getattr(c, field_name)
3126
Pablo Galindob0544ba2021-04-21 12:41:19 +01003127 if m.USING_STRINGS:
3128 # iv4 is interpreted as a normal field.
3129 self.assertIn('not_iv4', c.__dict__)
3130 self.assertEqual(c.not_iv4, 4)
3131 else:
3132 # iv4 is interpreted as an InitVar, so it
3133 # won't exist on the instance.
3134 self.assertNotIn('not_iv4', c.__dict__)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003135
Yury Selivanovd219cc42019-12-09 09:54:20 -05003136 def test_text_annotations(self):
3137 from test import dataclass_textanno
3138
3139 self.assertEqual(
3140 get_type_hints(dataclass_textanno.Bar),
3141 {'foo': dataclass_textanno.Foo})
3142 self.assertEqual(
3143 get_type_hints(dataclass_textanno.Bar.__init__),
3144 {'foo': dataclass_textanno.Foo,
3145 'return': type(None)})
3146
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003147
Eric V. Smith4e812962018-05-16 11:31:29 -04003148class TestMakeDataclass(unittest.TestCase):
3149 def test_simple(self):
3150 C = make_dataclass('C',
3151 [('x', int),
3152 ('y', int, field(default=5))],
3153 namespace={'add_one': lambda self: self.x + 1})
3154 c = C(10)
3155 self.assertEqual((c.x, c.y), (10, 5))
3156 self.assertEqual(c.add_one(), 11)
3157
3158
3159 def test_no_mutate_namespace(self):
3160 # Make sure a provided namespace isn't mutated.
3161 ns = {}
3162 C = make_dataclass('C',
3163 [('x', int),
3164 ('y', int, field(default=5))],
3165 namespace=ns)
3166 self.assertEqual(ns, {})
3167
3168 def test_base(self):
3169 class Base1:
3170 pass
3171 class Base2:
3172 pass
3173 C = make_dataclass('C',
3174 [('x', int)],
3175 bases=(Base1, Base2))
3176 c = C(2)
3177 self.assertIsInstance(c, C)
3178 self.assertIsInstance(c, Base1)
3179 self.assertIsInstance(c, Base2)
3180
3181 def test_base_dataclass(self):
3182 @dataclass
3183 class Base1:
3184 x: int
3185 class Base2:
3186 pass
3187 C = make_dataclass('C',
3188 [('y', int)],
3189 bases=(Base1, Base2))
3190 with self.assertRaisesRegex(TypeError, 'required positional'):
3191 c = C(2)
3192 c = C(1, 2)
3193 self.assertIsInstance(c, C)
3194 self.assertIsInstance(c, Base1)
3195 self.assertIsInstance(c, Base2)
3196
3197 self.assertEqual((c.x, c.y), (1, 2))
3198
3199 def test_init_var(self):
3200 def post_init(self, y):
3201 self.x *= y
3202
3203 C = make_dataclass('C',
3204 [('x', int),
3205 ('y', InitVar[int]),
3206 ],
3207 namespace={'__post_init__': post_init},
3208 )
3209 c = C(2, 3)
3210 self.assertEqual(vars(c), {'x': 6})
3211 self.assertEqual(len(fields(c)), 1)
3212
3213 def test_class_var(self):
3214 C = make_dataclass('C',
3215 [('x', int),
3216 ('y', ClassVar[int], 10),
3217 ('z', ClassVar[int], field(default=20)),
3218 ])
3219 c = C(1)
3220 self.assertEqual(vars(c), {'x': 1})
3221 self.assertEqual(len(fields(c)), 1)
3222 self.assertEqual(C.y, 10)
3223 self.assertEqual(C.z, 20)
3224
3225 def test_other_params(self):
3226 C = make_dataclass('C',
3227 [('x', int),
3228 ('y', ClassVar[int], 10),
3229 ('z', ClassVar[int], field(default=20)),
3230 ],
3231 init=False)
3232 # Make sure we have a repr, but no init.
3233 self.assertNotIn('__init__', vars(C))
3234 self.assertIn('__repr__', vars(C))
3235
3236 # Make sure random other params don't work.
3237 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3238 C = make_dataclass('C',
3239 [],
3240 xxinit=False)
3241
3242 def test_no_types(self):
3243 C = make_dataclass('Point', ['x', 'y', 'z'])
3244 c = C(1, 2, 3)
3245 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3246 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3247 'y': 'typing.Any',
3248 'z': 'typing.Any'})
3249
3250 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3251 c = C(1, 2, 3)
3252 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3253 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3254 'y': int,
3255 'z': 'typing.Any'})
3256
3257 def test_invalid_type_specification(self):
3258 for bad_field in [(),
3259 (1, 2, 3, 4),
3260 ]:
3261 with self.subTest(bad_field=bad_field):
3262 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3263 make_dataclass('C', ['a', bad_field])
3264
3265 # And test for things with no len().
3266 for bad_field in [float,
3267 lambda x:x,
3268 ]:
3269 with self.subTest(bad_field=bad_field):
3270 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3271 make_dataclass('C', ['a', bad_field])
3272
3273 def test_duplicate_field_names(self):
3274 for field in ['a', 'ab']:
3275 with self.subTest(field=field):
3276 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3277 make_dataclass('C', [field, 'a', field])
3278
3279 def test_keyword_field_names(self):
3280 for field in ['for', 'async', 'await', 'as']:
3281 with self.subTest(field=field):
3282 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3283 make_dataclass('C', ['a', field])
3284 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3285 make_dataclass('C', [field])
3286 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3287 make_dataclass('C', [field, 'a'])
3288
3289 def test_non_identifier_field_names(self):
3290 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3291 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003292 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003293 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003294 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003295 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003296 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003297 make_dataclass('C', [field, 'a'])
3298
3299 def test_underscore_field_names(self):
3300 # Unlike namedtuple, it's okay if dataclass field names have
3301 # an underscore.
3302 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3303
3304 def test_funny_class_names_names(self):
3305 # No reason to prevent weird class names, since
3306 # types.new_class allows them.
3307 for classname in ['()', 'x,y', '*', '2@3', '']:
3308 with self.subTest(classname=classname):
3309 C = make_dataclass(classname, ['a', 'b'])
3310 self.assertEqual(C.__name__, classname)
3311
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003312class TestReplace(unittest.TestCase):
3313 def test(self):
3314 @dataclass(frozen=True)
3315 class C:
3316 x: int
3317 y: int
3318
3319 c = C(1, 2)
3320 c1 = replace(c, x=3)
3321 self.assertEqual(c1.x, 3)
3322 self.assertEqual(c1.y, 2)
3323
3324 def test_frozen(self):
3325 @dataclass(frozen=True)
3326 class C:
3327 x: int
3328 y: int
3329 z: int = field(init=False, default=10)
3330 t: int = field(init=False, default=100)
3331
3332 c = C(1, 2)
3333 c1 = replace(c, x=3)
3334 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3335 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3336
3337
3338 with self.assertRaisesRegex(ValueError, 'init=False'):
3339 replace(c, x=3, z=20, t=50)
3340 with self.assertRaisesRegex(ValueError, 'init=False'):
3341 replace(c, z=20)
3342 replace(c, x=3, z=20, t=50)
3343
3344 # Make sure the result is still frozen.
3345 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3346 c1.x = 3
3347
3348 # Make sure we can't replace an attribute that doesn't exist,
3349 # if we're also replacing one that does exist. Test this
3350 # here, because setting attributes on frozen instances is
3351 # handled slightly differently from non-frozen ones.
3352 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3353 "keyword argument 'a'"):
3354 c1 = replace(c, x=20, a=5)
3355
3356 def test_invalid_field_name(self):
3357 @dataclass(frozen=True)
3358 class C:
3359 x: int
3360 y: int
3361
3362 c = C(1, 2)
3363 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3364 "keyword argument 'z'"):
3365 c1 = replace(c, z=3)
3366
3367 def test_invalid_object(self):
3368 @dataclass(frozen=True)
3369 class C:
3370 x: int
3371 y: int
3372
3373 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3374 replace(C, x=3)
3375
3376 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3377 replace(0, x=3)
3378
3379 def test_no_init(self):
3380 @dataclass
3381 class C:
3382 x: int
3383 y: int = field(init=False, default=10)
3384
3385 c = C(1)
3386 c.y = 20
3387
3388 # Make sure y gets the default value.
3389 c1 = replace(c, x=5)
3390 self.assertEqual((c1.x, c1.y), (5, 10))
3391
3392 # Trying to replace y is an error.
3393 with self.assertRaisesRegex(ValueError, 'init=False'):
3394 replace(c, x=2, y=30)
3395
3396 with self.assertRaisesRegex(ValueError, 'init=False'):
3397 replace(c, y=30)
3398
3399 def test_classvar(self):
3400 @dataclass
3401 class C:
3402 x: int
3403 y: ClassVar[int] = 1000
3404
3405 c = C(1)
3406 d = C(2)
3407
3408 self.assertIs(c.y, d.y)
3409 self.assertEqual(c.y, 1000)
3410
3411 # Trying to replace y is an error: can't replace ClassVars.
3412 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3413 "unexpected keyword argument 'y'"):
3414 replace(c, y=30)
3415
3416 replace(c, x=5)
3417
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003418 def test_initvar_is_specified(self):
3419 @dataclass
3420 class C:
3421 x: int
3422 y: InitVar[int]
3423
3424 def __post_init__(self, y):
3425 self.x *= y
3426
3427 c = C(1, 10)
3428 self.assertEqual(c.x, 10)
3429 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3430 "specified with replace()"):
3431 replace(c, x=3)
3432 c = replace(c, x=3, y=5)
3433 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303434
Zackery Spytz75220672021-04-05 13:41:01 -06003435 def test_initvar_with_default_value(self):
3436 @dataclass
3437 class C:
3438 x: int
3439 y: InitVar[int] = None
3440 z: InitVar[int] = 42
3441
3442 def __post_init__(self, y, z):
3443 if y is not None:
3444 self.x += y
3445 if z is not None:
3446 self.x += z
3447
3448 c = C(x=1, y=10, z=1)
3449 self.assertEqual(replace(c), C(x=12))
3450 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
3451 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
3452
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303453 def test_recursive_repr(self):
3454 @dataclass
3455 class C:
3456 f: "C"
3457
3458 c = C(None)
3459 c.f = c
3460 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3461
3462 def test_recursive_repr_two_attrs(self):
3463 @dataclass
3464 class C:
3465 f: "C"
3466 g: "C"
3467
3468 c = C(None, None)
3469 c.f = c
3470 c.g = c
3471 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3472 ".<locals>.C(f=..., g=...)")
3473
3474 def test_recursive_repr_indirection(self):
3475 @dataclass
3476 class C:
3477 f: "D"
3478
3479 @dataclass
3480 class D:
3481 f: "C"
3482
3483 c = C(None)
3484 d = D(None)
3485 c.f = d
3486 d.f = c
3487 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3488 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3489 ".<locals>.D(f=...))")
3490
3491 def test_recursive_repr_indirection_two(self):
3492 @dataclass
3493 class C:
3494 f: "D"
3495
3496 @dataclass
3497 class D:
3498 f: "E"
3499
3500 @dataclass
3501 class E:
3502 f: "C"
3503
3504 c = C(None)
3505 d = D(None)
3506 e = E(None)
3507 c.f = d
3508 d.f = e
3509 e.f = c
3510 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3511 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3512 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3513 ".<locals>.E(f=...)))")
3514
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303515 def test_recursive_repr_misc_attrs(self):
3516 @dataclass
3517 class C:
3518 f: "C"
3519 g: int
3520
3521 c = C(None, 1)
3522 c.f = c
3523 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3524 ".<locals>.C(f=..., g=1)")
3525
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003526 ## def test_initvar(self):
3527 ## @dataclass
3528 ## class C:
3529 ## x: int
3530 ## y: InitVar[int]
3531
3532 ## c = C(1, 10)
3533 ## d = C(2, 20)
3534
3535 ## # In our case, replacing an InitVar is a no-op
3536 ## self.assertEqual(c, replace(c, y=5))
3537
3538 ## replace(c, x=5)
3539
Ben Avrahamibef7d292020-10-06 20:40:50 +03003540class TestAbstract(unittest.TestCase):
3541 def test_abc_implementation(self):
3542 class Ordered(abc.ABC):
3543 @abc.abstractmethod
3544 def __lt__(self, other):
3545 pass
3546
3547 @abc.abstractmethod
3548 def __le__(self, other):
3549 pass
3550
3551 @dataclass(order=True)
3552 class Date(Ordered):
3553 year: int
3554 month: 'Month'
3555 day: 'int'
3556
3557 self.assertFalse(inspect.isabstract(Date))
3558 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3559
3560 def test_maintain_abc(self):
3561 class A(abc.ABC):
3562 @abc.abstractmethod
3563 def foo(self):
3564 pass
3565
3566 @dataclass
3567 class Date(A):
3568 year: int
3569 month: 'Month'
3570 day: 'int'
3571
3572 self.assertTrue(inspect.isabstract(Date))
3573 msg = 'class Date with abstract method foo'
3574 self.assertRaisesRegex(TypeError, msg, Date)
3575
Eric V. Smith4e812962018-05-16 11:31:29 -04003576
Brandt Bucher145bf262021-02-26 14:51:55 -08003577class TestMatchArgs(unittest.TestCase):
3578 def test_match_args(self):
3579 @dataclass
3580 class C:
3581 a: int
3582 self.assertEqual(C(42).__match_args__, ('a',))
3583
3584 def test_explicit_match_args(self):
Brandt Bucherf84d5a12021-04-05 19:17:08 -07003585 ma = ()
Brandt Bucher145bf262021-02-26 14:51:55 -08003586 @dataclass
3587 class C:
3588 a: int
3589 __match_args__ = ma
3590 self.assertIs(C(42).__match_args__, ma)
3591
Brandt Bucherd92c59f2021-04-08 12:54:34 -07003592 def test_bpo_43764(self):
3593 @dataclass(repr=False, eq=False, init=False)
3594 class X:
3595 a: int
3596 b: int
3597 c: int
3598 self.assertEqual(X.__match_args__, ("a", "b", "c"))
3599
Eric V. Smith750f4842021-04-10 21:28:42 -04003600 def test_match_args_argument(self):
3601 @dataclass(match_args=False)
3602 class X:
3603 a: int
3604 self.assertNotIn('__match_args__', X.__dict__)
3605
3606 @dataclass(match_args=False)
3607 class Y:
3608 a: int
3609 __match_args__ = ('b',)
3610 self.assertEqual(Y.__match_args__, ('b',))
3611
3612 @dataclass(match_args=False)
3613 class Z(Y):
3614 z: int
3615 self.assertEqual(Z.__match_args__, ('b',))
3616
3617 # Ensure parent dataclass __match_args__ is seen, if child class
3618 # specifies match_args=False.
3619 @dataclass
3620 class A:
3621 a: int
3622 z: int
3623 @dataclass(match_args=False)
3624 class B(A):
3625 b: int
3626 self.assertEqual(B.__match_args__, ('a', 'z'))
3627
3628 def test_make_dataclasses(self):
3629 C = make_dataclass('C', [('x', int), ('y', int)])
3630 self.assertEqual(C.__match_args__, ('x', 'y'))
3631
3632 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
3633 self.assertEqual(C.__match_args__, ('x', 'y'))
3634
3635 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
3636 self.assertNotIn('__match__args__', C.__dict__)
3637
3638 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
3639 self.assertEqual(C.__match_args__, ('z',))
3640
Brandt Bucher145bf262021-02-26 14:51:55 -08003641
Eric V. Smith94549ee2021-04-26 13:14:28 -04003642class TestKeywordArgs(unittest.TestCase):
Eric V. Smithc0280532021-04-25 20:42:39 -04003643 def test_no_classvar_kwarg(self):
3644 msg = 'field a is a ClassVar but specifies kw_only'
3645 with self.assertRaisesRegex(TypeError, msg):
3646 @dataclass
3647 class A:
3648 a: ClassVar[int] = field(kw_only=True)
3649
3650 with self.assertRaisesRegex(TypeError, msg):
3651 @dataclass
3652 class A:
3653 a: ClassVar[int] = field(kw_only=False)
3654
3655 with self.assertRaisesRegex(TypeError, msg):
3656 @dataclass(kw_only=True)
3657 class A:
3658 a: ClassVar[int] = field(kw_only=False)
3659
3660 def test_field_marked_as_kwonly(self):
3661 #######################
3662 # Using dataclass(kw_only=True)
3663 @dataclass(kw_only=True)
3664 class A:
3665 a: int
3666 self.assertTrue(fields(A)[0].kw_only)
3667
3668 @dataclass(kw_only=True)
3669 class A:
3670 a: int = field(kw_only=True)
3671 self.assertTrue(fields(A)[0].kw_only)
3672
3673 @dataclass(kw_only=True)
3674 class A:
3675 a: int = field(kw_only=False)
3676 self.assertFalse(fields(A)[0].kw_only)
3677
3678 #######################
3679 # Using dataclass(kw_only=False)
3680 @dataclass(kw_only=False)
3681 class A:
3682 a: int
3683 self.assertFalse(fields(A)[0].kw_only)
3684
3685 @dataclass(kw_only=False)
3686 class A:
3687 a: int = field(kw_only=True)
3688 self.assertTrue(fields(A)[0].kw_only)
3689
3690 @dataclass(kw_only=False)
3691 class A:
3692 a: int = field(kw_only=False)
3693 self.assertFalse(fields(A)[0].kw_only)
3694
3695 #######################
3696 # Not specifying dataclass(kw_only)
3697 @dataclass
3698 class A:
3699 a: int
3700 self.assertFalse(fields(A)[0].kw_only)
3701
3702 @dataclass
3703 class A:
3704 a: int = field(kw_only=True)
3705 self.assertTrue(fields(A)[0].kw_only)
3706
3707 @dataclass
3708 class A:
3709 a: int = field(kw_only=False)
3710 self.assertFalse(fields(A)[0].kw_only)
3711
3712 def test_match_args(self):
3713 # kw fields don't show up in __match_args__.
3714 @dataclass(kw_only=True)
3715 class C:
3716 a: int
3717 self.assertEqual(C(a=42).__match_args__, ())
3718
3719 @dataclass
3720 class C:
3721 a: int
3722 b: int = field(kw_only=True)
3723 self.assertEqual(C(42, b=10).__match_args__, ('a',))
3724
3725 def test_KW_ONLY(self):
3726 @dataclass
3727 class A:
3728 a: int
3729 _: KW_ONLY
3730 b: int
3731 c: int
3732 A(3, c=5, b=4)
3733 msg = "takes 2 positional arguments but 4 were given"
3734 with self.assertRaisesRegex(TypeError, msg):
3735 A(3, 4, 5)
3736
3737
3738 @dataclass(kw_only=True)
3739 class B:
3740 a: int
3741 _: KW_ONLY
3742 b: int
3743 c: int
3744 B(a=3, b=4, c=5)
3745 msg = "takes 1 positional argument but 4 were given"
3746 with self.assertRaisesRegex(TypeError, msg):
3747 B(3, 4, 5)
3748
Christian Clausscfca4a62021-10-07 17:49:47 +02003749 # Explicitly make a field that follows KW_ONLY be non-keyword-only.
Eric V. Smithc0280532021-04-25 20:42:39 -04003750 @dataclass
3751 class C:
3752 a: int
3753 _: KW_ONLY
3754 b: int
3755 c: int = field(kw_only=False)
3756 c = C(1, 2, b=3)
3757 self.assertEqual(c.a, 1)
3758 self.assertEqual(c.b, 3)
3759 self.assertEqual(c.c, 2)
3760 c = C(1, b=3, c=2)
3761 self.assertEqual(c.a, 1)
3762 self.assertEqual(c.b, 3)
3763 self.assertEqual(c.c, 2)
3764 c = C(1, b=3, c=2)
3765 self.assertEqual(c.a, 1)
3766 self.assertEqual(c.b, 3)
3767 self.assertEqual(c.c, 2)
3768 c = C(c=2, b=3, a=1)
3769 self.assertEqual(c.a, 1)
3770 self.assertEqual(c.b, 3)
3771 self.assertEqual(c.c, 2)
3772
Eric V. Smith99ad7422021-05-03 03:24:53 -04003773 def test_KW_ONLY_as_string(self):
3774 @dataclass
3775 class A:
3776 a: int
3777 _: 'dataclasses.KW_ONLY'
3778 b: int
3779 c: int
3780 A(3, c=5, b=4)
3781 msg = "takes 2 positional arguments but 4 were given"
3782 with self.assertRaisesRegex(TypeError, msg):
3783 A(3, 4, 5)
3784
3785 def test_KW_ONLY_twice(self):
3786 msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
3787
3788 with self.assertRaisesRegex(TypeError, msg):
3789 @dataclass
3790 class A:
3791 a: int
3792 X: KW_ONLY
3793 Y: KW_ONLY
3794 b: int
3795 c: int
3796
3797 with self.assertRaisesRegex(TypeError, msg):
3798 @dataclass
3799 class A:
3800 a: int
3801 X: KW_ONLY
3802 b: int
3803 Y: KW_ONLY
3804 c: int
3805
3806 with self.assertRaisesRegex(TypeError, msg):
3807 @dataclass
3808 class A:
3809 a: int
3810 X: KW_ONLY
3811 b: int
3812 c: int
3813 Y: KW_ONLY
3814
3815 # But this usage is okay, since it's not using KW_ONLY.
3816 @dataclass
3817 class A:
3818 a: int
3819 _: KW_ONLY
3820 b: int
3821 c: int = field(kw_only=True)
3822
3823 # And if inheriting, it's okay.
3824 @dataclass
3825 class A:
3826 a: int
3827 _: KW_ONLY
3828 b: int
3829 c: int
3830 @dataclass
3831 class B(A):
3832 _: KW_ONLY
3833 d: int
3834
3835 # Make sure the error is raised in a derived class.
3836 with self.assertRaisesRegex(TypeError, msg):
3837 @dataclass
3838 class A:
3839 a: int
3840 _: KW_ONLY
3841 b: int
3842 c: int
3843 @dataclass
3844 class B(A):
3845 X: KW_ONLY
3846 d: int
3847 Y: KW_ONLY
3848
3849
Eric V. Smithc0280532021-04-25 20:42:39 -04003850 def test_post_init(self):
3851 @dataclass
3852 class A:
3853 a: int
3854 _: KW_ONLY
3855 b: InitVar[int]
3856 c: int
3857 d: InitVar[int]
3858 def __post_init__(self, b, d):
3859 raise CustomError(f'{b=} {d=}')
3860 with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
3861 A(1, c=2, b=3, d=4)
3862
3863 @dataclass
3864 class B:
3865 a: int
3866 _: KW_ONLY
3867 b: InitVar[int]
3868 c: int
3869 d: InitVar[int]
3870 def __post_init__(self, b, d):
3871 self.a = b
3872 self.c = d
3873 b = B(1, c=2, b=3, d=4)
3874 self.assertEqual(asdict(b), {'a': 3, 'c': 4})
3875
Eric V. Smith94549ee2021-04-26 13:14:28 -04003876 def test_defaults(self):
3877 # For kwargs, make sure we can have defaults after non-defaults.
3878 @dataclass
3879 class A:
3880 a: int = 0
3881 _: KW_ONLY
3882 b: int
3883 c: int = 1
3884 d: int
3885
3886 a = A(d=4, b=3)
3887 self.assertEqual(a.a, 0)
3888 self.assertEqual(a.b, 3)
3889 self.assertEqual(a.c, 1)
3890 self.assertEqual(a.d, 4)
3891
3892 # Make sure we still check for non-kwarg non-defaults not following
3893 # defaults.
3894 err_regex = "non-default argument 'z' follows default argument"
3895 with self.assertRaisesRegex(TypeError, err_regex):
3896 @dataclass
3897 class A:
3898 a: int = 0
3899 z: int
3900 _: KW_ONLY
3901 b: int
3902 c: int = 1
3903 d: int
Eric V. Smithc0280532021-04-25 20:42:39 -04003904
Miss Islington (bot)cf8c8782021-11-20 15:46:56 -08003905 def test_make_dataclass(self):
3906 A = make_dataclass("A", ['a'], kw_only=True)
3907 self.assertTrue(fields(A)[0].kw_only)
3908
3909 B = make_dataclass("B",
3910 ['a', ('b', int, field(kw_only=False))],
3911 kw_only=True)
3912 self.assertTrue(fields(B)[0].kw_only)
3913 self.assertFalse(fields(B)[1].kw_only)
3914
3915
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003916if __name__ == '__main__':
3917 unittest.main()