blob: 47075df8d59f3b56cf3768348c9ed9243eb50f1a [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011import unittest
12from unittest.mock import Mock
Miss Islington (bot)79e9f5a2021-09-02 23:26:53 -070013from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional, Protocol
Yury Selivanovd219cc42019-12-09 09:54:20 -050014from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050016from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050017
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040018import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
19import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
20
Eric V. Smithf0db54a2017-12-04 16:58:55 -050021# Just any custom exception we can catch.
22class CustomError(Exception): pass
23
24class TestCase(unittest.TestCase):
25 def test_no_fields(self):
26 @dataclass
27 class C:
28 pass
29
30 o = C()
31 self.assertEqual(len(fields(C)), 0)
32
Eric V. Smith56970b82018-03-22 16:28:48 -040033 def test_no_fields_but_member_variable(self):
34 @dataclass
35 class C:
36 i = 0
37
38 o = C()
39 self.assertEqual(len(fields(C)), 0)
40
Eric V. Smithf0db54a2017-12-04 16:58:55 -050041 def test_one_field_no_default(self):
42 @dataclass
43 class C:
44 x: int
45
46 o = C(42)
47 self.assertEqual(o.x, 42)
48
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053049 def test_field_default_default_factory_error(self):
50 msg = "cannot specify both default and default_factory"
51 with self.assertRaisesRegex(ValueError, msg):
52 @dataclass
53 class C:
54 x: int = field(default=1, default_factory=int)
55
56 def test_field_repr(self):
57 int_field = field(default=1, init=True, repr=False)
58 int_field.name = "id"
59 repr_output = repr(int_field)
60 expected_output = "Field(name='id',type=None," \
61 f"default=1,default_factory={MISSING!r}," \
62 "init=True,repr=False,hash=None," \
63 "compare=True,metadata=mappingproxy({})," \
Eric V. Smithc0280532021-04-25 20:42:39 -040064 f"kw_only={MISSING!r}," \
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053065 "_field_type=None)"
66
67 self.assertEqual(repr_output, expected_output)
68
Eric V. Smithf0db54a2017-12-04 16:58:55 -050069 def test_named_init_params(self):
70 @dataclass
71 class C:
72 x: int
73
74 o = C(x=32)
75 self.assertEqual(o.x, 32)
76
77 def test_two_fields_one_default(self):
78 @dataclass
79 class C:
80 x: int
81 y: int = 0
82
83 o = C(3)
84 self.assertEqual((o.x, o.y), (3, 0))
85
86 # Non-defaults following defaults.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class C:
92 x: int = 0
93 y: int
94
95 # A derived class adds a non-default field after a default one.
96 with self.assertRaisesRegex(TypeError,
97 "non-default argument 'y' follows "
98 "default argument"):
99 @dataclass
100 class B:
101 x: int = 0
102
103 @dataclass
104 class C(B):
105 y: int
106
107 # Override a base class field and add a default to
108 # a field which didn't use to have a default.
109 with self.assertRaisesRegex(TypeError,
110 "non-default argument 'y' follows "
111 "default argument"):
112 @dataclass
113 class B:
114 x: int
115 y: int
116
117 @dataclass
118 class C(B):
119 x: int = 0
120
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500121 def test_overwrite_hash(self):
122 # Test that declaring this class isn't an error. It should
123 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500124 @dataclass(frozen=True)
125 class C:
126 x: int
127 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500128 return 301
129 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500130
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500131 # Test that declaring this class isn't an error. It should
132 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500133 @dataclass(frozen=True)
134 class C:
135 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500136 def __eq__(self, other):
137 return False
138 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500140 # But this one should generate an exception, because with
141 # unsafe_hash=True, it's an error to have a __hash__ defined.
142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 def __hash__(self):
147 pass
148
149 # Creating this class should not generate an exception,
150 # because even though __hash__ exists before @dataclass is
151 # called, (due to __eq__ being defined), since it's None
152 # that's okay.
153 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500154 class C:
155 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500156 def __eq__(self):
157 pass
158 # The generated hash function works as we'd expect.
159 self.assertEqual(hash(C(10)), hash((10,)))
160
161 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400162 # __hash__ exists and is not None, which it would be if it
163 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500164 with self.assertRaisesRegex(TypeError,
165 'Cannot overwrite attribute __hash__'):
166 @dataclass(unsafe_hash=True)
167 class C:
168 x: int
169 def __eq__(self):
170 pass
171 def __hash__(self):
172 pass
173
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500174 def test_overwrite_fields_in_derived_class(self):
175 # Note that x from C1 replaces x in Base, but the order remains
176 # the same as defined in Base.
177 @dataclass
178 class Base:
179 x: Any = 15.0
180 y: int = 0
181
182 @dataclass
183 class C1(Base):
184 z: int = 10
185 x: int = 15
186
187 o = Base()
188 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
189
190 o = C1()
191 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
192
193 o = C1(x=5)
194 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
195
196 def test_field_named_self(self):
197 @dataclass
198 class C:
199 self: str
200 c=C('foo')
201 self.assertEqual(c.self, 'foo')
202
203 # Make sure the first parameter is not named 'self'.
204 sig = inspect.signature(C.__init__)
205 first = next(iter(sig.parameters))
206 self.assertNotEqual('self', first)
207
208 # But we do use 'self' if no field named self.
209 @dataclass
210 class C:
211 selfx: str
212
213 # Make sure the first parameter is named 'self'.
214 sig = inspect.signature(C.__init__)
215 first = next(iter(sig.parameters))
216 self.assertEqual('self', first)
217
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300218 def test_field_named_object(self):
219 @dataclass
220 class C:
221 object: str
222 c = C('foo')
223 self.assertEqual(c.object, 'foo')
224
225 def test_field_named_object_frozen(self):
226 @dataclass(frozen=True)
227 class C:
228 object: str
229 c = C('foo')
230 self.assertEqual(c.object, 'foo')
231
232 def test_field_named_like_builtin(self):
233 # Attribute names can shadow built-in names
234 # since code generation is used.
235 # Ensure that this is not happening.
236 exclusions = {'None', 'True', 'False'}
237 builtins_names = sorted(
238 b for b in builtins.__dict__.keys()
239 if not b.startswith('__') and b not in exclusions
240 )
241 attributes = [(name, str) for name in builtins_names]
242 C = make_dataclass('C', attributes)
243
244 c = C(*[name for name in builtins_names])
245
246 for name in builtins_names:
247 self.assertEqual(getattr(c, name), name)
248
249 def test_field_named_like_builtin_frozen(self):
250 # Attribute names can shadow built-in names
251 # since code generation is used.
252 # Ensure that this is not happening
253 # for frozen data classes.
254 exclusions = {'None', 'True', 'False'}
255 builtins_names = sorted(
256 b for b in builtins.__dict__.keys()
257 if not b.startswith('__') and b not in exclusions
258 )
259 attributes = [(name, str) for name in builtins_names]
260 C = make_dataclass('C', attributes, frozen=True)
261
262 c = C(*[name for name in builtins_names])
263
264 for name in builtins_names:
265 self.assertEqual(getattr(c, name), name)
266
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500267 def test_0_field_compare(self):
268 # Ensure that order=False is the default.
269 @dataclass
270 class C0:
271 pass
272
273 @dataclass(order=False)
274 class C1:
275 pass
276
277 for cls in [C0, C1]:
278 with self.subTest(cls=cls):
279 self.assertEqual(cls(), cls())
280 for idx, fn in enumerate([lambda a, b: a < b,
281 lambda a, b: a <= b,
282 lambda a, b: a > b,
283 lambda a, b: a >= b]):
284 with self.subTest(idx=idx):
285 with self.assertRaisesRegex(TypeError,
286 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
287 fn(cls(), cls())
288
289 @dataclass(order=True)
290 class C:
291 pass
292 self.assertLessEqual(C(), C())
293 self.assertGreaterEqual(C(), C())
294
295 def test_1_field_compare(self):
296 # Ensure that order=False is the default.
297 @dataclass
298 class C0:
299 x: int
300
301 @dataclass(order=False)
302 class C1:
303 x: int
304
305 for cls in [C0, C1]:
306 with self.subTest(cls=cls):
307 self.assertEqual(cls(1), cls(1))
308 self.assertNotEqual(cls(0), cls(1))
309 for idx, fn in enumerate([lambda a, b: a < b,
310 lambda a, b: a <= b,
311 lambda a, b: a > b,
312 lambda a, b: a >= b]):
313 with self.subTest(idx=idx):
314 with self.assertRaisesRegex(TypeError,
315 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
316 fn(cls(0), cls(0))
317
318 @dataclass(order=True)
319 class C:
320 x: int
321 self.assertLess(C(0), C(1))
322 self.assertLessEqual(C(0), C(1))
323 self.assertLessEqual(C(1), C(1))
324 self.assertGreater(C(1), C(0))
325 self.assertGreaterEqual(C(1), C(0))
326 self.assertGreaterEqual(C(1), C(1))
327
328 def test_simple_compare(self):
329 # Ensure that order=False is the default.
330 @dataclass
331 class C0:
332 x: int
333 y: int
334
335 @dataclass(order=False)
336 class C1:
337 x: int
338 y: int
339
340 for cls in [C0, C1]:
341 with self.subTest(cls=cls):
342 self.assertEqual(cls(0, 0), cls(0, 0))
343 self.assertEqual(cls(1, 2), cls(1, 2))
344 self.assertNotEqual(cls(1, 0), cls(0, 0))
345 self.assertNotEqual(cls(1, 0), cls(1, 1))
346 for idx, fn in enumerate([lambda a, b: a < b,
347 lambda a, b: a <= b,
348 lambda a, b: a > b,
349 lambda a, b: a >= b]):
350 with self.subTest(idx=idx):
351 with self.assertRaisesRegex(TypeError,
352 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
353 fn(cls(0, 0), cls(0, 0))
354
355 @dataclass(order=True)
356 class C:
357 x: int
358 y: int
359
360 for idx, fn in enumerate([lambda a, b: a == b,
361 lambda a, b: a <= b,
362 lambda a, b: a >= b]):
363 with self.subTest(idx=idx):
364 self.assertTrue(fn(C(0, 0), C(0, 0)))
365
366 for idx, fn in enumerate([lambda a, b: a < b,
367 lambda a, b: a <= b,
368 lambda a, b: a != b]):
369 with self.subTest(idx=idx):
370 self.assertTrue(fn(C(0, 0), C(0, 1)))
371 self.assertTrue(fn(C(0, 1), C(1, 0)))
372 self.assertTrue(fn(C(1, 0), C(1, 1)))
373
374 for idx, fn in enumerate([lambda a, b: a > b,
375 lambda a, b: a >= b,
376 lambda a, b: a != b]):
377 with self.subTest(idx=idx):
378 self.assertTrue(fn(C(0, 1), C(0, 0)))
379 self.assertTrue(fn(C(1, 0), C(0, 1)))
380 self.assertTrue(fn(C(1, 1), C(1, 0)))
381
382 def test_compare_subclasses(self):
383 # Comparisons fail for subclasses, even if no fields
384 # are added.
385 @dataclass
386 class B:
387 i: int
388
389 @dataclass
390 class C(B):
391 pass
392
393 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
394 (lambda a, b: a != b, True)]):
395 with self.subTest(idx=idx):
396 self.assertEqual(fn(B(0), C(0)), expected)
397
398 for idx, fn in enumerate([lambda a, b: a < b,
399 lambda a, b: a <= b,
400 lambda a, b: a > b,
401 lambda a, b: a >= b]):
402 with self.subTest(idx=idx):
403 with self.assertRaisesRegex(TypeError,
404 "not supported between instances of 'B' and 'C'"):
405 fn(B(0), C(0))
406
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500407 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500408 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500409 for (eq, order, result ) in [
410 (False, False, 'neither'),
411 (False, True, 'exception'),
412 (True, False, 'eq_only'),
413 (True, True, 'both'),
414 ]:
415 with self.subTest(eq=eq, order=order):
416 if result == 'exception':
417 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
418 @dataclass(eq=eq, order=order)
419 class C:
420 pass
421 else:
422 @dataclass(eq=eq, order=order)
423 class C:
424 pass
425
426 if result == 'neither':
427 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500428 self.assertNotIn('__lt__', C.__dict__)
429 self.assertNotIn('__le__', C.__dict__)
430 self.assertNotIn('__gt__', C.__dict__)
431 self.assertNotIn('__ge__', C.__dict__)
432 elif result == 'both':
433 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 self.assertIn('__lt__', C.__dict__)
435 self.assertIn('__le__', C.__dict__)
436 self.assertIn('__gt__', C.__dict__)
437 self.assertIn('__ge__', C.__dict__)
438 elif result == 'eq_only':
439 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 self.assertNotIn('__lt__', C.__dict__)
441 self.assertNotIn('__le__', C.__dict__)
442 self.assertNotIn('__gt__', C.__dict__)
443 self.assertNotIn('__ge__', C.__dict__)
444 else:
445 assert False, f'unknown result {result!r}'
446
447 def test_field_no_default(self):
448 @dataclass
449 class C:
450 x: int = field()
451
452 self.assertEqual(C(5).x, 5)
453
454 with self.assertRaisesRegex(TypeError,
455 r"__init__\(\) missing 1 required "
456 "positional argument: 'x'"):
457 C()
458
459 def test_field_default(self):
460 default = object()
461 @dataclass
462 class C:
463 x: object = field(default=default)
464
465 self.assertIs(C.x, default)
466 c = C(10)
467 self.assertEqual(c.x, 10)
468
469 # If we delete the instance attribute, we should then see the
470 # class attribute.
471 del c.x
472 self.assertIs(c.x, default)
473
474 self.assertIs(C().x, default)
475
476 def test_not_in_repr(self):
477 @dataclass
478 class C:
479 x: int = field(repr=False)
480 with self.assertRaises(TypeError):
481 C()
482 c = C(10)
483 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
484
485 @dataclass
486 class C:
487 x: int = field(repr=False)
488 y: int
489 c = C(10, 20)
490 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
491
492 def test_not_in_compare(self):
493 @dataclass
494 class C:
495 x: int = 0
496 y: int = field(compare=False, default=4)
497
498 self.assertEqual(C(), C(0, 20))
499 self.assertEqual(C(1, 10), C(1, 20))
500 self.assertNotEqual(C(3), C(4, 10))
501 self.assertNotEqual(C(3, 10), C(4, 10))
502
503 def test_hash_field_rules(self):
504 # Test all 6 cases of:
505 # hash=True/False/None
506 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500507 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500508 (True, False, 'field' ),
509 (True, True, 'field' ),
510 (False, False, 'absent'),
511 (False, True, 'absent'),
512 (None, False, 'absent'),
513 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500514 ]:
515 with self.subTest(hash=hash_, compare=compare):
516 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500517 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500518 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500519
520 if result == 'field':
521 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500522 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500523 elif result == 'absent':
524 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500525 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500526 else:
527 assert False, f'unknown result {result!r}'
528
529 def test_init_false_no_default(self):
530 # If init=False and no default value, then the field won't be
531 # present in the instance.
532 @dataclass
533 class C:
534 x: int = field(init=False)
535
536 self.assertNotIn('x', C().__dict__)
537
538 @dataclass
539 class C:
540 x: int
541 y: int = 0
542 z: int = field(init=False)
543 t: int = 10
544
545 self.assertNotIn('z', C(0).__dict__)
546 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
547
548 def test_class_marker(self):
549 @dataclass
550 class C:
551 x: int
552 y: str = field(init=False, default=None)
553 z: str = field(repr=False)
554
555 the_fields = fields(C)
556 # the_fields is a tuple of 3 items, each value
557 # is in __annotations__.
558 self.assertIsInstance(the_fields, tuple)
559 for f in the_fields:
560 self.assertIs(type(f), Field)
561 self.assertIn(f.name, C.__annotations__)
562
563 self.assertEqual(len(the_fields), 3)
564
565 self.assertEqual(the_fields[0].name, 'x')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100566 self.assertEqual(the_fields[0].type, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500567 self.assertFalse(hasattr(C, 'x'))
568 self.assertTrue (the_fields[0].init)
569 self.assertTrue (the_fields[0].repr)
570 self.assertEqual(the_fields[1].name, 'y')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100571 self.assertEqual(the_fields[1].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500572 self.assertIsNone(getattr(C, 'y'))
573 self.assertFalse(the_fields[1].init)
574 self.assertTrue (the_fields[1].repr)
575 self.assertEqual(the_fields[2].name, 'z')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100576 self.assertEqual(the_fields[2].type, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500577 self.assertFalse(hasattr(C, 'z'))
578 self.assertTrue (the_fields[2].init)
579 self.assertFalse(the_fields[2].repr)
580
581 def test_field_order(self):
582 @dataclass
583 class B:
584 a: str = 'B:a'
585 b: str = 'B:b'
586 c: str = 'B:c'
587
588 @dataclass
589 class C(B):
590 b: str = 'C:b'
591
592 self.assertEqual([(f.name, f.default) for f in fields(C)],
593 [('a', 'B:a'),
594 ('b', 'C:b'),
595 ('c', 'B:c')])
596
597 @dataclass
598 class D(B):
599 c: str = 'D:c'
600
601 self.assertEqual([(f.name, f.default) for f in fields(D)],
602 [('a', 'B:a'),
603 ('b', 'B:b'),
604 ('c', 'D:c')])
605
606 @dataclass
607 class E(D):
608 a: str = 'E:a'
609 d: str = 'E:d'
610
611 self.assertEqual([(f.name, f.default) for f in fields(E)],
612 [('a', 'E:a'),
613 ('b', 'B:b'),
614 ('c', 'D:c'),
615 ('d', 'E:d')])
616
617 def test_class_attrs(self):
618 # We only have a class attribute if a default value is
619 # specified, either directly or via a field with a default.
620 default = object()
621 @dataclass
622 class C:
623 x: int
624 y: int = field(repr=False)
625 z: object = default
626 t: int = field(default=100)
627
628 self.assertFalse(hasattr(C, 'x'))
629 self.assertFalse(hasattr(C, 'y'))
630 self.assertIs (C.z, default)
631 self.assertEqual(C.t, 100)
632
633 def test_disallowed_mutable_defaults(self):
634 # For the known types, don't allow mutable default values.
635 for typ, empty, non_empty in [(list, [], [1]),
636 (dict, {}, {0:1}),
637 (set, set(), set([1])),
638 ]:
639 with self.subTest(typ=typ):
640 # Can't use a zero-length value.
641 with self.assertRaisesRegex(ValueError,
642 f'mutable default {typ} for field '
643 'x is not allowed'):
644 @dataclass
645 class Point:
646 x: typ = empty
647
648
649 # Nor a non-zero-length value
650 with self.assertRaisesRegex(ValueError,
651 f'mutable default {typ} for field '
652 'y is not allowed'):
653 @dataclass
654 class Point:
655 y: typ = non_empty
656
657 # Check subtypes also fail.
658 class Subclass(typ): pass
659
660 with self.assertRaisesRegex(ValueError,
661 f"mutable default .*Subclass'>"
662 ' for field z is not allowed'
663 ):
664 @dataclass
665 class Point:
666 z: typ = Subclass()
667
668 # Because this is a ClassVar, it can be mutable.
669 @dataclass
670 class C:
671 z: ClassVar[typ] = typ()
672
673 # Because this is a ClassVar, it can be mutable.
674 @dataclass
675 class C:
676 x: ClassVar[typ] = Subclass()
677
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500678 def test_deliberately_mutable_defaults(self):
679 # If a mutable default isn't in the known list of
680 # (list, dict, set), then it's okay.
681 class Mutable:
682 def __init__(self):
683 self.l = []
684
685 @dataclass
686 class C:
687 x: Mutable
688
689 # These 2 instances will share this value of x.
690 lst = Mutable()
691 o1 = C(lst)
692 o2 = C(lst)
693 self.assertEqual(o1, o2)
694 o1.x.l.extend([1, 2])
695 self.assertEqual(o1, o2)
696 self.assertEqual(o1.x.l, [1, 2])
697 self.assertIs(o1.x, o2.x)
698
699 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400700 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 @dataclass()
702 class C:
703 x: int
704
705 self.assertEqual(C(42).x, 42)
706
707 def test_not_tuple(self):
708 # Make sure we can't be compared to a tuple.
709 @dataclass
710 class Point:
711 x: int
712 y: int
713 self.assertNotEqual(Point(1, 2), (1, 2))
714
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400715 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716 @dataclass
717 class C:
718 x: int
719 y: int
720 self.assertNotEqual(Point(1, 3), C(1, 3))
721
Windson yangbe372d72019-04-23 02:45:34 +0800722 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # Test that some of the problems with namedtuple don't happen
724 # here.
725 @dataclass
726 class Point3D:
727 x: int
728 y: int
729 z: int
730
731 @dataclass
732 class Date:
733 year: int
734 month: int
735 day: int
736
737 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
738 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
739
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400740 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200741 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 x, y, z = Point3D(4, 5, 6)
743
Eric V. Smith7c99e932018-01-28 19:18:55 -0500744 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500745 # equal.
746 @dataclass
747 class Point3Dv1:
748 x: int = 0
749 y: int = 0
750 z: int = 0
751 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
752
753 def test_function_annotations(self):
754 # Some dummy class and instance to use as a default.
755 class F:
756 pass
757 f = F()
758
759 def validate_class(cls):
760 # First, check __annotations__, even though they're not
761 # function annotations.
Pablo Galindob0544ba2021-04-21 12:41:19 +0100762 self.assertEqual(cls.__annotations__['i'], int)
763 self.assertEqual(cls.__annotations__['j'], str)
764 self.assertEqual(cls.__annotations__['k'], F)
765 self.assertEqual(cls.__annotations__['l'], float)
766 self.assertEqual(cls.__annotations__['z'], complex)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500767
768 # Verify __init__.
769
770 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400771 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500772 self.assertIs(signature.return_annotation, None)
773
774 # Check each parameter.
775 params = iter(signature.parameters.values())
776 param = next(params)
777 # This is testing an internal name, and probably shouldn't be tested.
778 self.assertEqual(param.name, 'self')
779 param = next(params)
780 self.assertEqual(param.name, 'i')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100781 self.assertIs (param.annotation, int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500782 self.assertEqual(param.default, inspect.Parameter.empty)
783 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
784 param = next(params)
785 self.assertEqual(param.name, 'j')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100786 self.assertIs (param.annotation, str)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500787 self.assertEqual(param.default, inspect.Parameter.empty)
788 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
789 param = next(params)
790 self.assertEqual(param.name, 'k')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100791 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400792 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500793 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
794 param = next(params)
795 self.assertEqual(param.name, 'l')
Pablo Galindob0544ba2021-04-21 12:41:19 +0100796 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400797 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500798 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
799 self.assertRaises(StopIteration, next, params)
800
801
802 @dataclass
803 class C:
804 i: int
805 j: str
806 k: F = f
807 l: float=field(default=None)
808 z: complex=field(default=3+4j, init=False)
809
810 validate_class(C)
811
812 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500813 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500814 class C:
815 i: int
816 j: str
817 k: F = f
818 l: float=field(default=None)
819 z: complex=field(default=3+4j, init=False)
820
821 validate_class(C)
822
Eric V. Smith03220fd2017-12-29 13:59:58 -0500823 def test_missing_default(self):
824 # Test that MISSING works the same as a default not being
825 # specified.
826 @dataclass
827 class C:
828 x: int=field(default=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_default_factory(self):
845 # Test that MISSING works the same as a default factory not
846 # being specified (which is really the same as a default not
847 # being specified, too).
848 @dataclass
849 class C:
850 x: int=field(default_factory=MISSING)
851 with self.assertRaisesRegex(TypeError,
852 r'__init__\(\) missing 1 required '
853 'positional argument'):
854 C()
855 self.assertNotIn('x', C.__dict__)
856
857 @dataclass
858 class D:
859 x: int=field(default=MISSING, default_factory=MISSING)
860 with self.assertRaisesRegex(TypeError,
861 r'__init__\(\) missing 1 required '
862 'positional argument'):
863 D()
864 self.assertNotIn('x', D.__dict__)
865
866 def test_missing_repr(self):
867 self.assertIn('MISSING_TYPE object', repr(MISSING))
868
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500869 def test_dont_include_other_annotations(self):
870 @dataclass
871 class C:
872 i: int
873 def foo(self) -> int:
874 return 4
875 @property
876 def bar(self) -> int:
877 return 5
878 self.assertEqual(list(C.__annotations__), ['i'])
879 self.assertEqual(C(10).foo(), 4)
880 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400881 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500882
883 def test_post_init(self):
884 # Just make sure it gets called
885 @dataclass
886 class C:
887 def __post_init__(self):
888 raise CustomError()
889 with self.assertRaises(CustomError):
890 C()
891
892 @dataclass
893 class C:
894 i: int = 10
895 def __post_init__(self):
896 if self.i == 10:
897 raise CustomError()
898 with self.assertRaises(CustomError):
899 C()
900 # post-init gets called, but doesn't raise. This is just
901 # checking that self is used correctly.
902 C(5)
903
904 # If there's not an __init__, then post-init won't get called.
905 @dataclass(init=False)
906 class C:
907 def __post_init__(self):
908 raise CustomError()
909 # Creating the class won't raise
910 C()
911
912 @dataclass
913 class C:
914 x: int = 0
915 def __post_init__(self):
916 self.x *= 2
917 self.assertEqual(C().x, 0)
918 self.assertEqual(C(2).x, 4)
919
Mike53f7a7c2017-12-14 14:04:53 +0300920 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500921 # attributes.
922 @dataclass(frozen=True)
923 class C:
924 x: int = 0
925 def __post_init__(self):
926 self.x *= 2
927 with self.assertRaises(FrozenInstanceError):
928 C()
929
930 def test_post_init_super(self):
931 # Make sure super() post-init isn't called by default.
932 class B:
933 def __post_init__(self):
934 raise CustomError()
935
936 @dataclass
937 class C(B):
938 def __post_init__(self):
939 self.x = 5
940
941 self.assertEqual(C().x, 5)
942
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400943 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500944 @dataclass
945 class C(B):
946 def __post_init__(self):
947 super().__post_init__()
948
949 with self.assertRaises(CustomError):
950 C()
951
952 # Make sure post-init is called, even if not defined in our
953 # class.
954 @dataclass
955 class C(B):
956 pass
957
958 with self.assertRaises(CustomError):
959 C()
960
961 def test_post_init_staticmethod(self):
962 flag = False
963 @dataclass
964 class C:
965 x: int
966 y: int
967 @staticmethod
968 def __post_init__():
969 nonlocal flag
970 flag = True
971
972 self.assertFalse(flag)
973 c = C(3, 4)
974 self.assertEqual((c.x, c.y), (3, 4))
975 self.assertTrue(flag)
976
977 def test_post_init_classmethod(self):
978 @dataclass
979 class C:
980 flag = False
981 x: int
982 y: int
983 @classmethod
984 def __post_init__(cls):
985 cls.flag = True
986
987 self.assertFalse(C.flag)
988 c = C(3, 4)
989 self.assertEqual((c.x, c.y), (3, 4))
990 self.assertTrue(C.flag)
991
992 def test_class_var(self):
993 # Make sure ClassVars are ignored in __init__, __repr__, etc.
994 @dataclass
995 class C:
996 x: int
997 y: int = 10
998 z: ClassVar[int] = 1000
999 w: ClassVar[int] = 2000
1000 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001001 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001002
1003 c = C(5)
1004 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001005 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001006 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001007 self.assertEqual(c.z, 1000)
1008 self.assertEqual(c.w, 2000)
1009 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001010 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001011 C.z += 1
1012 self.assertEqual(c.z, 1001)
1013 c = C(20)
1014 self.assertEqual((c.x, c.y), (20, 10))
1015 self.assertEqual(c.z, 1001)
1016 self.assertEqual(c.w, 2000)
1017 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001018 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001019
1020 def test_class_var_no_default(self):
1021 # If a ClassVar has no default value, it should not be set on the class.
1022 @dataclass
1023 class C:
1024 x: ClassVar[int]
1025
1026 self.assertNotIn('x', C.__dict__)
1027
1028 def test_class_var_default_factory(self):
1029 # It makes no sense for a ClassVar to have a default factory. When
1030 # would it be called? Call it yourself, since it's class-wide.
1031 with self.assertRaisesRegex(TypeError,
1032 'cannot have a default factory'):
1033 @dataclass
1034 class C:
1035 x: ClassVar[int] = field(default_factory=int)
1036
1037 self.assertNotIn('x', C.__dict__)
1038
1039 def test_class_var_with_default(self):
1040 # If a ClassVar has a default value, it should be set on the class.
1041 @dataclass
1042 class C:
1043 x: ClassVar[int] = 10
1044 self.assertEqual(C.x, 10)
1045
1046 @dataclass
1047 class C:
1048 x: ClassVar[int] = field(default=10)
1049 self.assertEqual(C.x, 10)
1050
1051 def test_class_var_frozen(self):
1052 # Make sure ClassVars work even if we're frozen.
1053 @dataclass(frozen=True)
1054 class C:
1055 x: int
1056 y: int = 10
1057 z: ClassVar[int] = 1000
1058 w: ClassVar[int] = 2000
1059 t: ClassVar[int] = 3000
1060
1061 c = C(5)
1062 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1063 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1064 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1065 self.assertEqual(c.z, 1000)
1066 self.assertEqual(c.w, 2000)
1067 self.assertEqual(c.t, 3000)
1068 # We can still modify the ClassVar, it's only instances that are
1069 # frozen.
1070 C.z += 1
1071 self.assertEqual(c.z, 1001)
1072 c = C(20)
1073 self.assertEqual((c.x, c.y), (20, 10))
1074 self.assertEqual(c.z, 1001)
1075 self.assertEqual(c.w, 2000)
1076 self.assertEqual(c.t, 3000)
1077
1078 def test_init_var_no_default(self):
1079 # If an InitVar has no default value, it should not be set on the class.
1080 @dataclass
1081 class C:
1082 x: InitVar[int]
1083
1084 self.assertNotIn('x', C.__dict__)
1085
1086 def test_init_var_default_factory(self):
1087 # It makes no sense for an InitVar to have a default factory. When
1088 # would it be called? Call it yourself, since it's class-wide.
1089 with self.assertRaisesRegex(TypeError,
1090 'cannot have a default factory'):
1091 @dataclass
1092 class C:
1093 x: InitVar[int] = field(default_factory=int)
1094
1095 self.assertNotIn('x', C.__dict__)
1096
1097 def test_init_var_with_default(self):
1098 # If an InitVar has a default value, it should be set on the class.
1099 @dataclass
1100 class C:
1101 x: InitVar[int] = 10
1102 self.assertEqual(C.x, 10)
1103
1104 @dataclass
1105 class C:
1106 x: InitVar[int] = field(default=10)
1107 self.assertEqual(C.x, 10)
1108
1109 def test_init_var(self):
1110 @dataclass
1111 class C:
1112 x: int = None
1113 init_param: InitVar[int] = None
1114
1115 def __post_init__(self, init_param):
1116 if self.x is None:
1117 self.x = init_param*2
1118
1119 c = C(init_param=10)
1120 self.assertEqual(c.x, 20)
1121
Augusto Hack01ee12b2019-06-02 23:14:48 -03001122 def test_init_var_preserve_type(self):
1123 self.assertEqual(InitVar[int].type, int)
1124
1125 # Make sure the repr is correct.
1126 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001127 self.assertEqual(repr(InitVar[List[int]]),
1128 'dataclasses.InitVar[typing.List[int]]')
Miss Islington (bot)f1dd5ed2021-12-05 13:02:47 -08001129 self.assertEqual(repr(InitVar[list[int]]),
1130 'dataclasses.InitVar[list[int]]')
1131 self.assertEqual(repr(InitVar[int|str]),
1132 'dataclasses.InitVar[int | str]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001133
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001134 def test_init_var_inheritance(self):
1135 # Note that this deliberately tests that a dataclass need not
1136 # have a __post_init__ function if it has an InitVar field.
1137 # It could just be used in a derived class, as shown here.
1138 @dataclass
1139 class Base:
1140 x: int
1141 init_base: InitVar[int]
1142
1143 # We can instantiate by passing the InitVar, even though
1144 # it's not used.
1145 b = Base(0, 10)
1146 self.assertEqual(vars(b), {'x': 0})
1147
1148 @dataclass
1149 class C(Base):
1150 y: int
1151 init_derived: InitVar[int]
1152
1153 def __post_init__(self, init_base, init_derived):
1154 self.x = self.x + init_base
1155 self.y = self.y + init_derived
1156
1157 c = C(10, 11, 50, 51)
1158 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1159
1160 def test_default_factory(self):
1161 # Test a factory that returns a new list.
1162 @dataclass
1163 class C:
1164 x: int
1165 y: list = field(default_factory=list)
1166
1167 c0 = C(3)
1168 c1 = C(3)
1169 self.assertEqual(c0.x, 3)
1170 self.assertEqual(c0.y, [])
1171 self.assertEqual(c0, c1)
1172 self.assertIsNot(c0.y, c1.y)
1173 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1174
1175 # Test a factory that returns a shared list.
1176 l = []
1177 @dataclass
1178 class C:
1179 x: int
1180 y: list = field(default_factory=lambda: l)
1181
1182 c0 = C(3)
1183 c1 = C(3)
1184 self.assertEqual(c0.x, 3)
1185 self.assertEqual(c0.y, [])
1186 self.assertEqual(c0, c1)
1187 self.assertIs(c0.y, c1.y)
1188 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1189
1190 # Test various other field flags.
1191 # repr
1192 @dataclass
1193 class C:
1194 x: list = field(default_factory=list, repr=False)
1195 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1196 self.assertEqual(C().x, [])
1197
1198 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001199 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001200 class C:
1201 x: list = field(default_factory=list, hash=False)
1202 self.assertEqual(astuple(C()), ([],))
1203 self.assertEqual(hash(C()), hash(()))
1204
1205 # init (see also test_default_factory_with_no_init)
1206 @dataclass
1207 class C:
1208 x: list = field(default_factory=list, init=False)
1209 self.assertEqual(astuple(C()), ([],))
1210
1211 # compare
1212 @dataclass
1213 class C:
1214 x: list = field(default_factory=list, compare=False)
1215 self.assertEqual(C(), C([1]))
1216
1217 def test_default_factory_with_no_init(self):
1218 # We need a factory with a side effect.
1219 factory = Mock()
1220
1221 @dataclass
1222 class C:
1223 x: list = field(default_factory=factory, init=False)
1224
1225 # Make sure the default factory is called for each new instance.
1226 C().x
1227 self.assertEqual(factory.call_count, 1)
1228 C().x
1229 self.assertEqual(factory.call_count, 2)
1230
1231 def test_default_factory_not_called_if_value_given(self):
1232 # We need a factory that we can test if it's been called.
1233 factory = Mock()
1234
1235 @dataclass
1236 class C:
1237 x: int = field(default_factory=factory)
1238
1239 # Make sure that if a field has a default factory function,
1240 # it's not called if a value is specified.
1241 C().x
1242 self.assertEqual(factory.call_count, 1)
1243 self.assertEqual(C(10).x, 10)
1244 self.assertEqual(factory.call_count, 1)
1245 C().x
1246 self.assertEqual(factory.call_count, 2)
1247
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001248 def test_default_factory_derived(self):
1249 # See bpo-32896.
1250 @dataclass
1251 class Foo:
1252 x: dict = field(default_factory=dict)
1253
1254 @dataclass
1255 class Bar(Foo):
1256 y: int = 1
1257
1258 self.assertEqual(Foo().x, {})
1259 self.assertEqual(Bar().x, {})
1260 self.assertEqual(Bar().y, 1)
1261
1262 @dataclass
1263 class Baz(Foo):
1264 pass
1265 self.assertEqual(Baz().x, {})
1266
1267 def test_intermediate_non_dataclass(self):
1268 # Test that an intermediate class that defines
1269 # annotations does not define fields.
1270
1271 @dataclass
1272 class A:
1273 x: int
1274
1275 class B(A):
1276 y: int
1277
1278 @dataclass
1279 class C(B):
1280 z: int
1281
1282 c = C(1, 3)
1283 self.assertEqual((c.x, c.z), (1, 3))
1284
1285 # .y was not initialized.
1286 with self.assertRaisesRegex(AttributeError,
1287 'object has no attribute'):
1288 c.y
1289
1290 # And if we again derive a non-dataclass, no fields are added.
1291 class D(C):
1292 t: int
1293 d = D(4, 5)
1294 self.assertEqual((d.x, d.z), (4, 5))
1295
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001296 def test_classvar_default_factory(self):
1297 # It's an error for a ClassVar to have a factory function.
1298 with self.assertRaisesRegex(TypeError,
1299 'cannot have a default factory'):
1300 @dataclass
1301 class C:
1302 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001303
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001304 def test_is_dataclass(self):
1305 class NotDataClass:
1306 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001307
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001308 self.assertFalse(is_dataclass(0))
1309 self.assertFalse(is_dataclass(int))
1310 self.assertFalse(is_dataclass(NotDataClass))
1311 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001312
1313 @dataclass
1314 class C:
1315 x: int
1316
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001317 @dataclass
1318 class D:
1319 d: C
1320 e: int
1321
1322 c = C(10)
1323 d = D(c, 4)
1324
1325 self.assertTrue(is_dataclass(C))
1326 self.assertTrue(is_dataclass(c))
1327 self.assertFalse(is_dataclass(c.x))
1328 self.assertTrue(is_dataclass(d.d))
1329 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001330
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001331 def test_is_dataclass_when_getattr_always_returns(self):
1332 # See bpo-37868.
1333 class A:
1334 def __getattr__(self, key):
1335 return 0
1336 self.assertFalse(is_dataclass(A))
1337 a = A()
1338
1339 # Also test for an instance attribute.
1340 class B:
1341 pass
1342 b = B()
1343 b.__dataclass_fields__ = []
1344
1345 for obj in a, b:
1346 with self.subTest(obj=obj):
1347 self.assertFalse(is_dataclass(obj))
1348
1349 # Indirect tests for _is_dataclass_instance().
1350 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1351 asdict(obj)
1352 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1353 astuple(obj)
1354 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1355 replace(obj, x=0)
1356
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001357 def test_helper_fields_with_class_instance(self):
1358 # Check that we can call fields() on either a class or instance,
1359 # and get back the same thing.
1360 @dataclass
1361 class C:
1362 x: int
1363 y: float
1364
1365 self.assertEqual(fields(C), fields(C(0, 0.0)))
1366
1367 def test_helper_fields_exception(self):
1368 # Check that TypeError is raised if not passed a dataclass or
1369 # instance.
1370 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1371 fields(0)
1372
1373 class C: pass
1374 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1375 fields(C)
1376 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1377 fields(C())
1378
1379 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001380 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001381 @dataclass
1382 class C:
1383 x: int
1384 y: int
1385 c = C(1, 2)
1386
1387 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1388 self.assertEqual(asdict(c), asdict(c))
1389 self.assertIsNot(asdict(c), asdict(c))
1390 c.x = 42
1391 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1392 self.assertIs(type(asdict(c)), dict)
1393
1394 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001395 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001396 @dataclass
1397 class C:
1398 x: int
1399 y: int
1400 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1401 asdict(C)
1402 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1403 asdict(int)
1404
1405 def test_helper_asdict_copy_values(self):
1406 @dataclass
1407 class C:
1408 x: int
1409 y: List[int] = field(default_factory=list)
1410 initial = []
1411 c = C(1, initial)
1412 d = asdict(c)
1413 self.assertEqual(d['y'], initial)
1414 self.assertIsNot(d['y'], initial)
1415 c = C(1)
1416 d = asdict(c)
1417 d['y'].append(1)
1418 self.assertEqual(c.y, [])
1419
1420 def test_helper_asdict_nested(self):
1421 @dataclass
1422 class UserId:
1423 token: int
1424 group: int
1425 @dataclass
1426 class User:
1427 name: str
1428 id: UserId
1429 u = User('Joe', UserId(123, 1))
1430 d = asdict(u)
1431 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1432 self.assertIsNot(asdict(u), asdict(u))
1433 u.id.group = 2
1434 self.assertEqual(asdict(u), {'name': 'Joe',
1435 'id': {'token': 123, 'group': 2}})
1436
1437 def test_helper_asdict_builtin_containers(self):
1438 @dataclass
1439 class User:
1440 name: str
1441 id: int
1442 @dataclass
1443 class GroupList:
1444 id: int
1445 users: List[User]
1446 @dataclass
1447 class GroupTuple:
1448 id: int
1449 users: Tuple[User, ...]
1450 @dataclass
1451 class GroupDict:
1452 id: int
1453 users: Dict[str, User]
1454 a = User('Alice', 1)
1455 b = User('Bob', 2)
1456 gl = GroupList(0, [a, b])
1457 gt = GroupTuple(0, (a, b))
1458 gd = GroupDict(0, {'first': a, 'second': b})
1459 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1460 {'name': 'Bob', 'id': 2}]})
1461 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1462 {'name': 'Bob', 'id': 2})})
1463 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1464 'second': {'name': 'Bob', 'id': 2}}})
1465
Windson yangbe372d72019-04-23 02:45:34 +08001466 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001467 @dataclass
1468 class Child:
1469 d: object
1470
1471 @dataclass
1472 class Parent:
1473 child: Child
1474
1475 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1476 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1477
1478 def test_helper_asdict_factory(self):
1479 @dataclass
1480 class C:
1481 x: int
1482 y: int
1483 c = C(1, 2)
1484 d = asdict(c, dict_factory=OrderedDict)
1485 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1486 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1487 c.x = 42
1488 d = asdict(c, dict_factory=OrderedDict)
1489 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1490 self.assertIs(type(d), OrderedDict)
1491
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001492 def test_helper_asdict_namedtuple(self):
1493 T = namedtuple('T', 'a b c')
1494 @dataclass
1495 class C:
1496 x: str
1497 y: T
1498 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1499
1500 d = asdict(c)
1501 self.assertEqual(d, {'x': 'outer',
1502 'y': T(1,
1503 {'x': 'inner',
1504 'y': T(11, 12, 13)},
1505 2),
1506 }
1507 )
1508
1509 # Now with a dict_factory. OrderedDict is convenient, but
1510 # since it compares to dicts, we also need to have separate
1511 # assertIs tests.
1512 d = asdict(c, dict_factory=OrderedDict)
1513 self.assertEqual(d, {'x': 'outer',
1514 'y': T(1,
1515 {'x': 'inner',
1516 'y': T(11, 12, 13)},
1517 2),
1518 }
1519 )
1520
penguindustin96466302019-05-06 14:57:17 -04001521 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001522 self.assertIs(type(d), OrderedDict)
1523 self.assertIs(type(d['y'][1]), OrderedDict)
1524
1525 def test_helper_asdict_namedtuple_key(self):
1526 # Ensure that a field that contains a dict which has a
1527 # namedtuple as a key works with asdict().
1528
1529 @dataclass
1530 class C:
1531 f: dict
1532 T = namedtuple('T', 'a')
1533
1534 c = C({T('an a'): 0})
1535
1536 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1537
1538 def test_helper_asdict_namedtuple_derived(self):
1539 class T(namedtuple('Tbase', 'a')):
1540 def my_a(self):
1541 return self.a
1542
1543 @dataclass
1544 class C:
1545 f: T
1546
1547 t = T(6)
1548 c = C(t)
1549
1550 d = asdict(c)
1551 self.assertEqual(d, {'f': T(a=6)})
1552 # Make sure that t has been copied, not used directly.
1553 self.assertIsNot(d['f'], t)
1554 self.assertEqual(d['f'].my_a(), 6)
1555
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001556 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001557 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001558 @dataclass
1559 class C:
1560 x: int
1561 y: int = 0
1562 c = C(1)
1563
1564 self.assertEqual(astuple(c), (1, 0))
1565 self.assertEqual(astuple(c), astuple(c))
1566 self.assertIsNot(astuple(c), astuple(c))
1567 c.y = 42
1568 self.assertEqual(astuple(c), (1, 42))
1569 self.assertIs(type(astuple(c)), tuple)
1570
1571 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001572 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001573 @dataclass
1574 class C:
1575 x: int
1576 y: int
1577 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1578 astuple(C)
1579 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1580 astuple(int)
1581
1582 def test_helper_astuple_copy_values(self):
1583 @dataclass
1584 class C:
1585 x: int
1586 y: List[int] = field(default_factory=list)
1587 initial = []
1588 c = C(1, initial)
1589 t = astuple(c)
1590 self.assertEqual(t[1], initial)
1591 self.assertIsNot(t[1], initial)
1592 c = C(1)
1593 t = astuple(c)
1594 t[1].append(1)
1595 self.assertEqual(c.y, [])
1596
1597 def test_helper_astuple_nested(self):
1598 @dataclass
1599 class UserId:
1600 token: int
1601 group: int
1602 @dataclass
1603 class User:
1604 name: str
1605 id: UserId
1606 u = User('Joe', UserId(123, 1))
1607 t = astuple(u)
1608 self.assertEqual(t, ('Joe', (123, 1)))
1609 self.assertIsNot(astuple(u), astuple(u))
1610 u.id.group = 2
1611 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1612
1613 def test_helper_astuple_builtin_containers(self):
1614 @dataclass
1615 class User:
1616 name: str
1617 id: int
1618 @dataclass
1619 class GroupList:
1620 id: int
1621 users: List[User]
1622 @dataclass
1623 class GroupTuple:
1624 id: int
1625 users: Tuple[User, ...]
1626 @dataclass
1627 class GroupDict:
1628 id: int
1629 users: Dict[str, User]
1630 a = User('Alice', 1)
1631 b = User('Bob', 2)
1632 gl = GroupList(0, [a, b])
1633 gt = GroupTuple(0, (a, b))
1634 gd = GroupDict(0, {'first': a, 'second': b})
1635 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1636 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1637 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1638
Windson yangbe372d72019-04-23 02:45:34 +08001639 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001640 @dataclass
1641 class Child:
1642 d: object
1643
1644 @dataclass
1645 class Parent:
1646 child: Child
1647
1648 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1649 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1650
1651 def test_helper_astuple_factory(self):
1652 @dataclass
1653 class C:
1654 x: int
1655 y: int
1656 NT = namedtuple('NT', 'x y')
1657 def nt(lst):
1658 return NT(*lst)
1659 c = C(1, 2)
1660 t = astuple(c, tuple_factory=nt)
1661 self.assertEqual(t, NT(1, 2))
1662 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1663 c.x = 42
1664 t = astuple(c, tuple_factory=nt)
1665 self.assertEqual(t, NT(42, 2))
1666 self.assertIs(type(t), NT)
1667
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001668 def test_helper_astuple_namedtuple(self):
1669 T = namedtuple('T', 'a b c')
1670 @dataclass
1671 class C:
1672 x: str
1673 y: T
1674 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1675
1676 t = astuple(c)
1677 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1678
1679 # Now, using a tuple_factory. list is convenient here.
1680 t = astuple(c, tuple_factory=list)
1681 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1682
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001683 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001684 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001685 }
1686
1687 # Create the class.
1688 cls = type('C', (), cls_dict)
1689
1690 # Make it a dataclass.
1691 cls1 = dataclass(cls)
1692
1693 self.assertEqual(cls1, cls)
1694 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1695
1696 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001697 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001698 'y': field(default=5),
1699 }
1700
1701 # Create the class.
1702 cls = type('C', (), cls_dict)
1703
1704 # Make it a dataclass.
1705 cls1 = dataclass(cls)
1706
1707 self.assertEqual(cls1, cls)
1708 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1709
1710 def test_init_in_order(self):
1711 @dataclass
1712 class C:
1713 a: int
1714 b: int = field()
1715 c: list = field(default_factory=list, init=False)
1716 d: list = field(default_factory=list)
1717 e: int = field(default=4, init=False)
1718 f: int = 4
1719
1720 calls = []
1721 def setattr(self, name, value):
1722 calls.append((name, value))
1723
1724 C.__setattr__ = setattr
1725 c = C(0, 1)
1726 self.assertEqual(('a', 0), calls[0])
1727 self.assertEqual(('b', 1), calls[1])
1728 self.assertEqual(('c', []), calls[2])
1729 self.assertEqual(('d', []), calls[3])
1730 self.assertNotIn(('e', 4), calls)
1731 self.assertEqual(('f', 4), calls[4])
1732
1733 def test_items_in_dicts(self):
1734 @dataclass
1735 class C:
1736 a: int
1737 b: list = field(default_factory=list, init=False)
1738 c: list = field(default_factory=list)
1739 d: int = field(default=4, init=False)
1740 e: int = 0
1741
1742 c = C(0)
1743 # Class dict
1744 self.assertNotIn('a', C.__dict__)
1745 self.assertNotIn('b', C.__dict__)
1746 self.assertNotIn('c', C.__dict__)
1747 self.assertIn('d', C.__dict__)
1748 self.assertEqual(C.d, 4)
1749 self.assertIn('e', C.__dict__)
1750 self.assertEqual(C.e, 0)
1751 # Instance dict
1752 self.assertIn('a', c.__dict__)
1753 self.assertEqual(c.a, 0)
1754 self.assertIn('b', c.__dict__)
1755 self.assertEqual(c.b, [])
1756 self.assertIn('c', c.__dict__)
1757 self.assertEqual(c.c, [])
1758 self.assertNotIn('d', c.__dict__)
1759 self.assertIn('e', c.__dict__)
1760 self.assertEqual(c.e, 0)
1761
1762 def test_alternate_classmethod_constructor(self):
1763 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001764 # alternate constructor. This is mostly an example to show
1765 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001766 @dataclass
1767 class C:
1768 x: int
1769 @classmethod
1770 def from_file(cls, filename):
1771 # In a real example, create a new instance
1772 # and populate 'x' from contents of a file.
1773 value_in_file = 20
1774 return cls(value_in_file)
1775
1776 self.assertEqual(C.from_file('filename').x, 20)
1777
1778 def test_field_metadata_default(self):
1779 # Make sure the default metadata is read-only and of
1780 # zero length.
1781 @dataclass
1782 class C:
1783 i: int
1784
1785 self.assertFalse(fields(C)[0].metadata)
1786 self.assertEqual(len(fields(C)[0].metadata), 0)
1787 with self.assertRaisesRegex(TypeError,
1788 'does not support item assignment'):
1789 fields(C)[0].metadata['test'] = 3
1790
1791 def test_field_metadata_mapping(self):
1792 # Make sure only a mapping can be passed as metadata
1793 # zero length.
1794 with self.assertRaises(TypeError):
1795 @dataclass
1796 class C:
1797 i: int = field(metadata=0)
1798
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001799 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001800 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001801 @dataclass
1802 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001803 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001804 self.assertFalse(fields(C)[0].metadata)
1805 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001806 # Update should work (see bpo-35960).
1807 d['foo'] = 1
1808 self.assertEqual(len(fields(C)[0].metadata), 1)
1809 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001810 with self.assertRaisesRegex(TypeError,
1811 'does not support item assignment'):
1812 fields(C)[0].metadata['test'] = 3
1813
1814 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001815 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001816 @dataclass
1817 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001818 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001819 self.assertEqual(len(fields(C)[0].metadata), 3)
1820 self.assertEqual(fields(C)[0].metadata['test'], 10)
1821 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1822 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001823 # Update should work.
1824 d['foo'] = 1
1825 self.assertEqual(len(fields(C)[0].metadata), 4)
1826 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001827 with self.assertRaises(KeyError):
1828 # Non-existent key.
1829 fields(C)[0].metadata['baz']
1830 with self.assertRaisesRegex(TypeError,
1831 'does not support item assignment'):
1832 fields(C)[0].metadata['test'] = 3
1833
1834 def test_field_metadata_custom_mapping(self):
1835 # Try a custom mapping.
1836 class SimpleNameSpace:
1837 def __init__(self, **kw):
1838 self.__dict__.update(kw)
1839
1840 def __getitem__(self, item):
1841 if item == 'xyzzy':
1842 return 'plugh'
1843 return getattr(self, item)
1844
1845 def __len__(self):
1846 return self.__dict__.__len__()
1847
1848 @dataclass
1849 class C:
1850 i: int = field(metadata=SimpleNameSpace(a=10))
1851
1852 self.assertEqual(len(fields(C)[0].metadata), 1)
1853 self.assertEqual(fields(C)[0].metadata['a'], 10)
1854 with self.assertRaises(AttributeError):
1855 fields(C)[0].metadata['b']
1856 # Make sure we're still talking to our custom mapping.
1857 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1858
1859 def test_generic_dataclasses(self):
1860 T = TypeVar('T')
1861
1862 @dataclass
1863 class LabeledBox(Generic[T]):
1864 content: T
1865 label: str = '<unknown>'
1866
1867 box = LabeledBox(42)
1868 self.assertEqual(box.content, 42)
1869 self.assertEqual(box.label, '<unknown>')
1870
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001871 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001872 Alias = List[LabeledBox[int]]
1873
1874 def test_generic_extending(self):
1875 S = TypeVar('S')
1876 T = TypeVar('T')
1877
1878 @dataclass
1879 class Base(Generic[T, S]):
1880 x: T
1881 y: S
1882
1883 @dataclass
1884 class DataDerived(Base[int, T]):
1885 new_field: str
1886 Alias = DataDerived[str]
1887 c = Alias(0, 'test1', 'test2')
1888 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1889
1890 class NonDataDerived(Base[int, T]):
1891 def new_method(self):
1892 return self.y
1893 Alias = NonDataDerived[float]
1894 c = Alias(10, 1.0)
1895 self.assertEqual(c.new_method(), 1.0)
1896
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001897 def test_generic_dynamic(self):
1898 T = TypeVar('T')
1899
1900 @dataclass
1901 class Parent(Generic[T]):
1902 x: T
1903 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1904 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1905 self.assertIs(Child[int](1, 2).z, None)
1906 self.assertEqual(Child[int](1, 2, 3).z, 3)
1907 self.assertEqual(Child[int](1, 2, 3).other, 42)
1908 # Check that type aliases work correctly.
1909 Alias = Child[T]
1910 self.assertEqual(Alias[int](1, 2).x, 1)
1911 # Check MRO resolution.
1912 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1913
Miss Islington (bot)e086bfe2021-10-09 12:50:45 -07001914 def test_dataclasses_pickleable(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001915 global P, Q, R
1916 @dataclass
1917 class P:
1918 x: int
1919 y: int = 0
1920 @dataclass
1921 class Q:
1922 x: int
1923 y: int = field(default=0, init=False)
1924 @dataclass
1925 class R:
1926 x: int
1927 y: List[int] = field(default_factory=list)
1928 q = Q(1)
1929 q.y = 2
1930 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1931 for sample in samples:
1932 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1933 with self.subTest(sample=sample, proto=proto):
1934 new_sample = pickle.loads(pickle.dumps(sample, proto))
1935 self.assertEqual(sample.x, new_sample.x)
1936 self.assertEqual(sample.y, new_sample.y)
1937 self.assertIsNot(sample, new_sample)
1938 new_sample.x = 42
1939 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1940 self.assertEqual(new_sample.x, another_new_sample.x)
1941 self.assertEqual(sample.y, another_new_sample.y)
1942
Batuhan Taskayac7437e22020-10-21 16:49:22 +03001943 def test_dataclasses_qualnames(self):
1944 @dataclass(order=True, unsafe_hash=True, frozen=True)
1945 class A:
1946 x: int
1947 y: int
1948
1949 self.assertEqual(A.__init__.__name__, "__init__")
1950 for function in (
1951 '__eq__',
1952 '__lt__',
1953 '__le__',
1954 '__gt__',
1955 '__ge__',
1956 '__hash__',
1957 '__init__',
1958 '__repr__',
1959 '__setattr__',
1960 '__delattr__',
1961 ):
1962 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
1963
1964 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
1965 A()
1966
Eric V. Smithea8fc522018-01-27 19:07:40 -05001967
Eric V. Smith56970b82018-03-22 16:28:48 -04001968class TestFieldNoAnnotation(unittest.TestCase):
1969 def test_field_without_annotation(self):
1970 with self.assertRaisesRegex(TypeError,
1971 "'f' is a field but has no type annotation"):
1972 @dataclass
1973 class C:
1974 f = field()
1975
1976 def test_field_without_annotation_but_annotation_in_base(self):
1977 @dataclass
1978 class B:
1979 f: int
1980
1981 with self.assertRaisesRegex(TypeError,
1982 "'f' is a field but has no type annotation"):
1983 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001984 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001985 @dataclass
1986 class C(B):
1987 f = field()
1988
1989 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1990 # Same test, but with the base class not a dataclass.
1991 class B:
1992 f: int
1993
1994 with self.assertRaisesRegex(TypeError,
1995 "'f' is a field but has no type annotation"):
1996 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001997 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001998 @dataclass
1999 class C(B):
2000 f = field()
2001
2002
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002003class TestDocString(unittest.TestCase):
2004 def assertDocStrEqual(self, a, b):
2005 # Because 3.6 and 3.7 differ in how inspect.signature work
2006 # (see bpo #32108), for the time being just compare them with
2007 # whitespace stripped.
2008 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2009
2010 def test_existing_docstring_not_overridden(self):
2011 @dataclass
2012 class C:
2013 """Lorem ipsum"""
2014 x: int
2015
2016 self.assertEqual(C.__doc__, "Lorem ipsum")
2017
2018 def test_docstring_no_fields(self):
2019 @dataclass
2020 class C:
2021 pass
2022
2023 self.assertDocStrEqual(C.__doc__, "C()")
2024
2025 def test_docstring_one_field(self):
2026 @dataclass
2027 class C:
2028 x: int
2029
2030 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2031
2032 def test_docstring_two_fields(self):
2033 @dataclass
2034 class C:
2035 x: int
2036 y: int
2037
2038 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2039
2040 def test_docstring_three_fields(self):
2041 @dataclass
2042 class C:
2043 x: int
2044 y: int
2045 z: str
2046
2047 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2048
2049 def test_docstring_one_field_with_default(self):
2050 @dataclass
2051 class C:
2052 x: int = 3
2053
2054 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2055
2056 def test_docstring_one_field_with_default_none(self):
2057 @dataclass
2058 class C:
2059 x: Union[int, type(None)] = None
2060
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002061 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002062
2063 def test_docstring_list_field(self):
2064 @dataclass
2065 class C:
2066 x: List[int]
2067
2068 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2069
2070 def test_docstring_list_field_with_default_factory(self):
2071 @dataclass
2072 class C:
2073 x: List[int] = field(default_factory=list)
2074
2075 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2076
2077 def test_docstring_deque_field(self):
2078 @dataclass
2079 class C:
2080 x: deque
2081
2082 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2083
2084 def test_docstring_deque_field_with_default_factory(self):
2085 @dataclass
2086 class C:
2087 x: deque = field(default_factory=deque)
2088
2089 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2090
2091
Eric V. Smithea8fc522018-01-27 19:07:40 -05002092class TestInit(unittest.TestCase):
2093 def test_base_has_init(self):
2094 class B:
2095 def __init__(self):
2096 self.z = 100
2097 pass
2098
2099 # Make sure that declaring this class doesn't raise an error.
2100 # The issue is that we can't override __init__ in our class,
2101 # but it should be okay to add __init__ to us if our base has
2102 # an __init__.
2103 @dataclass
2104 class C(B):
2105 x: int = 0
2106 c = C(10)
2107 self.assertEqual(c.x, 10)
2108 self.assertNotIn('z', vars(c))
2109
2110 # Make sure that if we don't add an init, the base __init__
2111 # gets called.
2112 @dataclass(init=False)
2113 class C(B):
2114 x: int = 10
2115 c = C()
2116 self.assertEqual(c.x, 10)
2117 self.assertEqual(c.z, 100)
2118
2119 def test_no_init(self):
2120 dataclass(init=False)
2121 class C:
2122 i: int = 0
2123 self.assertEqual(C().i, 0)
2124
2125 dataclass(init=False)
2126 class C:
2127 i: int = 2
2128 def __init__(self):
2129 self.i = 3
2130 self.assertEqual(C().i, 3)
2131
2132 def test_overwriting_init(self):
2133 # If the class has __init__, use it no matter the value of
2134 # init=.
2135
2136 @dataclass
2137 class C:
2138 x: int
2139 def __init__(self, x):
2140 self.x = 2 * x
2141 self.assertEqual(C(3).x, 6)
2142
2143 @dataclass(init=True)
2144 class C:
2145 x: int
2146 def __init__(self, x):
2147 self.x = 2 * x
2148 self.assertEqual(C(4).x, 8)
2149
2150 @dataclass(init=False)
2151 class C:
2152 x: int
2153 def __init__(self, x):
2154 self.x = 2 * x
2155 self.assertEqual(C(5).x, 10)
2156
Miss Islington (bot)79e9f5a2021-09-02 23:26:53 -07002157 def test_inherit_from_protocol(self):
2158 # Dataclasses inheriting from protocol should preserve their own `__init__`.
2159 # See bpo-45081.
2160
2161 class P(Protocol):
2162 a: int
2163
2164 @dataclass
2165 class C(P):
2166 a: int
2167
2168 self.assertEqual(C(5).a, 5)
2169
2170 @dataclass
2171 class D(P):
2172 def __init__(self, a):
2173 self.a = a * 2
2174
2175 self.assertEqual(D(5).a, 10)
2176
Eric V. Smithea8fc522018-01-27 19:07:40 -05002177
2178class TestRepr(unittest.TestCase):
2179 def test_repr(self):
2180 @dataclass
2181 class B:
2182 x: int
2183
2184 @dataclass
2185 class C(B):
2186 y: int = 10
2187
2188 o = C(4)
2189 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2190
2191 @dataclass
2192 class D(C):
2193 x: int = 20
2194 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2195
2196 @dataclass
2197 class C:
2198 @dataclass
2199 class D:
2200 i: int
2201 @dataclass
2202 class E:
2203 pass
2204 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2205 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2206
2207 def test_no_repr(self):
2208 # Test a class with no __repr__ and repr=False.
2209 @dataclass(repr=False)
2210 class C:
2211 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002212 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002213 repr(C(3)))
2214
2215 # Test a class with a __repr__ and repr=False.
2216 @dataclass(repr=False)
2217 class C:
2218 x: int
2219 def __repr__(self):
2220 return 'C-class'
2221 self.assertEqual(repr(C(3)), 'C-class')
2222
2223 def test_overwriting_repr(self):
2224 # If the class has __repr__, use it no matter the value of
2225 # repr=.
2226
2227 @dataclass
2228 class C:
2229 x: int
2230 def __repr__(self):
2231 return 'x'
2232 self.assertEqual(repr(C(0)), 'x')
2233
2234 @dataclass(repr=True)
2235 class C:
2236 x: int
2237 def __repr__(self):
2238 return 'x'
2239 self.assertEqual(repr(C(0)), 'x')
2240
2241 @dataclass(repr=False)
2242 class C:
2243 x: int
2244 def __repr__(self):
2245 return 'x'
2246 self.assertEqual(repr(C(0)), 'x')
2247
2248
Eric V. Smithea8fc522018-01-27 19:07:40 -05002249class TestEq(unittest.TestCase):
2250 def test_no_eq(self):
2251 # Test a class with no __eq__ and eq=False.
2252 @dataclass(eq=False)
2253 class C:
2254 x: int
2255 self.assertNotEqual(C(0), C(0))
2256 c = C(3)
2257 self.assertEqual(c, c)
2258
2259 # Test a class with an __eq__ and eq=False.
2260 @dataclass(eq=False)
2261 class C:
2262 x: int
2263 def __eq__(self, other):
2264 return other == 10
2265 self.assertEqual(C(3), 10)
2266
2267 def test_overwriting_eq(self):
2268 # If the class has __eq__, use it no matter the value of
2269 # eq=.
2270
2271 @dataclass
2272 class C:
2273 x: int
2274 def __eq__(self, other):
2275 return other == 3
2276 self.assertEqual(C(1), 3)
2277 self.assertNotEqual(C(1), 1)
2278
2279 @dataclass(eq=True)
2280 class C:
2281 x: int
2282 def __eq__(self, other):
2283 return other == 4
2284 self.assertEqual(C(1), 4)
2285 self.assertNotEqual(C(1), 1)
2286
2287 @dataclass(eq=False)
2288 class C:
2289 x: int
2290 def __eq__(self, other):
2291 return other == 5
2292 self.assertEqual(C(1), 5)
2293 self.assertNotEqual(C(1), 1)
2294
2295
2296class TestOrdering(unittest.TestCase):
2297 def test_functools_total_ordering(self):
2298 # Test that functools.total_ordering works with this class.
2299 @total_ordering
2300 @dataclass
2301 class C:
2302 x: int
2303 def __lt__(self, other):
2304 # Perform the test "backward", just to make
2305 # sure this is being called.
2306 return self.x >= other
2307
2308 self.assertLess(C(0), -1)
2309 self.assertLessEqual(C(0), -1)
2310 self.assertGreater(C(0), 1)
2311 self.assertGreaterEqual(C(0), 1)
2312
2313 def test_no_order(self):
2314 # Test that no ordering functions are added by default.
2315 @dataclass(order=False)
2316 class C:
2317 x: int
2318 # Make sure no order methods are added.
2319 self.assertNotIn('__le__', C.__dict__)
2320 self.assertNotIn('__lt__', C.__dict__)
2321 self.assertNotIn('__ge__', C.__dict__)
2322 self.assertNotIn('__gt__', C.__dict__)
2323
2324 # Test that __lt__ is still called
2325 @dataclass(order=False)
2326 class C:
2327 x: int
2328 def __lt__(self, other):
2329 return False
2330 # Make sure other methods aren't added.
2331 self.assertNotIn('__le__', C.__dict__)
2332 self.assertNotIn('__ge__', C.__dict__)
2333 self.assertNotIn('__gt__', C.__dict__)
2334
2335 def test_overwriting_order(self):
2336 with self.assertRaisesRegex(TypeError,
2337 'Cannot overwrite attribute __lt__'
2338 '.*using functools.total_ordering'):
2339 @dataclass(order=True)
2340 class C:
2341 x: int
2342 def __lt__(self):
2343 pass
2344
2345 with self.assertRaisesRegex(TypeError,
2346 'Cannot overwrite attribute __le__'
2347 '.*using functools.total_ordering'):
2348 @dataclass(order=True)
2349 class C:
2350 x: int
2351 def __le__(self):
2352 pass
2353
2354 with self.assertRaisesRegex(TypeError,
2355 'Cannot overwrite attribute __gt__'
2356 '.*using functools.total_ordering'):
2357 @dataclass(order=True)
2358 class C:
2359 x: int
2360 def __gt__(self):
2361 pass
2362
2363 with self.assertRaisesRegex(TypeError,
2364 'Cannot overwrite attribute __ge__'
2365 '.*using functools.total_ordering'):
2366 @dataclass(order=True)
2367 class C:
2368 x: int
2369 def __ge__(self):
2370 pass
2371
2372class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002373 def test_unsafe_hash(self):
2374 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002375 class C:
2376 x: int
2377 y: str
2378 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2379
Eric V. Smithea8fc522018-01-27 19:07:40 -05002380 def test_hash_rules(self):
2381 def non_bool(value):
2382 # Map to something else that's True, but not a bool.
2383 if value is None:
2384 return None
2385 if value:
2386 return (3,)
2387 return 0
2388
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002389 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2390 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2391 frozen=frozen):
2392 if result != 'exception':
2393 if with_hash:
2394 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2395 class C:
2396 def __hash__(self):
2397 return 0
2398 else:
2399 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2400 class C:
2401 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002402
2403 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002404 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002405 # __hash__ contains the function we generated.
2406 self.assertIn('__hash__', C.__dict__)
2407 self.assertIsNotNone(C.__dict__['__hash__'])
2408
Eric V. Smithea8fc522018-01-27 19:07:40 -05002409 elif result == '':
2410 # __hash__ is not present in our class.
2411 if not with_hash:
2412 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002413
Eric V. Smithea8fc522018-01-27 19:07:40 -05002414 elif result == 'none':
2415 # __hash__ is set to None.
2416 self.assertIn('__hash__', C.__dict__)
2417 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002418
2419 elif result == 'exception':
2420 # Creating the class should cause an exception.
2421 # This only happens with with_hash==True.
2422 assert(with_hash)
2423 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2424 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2425 class C:
2426 def __hash__(self):
2427 return 0
2428
Eric V. Smithea8fc522018-01-27 19:07:40 -05002429 else:
2430 assert False, f'unknown result {result!r}'
2431
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002432 # There are 8 cases of:
2433 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002434 # eq=True/False
2435 # frozen=True/False
2436 # And for each of these, a different result if
2437 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002438 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2439 (False, False, False, '', ''),
2440 (False, False, True, '', ''),
2441 (False, True, False, 'none', ''),
2442 (False, True, True, 'fn', ''),
2443 (True, False, False, 'fn', 'exception'),
2444 (True, False, True, 'fn', 'exception'),
2445 (True, True, False, 'fn', 'exception'),
2446 (True, True, True, 'fn', 'exception'),
2447 ], 1):
2448 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2449 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002450
2451 # Test non-bool truth values, too. This is just to
2452 # make sure the data-driven table in the decorator
2453 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002454 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2455 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002456
2457
2458 def test_eq_only(self):
2459 # If a class defines __eq__, __hash__ is automatically added
2460 # and set to None. This is normal Python behavior, not
2461 # related to dataclasses. Make sure we don't interfere with
2462 # that (see bpo=32546).
2463
2464 @dataclass
2465 class C:
2466 i: int
2467 def __eq__(self, other):
2468 return self.i == other.i
2469 self.assertEqual(C(1), C(1))
2470 self.assertNotEqual(C(1), C(4))
2471
2472 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002473 # unsafe_hash=True.
2474 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002475 class C:
2476 i: int
2477 def __eq__(self, other):
2478 return self.i == other.i
2479 self.assertEqual(C(1), C(1.0))
2480 self.assertEqual(hash(C(1)), hash(C(1.0)))
2481
2482 # And check that the classes __eq__ is being used, despite
2483 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002484 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002485 class C:
2486 i: int
2487 def __eq__(self, other):
2488 return self.i == 3 and self.i == other.i
2489 self.assertEqual(C(3), C(3))
2490 self.assertNotEqual(C(1), C(1))
2491 self.assertEqual(hash(C(1)), hash(C(1.0)))
2492
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002493 def test_0_field_hash(self):
2494 @dataclass(frozen=True)
2495 class C:
2496 pass
2497 self.assertEqual(hash(C()), hash(()))
2498
2499 @dataclass(unsafe_hash=True)
2500 class C:
2501 pass
2502 self.assertEqual(hash(C()), hash(()))
2503
2504 def test_1_field_hash(self):
2505 @dataclass(frozen=True)
2506 class C:
2507 x: int
2508 self.assertEqual(hash(C(4)), hash((4,)))
2509 self.assertEqual(hash(C(42)), hash((42,)))
2510
2511 @dataclass(unsafe_hash=True)
2512 class C:
2513 x: int
2514 self.assertEqual(hash(C(4)), hash((4,)))
2515 self.assertEqual(hash(C(42)), hash((42,)))
2516
Eric V. Smith718070d2018-02-23 13:01:31 -05002517 def test_hash_no_args(self):
2518 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002519 # make sure that if the @dataclass parameter name is changed
2520 # or the non-default hashing behavior changes, the default
2521 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002522
2523 class Base:
2524 def __hash__(self):
2525 return 301
2526
2527 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002528 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002529 for frozen, eq, base, expected in [
2530 (None, None, object, 'unhashable'),
2531 (None, None, Base, 'unhashable'),
2532 (None, False, object, 'object'),
2533 (None, False, Base, 'base'),
2534 (None, True, object, 'unhashable'),
2535 (None, True, Base, 'unhashable'),
2536 (False, None, object, 'unhashable'),
2537 (False, None, Base, 'unhashable'),
2538 (False, False, object, 'object'),
2539 (False, False, Base, 'base'),
2540 (False, True, object, 'unhashable'),
2541 (False, True, Base, 'unhashable'),
2542 (True, None, object, 'tuple'),
2543 (True, None, Base, 'tuple'),
2544 (True, False, object, 'object'),
2545 (True, False, Base, 'base'),
2546 (True, True, object, 'tuple'),
2547 (True, True, Base, 'tuple'),
2548 ]:
2549
2550 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2551 # First, create the class.
2552 if frozen is None and eq is None:
2553 @dataclass
2554 class C(base):
2555 i: int
2556 elif frozen is None:
2557 @dataclass(eq=eq)
2558 class C(base):
2559 i: int
2560 elif eq is None:
2561 @dataclass(frozen=frozen)
2562 class C(base):
2563 i: int
2564 else:
2565 @dataclass(frozen=frozen, eq=eq)
2566 class C(base):
2567 i: int
2568
2569 # Now, make sure it hashes as expected.
2570 if expected == 'unhashable':
2571 c = C(10)
2572 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2573 hash(c)
2574
2575 elif expected == 'base':
2576 self.assertEqual(hash(C(10)), 301)
2577
2578 elif expected == 'object':
2579 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002580 # hash isn't based on id(), so calling hash()
2581 # won't tell us much. So, just check the
2582 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002583 self.assertIs(C.__hash__, object.__hash__)
2584
2585 elif expected == 'tuple':
2586 self.assertEqual(hash(C(42)), hash((42,)))
2587
2588 else:
2589 assert False, f'unknown value for expected={expected!r}'
2590
Eric V. Smithea8fc522018-01-27 19:07:40 -05002591
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002592class TestFrozen(unittest.TestCase):
2593 def test_frozen(self):
2594 @dataclass(frozen=True)
2595 class C:
2596 i: int
2597
2598 c = C(10)
2599 self.assertEqual(c.i, 10)
2600 with self.assertRaises(FrozenInstanceError):
2601 c.i = 5
2602 self.assertEqual(c.i, 10)
2603
2604 def test_inherit(self):
2605 @dataclass(frozen=True)
2606 class C:
2607 i: int
2608
2609 @dataclass(frozen=True)
2610 class D(C):
2611 j: int
2612
2613 d = D(0, 10)
2614 with self.assertRaises(FrozenInstanceError):
2615 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002616 with self.assertRaises(FrozenInstanceError):
2617 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002618 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002619 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002620
Iurii Kemaev376ffc62021-04-06 06:14:01 +01002621 def test_inherit_nonfrozen_from_empty_frozen(self):
2622 @dataclass(frozen=True)
2623 class C:
2624 pass
2625
2626 with self.assertRaisesRegex(TypeError,
2627 'cannot inherit non-frozen dataclass from a frozen one'):
2628 @dataclass
2629 class D(C):
2630 j: int
2631
2632 def test_inherit_nonfrozen_from_empty(self):
2633 @dataclass
2634 class C:
2635 pass
2636
2637 @dataclass
2638 class D(C):
2639 j: int
2640
2641 d = D(3)
2642 self.assertEqual(d.j, 3)
2643 self.assertIsInstance(d, C)
2644
Eric V. Smithf199bc62018-03-18 20:40:34 -04002645 # Test both ways: with an intermediate normal (non-dataclass)
2646 # class and without an intermediate class.
2647 def test_inherit_nonfrozen_from_frozen(self):
2648 for intermediate_class in [True, False]:
2649 with self.subTest(intermediate_class=intermediate_class):
2650 @dataclass(frozen=True)
2651 class C:
2652 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002653
Eric V. Smithf199bc62018-03-18 20:40:34 -04002654 if intermediate_class:
2655 class I(C): pass
2656 else:
2657 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002658
Eric V. Smithf199bc62018-03-18 20:40:34 -04002659 with self.assertRaisesRegex(TypeError,
2660 'cannot inherit non-frozen dataclass from a frozen one'):
2661 @dataclass
2662 class D(I):
2663 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002664
Eric V. Smithf199bc62018-03-18 20:40:34 -04002665 def test_inherit_frozen_from_nonfrozen(self):
2666 for intermediate_class in [True, False]:
2667 with self.subTest(intermediate_class=intermediate_class):
2668 @dataclass
2669 class C:
2670 i: int
2671
2672 if intermediate_class:
2673 class I(C): pass
2674 else:
2675 I = C
2676
2677 with self.assertRaisesRegex(TypeError,
2678 'cannot inherit frozen dataclass from a non-frozen one'):
2679 @dataclass(frozen=True)
2680 class D(I):
2681 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002682
2683 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002684 for intermediate_class in [True, False]:
2685 with self.subTest(intermediate_class=intermediate_class):
2686 class C:
2687 pass
2688
2689 if intermediate_class:
2690 class I(C): pass
2691 else:
2692 I = C
2693
2694 @dataclass(frozen=True)
2695 class D(I):
2696 i: int
2697
2698 d = D(10)
2699 with self.assertRaises(FrozenInstanceError):
2700 d.i = 5
2701
2702 def test_non_frozen_normal_derived(self):
2703 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002704
2705 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002706 class D:
2707 x: int
2708 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002709
Eric V. Smithf199bc62018-03-18 20:40:34 -04002710 class S(D):
2711 pass
2712
2713 s = S(3)
2714 self.assertEqual(s.x, 3)
2715 self.assertEqual(s.y, 10)
2716 s.cached = True
2717
2718 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002719 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002720 s.x = 5
2721 with self.assertRaises(FrozenInstanceError):
2722 s.y = 5
2723 self.assertEqual(s.x, 3)
2724 self.assertEqual(s.y, 10)
2725 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002726
Eric V. Smith74940912018-04-05 06:50:18 -04002727 def test_overwriting_frozen(self):
2728 # frozen uses __setattr__ and __delattr__.
2729 with self.assertRaisesRegex(TypeError,
2730 'Cannot overwrite attribute __setattr__'):
2731 @dataclass(frozen=True)
2732 class C:
2733 x: int
2734 def __setattr__(self):
2735 pass
2736
2737 with self.assertRaisesRegex(TypeError,
2738 'Cannot overwrite attribute __delattr__'):
2739 @dataclass(frozen=True)
2740 class C:
2741 x: int
2742 def __delattr__(self):
2743 pass
2744
2745 @dataclass(frozen=False)
2746 class C:
2747 x: int
2748 def __setattr__(self, name, value):
2749 self.__dict__['x'] = value * 2
2750 self.assertEqual(C(10).x, 20)
2751
2752 def test_frozen_hash(self):
2753 @dataclass(frozen=True)
2754 class C:
2755 x: Any
2756
2757 # If x is immutable, we can compute the hash. No exception is
2758 # raised.
2759 hash(C(3))
2760
2761 # If x is mutable, computing the hash is an error.
2762 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2763 hash(C({}))
2764
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002765
Eric V. Smith7389fd92018-03-19 21:07:51 -04002766class TestSlots(unittest.TestCase):
2767 def test_simple(self):
2768 @dataclass
2769 class C:
2770 __slots__ = ('x',)
2771 x: Any
2772
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002773 # There was a bug where a variable in a slot was assumed to
2774 # also have a default value (of type
2775 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002776 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002777 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002778 C()
2779
2780 # We can create an instance, and assign to x.
2781 c = C(10)
2782 self.assertEqual(c.x, 10)
2783 c.x = 5
2784 self.assertEqual(c.x, 5)
2785
2786 # We can't assign to anything else.
2787 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2788 c.y = 5
2789
2790 def test_derived_added_field(self):
2791 # See bpo-33100.
2792 @dataclass
2793 class Base:
2794 __slots__ = ('x',)
2795 x: Any
2796
2797 @dataclass
2798 class Derived(Base):
2799 x: int
2800 y: int
2801
2802 d = Derived(1, 2)
2803 self.assertEqual((d.x, d.y), (1, 2))
2804
2805 # We can add a new field to the derived instance.
2806 d.z = 10
2807
Yurii Karabasc2419912021-05-01 05:14:30 +03002808 def test_generated_slots(self):
2809 @dataclass(slots=True)
2810 class C:
2811 x: int
2812 y: int
2813
2814 c = C(1, 2)
2815 self.assertEqual((c.x, c.y), (1, 2))
2816
2817 c.x = 3
2818 c.y = 4
2819 self.assertEqual((c.x, c.y), (3, 4))
2820
2821 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'z'"):
2822 c.z = 5
2823
2824 def test_add_slots_when_slots_exists(self):
2825 with self.assertRaisesRegex(TypeError, '^C already specifies __slots__$'):
2826 @dataclass(slots=True)
2827 class C:
2828 __slots__ = ('x',)
2829 x: int
2830
2831 def test_generated_slots_value(self):
2832 @dataclass(slots=True)
2833 class Base:
2834 x: int
2835
2836 self.assertEqual(Base.__slots__, ('x',))
2837
2838 @dataclass(slots=True)
2839 class Delivered(Base):
2840 y: int
2841
2842 self.assertEqual(Delivered.__slots__, ('x', 'y'))
2843
2844 @dataclass
2845 class AnotherDelivered(Base):
2846 z: int
2847
2848 self.assertTrue('__slots__' not in AnotherDelivered.__dict__)
2849
2850 def test_returns_new_class(self):
2851 class A:
2852 x: int
2853
2854 B = dataclass(A, slots=True)
2855 self.assertIsNot(A, B)
2856
2857 self.assertFalse(hasattr(A, "__slots__"))
2858 self.assertTrue(hasattr(B, "__slots__"))
2859
Eric V. Smith823fbf42021-05-01 13:27:30 -04002860 # Can't be local to test_frozen_pickle.
2861 @dataclass(frozen=True, slots=True)
2862 class FrozenSlotsClass:
2863 foo: str
2864 bar: int
2865
Miss Islington (bot)36971fd2021-10-24 06:29:37 -07002866 @dataclass(frozen=True)
2867 class FrozenWithoutSlotsClass:
2868 foo: str
2869 bar: int
2870
Eric V. Smith823fbf42021-05-01 13:27:30 -04002871 def test_frozen_pickle(self):
2872 # bpo-43999
2873
Miss Islington (bot)36971fd2021-10-24 06:29:37 -07002874 self.assertEqual(self.FrozenSlotsClass.__slots__, ("foo", "bar"))
2875 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
2876 with self.subTest(proto=proto):
2877 obj = self.FrozenSlotsClass("a", 1)
2878 p = pickle.loads(pickle.dumps(obj, protocol=proto))
2879 self.assertIsNot(obj, p)
2880 self.assertEqual(obj, p)
Eric V. Smith823fbf42021-05-01 13:27:30 -04002881
Miss Islington (bot)36971fd2021-10-24 06:29:37 -07002882 obj = self.FrozenWithoutSlotsClass("a", 1)
2883 p = pickle.loads(pickle.dumps(obj, protocol=proto))
2884 self.assertIsNot(obj, p)
2885 self.assertEqual(obj, p)
Yurii Karabasc2419912021-05-01 05:14:30 +03002886
Miss Islington (bot)10343bd2021-11-22 05:47:41 -08002887 def test_slots_with_default_no_init(self):
2888 # Originally reported in bpo-44649.
2889 @dataclass(slots=True)
2890 class A:
2891 a: str
2892 b: str = field(default='b', init=False)
2893
2894 obj = A("a")
2895 self.assertEqual(obj.a, 'a')
2896 self.assertEqual(obj.b, 'b')
2897
2898 def test_slots_with_default_factory_no_init(self):
2899 # Originally reported in bpo-44649.
2900 @dataclass(slots=True)
2901 class A:
2902 a: str
2903 b: str = field(default_factory=lambda:'b', init=False)
2904
2905 obj = A("a")
2906 self.assertEqual(obj.a, 'a')
2907 self.assertEqual(obj.b, 'b')
2908
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002909class TestDescriptors(unittest.TestCase):
2910 def test_set_name(self):
2911 # See bpo-33141.
2912
2913 # Create a descriptor.
2914 class D:
2915 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002916 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002917 def __get__(self, instance, owner):
2918 if instance is not None:
2919 return 1
2920 return self
2921
2922 # This is the case of just normal descriptor behavior, no
2923 # dataclass code is involved in initializing the descriptor.
2924 @dataclass
2925 class C:
2926 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002927 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002928
2929 # Now test with a default value and init=False, which is the
2930 # only time this is really meaningful. If not using
2931 # init=False, then the descriptor will be overwritten, anyway.
2932 @dataclass
2933 class C:
2934 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002935 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002936 self.assertEqual(C().c, 1)
2937
2938 def test_non_descriptor(self):
2939 # PEP 487 says __set_name__ should work on non-descriptors.
2940 # Create a descriptor.
2941
2942 class D:
2943 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002944 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002945
2946 @dataclass
2947 class C:
2948 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002949 self.assertEqual(C.c.name, 'cx')
2950
2951 def test_lookup_on_instance(self):
2952 # See bpo-33175.
2953 class D:
2954 pass
2955
2956 d = D()
2957 # Create an attribute on the instance, not type.
2958 d.__set_name__ = Mock()
2959
2960 # Make sure d.__set_name__ is not called.
2961 @dataclass
2962 class C:
2963 i: int=field(default=d, init=False)
2964
2965 self.assertEqual(d.__set_name__.call_count, 0)
2966
2967 def test_lookup_on_class(self):
2968 # See bpo-33175.
2969 class D:
2970 pass
2971 D.__set_name__ = Mock()
2972
2973 # Make sure D.__set_name__ is called.
2974 @dataclass
2975 class C:
2976 i: int=field(default=D(), init=False)
2977
2978 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002979
Eric V. Smith7389fd92018-03-19 21:07:51 -04002980
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002981class TestStringAnnotations(unittest.TestCase):
2982 def test_classvar(self):
Pablo Galindob0544ba2021-04-21 12:41:19 +01002983 # Some expressions recognized as ClassVar really aren't. But
2984 # if you're using string annotations, it's not an exact
2985 # science.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002986 # These tests assume that both "import typing" and "from
2987 # typing import *" have been run in this file.
2988 for typestr in ('ClassVar[int]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002989 'ClassVar [int]',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002990 ' ClassVar [int]',
2991 'ClassVar',
2992 ' ClassVar ',
2993 'typing.ClassVar[int]',
2994 'typing.ClassVar[str]',
2995 ' typing.ClassVar[str]',
2996 'typing .ClassVar[str]',
2997 'typing. ClassVar[str]',
2998 'typing.ClassVar [str]',
2999 'typing.ClassVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01003000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003001 # Not syntactically valid, but these will
Pablo Galindob0544ba2021-04-21 12:41:19 +01003002 # be treated as ClassVars.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003003 'typing.ClassVar.[int]',
3004 'typing.ClassVar+',
3005 ):
3006 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003007 @dataclass
3008 class C:
3009 x: typestr
3010
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003011 # x is a ClassVar, so C() takes no args.
3012 C()
3013
3014 # And it won't appear in the class's dict because it doesn't
3015 # have a default.
3016 self.assertNotIn('x', C.__dict__)
3017
3018 def test_isnt_classvar(self):
3019 for typestr in ('CV',
3020 't.ClassVar',
3021 't.ClassVar[int]',
3022 'typing..ClassVar[int]',
3023 'Classvar',
3024 'Classvar[int]',
3025 'typing.ClassVarx[int]',
3026 'typong.ClassVar[int]',
3027 'dataclasses.ClassVar[int]',
3028 'typingxClassVar[str]',
3029 ):
3030 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003031 @dataclass
3032 class C:
3033 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003034
3035 # x is not a ClassVar, so C() takes one arg.
3036 self.assertEqual(C(10).x, 10)
3037
3038 def test_initvar(self):
3039 # These tests assume that both "import dataclasses" and "from
3040 # dataclasses import *" have been run in this file.
3041 for typestr in ('InitVar[int]',
3042 'InitVar [int]'
3043 ' InitVar [int]',
3044 'InitVar',
3045 ' InitVar ',
3046 'dataclasses.InitVar[int]',
3047 'dataclasses.InitVar[str]',
3048 ' dataclasses.InitVar[str]',
3049 'dataclasses .InitVar[str]',
3050 'dataclasses. InitVar[str]',
3051 'dataclasses.InitVar [str]',
3052 'dataclasses.InitVar [ str]',
Pablo Galindob0544ba2021-04-21 12:41:19 +01003053
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003054 # Not syntactically valid, but these will
3055 # be treated as InitVars.
3056 'dataclasses.InitVar.[int]',
3057 'dataclasses.InitVar+',
3058 ):
3059 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003060 @dataclass
3061 class C:
3062 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003063
3064 # x is an InitVar, so doesn't create a member.
3065 with self.assertRaisesRegex(AttributeError,
3066 "object has no attribute 'x'"):
3067 C(1).x
3068
3069 def test_isnt_initvar(self):
3070 for typestr in ('IV',
3071 'dc.InitVar',
3072 'xdataclasses.xInitVar',
3073 'typing.xInitVar[int]',
3074 ):
3075 with self.subTest(typestr=typestr):
Pablo Galindob0544ba2021-04-21 12:41:19 +01003076 @dataclass
3077 class C:
3078 x: typestr
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003079
3080 # x is not an InitVar, so there will be a member x.
3081 self.assertEqual(C(10).x, 10)
3082
3083 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03003084 from test import dataclass_module_1
Pablo Galindob0544ba2021-04-21 12:41:19 +01003085 from test import dataclass_module_1_str
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03003086 from test import dataclass_module_2
Pablo Galindob0544ba2021-04-21 12:41:19 +01003087 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003088
Pablo Galindob0544ba2021-04-21 12:41:19 +01003089 for m in (dataclass_module_1, dataclass_module_1_str,
3090 dataclass_module_2, dataclass_module_2_str,
3091 ):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003092 with self.subTest(m=m):
3093 # There's a difference in how the ClassVars are
3094 # interpreted when using string annotations or
3095 # not. See the imported modules for details.
Pablo Galindob0544ba2021-04-21 12:41:19 +01003096 if m.USING_STRINGS:
3097 c = m.CV(10)
3098 else:
3099 c = m.CV()
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003100 self.assertEqual(c.cv0, 20)
3101
3102
3103 # There's a difference in how the InitVars are
3104 # interpreted when using string annotations or
3105 # not. See the imported modules for details.
3106 c = m.IV(0, 1, 2, 3, 4)
3107
3108 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
3109 with self.subTest(field_name=field_name):
3110 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
3111 # Since field_name is an InitVar, it's
3112 # not an instance field.
3113 getattr(c, field_name)
3114
Pablo Galindob0544ba2021-04-21 12:41:19 +01003115 if m.USING_STRINGS:
3116 # iv4 is interpreted as a normal field.
3117 self.assertIn('not_iv4', c.__dict__)
3118 self.assertEqual(c.not_iv4, 4)
3119 else:
3120 # iv4 is interpreted as an InitVar, so it
3121 # won't exist on the instance.
3122 self.assertNotIn('not_iv4', c.__dict__)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003123
Yury Selivanovd219cc42019-12-09 09:54:20 -05003124 def test_text_annotations(self):
3125 from test import dataclass_textanno
3126
3127 self.assertEqual(
3128 get_type_hints(dataclass_textanno.Bar),
3129 {'foo': dataclass_textanno.Foo})
3130 self.assertEqual(
3131 get_type_hints(dataclass_textanno.Bar.__init__),
3132 {'foo': dataclass_textanno.Foo,
3133 'return': type(None)})
3134
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04003135
Eric V. Smith4e812962018-05-16 11:31:29 -04003136class TestMakeDataclass(unittest.TestCase):
3137 def test_simple(self):
3138 C = make_dataclass('C',
3139 [('x', int),
3140 ('y', int, field(default=5))],
3141 namespace={'add_one': lambda self: self.x + 1})
3142 c = C(10)
3143 self.assertEqual((c.x, c.y), (10, 5))
3144 self.assertEqual(c.add_one(), 11)
3145
3146
3147 def test_no_mutate_namespace(self):
3148 # Make sure a provided namespace isn't mutated.
3149 ns = {}
3150 C = make_dataclass('C',
3151 [('x', int),
3152 ('y', int, field(default=5))],
3153 namespace=ns)
3154 self.assertEqual(ns, {})
3155
3156 def test_base(self):
3157 class Base1:
3158 pass
3159 class Base2:
3160 pass
3161 C = make_dataclass('C',
3162 [('x', int)],
3163 bases=(Base1, Base2))
3164 c = C(2)
3165 self.assertIsInstance(c, C)
3166 self.assertIsInstance(c, Base1)
3167 self.assertIsInstance(c, Base2)
3168
3169 def test_base_dataclass(self):
3170 @dataclass
3171 class Base1:
3172 x: int
3173 class Base2:
3174 pass
3175 C = make_dataclass('C',
3176 [('y', int)],
3177 bases=(Base1, Base2))
3178 with self.assertRaisesRegex(TypeError, 'required positional'):
3179 c = C(2)
3180 c = C(1, 2)
3181 self.assertIsInstance(c, C)
3182 self.assertIsInstance(c, Base1)
3183 self.assertIsInstance(c, Base2)
3184
3185 self.assertEqual((c.x, c.y), (1, 2))
3186
3187 def test_init_var(self):
3188 def post_init(self, y):
3189 self.x *= y
3190
3191 C = make_dataclass('C',
3192 [('x', int),
3193 ('y', InitVar[int]),
3194 ],
3195 namespace={'__post_init__': post_init},
3196 )
3197 c = C(2, 3)
3198 self.assertEqual(vars(c), {'x': 6})
3199 self.assertEqual(len(fields(c)), 1)
3200
3201 def test_class_var(self):
3202 C = make_dataclass('C',
3203 [('x', int),
3204 ('y', ClassVar[int], 10),
3205 ('z', ClassVar[int], field(default=20)),
3206 ])
3207 c = C(1)
3208 self.assertEqual(vars(c), {'x': 1})
3209 self.assertEqual(len(fields(c)), 1)
3210 self.assertEqual(C.y, 10)
3211 self.assertEqual(C.z, 20)
3212
3213 def test_other_params(self):
3214 C = make_dataclass('C',
3215 [('x', int),
3216 ('y', ClassVar[int], 10),
3217 ('z', ClassVar[int], field(default=20)),
3218 ],
3219 init=False)
3220 # Make sure we have a repr, but no init.
3221 self.assertNotIn('__init__', vars(C))
3222 self.assertIn('__repr__', vars(C))
3223
3224 # Make sure random other params don't work.
3225 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3226 C = make_dataclass('C',
3227 [],
3228 xxinit=False)
3229
3230 def test_no_types(self):
3231 C = make_dataclass('Point', ['x', 'y', 'z'])
3232 c = C(1, 2, 3)
3233 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3234 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3235 'y': 'typing.Any',
3236 'z': 'typing.Any'})
3237
3238 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3239 c = C(1, 2, 3)
3240 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3241 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3242 'y': int,
3243 'z': 'typing.Any'})
3244
3245 def test_invalid_type_specification(self):
3246 for bad_field in [(),
3247 (1, 2, 3, 4),
3248 ]:
3249 with self.subTest(bad_field=bad_field):
3250 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3251 make_dataclass('C', ['a', bad_field])
3252
3253 # And test for things with no len().
3254 for bad_field in [float,
3255 lambda x:x,
3256 ]:
3257 with self.subTest(bad_field=bad_field):
3258 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3259 make_dataclass('C', ['a', bad_field])
3260
3261 def test_duplicate_field_names(self):
3262 for field in ['a', 'ab']:
3263 with self.subTest(field=field):
3264 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3265 make_dataclass('C', [field, 'a', field])
3266
3267 def test_keyword_field_names(self):
3268 for field in ['for', 'async', 'await', 'as']:
3269 with self.subTest(field=field):
3270 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3271 make_dataclass('C', ['a', field])
3272 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3273 make_dataclass('C', [field])
3274 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3275 make_dataclass('C', [field, 'a'])
3276
3277 def test_non_identifier_field_names(self):
3278 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3279 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003280 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003281 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003282 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003283 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003284 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003285 make_dataclass('C', [field, 'a'])
3286
3287 def test_underscore_field_names(self):
3288 # Unlike namedtuple, it's okay if dataclass field names have
3289 # an underscore.
3290 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3291
3292 def test_funny_class_names_names(self):
3293 # No reason to prevent weird class names, since
3294 # types.new_class allows them.
3295 for classname in ['()', 'x,y', '*', '2@3', '']:
3296 with self.subTest(classname=classname):
3297 C = make_dataclass(classname, ['a', 'b'])
3298 self.assertEqual(C.__name__, classname)
3299
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003300class TestReplace(unittest.TestCase):
3301 def test(self):
3302 @dataclass(frozen=True)
3303 class C:
3304 x: int
3305 y: int
3306
3307 c = C(1, 2)
3308 c1 = replace(c, x=3)
3309 self.assertEqual(c1.x, 3)
3310 self.assertEqual(c1.y, 2)
3311
3312 def test_frozen(self):
3313 @dataclass(frozen=True)
3314 class C:
3315 x: int
3316 y: int
3317 z: int = field(init=False, default=10)
3318 t: int = field(init=False, default=100)
3319
3320 c = C(1, 2)
3321 c1 = replace(c, x=3)
3322 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3323 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3324
3325
3326 with self.assertRaisesRegex(ValueError, 'init=False'):
3327 replace(c, x=3, z=20, t=50)
3328 with self.assertRaisesRegex(ValueError, 'init=False'):
3329 replace(c, z=20)
3330 replace(c, x=3, z=20, t=50)
3331
3332 # Make sure the result is still frozen.
3333 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3334 c1.x = 3
3335
3336 # Make sure we can't replace an attribute that doesn't exist,
3337 # if we're also replacing one that does exist. Test this
3338 # here, because setting attributes on frozen instances is
3339 # handled slightly differently from non-frozen ones.
3340 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3341 "keyword argument 'a'"):
3342 c1 = replace(c, x=20, a=5)
3343
3344 def test_invalid_field_name(self):
3345 @dataclass(frozen=True)
3346 class C:
3347 x: int
3348 y: int
3349
3350 c = C(1, 2)
3351 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3352 "keyword argument 'z'"):
3353 c1 = replace(c, z=3)
3354
3355 def test_invalid_object(self):
3356 @dataclass(frozen=True)
3357 class C:
3358 x: int
3359 y: int
3360
3361 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3362 replace(C, x=3)
3363
3364 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3365 replace(0, x=3)
3366
3367 def test_no_init(self):
3368 @dataclass
3369 class C:
3370 x: int
3371 y: int = field(init=False, default=10)
3372
3373 c = C(1)
3374 c.y = 20
3375
3376 # Make sure y gets the default value.
3377 c1 = replace(c, x=5)
3378 self.assertEqual((c1.x, c1.y), (5, 10))
3379
3380 # Trying to replace y is an error.
3381 with self.assertRaisesRegex(ValueError, 'init=False'):
3382 replace(c, x=2, y=30)
3383
3384 with self.assertRaisesRegex(ValueError, 'init=False'):
3385 replace(c, y=30)
3386
3387 def test_classvar(self):
3388 @dataclass
3389 class C:
3390 x: int
3391 y: ClassVar[int] = 1000
3392
3393 c = C(1)
3394 d = C(2)
3395
3396 self.assertIs(c.y, d.y)
3397 self.assertEqual(c.y, 1000)
3398
3399 # Trying to replace y is an error: can't replace ClassVars.
3400 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3401 "unexpected keyword argument 'y'"):
3402 replace(c, y=30)
3403
3404 replace(c, x=5)
3405
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003406 def test_initvar_is_specified(self):
3407 @dataclass
3408 class C:
3409 x: int
3410 y: InitVar[int]
3411
3412 def __post_init__(self, y):
3413 self.x *= y
3414
3415 c = C(1, 10)
3416 self.assertEqual(c.x, 10)
3417 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3418 "specified with replace()"):
3419 replace(c, x=3)
3420 c = replace(c, x=3, y=5)
3421 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303422
Zackery Spytz75220672021-04-05 13:41:01 -06003423 def test_initvar_with_default_value(self):
3424 @dataclass
3425 class C:
3426 x: int
3427 y: InitVar[int] = None
3428 z: InitVar[int] = 42
3429
3430 def __post_init__(self, y, z):
3431 if y is not None:
3432 self.x += y
3433 if z is not None:
3434 self.x += z
3435
3436 c = C(x=1, y=10, z=1)
3437 self.assertEqual(replace(c), C(x=12))
3438 self.assertEqual(replace(c, y=4), C(x=12, y=4, z=42))
3439 self.assertEqual(replace(c, y=4, z=1), C(x=12, y=4, z=1))
3440
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303441 def test_recursive_repr(self):
3442 @dataclass
3443 class C:
3444 f: "C"
3445
3446 c = C(None)
3447 c.f = c
3448 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3449
3450 def test_recursive_repr_two_attrs(self):
3451 @dataclass
3452 class C:
3453 f: "C"
3454 g: "C"
3455
3456 c = C(None, None)
3457 c.f = c
3458 c.g = c
3459 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3460 ".<locals>.C(f=..., g=...)")
3461
3462 def test_recursive_repr_indirection(self):
3463 @dataclass
3464 class C:
3465 f: "D"
3466
3467 @dataclass
3468 class D:
3469 f: "C"
3470
3471 c = C(None)
3472 d = D(None)
3473 c.f = d
3474 d.f = c
3475 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3476 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3477 ".<locals>.D(f=...))")
3478
3479 def test_recursive_repr_indirection_two(self):
3480 @dataclass
3481 class C:
3482 f: "D"
3483
3484 @dataclass
3485 class D:
3486 f: "E"
3487
3488 @dataclass
3489 class E:
3490 f: "C"
3491
3492 c = C(None)
3493 d = D(None)
3494 e = E(None)
3495 c.f = d
3496 d.f = e
3497 e.f = c
3498 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3499 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3500 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3501 ".<locals>.E(f=...)))")
3502
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303503 def test_recursive_repr_misc_attrs(self):
3504 @dataclass
3505 class C:
3506 f: "C"
3507 g: int
3508
3509 c = C(None, 1)
3510 c.f = c
3511 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3512 ".<locals>.C(f=..., g=1)")
3513
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003514 ## def test_initvar(self):
3515 ## @dataclass
3516 ## class C:
3517 ## x: int
3518 ## y: InitVar[int]
3519
3520 ## c = C(1, 10)
3521 ## d = C(2, 20)
3522
3523 ## # In our case, replacing an InitVar is a no-op
3524 ## self.assertEqual(c, replace(c, y=5))
3525
3526 ## replace(c, x=5)
3527
Ben Avrahamibef7d292020-10-06 20:40:50 +03003528class TestAbstract(unittest.TestCase):
3529 def test_abc_implementation(self):
3530 class Ordered(abc.ABC):
3531 @abc.abstractmethod
3532 def __lt__(self, other):
3533 pass
3534
3535 @abc.abstractmethod
3536 def __le__(self, other):
3537 pass
3538
3539 @dataclass(order=True)
3540 class Date(Ordered):
3541 year: int
3542 month: 'Month'
3543 day: 'int'
3544
3545 self.assertFalse(inspect.isabstract(Date))
3546 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3547
3548 def test_maintain_abc(self):
3549 class A(abc.ABC):
3550 @abc.abstractmethod
3551 def foo(self):
3552 pass
3553
3554 @dataclass
3555 class Date(A):
3556 year: int
3557 month: 'Month'
3558 day: 'int'
3559
3560 self.assertTrue(inspect.isabstract(Date))
3561 msg = 'class Date with abstract method foo'
3562 self.assertRaisesRegex(TypeError, msg, Date)
3563
Eric V. Smith4e812962018-05-16 11:31:29 -04003564
Brandt Bucher145bf262021-02-26 14:51:55 -08003565class TestMatchArgs(unittest.TestCase):
3566 def test_match_args(self):
3567 @dataclass
3568 class C:
3569 a: int
3570 self.assertEqual(C(42).__match_args__, ('a',))
3571
3572 def test_explicit_match_args(self):
Brandt Bucherf84d5a12021-04-05 19:17:08 -07003573 ma = ()
Brandt Bucher145bf262021-02-26 14:51:55 -08003574 @dataclass
3575 class C:
3576 a: int
3577 __match_args__ = ma
3578 self.assertIs(C(42).__match_args__, ma)
3579
Brandt Bucherd92c59f2021-04-08 12:54:34 -07003580 def test_bpo_43764(self):
3581 @dataclass(repr=False, eq=False, init=False)
3582 class X:
3583 a: int
3584 b: int
3585 c: int
3586 self.assertEqual(X.__match_args__, ("a", "b", "c"))
3587
Eric V. Smith750f4842021-04-10 21:28:42 -04003588 def test_match_args_argument(self):
3589 @dataclass(match_args=False)
3590 class X:
3591 a: int
3592 self.assertNotIn('__match_args__', X.__dict__)
3593
3594 @dataclass(match_args=False)
3595 class Y:
3596 a: int
3597 __match_args__ = ('b',)
3598 self.assertEqual(Y.__match_args__, ('b',))
3599
3600 @dataclass(match_args=False)
3601 class Z(Y):
3602 z: int
3603 self.assertEqual(Z.__match_args__, ('b',))
3604
3605 # Ensure parent dataclass __match_args__ is seen, if child class
3606 # specifies match_args=False.
3607 @dataclass
3608 class A:
3609 a: int
3610 z: int
3611 @dataclass(match_args=False)
3612 class B(A):
3613 b: int
3614 self.assertEqual(B.__match_args__, ('a', 'z'))
3615
3616 def test_make_dataclasses(self):
3617 C = make_dataclass('C', [('x', int), ('y', int)])
3618 self.assertEqual(C.__match_args__, ('x', 'y'))
3619
3620 C = make_dataclass('C', [('x', int), ('y', int)], match_args=True)
3621 self.assertEqual(C.__match_args__, ('x', 'y'))
3622
3623 C = make_dataclass('C', [('x', int), ('y', int)], match_args=False)
3624 self.assertNotIn('__match__args__', C.__dict__)
3625
3626 C = make_dataclass('C', [('x', int), ('y', int)], namespace={'__match_args__': ('z',)})
3627 self.assertEqual(C.__match_args__, ('z',))
3628
Brandt Bucher145bf262021-02-26 14:51:55 -08003629
Eric V. Smith94549ee2021-04-26 13:14:28 -04003630class TestKeywordArgs(unittest.TestCase):
Eric V. Smithc0280532021-04-25 20:42:39 -04003631 def test_no_classvar_kwarg(self):
3632 msg = 'field a is a ClassVar but specifies kw_only'
3633 with self.assertRaisesRegex(TypeError, msg):
3634 @dataclass
3635 class A:
3636 a: ClassVar[int] = field(kw_only=True)
3637
3638 with self.assertRaisesRegex(TypeError, msg):
3639 @dataclass
3640 class A:
3641 a: ClassVar[int] = field(kw_only=False)
3642
3643 with self.assertRaisesRegex(TypeError, msg):
3644 @dataclass(kw_only=True)
3645 class A:
3646 a: ClassVar[int] = field(kw_only=False)
3647
3648 def test_field_marked_as_kwonly(self):
3649 #######################
3650 # Using dataclass(kw_only=True)
3651 @dataclass(kw_only=True)
3652 class A:
3653 a: int
3654 self.assertTrue(fields(A)[0].kw_only)
3655
3656 @dataclass(kw_only=True)
3657 class A:
3658 a: int = field(kw_only=True)
3659 self.assertTrue(fields(A)[0].kw_only)
3660
3661 @dataclass(kw_only=True)
3662 class A:
3663 a: int = field(kw_only=False)
3664 self.assertFalse(fields(A)[0].kw_only)
3665
3666 #######################
3667 # Using dataclass(kw_only=False)
3668 @dataclass(kw_only=False)
3669 class A:
3670 a: int
3671 self.assertFalse(fields(A)[0].kw_only)
3672
3673 @dataclass(kw_only=False)
3674 class A:
3675 a: int = field(kw_only=True)
3676 self.assertTrue(fields(A)[0].kw_only)
3677
3678 @dataclass(kw_only=False)
3679 class A:
3680 a: int = field(kw_only=False)
3681 self.assertFalse(fields(A)[0].kw_only)
3682
3683 #######################
3684 # Not specifying dataclass(kw_only)
3685 @dataclass
3686 class A:
3687 a: int
3688 self.assertFalse(fields(A)[0].kw_only)
3689
3690 @dataclass
3691 class A:
3692 a: int = field(kw_only=True)
3693 self.assertTrue(fields(A)[0].kw_only)
3694
3695 @dataclass
3696 class A:
3697 a: int = field(kw_only=False)
3698 self.assertFalse(fields(A)[0].kw_only)
3699
3700 def test_match_args(self):
3701 # kw fields don't show up in __match_args__.
3702 @dataclass(kw_only=True)
3703 class C:
3704 a: int
3705 self.assertEqual(C(a=42).__match_args__, ())
3706
3707 @dataclass
3708 class C:
3709 a: int
3710 b: int = field(kw_only=True)
3711 self.assertEqual(C(42, b=10).__match_args__, ('a',))
3712
3713 def test_KW_ONLY(self):
3714 @dataclass
3715 class A:
3716 a: int
3717 _: KW_ONLY
3718 b: int
3719 c: int
3720 A(3, c=5, b=4)
3721 msg = "takes 2 positional arguments but 4 were given"
3722 with self.assertRaisesRegex(TypeError, msg):
3723 A(3, 4, 5)
3724
3725
3726 @dataclass(kw_only=True)
3727 class B:
3728 a: int
3729 _: KW_ONLY
3730 b: int
3731 c: int
3732 B(a=3, b=4, c=5)
3733 msg = "takes 1 positional argument but 4 were given"
3734 with self.assertRaisesRegex(TypeError, msg):
3735 B(3, 4, 5)
3736
Christian Clausscfca4a62021-10-07 17:49:47 +02003737 # Explicitly make a field that follows KW_ONLY be non-keyword-only.
Eric V. Smithc0280532021-04-25 20:42:39 -04003738 @dataclass
3739 class C:
3740 a: int
3741 _: KW_ONLY
3742 b: int
3743 c: int = field(kw_only=False)
3744 c = C(1, 2, b=3)
3745 self.assertEqual(c.a, 1)
3746 self.assertEqual(c.b, 3)
3747 self.assertEqual(c.c, 2)
3748 c = C(1, b=3, c=2)
3749 self.assertEqual(c.a, 1)
3750 self.assertEqual(c.b, 3)
3751 self.assertEqual(c.c, 2)
3752 c = C(1, b=3, c=2)
3753 self.assertEqual(c.a, 1)
3754 self.assertEqual(c.b, 3)
3755 self.assertEqual(c.c, 2)
3756 c = C(c=2, b=3, a=1)
3757 self.assertEqual(c.a, 1)
3758 self.assertEqual(c.b, 3)
3759 self.assertEqual(c.c, 2)
3760
Eric V. Smith99ad7422021-05-03 03:24:53 -04003761 def test_KW_ONLY_as_string(self):
3762 @dataclass
3763 class A:
3764 a: int
3765 _: 'dataclasses.KW_ONLY'
3766 b: int
3767 c: int
3768 A(3, c=5, b=4)
3769 msg = "takes 2 positional arguments but 4 were given"
3770 with self.assertRaisesRegex(TypeError, msg):
3771 A(3, 4, 5)
3772
3773 def test_KW_ONLY_twice(self):
3774 msg = "'Y' is KW_ONLY, but KW_ONLY has already been specified"
3775
3776 with self.assertRaisesRegex(TypeError, msg):
3777 @dataclass
3778 class A:
3779 a: int
3780 X: KW_ONLY
3781 Y: KW_ONLY
3782 b: int
3783 c: int
3784
3785 with self.assertRaisesRegex(TypeError, msg):
3786 @dataclass
3787 class A:
3788 a: int
3789 X: KW_ONLY
3790 b: int
3791 Y: KW_ONLY
3792 c: int
3793
3794 with self.assertRaisesRegex(TypeError, msg):
3795 @dataclass
3796 class A:
3797 a: int
3798 X: KW_ONLY
3799 b: int
3800 c: int
3801 Y: KW_ONLY
3802
3803 # But this usage is okay, since it's not using KW_ONLY.
3804 @dataclass
3805 class A:
3806 a: int
3807 _: KW_ONLY
3808 b: int
3809 c: int = field(kw_only=True)
3810
3811 # And if inheriting, it's okay.
3812 @dataclass
3813 class A:
3814 a: int
3815 _: KW_ONLY
3816 b: int
3817 c: int
3818 @dataclass
3819 class B(A):
3820 _: KW_ONLY
3821 d: int
3822
3823 # Make sure the error is raised in a derived class.
3824 with self.assertRaisesRegex(TypeError, msg):
3825 @dataclass
3826 class A:
3827 a: int
3828 _: KW_ONLY
3829 b: int
3830 c: int
3831 @dataclass
3832 class B(A):
3833 X: KW_ONLY
3834 d: int
3835 Y: KW_ONLY
3836
3837
Eric V. Smithc0280532021-04-25 20:42:39 -04003838 def test_post_init(self):
3839 @dataclass
3840 class A:
3841 a: int
3842 _: KW_ONLY
3843 b: InitVar[int]
3844 c: int
3845 d: InitVar[int]
3846 def __post_init__(self, b, d):
3847 raise CustomError(f'{b=} {d=}')
3848 with self.assertRaisesRegex(CustomError, 'b=3 d=4'):
3849 A(1, c=2, b=3, d=4)
3850
3851 @dataclass
3852 class B:
3853 a: int
3854 _: KW_ONLY
3855 b: InitVar[int]
3856 c: int
3857 d: InitVar[int]
3858 def __post_init__(self, b, d):
3859 self.a = b
3860 self.c = d
3861 b = B(1, c=2, b=3, d=4)
3862 self.assertEqual(asdict(b), {'a': 3, 'c': 4})
3863
Eric V. Smith94549ee2021-04-26 13:14:28 -04003864 def test_defaults(self):
3865 # For kwargs, make sure we can have defaults after non-defaults.
3866 @dataclass
3867 class A:
3868 a: int = 0
3869 _: KW_ONLY
3870 b: int
3871 c: int = 1
3872 d: int
3873
3874 a = A(d=4, b=3)
3875 self.assertEqual(a.a, 0)
3876 self.assertEqual(a.b, 3)
3877 self.assertEqual(a.c, 1)
3878 self.assertEqual(a.d, 4)
3879
3880 # Make sure we still check for non-kwarg non-defaults not following
3881 # defaults.
3882 err_regex = "non-default argument 'z' follows default argument"
3883 with self.assertRaisesRegex(TypeError, err_regex):
3884 @dataclass
3885 class A:
3886 a: int = 0
3887 z: int
3888 _: KW_ONLY
3889 b: int
3890 c: int = 1
3891 d: int
Eric V. Smithc0280532021-04-25 20:42:39 -04003892
Miss Islington (bot)cf8c8782021-11-20 15:46:56 -08003893 def test_make_dataclass(self):
3894 A = make_dataclass("A", ['a'], kw_only=True)
3895 self.assertTrue(fields(A)[0].kw_only)
3896
3897 B = make_dataclass("B",
3898 ['a', ('b', int, field(kw_only=False))],
3899 kw_only=True)
3900 self.assertTrue(fields(B)[0].kw_only)
3901 self.assertFalse(fields(B)[1].kw_only)
3902
3903
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003904if __name__ == '__main__':
3905 unittest.main()