blob: 0bfed41b369d19c6b084b5be77e7b5e1fd318b07 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011import unittest
Batuhan Taskaya044a1042020-10-06 23:03:02 +030012from textwrap import dedent
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010014from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Yury Selivanovd219cc42019-12-09 09:54:20 -050015from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050016from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050017from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050018
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040019import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
20import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
21
Eric V. Smithf0db54a2017-12-04 16:58:55 -050022# Just any custom exception we can catch.
23class CustomError(Exception): pass
24
25class TestCase(unittest.TestCase):
26 def test_no_fields(self):
27 @dataclass
28 class C:
29 pass
30
31 o = C()
32 self.assertEqual(len(fields(C)), 0)
33
Eric V. Smith56970b82018-03-22 16:28:48 -040034 def test_no_fields_but_member_variable(self):
35 @dataclass
36 class C:
37 i = 0
38
39 o = C()
40 self.assertEqual(len(fields(C)), 0)
41
Eric V. Smithf0db54a2017-12-04 16:58:55 -050042 def test_one_field_no_default(self):
43 @dataclass
44 class C:
45 x: int
46
47 o = C(42)
48 self.assertEqual(o.x, 42)
49
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053050 def test_field_default_default_factory_error(self):
51 msg = "cannot specify both default and default_factory"
52 with self.assertRaisesRegex(ValueError, msg):
53 @dataclass
54 class C:
55 x: int = field(default=1, default_factory=int)
56
57 def test_field_repr(self):
58 int_field = field(default=1, init=True, repr=False)
59 int_field.name = "id"
60 repr_output = repr(int_field)
61 expected_output = "Field(name='id',type=None," \
62 f"default=1,default_factory={MISSING!r}," \
63 "init=True,repr=False,hash=None," \
64 "compare=True,metadata=mappingproxy({})," \
65 "_field_type=None)"
66
67 self.assertEqual(repr_output, expected_output)
68
Eric V. Smithf0db54a2017-12-04 16:58:55 -050069 def test_named_init_params(self):
70 @dataclass
71 class C:
72 x: int
73
74 o = C(x=32)
75 self.assertEqual(o.x, 32)
76
77 def test_two_fields_one_default(self):
78 @dataclass
79 class C:
80 x: int
81 y: int = 0
82
83 o = C(3)
84 self.assertEqual((o.x, o.y), (3, 0))
85
86 # Non-defaults following defaults.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class C:
92 x: int = 0
93 y: int
94
95 # A derived class adds a non-default field after a default one.
96 with self.assertRaisesRegex(TypeError,
97 "non-default argument 'y' follows "
98 "default argument"):
99 @dataclass
100 class B:
101 x: int = 0
102
103 @dataclass
104 class C(B):
105 y: int
106
107 # Override a base class field and add a default to
108 # a field which didn't use to have a default.
109 with self.assertRaisesRegex(TypeError,
110 "non-default argument 'y' follows "
111 "default argument"):
112 @dataclass
113 class B:
114 x: int
115 y: int
116
117 @dataclass
118 class C(B):
119 x: int = 0
120
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500121 def test_overwrite_hash(self):
122 # Test that declaring this class isn't an error. It should
123 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500124 @dataclass(frozen=True)
125 class C:
126 x: int
127 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500128 return 301
129 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500130
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500131 # Test that declaring this class isn't an error. It should
132 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500133 @dataclass(frozen=True)
134 class C:
135 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500136 def __eq__(self, other):
137 return False
138 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500140 # But this one should generate an exception, because with
141 # unsafe_hash=True, it's an error to have a __hash__ defined.
142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 def __hash__(self):
147 pass
148
149 # Creating this class should not generate an exception,
150 # because even though __hash__ exists before @dataclass is
151 # called, (due to __eq__ being defined), since it's None
152 # that's okay.
153 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500154 class C:
155 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500156 def __eq__(self):
157 pass
158 # The generated hash function works as we'd expect.
159 self.assertEqual(hash(C(10)), hash((10,)))
160
161 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400162 # __hash__ exists and is not None, which it would be if it
163 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500164 with self.assertRaisesRegex(TypeError,
165 'Cannot overwrite attribute __hash__'):
166 @dataclass(unsafe_hash=True)
167 class C:
168 x: int
169 def __eq__(self):
170 pass
171 def __hash__(self):
172 pass
173
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500174 def test_overwrite_fields_in_derived_class(self):
175 # Note that x from C1 replaces x in Base, but the order remains
176 # the same as defined in Base.
177 @dataclass
178 class Base:
179 x: Any = 15.0
180 y: int = 0
181
182 @dataclass
183 class C1(Base):
184 z: int = 10
185 x: int = 15
186
187 o = Base()
188 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
189
190 o = C1()
191 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
192
193 o = C1(x=5)
194 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
195
196 def test_field_named_self(self):
197 @dataclass
198 class C:
199 self: str
200 c=C('foo')
201 self.assertEqual(c.self, 'foo')
202
203 # Make sure the first parameter is not named 'self'.
204 sig = inspect.signature(C.__init__)
205 first = next(iter(sig.parameters))
206 self.assertNotEqual('self', first)
207
208 # But we do use 'self' if no field named self.
209 @dataclass
210 class C:
211 selfx: str
212
213 # Make sure the first parameter is named 'self'.
214 sig = inspect.signature(C.__init__)
215 first = next(iter(sig.parameters))
216 self.assertEqual('self', first)
217
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300218 def test_field_named_object(self):
219 @dataclass
220 class C:
221 object: str
222 c = C('foo')
223 self.assertEqual(c.object, 'foo')
224
225 def test_field_named_object_frozen(self):
226 @dataclass(frozen=True)
227 class C:
228 object: str
229 c = C('foo')
230 self.assertEqual(c.object, 'foo')
231
232 def test_field_named_like_builtin(self):
233 # Attribute names can shadow built-in names
234 # since code generation is used.
235 # Ensure that this is not happening.
236 exclusions = {'None', 'True', 'False'}
237 builtins_names = sorted(
238 b for b in builtins.__dict__.keys()
239 if not b.startswith('__') and b not in exclusions
240 )
241 attributes = [(name, str) for name in builtins_names]
242 C = make_dataclass('C', attributes)
243
244 c = C(*[name for name in builtins_names])
245
246 for name in builtins_names:
247 self.assertEqual(getattr(c, name), name)
248
249 def test_field_named_like_builtin_frozen(self):
250 # Attribute names can shadow built-in names
251 # since code generation is used.
252 # Ensure that this is not happening
253 # for frozen data classes.
254 exclusions = {'None', 'True', 'False'}
255 builtins_names = sorted(
256 b for b in builtins.__dict__.keys()
257 if not b.startswith('__') and b not in exclusions
258 )
259 attributes = [(name, str) for name in builtins_names]
260 C = make_dataclass('C', attributes, frozen=True)
261
262 c = C(*[name for name in builtins_names])
263
264 for name in builtins_names:
265 self.assertEqual(getattr(c, name), name)
266
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500267 def test_0_field_compare(self):
268 # Ensure that order=False is the default.
269 @dataclass
270 class C0:
271 pass
272
273 @dataclass(order=False)
274 class C1:
275 pass
276
277 for cls in [C0, C1]:
278 with self.subTest(cls=cls):
279 self.assertEqual(cls(), cls())
280 for idx, fn in enumerate([lambda a, b: a < b,
281 lambda a, b: a <= b,
282 lambda a, b: a > b,
283 lambda a, b: a >= b]):
284 with self.subTest(idx=idx):
285 with self.assertRaisesRegex(TypeError,
286 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
287 fn(cls(), cls())
288
289 @dataclass(order=True)
290 class C:
291 pass
292 self.assertLessEqual(C(), C())
293 self.assertGreaterEqual(C(), C())
294
295 def test_1_field_compare(self):
296 # Ensure that order=False is the default.
297 @dataclass
298 class C0:
299 x: int
300
301 @dataclass(order=False)
302 class C1:
303 x: int
304
305 for cls in [C0, C1]:
306 with self.subTest(cls=cls):
307 self.assertEqual(cls(1), cls(1))
308 self.assertNotEqual(cls(0), cls(1))
309 for idx, fn in enumerate([lambda a, b: a < b,
310 lambda a, b: a <= b,
311 lambda a, b: a > b,
312 lambda a, b: a >= b]):
313 with self.subTest(idx=idx):
314 with self.assertRaisesRegex(TypeError,
315 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
316 fn(cls(0), cls(0))
317
318 @dataclass(order=True)
319 class C:
320 x: int
321 self.assertLess(C(0), C(1))
322 self.assertLessEqual(C(0), C(1))
323 self.assertLessEqual(C(1), C(1))
324 self.assertGreater(C(1), C(0))
325 self.assertGreaterEqual(C(1), C(0))
326 self.assertGreaterEqual(C(1), C(1))
327
328 def test_simple_compare(self):
329 # Ensure that order=False is the default.
330 @dataclass
331 class C0:
332 x: int
333 y: int
334
335 @dataclass(order=False)
336 class C1:
337 x: int
338 y: int
339
340 for cls in [C0, C1]:
341 with self.subTest(cls=cls):
342 self.assertEqual(cls(0, 0), cls(0, 0))
343 self.assertEqual(cls(1, 2), cls(1, 2))
344 self.assertNotEqual(cls(1, 0), cls(0, 0))
345 self.assertNotEqual(cls(1, 0), cls(1, 1))
346 for idx, fn in enumerate([lambda a, b: a < b,
347 lambda a, b: a <= b,
348 lambda a, b: a > b,
349 lambda a, b: a >= b]):
350 with self.subTest(idx=idx):
351 with self.assertRaisesRegex(TypeError,
352 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
353 fn(cls(0, 0), cls(0, 0))
354
355 @dataclass(order=True)
356 class C:
357 x: int
358 y: int
359
360 for idx, fn in enumerate([lambda a, b: a == b,
361 lambda a, b: a <= b,
362 lambda a, b: a >= b]):
363 with self.subTest(idx=idx):
364 self.assertTrue(fn(C(0, 0), C(0, 0)))
365
366 for idx, fn in enumerate([lambda a, b: a < b,
367 lambda a, b: a <= b,
368 lambda a, b: a != b]):
369 with self.subTest(idx=idx):
370 self.assertTrue(fn(C(0, 0), C(0, 1)))
371 self.assertTrue(fn(C(0, 1), C(1, 0)))
372 self.assertTrue(fn(C(1, 0), C(1, 1)))
373
374 for idx, fn in enumerate([lambda a, b: a > b,
375 lambda a, b: a >= b,
376 lambda a, b: a != b]):
377 with self.subTest(idx=idx):
378 self.assertTrue(fn(C(0, 1), C(0, 0)))
379 self.assertTrue(fn(C(1, 0), C(0, 1)))
380 self.assertTrue(fn(C(1, 1), C(1, 0)))
381
382 def test_compare_subclasses(self):
383 # Comparisons fail for subclasses, even if no fields
384 # are added.
385 @dataclass
386 class B:
387 i: int
388
389 @dataclass
390 class C(B):
391 pass
392
393 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
394 (lambda a, b: a != b, True)]):
395 with self.subTest(idx=idx):
396 self.assertEqual(fn(B(0), C(0)), expected)
397
398 for idx, fn in enumerate([lambda a, b: a < b,
399 lambda a, b: a <= b,
400 lambda a, b: a > b,
401 lambda a, b: a >= b]):
402 with self.subTest(idx=idx):
403 with self.assertRaisesRegex(TypeError,
404 "not supported between instances of 'B' and 'C'"):
405 fn(B(0), C(0))
406
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500407 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500408 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500409 for (eq, order, result ) in [
410 (False, False, 'neither'),
411 (False, True, 'exception'),
412 (True, False, 'eq_only'),
413 (True, True, 'both'),
414 ]:
415 with self.subTest(eq=eq, order=order):
416 if result == 'exception':
417 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
418 @dataclass(eq=eq, order=order)
419 class C:
420 pass
421 else:
422 @dataclass(eq=eq, order=order)
423 class C:
424 pass
425
426 if result == 'neither':
427 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500428 self.assertNotIn('__lt__', C.__dict__)
429 self.assertNotIn('__le__', C.__dict__)
430 self.assertNotIn('__gt__', C.__dict__)
431 self.assertNotIn('__ge__', C.__dict__)
432 elif result == 'both':
433 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 self.assertIn('__lt__', C.__dict__)
435 self.assertIn('__le__', C.__dict__)
436 self.assertIn('__gt__', C.__dict__)
437 self.assertIn('__ge__', C.__dict__)
438 elif result == 'eq_only':
439 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 self.assertNotIn('__lt__', C.__dict__)
441 self.assertNotIn('__le__', C.__dict__)
442 self.assertNotIn('__gt__', C.__dict__)
443 self.assertNotIn('__ge__', C.__dict__)
444 else:
445 assert False, f'unknown result {result!r}'
446
447 def test_field_no_default(self):
448 @dataclass
449 class C:
450 x: int = field()
451
452 self.assertEqual(C(5).x, 5)
453
454 with self.assertRaisesRegex(TypeError,
455 r"__init__\(\) missing 1 required "
456 "positional argument: 'x'"):
457 C()
458
459 def test_field_default(self):
460 default = object()
461 @dataclass
462 class C:
463 x: object = field(default=default)
464
465 self.assertIs(C.x, default)
466 c = C(10)
467 self.assertEqual(c.x, 10)
468
469 # If we delete the instance attribute, we should then see the
470 # class attribute.
471 del c.x
472 self.assertIs(c.x, default)
473
474 self.assertIs(C().x, default)
475
476 def test_not_in_repr(self):
477 @dataclass
478 class C:
479 x: int = field(repr=False)
480 with self.assertRaises(TypeError):
481 C()
482 c = C(10)
483 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
484
485 @dataclass
486 class C:
487 x: int = field(repr=False)
488 y: int
489 c = C(10, 20)
490 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
491
492 def test_not_in_compare(self):
493 @dataclass
494 class C:
495 x: int = 0
496 y: int = field(compare=False, default=4)
497
498 self.assertEqual(C(), C(0, 20))
499 self.assertEqual(C(1, 10), C(1, 20))
500 self.assertNotEqual(C(3), C(4, 10))
501 self.assertNotEqual(C(3, 10), C(4, 10))
502
503 def test_hash_field_rules(self):
504 # Test all 6 cases of:
505 # hash=True/False/None
506 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500507 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500508 (True, False, 'field' ),
509 (True, True, 'field' ),
510 (False, False, 'absent'),
511 (False, True, 'absent'),
512 (None, False, 'absent'),
513 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500514 ]:
515 with self.subTest(hash=hash_, compare=compare):
516 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500517 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500518 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500519
520 if result == 'field':
521 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500522 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500523 elif result == 'absent':
524 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500525 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500526 else:
527 assert False, f'unknown result {result!r}'
528
529 def test_init_false_no_default(self):
530 # If init=False and no default value, then the field won't be
531 # present in the instance.
532 @dataclass
533 class C:
534 x: int = field(init=False)
535
536 self.assertNotIn('x', C().__dict__)
537
538 @dataclass
539 class C:
540 x: int
541 y: int = 0
542 z: int = field(init=False)
543 t: int = 10
544
545 self.assertNotIn('z', C(0).__dict__)
546 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
547
548 def test_class_marker(self):
549 @dataclass
550 class C:
551 x: int
552 y: str = field(init=False, default=None)
553 z: str = field(repr=False)
554
555 the_fields = fields(C)
556 # the_fields is a tuple of 3 items, each value
557 # is in __annotations__.
558 self.assertIsInstance(the_fields, tuple)
559 for f in the_fields:
560 self.assertIs(type(f), Field)
561 self.assertIn(f.name, C.__annotations__)
562
563 self.assertEqual(len(the_fields), 3)
564
565 self.assertEqual(the_fields[0].name, 'x')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300566 self.assertEqual(the_fields[0].type, 'int')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500567 self.assertFalse(hasattr(C, 'x'))
568 self.assertTrue (the_fields[0].init)
569 self.assertTrue (the_fields[0].repr)
570 self.assertEqual(the_fields[1].name, 'y')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300571 self.assertEqual(the_fields[1].type, 'str')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500572 self.assertIsNone(getattr(C, 'y'))
573 self.assertFalse(the_fields[1].init)
574 self.assertTrue (the_fields[1].repr)
575 self.assertEqual(the_fields[2].name, 'z')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300576 self.assertEqual(the_fields[2].type, 'str')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500577 self.assertFalse(hasattr(C, 'z'))
578 self.assertTrue (the_fields[2].init)
579 self.assertFalse(the_fields[2].repr)
580
581 def test_field_order(self):
582 @dataclass
583 class B:
584 a: str = 'B:a'
585 b: str = 'B:b'
586 c: str = 'B:c'
587
588 @dataclass
589 class C(B):
590 b: str = 'C:b'
591
592 self.assertEqual([(f.name, f.default) for f in fields(C)],
593 [('a', 'B:a'),
594 ('b', 'C:b'),
595 ('c', 'B:c')])
596
597 @dataclass
598 class D(B):
599 c: str = 'D:c'
600
601 self.assertEqual([(f.name, f.default) for f in fields(D)],
602 [('a', 'B:a'),
603 ('b', 'B:b'),
604 ('c', 'D:c')])
605
606 @dataclass
607 class E(D):
608 a: str = 'E:a'
609 d: str = 'E:d'
610
611 self.assertEqual([(f.name, f.default) for f in fields(E)],
612 [('a', 'E:a'),
613 ('b', 'B:b'),
614 ('c', 'D:c'),
615 ('d', 'E:d')])
616
617 def test_class_attrs(self):
618 # We only have a class attribute if a default value is
619 # specified, either directly or via a field with a default.
620 default = object()
621 @dataclass
622 class C:
623 x: int
624 y: int = field(repr=False)
625 z: object = default
626 t: int = field(default=100)
627
628 self.assertFalse(hasattr(C, 'x'))
629 self.assertFalse(hasattr(C, 'y'))
630 self.assertIs (C.z, default)
631 self.assertEqual(C.t, 100)
632
633 def test_disallowed_mutable_defaults(self):
634 # For the known types, don't allow mutable default values.
635 for typ, empty, non_empty in [(list, [], [1]),
636 (dict, {}, {0:1}),
637 (set, set(), set([1])),
638 ]:
639 with self.subTest(typ=typ):
640 # Can't use a zero-length value.
641 with self.assertRaisesRegex(ValueError,
642 f'mutable default {typ} for field '
643 'x is not allowed'):
644 @dataclass
645 class Point:
646 x: typ = empty
647
648
649 # Nor a non-zero-length value
650 with self.assertRaisesRegex(ValueError,
651 f'mutable default {typ} for field '
652 'y is not allowed'):
653 @dataclass
654 class Point:
655 y: typ = non_empty
656
657 # Check subtypes also fail.
658 class Subclass(typ): pass
659
660 with self.assertRaisesRegex(ValueError,
661 f"mutable default .*Subclass'>"
662 ' for field z is not allowed'
663 ):
664 @dataclass
665 class Point:
666 z: typ = Subclass()
667
668 # Because this is a ClassVar, it can be mutable.
669 @dataclass
670 class C:
671 z: ClassVar[typ] = typ()
672
673 # Because this is a ClassVar, it can be mutable.
674 @dataclass
675 class C:
676 x: ClassVar[typ] = Subclass()
677
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500678 def test_deliberately_mutable_defaults(self):
679 # If a mutable default isn't in the known list of
680 # (list, dict, set), then it's okay.
681 class Mutable:
682 def __init__(self):
683 self.l = []
684
685 @dataclass
686 class C:
687 x: Mutable
688
689 # These 2 instances will share this value of x.
690 lst = Mutable()
691 o1 = C(lst)
692 o2 = C(lst)
693 self.assertEqual(o1, o2)
694 o1.x.l.extend([1, 2])
695 self.assertEqual(o1, o2)
696 self.assertEqual(o1.x.l, [1, 2])
697 self.assertIs(o1.x, o2.x)
698
699 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400700 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 @dataclass()
702 class C:
703 x: int
704
705 self.assertEqual(C(42).x, 42)
706
707 def test_not_tuple(self):
708 # Make sure we can't be compared to a tuple.
709 @dataclass
710 class Point:
711 x: int
712 y: int
713 self.assertNotEqual(Point(1, 2), (1, 2))
714
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400715 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716 @dataclass
717 class C:
718 x: int
719 y: int
720 self.assertNotEqual(Point(1, 3), C(1, 3))
721
Windson yangbe372d72019-04-23 02:45:34 +0800722 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # Test that some of the problems with namedtuple don't happen
724 # here.
725 @dataclass
726 class Point3D:
727 x: int
728 y: int
729 z: int
730
731 @dataclass
732 class Date:
733 year: int
734 month: int
735 day: int
736
737 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
738 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
739
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400740 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200741 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 x, y, z = Point3D(4, 5, 6)
743
Eric V. Smith7c99e932018-01-28 19:18:55 -0500744 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500745 # equal.
746 @dataclass
747 class Point3Dv1:
748 x: int = 0
749 y: int = 0
750 z: int = 0
751 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
752
753 def test_function_annotations(self):
754 # Some dummy class and instance to use as a default.
755 class F:
756 pass
757 f = F()
758
759 def validate_class(cls):
760 # First, check __annotations__, even though they're not
761 # function annotations.
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300762 self.assertEqual(cls.__annotations__['i'], 'int')
763 self.assertEqual(cls.__annotations__['j'], 'str')
764 self.assertEqual(cls.__annotations__['k'], 'F')
765 self.assertEqual(cls.__annotations__['l'], 'float')
766 self.assertEqual(cls.__annotations__['z'], 'complex')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500767
768 # Verify __init__.
769
770 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400771 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500772 self.assertIs(signature.return_annotation, None)
773
774 # Check each parameter.
775 params = iter(signature.parameters.values())
776 param = next(params)
777 # This is testing an internal name, and probably shouldn't be tested.
778 self.assertEqual(param.name, 'self')
779 param = next(params)
780 self.assertEqual(param.name, 'i')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300781 self.assertIs (param.annotation, 'int')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500782 self.assertEqual(param.default, inspect.Parameter.empty)
783 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
784 param = next(params)
785 self.assertEqual(param.name, 'j')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300786 self.assertIs (param.annotation, 'str')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500787 self.assertEqual(param.default, inspect.Parameter.empty)
788 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
789 param = next(params)
790 self.assertEqual(param.name, 'k')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300791 self.assertIs (param.annotation, 'F')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400792 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500793 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
794 param = next(params)
795 self.assertEqual(param.name, 'l')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300796 self.assertIs (param.annotation, 'float')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400797 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500798 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
799 self.assertRaises(StopIteration, next, params)
800
801
802 @dataclass
803 class C:
804 i: int
805 j: str
806 k: F = f
807 l: float=field(default=None)
808 z: complex=field(default=3+4j, init=False)
809
810 validate_class(C)
811
812 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500813 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500814 class C:
815 i: int
816 j: str
817 k: F = f
818 l: float=field(default=None)
819 z: complex=field(default=3+4j, init=False)
820
821 validate_class(C)
822
Eric V. Smith03220fd2017-12-29 13:59:58 -0500823 def test_missing_default(self):
824 # Test that MISSING works the same as a default not being
825 # specified.
826 @dataclass
827 class C:
828 x: int=field(default=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_default_factory(self):
845 # Test that MISSING works the same as a default factory not
846 # being specified (which is really the same as a default not
847 # being specified, too).
848 @dataclass
849 class C:
850 x: int=field(default_factory=MISSING)
851 with self.assertRaisesRegex(TypeError,
852 r'__init__\(\) missing 1 required '
853 'positional argument'):
854 C()
855 self.assertNotIn('x', C.__dict__)
856
857 @dataclass
858 class D:
859 x: int=field(default=MISSING, default_factory=MISSING)
860 with self.assertRaisesRegex(TypeError,
861 r'__init__\(\) missing 1 required '
862 'positional argument'):
863 D()
864 self.assertNotIn('x', D.__dict__)
865
866 def test_missing_repr(self):
867 self.assertIn('MISSING_TYPE object', repr(MISSING))
868
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500869 def test_dont_include_other_annotations(self):
870 @dataclass
871 class C:
872 i: int
873 def foo(self) -> int:
874 return 4
875 @property
876 def bar(self) -> int:
877 return 5
878 self.assertEqual(list(C.__annotations__), ['i'])
879 self.assertEqual(C(10).foo(), 4)
880 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400881 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500882
883 def test_post_init(self):
884 # Just make sure it gets called
885 @dataclass
886 class C:
887 def __post_init__(self):
888 raise CustomError()
889 with self.assertRaises(CustomError):
890 C()
891
892 @dataclass
893 class C:
894 i: int = 10
895 def __post_init__(self):
896 if self.i == 10:
897 raise CustomError()
898 with self.assertRaises(CustomError):
899 C()
900 # post-init gets called, but doesn't raise. This is just
901 # checking that self is used correctly.
902 C(5)
903
904 # If there's not an __init__, then post-init won't get called.
905 @dataclass(init=False)
906 class C:
907 def __post_init__(self):
908 raise CustomError()
909 # Creating the class won't raise
910 C()
911
912 @dataclass
913 class C:
914 x: int = 0
915 def __post_init__(self):
916 self.x *= 2
917 self.assertEqual(C().x, 0)
918 self.assertEqual(C(2).x, 4)
919
Mike53f7a7c2017-12-14 14:04:53 +0300920 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500921 # attributes.
922 @dataclass(frozen=True)
923 class C:
924 x: int = 0
925 def __post_init__(self):
926 self.x *= 2
927 with self.assertRaises(FrozenInstanceError):
928 C()
929
930 def test_post_init_super(self):
931 # Make sure super() post-init isn't called by default.
932 class B:
933 def __post_init__(self):
934 raise CustomError()
935
936 @dataclass
937 class C(B):
938 def __post_init__(self):
939 self.x = 5
940
941 self.assertEqual(C().x, 5)
942
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400943 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500944 @dataclass
945 class C(B):
946 def __post_init__(self):
947 super().__post_init__()
948
949 with self.assertRaises(CustomError):
950 C()
951
952 # Make sure post-init is called, even if not defined in our
953 # class.
954 @dataclass
955 class C(B):
956 pass
957
958 with self.assertRaises(CustomError):
959 C()
960
961 def test_post_init_staticmethod(self):
962 flag = False
963 @dataclass
964 class C:
965 x: int
966 y: int
967 @staticmethod
968 def __post_init__():
969 nonlocal flag
970 flag = True
971
972 self.assertFalse(flag)
973 c = C(3, 4)
974 self.assertEqual((c.x, c.y), (3, 4))
975 self.assertTrue(flag)
976
977 def test_post_init_classmethod(self):
978 @dataclass
979 class C:
980 flag = False
981 x: int
982 y: int
983 @classmethod
984 def __post_init__(cls):
985 cls.flag = True
986
987 self.assertFalse(C.flag)
988 c = C(3, 4)
989 self.assertEqual((c.x, c.y), (3, 4))
990 self.assertTrue(C.flag)
991
992 def test_class_var(self):
993 # Make sure ClassVars are ignored in __init__, __repr__, etc.
994 @dataclass
995 class C:
996 x: int
997 y: int = 10
998 z: ClassVar[int] = 1000
999 w: ClassVar[int] = 2000
1000 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001001 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001002
1003 c = C(5)
1004 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001005 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001006 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001007 self.assertEqual(c.z, 1000)
1008 self.assertEqual(c.w, 2000)
1009 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001010 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001011 C.z += 1
1012 self.assertEqual(c.z, 1001)
1013 c = C(20)
1014 self.assertEqual((c.x, c.y), (20, 10))
1015 self.assertEqual(c.z, 1001)
1016 self.assertEqual(c.w, 2000)
1017 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001018 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001019
1020 def test_class_var_no_default(self):
1021 # If a ClassVar has no default value, it should not be set on the class.
1022 @dataclass
1023 class C:
1024 x: ClassVar[int]
1025
1026 self.assertNotIn('x', C.__dict__)
1027
1028 def test_class_var_default_factory(self):
1029 # It makes no sense for a ClassVar to have a default factory. When
1030 # would it be called? Call it yourself, since it's class-wide.
1031 with self.assertRaisesRegex(TypeError,
1032 'cannot have a default factory'):
1033 @dataclass
1034 class C:
1035 x: ClassVar[int] = field(default_factory=int)
1036
1037 self.assertNotIn('x', C.__dict__)
1038
1039 def test_class_var_with_default(self):
1040 # If a ClassVar has a default value, it should be set on the class.
1041 @dataclass
1042 class C:
1043 x: ClassVar[int] = 10
1044 self.assertEqual(C.x, 10)
1045
1046 @dataclass
1047 class C:
1048 x: ClassVar[int] = field(default=10)
1049 self.assertEqual(C.x, 10)
1050
1051 def test_class_var_frozen(self):
1052 # Make sure ClassVars work even if we're frozen.
1053 @dataclass(frozen=True)
1054 class C:
1055 x: int
1056 y: int = 10
1057 z: ClassVar[int] = 1000
1058 w: ClassVar[int] = 2000
1059 t: ClassVar[int] = 3000
1060
1061 c = C(5)
1062 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1063 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1064 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1065 self.assertEqual(c.z, 1000)
1066 self.assertEqual(c.w, 2000)
1067 self.assertEqual(c.t, 3000)
1068 # We can still modify the ClassVar, it's only instances that are
1069 # frozen.
1070 C.z += 1
1071 self.assertEqual(c.z, 1001)
1072 c = C(20)
1073 self.assertEqual((c.x, c.y), (20, 10))
1074 self.assertEqual(c.z, 1001)
1075 self.assertEqual(c.w, 2000)
1076 self.assertEqual(c.t, 3000)
1077
1078 def test_init_var_no_default(self):
1079 # If an InitVar has no default value, it should not be set on the class.
1080 @dataclass
1081 class C:
1082 x: InitVar[int]
1083
1084 self.assertNotIn('x', C.__dict__)
1085
1086 def test_init_var_default_factory(self):
1087 # It makes no sense for an InitVar to have a default factory. When
1088 # would it be called? Call it yourself, since it's class-wide.
1089 with self.assertRaisesRegex(TypeError,
1090 'cannot have a default factory'):
1091 @dataclass
1092 class C:
1093 x: InitVar[int] = field(default_factory=int)
1094
1095 self.assertNotIn('x', C.__dict__)
1096
1097 def test_init_var_with_default(self):
1098 # If an InitVar has a default value, it should be set on the class.
1099 @dataclass
1100 class C:
1101 x: InitVar[int] = 10
1102 self.assertEqual(C.x, 10)
1103
1104 @dataclass
1105 class C:
1106 x: InitVar[int] = field(default=10)
1107 self.assertEqual(C.x, 10)
1108
1109 def test_init_var(self):
1110 @dataclass
1111 class C:
1112 x: int = None
1113 init_param: InitVar[int] = None
1114
1115 def __post_init__(self, init_param):
1116 if self.x is None:
1117 self.x = init_param*2
1118
1119 c = C(init_param=10)
1120 self.assertEqual(c.x, 20)
1121
Augusto Hack01ee12b2019-06-02 23:14:48 -03001122 def test_init_var_preserve_type(self):
1123 self.assertEqual(InitVar[int].type, int)
1124
1125 # Make sure the repr is correct.
1126 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001127 self.assertEqual(repr(InitVar[List[int]]),
1128 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001129
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001130 def test_init_var_inheritance(self):
1131 # Note that this deliberately tests that a dataclass need not
1132 # have a __post_init__ function if it has an InitVar field.
1133 # It could just be used in a derived class, as shown here.
1134 @dataclass
1135 class Base:
1136 x: int
1137 init_base: InitVar[int]
1138
1139 # We can instantiate by passing the InitVar, even though
1140 # it's not used.
1141 b = Base(0, 10)
1142 self.assertEqual(vars(b), {'x': 0})
1143
1144 @dataclass
1145 class C(Base):
1146 y: int
1147 init_derived: InitVar[int]
1148
1149 def __post_init__(self, init_base, init_derived):
1150 self.x = self.x + init_base
1151 self.y = self.y + init_derived
1152
1153 c = C(10, 11, 50, 51)
1154 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1155
1156 def test_default_factory(self):
1157 # Test a factory that returns a new list.
1158 @dataclass
1159 class C:
1160 x: int
1161 y: list = field(default_factory=list)
1162
1163 c0 = C(3)
1164 c1 = C(3)
1165 self.assertEqual(c0.x, 3)
1166 self.assertEqual(c0.y, [])
1167 self.assertEqual(c0, c1)
1168 self.assertIsNot(c0.y, c1.y)
1169 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1170
1171 # Test a factory that returns a shared list.
1172 l = []
1173 @dataclass
1174 class C:
1175 x: int
1176 y: list = field(default_factory=lambda: l)
1177
1178 c0 = C(3)
1179 c1 = C(3)
1180 self.assertEqual(c0.x, 3)
1181 self.assertEqual(c0.y, [])
1182 self.assertEqual(c0, c1)
1183 self.assertIs(c0.y, c1.y)
1184 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1185
1186 # Test various other field flags.
1187 # repr
1188 @dataclass
1189 class C:
1190 x: list = field(default_factory=list, repr=False)
1191 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1192 self.assertEqual(C().x, [])
1193
1194 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001195 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001196 class C:
1197 x: list = field(default_factory=list, hash=False)
1198 self.assertEqual(astuple(C()), ([],))
1199 self.assertEqual(hash(C()), hash(()))
1200
1201 # init (see also test_default_factory_with_no_init)
1202 @dataclass
1203 class C:
1204 x: list = field(default_factory=list, init=False)
1205 self.assertEqual(astuple(C()), ([],))
1206
1207 # compare
1208 @dataclass
1209 class C:
1210 x: list = field(default_factory=list, compare=False)
1211 self.assertEqual(C(), C([1]))
1212
1213 def test_default_factory_with_no_init(self):
1214 # We need a factory with a side effect.
1215 factory = Mock()
1216
1217 @dataclass
1218 class C:
1219 x: list = field(default_factory=factory, init=False)
1220
1221 # Make sure the default factory is called for each new instance.
1222 C().x
1223 self.assertEqual(factory.call_count, 1)
1224 C().x
1225 self.assertEqual(factory.call_count, 2)
1226
1227 def test_default_factory_not_called_if_value_given(self):
1228 # We need a factory that we can test if it's been called.
1229 factory = Mock()
1230
1231 @dataclass
1232 class C:
1233 x: int = field(default_factory=factory)
1234
1235 # Make sure that if a field has a default factory function,
1236 # it's not called if a value is specified.
1237 C().x
1238 self.assertEqual(factory.call_count, 1)
1239 self.assertEqual(C(10).x, 10)
1240 self.assertEqual(factory.call_count, 1)
1241 C().x
1242 self.assertEqual(factory.call_count, 2)
1243
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001244 def test_default_factory_derived(self):
1245 # See bpo-32896.
1246 @dataclass
1247 class Foo:
1248 x: dict = field(default_factory=dict)
1249
1250 @dataclass
1251 class Bar(Foo):
1252 y: int = 1
1253
1254 self.assertEqual(Foo().x, {})
1255 self.assertEqual(Bar().x, {})
1256 self.assertEqual(Bar().y, 1)
1257
1258 @dataclass
1259 class Baz(Foo):
1260 pass
1261 self.assertEqual(Baz().x, {})
1262
1263 def test_intermediate_non_dataclass(self):
1264 # Test that an intermediate class that defines
1265 # annotations does not define fields.
1266
1267 @dataclass
1268 class A:
1269 x: int
1270
1271 class B(A):
1272 y: int
1273
1274 @dataclass
1275 class C(B):
1276 z: int
1277
1278 c = C(1, 3)
1279 self.assertEqual((c.x, c.z), (1, 3))
1280
1281 # .y was not initialized.
1282 with self.assertRaisesRegex(AttributeError,
1283 'object has no attribute'):
1284 c.y
1285
1286 # And if we again derive a non-dataclass, no fields are added.
1287 class D(C):
1288 t: int
1289 d = D(4, 5)
1290 self.assertEqual((d.x, d.z), (4, 5))
1291
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001292 def test_classvar_default_factory(self):
1293 # It's an error for a ClassVar to have a factory function.
1294 with self.assertRaisesRegex(TypeError,
1295 'cannot have a default factory'):
1296 @dataclass
1297 class C:
1298 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001299
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001300 def test_is_dataclass(self):
1301 class NotDataClass:
1302 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001303
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001304 self.assertFalse(is_dataclass(0))
1305 self.assertFalse(is_dataclass(int))
1306 self.assertFalse(is_dataclass(NotDataClass))
1307 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001308
1309 @dataclass
1310 class C:
1311 x: int
1312
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001313 @dataclass
1314 class D:
1315 d: C
1316 e: int
1317
1318 c = C(10)
1319 d = D(c, 4)
1320
1321 self.assertTrue(is_dataclass(C))
1322 self.assertTrue(is_dataclass(c))
1323 self.assertFalse(is_dataclass(c.x))
1324 self.assertTrue(is_dataclass(d.d))
1325 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001326
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001327 def test_is_dataclass_when_getattr_always_returns(self):
1328 # See bpo-37868.
1329 class A:
1330 def __getattr__(self, key):
1331 return 0
1332 self.assertFalse(is_dataclass(A))
1333 a = A()
1334
1335 # Also test for an instance attribute.
1336 class B:
1337 pass
1338 b = B()
1339 b.__dataclass_fields__ = []
1340
1341 for obj in a, b:
1342 with self.subTest(obj=obj):
1343 self.assertFalse(is_dataclass(obj))
1344
1345 # Indirect tests for _is_dataclass_instance().
1346 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1347 asdict(obj)
1348 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1349 astuple(obj)
1350 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1351 replace(obj, x=0)
1352
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001353 def test_helper_fields_with_class_instance(self):
1354 # Check that we can call fields() on either a class or instance,
1355 # and get back the same thing.
1356 @dataclass
1357 class C:
1358 x: int
1359 y: float
1360
1361 self.assertEqual(fields(C), fields(C(0, 0.0)))
1362
1363 def test_helper_fields_exception(self):
1364 # Check that TypeError is raised if not passed a dataclass or
1365 # instance.
1366 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1367 fields(0)
1368
1369 class C: pass
1370 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1371 fields(C)
1372 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1373 fields(C())
1374
1375 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001376 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001377 @dataclass
1378 class C:
1379 x: int
1380 y: int
1381 c = C(1, 2)
1382
1383 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1384 self.assertEqual(asdict(c), asdict(c))
1385 self.assertIsNot(asdict(c), asdict(c))
1386 c.x = 42
1387 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1388 self.assertIs(type(asdict(c)), dict)
1389
1390 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001391 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001392 @dataclass
1393 class C:
1394 x: int
1395 y: int
1396 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1397 asdict(C)
1398 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1399 asdict(int)
1400
1401 def test_helper_asdict_copy_values(self):
1402 @dataclass
1403 class C:
1404 x: int
1405 y: List[int] = field(default_factory=list)
1406 initial = []
1407 c = C(1, initial)
1408 d = asdict(c)
1409 self.assertEqual(d['y'], initial)
1410 self.assertIsNot(d['y'], initial)
1411 c = C(1)
1412 d = asdict(c)
1413 d['y'].append(1)
1414 self.assertEqual(c.y, [])
1415
1416 def test_helper_asdict_nested(self):
1417 @dataclass
1418 class UserId:
1419 token: int
1420 group: int
1421 @dataclass
1422 class User:
1423 name: str
1424 id: UserId
1425 u = User('Joe', UserId(123, 1))
1426 d = asdict(u)
1427 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1428 self.assertIsNot(asdict(u), asdict(u))
1429 u.id.group = 2
1430 self.assertEqual(asdict(u), {'name': 'Joe',
1431 'id': {'token': 123, 'group': 2}})
1432
1433 def test_helper_asdict_builtin_containers(self):
1434 @dataclass
1435 class User:
1436 name: str
1437 id: int
1438 @dataclass
1439 class GroupList:
1440 id: int
1441 users: List[User]
1442 @dataclass
1443 class GroupTuple:
1444 id: int
1445 users: Tuple[User, ...]
1446 @dataclass
1447 class GroupDict:
1448 id: int
1449 users: Dict[str, User]
1450 a = User('Alice', 1)
1451 b = User('Bob', 2)
1452 gl = GroupList(0, [a, b])
1453 gt = GroupTuple(0, (a, b))
1454 gd = GroupDict(0, {'first': a, 'second': b})
1455 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1456 {'name': 'Bob', 'id': 2}]})
1457 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1458 {'name': 'Bob', 'id': 2})})
1459 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1460 'second': {'name': 'Bob', 'id': 2}}})
1461
Windson yangbe372d72019-04-23 02:45:34 +08001462 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001463 @dataclass
1464 class Child:
1465 d: object
1466
1467 @dataclass
1468 class Parent:
1469 child: Child
1470
1471 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1472 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1473
1474 def test_helper_asdict_factory(self):
1475 @dataclass
1476 class C:
1477 x: int
1478 y: int
1479 c = C(1, 2)
1480 d = asdict(c, dict_factory=OrderedDict)
1481 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1482 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1483 c.x = 42
1484 d = asdict(c, dict_factory=OrderedDict)
1485 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1486 self.assertIs(type(d), OrderedDict)
1487
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001488 def test_helper_asdict_namedtuple(self):
1489 T = namedtuple('T', 'a b c')
1490 @dataclass
1491 class C:
1492 x: str
1493 y: T
1494 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1495
1496 d = asdict(c)
1497 self.assertEqual(d, {'x': 'outer',
1498 'y': T(1,
1499 {'x': 'inner',
1500 'y': T(11, 12, 13)},
1501 2),
1502 }
1503 )
1504
1505 # Now with a dict_factory. OrderedDict is convenient, but
1506 # since it compares to dicts, we also need to have separate
1507 # assertIs tests.
1508 d = asdict(c, dict_factory=OrderedDict)
1509 self.assertEqual(d, {'x': 'outer',
1510 'y': T(1,
1511 {'x': 'inner',
1512 'y': T(11, 12, 13)},
1513 2),
1514 }
1515 )
1516
penguindustin96466302019-05-06 14:57:17 -04001517 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001518 self.assertIs(type(d), OrderedDict)
1519 self.assertIs(type(d['y'][1]), OrderedDict)
1520
1521 def test_helper_asdict_namedtuple_key(self):
1522 # Ensure that a field that contains a dict which has a
1523 # namedtuple as a key works with asdict().
1524
1525 @dataclass
1526 class C:
1527 f: dict
1528 T = namedtuple('T', 'a')
1529
1530 c = C({T('an a'): 0})
1531
1532 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1533
1534 def test_helper_asdict_namedtuple_derived(self):
1535 class T(namedtuple('Tbase', 'a')):
1536 def my_a(self):
1537 return self.a
1538
1539 @dataclass
1540 class C:
1541 f: T
1542
1543 t = T(6)
1544 c = C(t)
1545
1546 d = asdict(c)
1547 self.assertEqual(d, {'f': T(a=6)})
1548 # Make sure that t has been copied, not used directly.
1549 self.assertIsNot(d['f'], t)
1550 self.assertEqual(d['f'].my_a(), 6)
1551
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001552 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001553 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001554 @dataclass
1555 class C:
1556 x: int
1557 y: int = 0
1558 c = C(1)
1559
1560 self.assertEqual(astuple(c), (1, 0))
1561 self.assertEqual(astuple(c), astuple(c))
1562 self.assertIsNot(astuple(c), astuple(c))
1563 c.y = 42
1564 self.assertEqual(astuple(c), (1, 42))
1565 self.assertIs(type(astuple(c)), tuple)
1566
1567 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001568 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001569 @dataclass
1570 class C:
1571 x: int
1572 y: int
1573 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1574 astuple(C)
1575 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1576 astuple(int)
1577
1578 def test_helper_astuple_copy_values(self):
1579 @dataclass
1580 class C:
1581 x: int
1582 y: List[int] = field(default_factory=list)
1583 initial = []
1584 c = C(1, initial)
1585 t = astuple(c)
1586 self.assertEqual(t[1], initial)
1587 self.assertIsNot(t[1], initial)
1588 c = C(1)
1589 t = astuple(c)
1590 t[1].append(1)
1591 self.assertEqual(c.y, [])
1592
1593 def test_helper_astuple_nested(self):
1594 @dataclass
1595 class UserId:
1596 token: int
1597 group: int
1598 @dataclass
1599 class User:
1600 name: str
1601 id: UserId
1602 u = User('Joe', UserId(123, 1))
1603 t = astuple(u)
1604 self.assertEqual(t, ('Joe', (123, 1)))
1605 self.assertIsNot(astuple(u), astuple(u))
1606 u.id.group = 2
1607 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1608
1609 def test_helper_astuple_builtin_containers(self):
1610 @dataclass
1611 class User:
1612 name: str
1613 id: int
1614 @dataclass
1615 class GroupList:
1616 id: int
1617 users: List[User]
1618 @dataclass
1619 class GroupTuple:
1620 id: int
1621 users: Tuple[User, ...]
1622 @dataclass
1623 class GroupDict:
1624 id: int
1625 users: Dict[str, User]
1626 a = User('Alice', 1)
1627 b = User('Bob', 2)
1628 gl = GroupList(0, [a, b])
1629 gt = GroupTuple(0, (a, b))
1630 gd = GroupDict(0, {'first': a, 'second': b})
1631 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1632 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1633 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1634
Windson yangbe372d72019-04-23 02:45:34 +08001635 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001636 @dataclass
1637 class Child:
1638 d: object
1639
1640 @dataclass
1641 class Parent:
1642 child: Child
1643
1644 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1645 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1646
1647 def test_helper_astuple_factory(self):
1648 @dataclass
1649 class C:
1650 x: int
1651 y: int
1652 NT = namedtuple('NT', 'x y')
1653 def nt(lst):
1654 return NT(*lst)
1655 c = C(1, 2)
1656 t = astuple(c, tuple_factory=nt)
1657 self.assertEqual(t, NT(1, 2))
1658 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1659 c.x = 42
1660 t = astuple(c, tuple_factory=nt)
1661 self.assertEqual(t, NT(42, 2))
1662 self.assertIs(type(t), NT)
1663
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001664 def test_helper_astuple_namedtuple(self):
1665 T = namedtuple('T', 'a b c')
1666 @dataclass
1667 class C:
1668 x: str
1669 y: T
1670 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1671
1672 t = astuple(c)
1673 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1674
1675 # Now, using a tuple_factory. list is convenient here.
1676 t = astuple(c, tuple_factory=list)
1677 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1678
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001679 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001680 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001681 }
1682
1683 # Create the class.
1684 cls = type('C', (), cls_dict)
1685
1686 # Make it a dataclass.
1687 cls1 = dataclass(cls)
1688
1689 self.assertEqual(cls1, cls)
1690 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1691
1692 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001693 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001694 'y': field(default=5),
1695 }
1696
1697 # Create the class.
1698 cls = type('C', (), cls_dict)
1699
1700 # Make it a dataclass.
1701 cls1 = dataclass(cls)
1702
1703 self.assertEqual(cls1, cls)
1704 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1705
1706 def test_init_in_order(self):
1707 @dataclass
1708 class C:
1709 a: int
1710 b: int = field()
1711 c: list = field(default_factory=list, init=False)
1712 d: list = field(default_factory=list)
1713 e: int = field(default=4, init=False)
1714 f: int = 4
1715
1716 calls = []
1717 def setattr(self, name, value):
1718 calls.append((name, value))
1719
1720 C.__setattr__ = setattr
1721 c = C(0, 1)
1722 self.assertEqual(('a', 0), calls[0])
1723 self.assertEqual(('b', 1), calls[1])
1724 self.assertEqual(('c', []), calls[2])
1725 self.assertEqual(('d', []), calls[3])
1726 self.assertNotIn(('e', 4), calls)
1727 self.assertEqual(('f', 4), calls[4])
1728
1729 def test_items_in_dicts(self):
1730 @dataclass
1731 class C:
1732 a: int
1733 b: list = field(default_factory=list, init=False)
1734 c: list = field(default_factory=list)
1735 d: int = field(default=4, init=False)
1736 e: int = 0
1737
1738 c = C(0)
1739 # Class dict
1740 self.assertNotIn('a', C.__dict__)
1741 self.assertNotIn('b', C.__dict__)
1742 self.assertNotIn('c', C.__dict__)
1743 self.assertIn('d', C.__dict__)
1744 self.assertEqual(C.d, 4)
1745 self.assertIn('e', C.__dict__)
1746 self.assertEqual(C.e, 0)
1747 # Instance dict
1748 self.assertIn('a', c.__dict__)
1749 self.assertEqual(c.a, 0)
1750 self.assertIn('b', c.__dict__)
1751 self.assertEqual(c.b, [])
1752 self.assertIn('c', c.__dict__)
1753 self.assertEqual(c.c, [])
1754 self.assertNotIn('d', c.__dict__)
1755 self.assertIn('e', c.__dict__)
1756 self.assertEqual(c.e, 0)
1757
1758 def test_alternate_classmethod_constructor(self):
1759 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001760 # alternate constructor. This is mostly an example to show
1761 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001762 @dataclass
1763 class C:
1764 x: int
1765 @classmethod
1766 def from_file(cls, filename):
1767 # In a real example, create a new instance
1768 # and populate 'x' from contents of a file.
1769 value_in_file = 20
1770 return cls(value_in_file)
1771
1772 self.assertEqual(C.from_file('filename').x, 20)
1773
1774 def test_field_metadata_default(self):
1775 # Make sure the default metadata is read-only and of
1776 # zero length.
1777 @dataclass
1778 class C:
1779 i: int
1780
1781 self.assertFalse(fields(C)[0].metadata)
1782 self.assertEqual(len(fields(C)[0].metadata), 0)
1783 with self.assertRaisesRegex(TypeError,
1784 'does not support item assignment'):
1785 fields(C)[0].metadata['test'] = 3
1786
1787 def test_field_metadata_mapping(self):
1788 # Make sure only a mapping can be passed as metadata
1789 # zero length.
1790 with self.assertRaises(TypeError):
1791 @dataclass
1792 class C:
1793 i: int = field(metadata=0)
1794
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001795 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001796 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001797 @dataclass
1798 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001799 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001800 self.assertFalse(fields(C)[0].metadata)
1801 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001802 # Update should work (see bpo-35960).
1803 d['foo'] = 1
1804 self.assertEqual(len(fields(C)[0].metadata), 1)
1805 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001806 with self.assertRaisesRegex(TypeError,
1807 'does not support item assignment'):
1808 fields(C)[0].metadata['test'] = 3
1809
1810 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001811 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001812 @dataclass
1813 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001814 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001815 self.assertEqual(len(fields(C)[0].metadata), 3)
1816 self.assertEqual(fields(C)[0].metadata['test'], 10)
1817 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1818 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001819 # Update should work.
1820 d['foo'] = 1
1821 self.assertEqual(len(fields(C)[0].metadata), 4)
1822 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001823 with self.assertRaises(KeyError):
1824 # Non-existent key.
1825 fields(C)[0].metadata['baz']
1826 with self.assertRaisesRegex(TypeError,
1827 'does not support item assignment'):
1828 fields(C)[0].metadata['test'] = 3
1829
1830 def test_field_metadata_custom_mapping(self):
1831 # Try a custom mapping.
1832 class SimpleNameSpace:
1833 def __init__(self, **kw):
1834 self.__dict__.update(kw)
1835
1836 def __getitem__(self, item):
1837 if item == 'xyzzy':
1838 return 'plugh'
1839 return getattr(self, item)
1840
1841 def __len__(self):
1842 return self.__dict__.__len__()
1843
1844 @dataclass
1845 class C:
1846 i: int = field(metadata=SimpleNameSpace(a=10))
1847
1848 self.assertEqual(len(fields(C)[0].metadata), 1)
1849 self.assertEqual(fields(C)[0].metadata['a'], 10)
1850 with self.assertRaises(AttributeError):
1851 fields(C)[0].metadata['b']
1852 # Make sure we're still talking to our custom mapping.
1853 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1854
1855 def test_generic_dataclasses(self):
1856 T = TypeVar('T')
1857
1858 @dataclass
1859 class LabeledBox(Generic[T]):
1860 content: T
1861 label: str = '<unknown>'
1862
1863 box = LabeledBox(42)
1864 self.assertEqual(box.content, 42)
1865 self.assertEqual(box.label, '<unknown>')
1866
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001867 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001868 Alias = List[LabeledBox[int]]
1869
1870 def test_generic_extending(self):
1871 S = TypeVar('S')
1872 T = TypeVar('T')
1873
1874 @dataclass
1875 class Base(Generic[T, S]):
1876 x: T
1877 y: S
1878
1879 @dataclass
1880 class DataDerived(Base[int, T]):
1881 new_field: str
1882 Alias = DataDerived[str]
1883 c = Alias(0, 'test1', 'test2')
1884 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1885
1886 class NonDataDerived(Base[int, T]):
1887 def new_method(self):
1888 return self.y
1889 Alias = NonDataDerived[float]
1890 c = Alias(10, 1.0)
1891 self.assertEqual(c.new_method(), 1.0)
1892
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001893 def test_generic_dynamic(self):
1894 T = TypeVar('T')
1895
1896 @dataclass
1897 class Parent(Generic[T]):
1898 x: T
1899 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1900 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1901 self.assertIs(Child[int](1, 2).z, None)
1902 self.assertEqual(Child[int](1, 2, 3).z, 3)
1903 self.assertEqual(Child[int](1, 2, 3).other, 42)
1904 # Check that type aliases work correctly.
1905 Alias = Child[T]
1906 self.assertEqual(Alias[int](1, 2).x, 1)
1907 # Check MRO resolution.
1908 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1909
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001910 def test_dataclassses_pickleable(self):
1911 global P, Q, R
1912 @dataclass
1913 class P:
1914 x: int
1915 y: int = 0
1916 @dataclass
1917 class Q:
1918 x: int
1919 y: int = field(default=0, init=False)
1920 @dataclass
1921 class R:
1922 x: int
1923 y: List[int] = field(default_factory=list)
1924 q = Q(1)
1925 q.y = 2
1926 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1927 for sample in samples:
1928 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1929 with self.subTest(sample=sample, proto=proto):
1930 new_sample = pickle.loads(pickle.dumps(sample, proto))
1931 self.assertEqual(sample.x, new_sample.x)
1932 self.assertEqual(sample.y, new_sample.y)
1933 self.assertIsNot(sample, new_sample)
1934 new_sample.x = 42
1935 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1936 self.assertEqual(new_sample.x, another_new_sample.x)
1937 self.assertEqual(sample.y, another_new_sample.y)
1938
Batuhan Taskayac7437e22020-10-21 16:49:22 +03001939 def test_dataclasses_qualnames(self):
1940 @dataclass(order=True, unsafe_hash=True, frozen=True)
1941 class A:
1942 x: int
1943 y: int
1944
1945 self.assertEqual(A.__init__.__name__, "__init__")
1946 for function in (
1947 '__eq__',
1948 '__lt__',
1949 '__le__',
1950 '__gt__',
1951 '__ge__',
1952 '__hash__',
1953 '__init__',
1954 '__repr__',
1955 '__setattr__',
1956 '__delattr__',
1957 ):
1958 self.assertEqual(getattr(A, function).__qualname__, f"TestCase.test_dataclasses_qualnames.<locals>.A.{function}")
1959
1960 with self.assertRaisesRegex(TypeError, r"A\.__init__\(\) missing"):
1961 A()
1962
Eric V. Smithea8fc522018-01-27 19:07:40 -05001963
Eric V. Smith56970b82018-03-22 16:28:48 -04001964class TestFieldNoAnnotation(unittest.TestCase):
1965 def test_field_without_annotation(self):
1966 with self.assertRaisesRegex(TypeError,
1967 "'f' is a field but has no type annotation"):
1968 @dataclass
1969 class C:
1970 f = field()
1971
1972 def test_field_without_annotation_but_annotation_in_base(self):
1973 @dataclass
1974 class B:
1975 f: int
1976
1977 with self.assertRaisesRegex(TypeError,
1978 "'f' is a field but has no type annotation"):
1979 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001980 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001981 @dataclass
1982 class C(B):
1983 f = field()
1984
1985 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1986 # Same test, but with the base class not a dataclass.
1987 class B:
1988 f: int
1989
1990 with self.assertRaisesRegex(TypeError,
1991 "'f' is a field but has no type annotation"):
1992 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001993 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001994 @dataclass
1995 class C(B):
1996 f = field()
1997
1998
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001999class TestDocString(unittest.TestCase):
2000 def assertDocStrEqual(self, a, b):
2001 # Because 3.6 and 3.7 differ in how inspect.signature work
2002 # (see bpo #32108), for the time being just compare them with
2003 # whitespace stripped.
2004 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
2005
2006 def test_existing_docstring_not_overridden(self):
2007 @dataclass
2008 class C:
2009 """Lorem ipsum"""
2010 x: int
2011
2012 self.assertEqual(C.__doc__, "Lorem ipsum")
2013
2014 def test_docstring_no_fields(self):
2015 @dataclass
2016 class C:
2017 pass
2018
2019 self.assertDocStrEqual(C.__doc__, "C()")
2020
2021 def test_docstring_one_field(self):
2022 @dataclass
2023 class C:
2024 x: int
2025
2026 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2027
2028 def test_docstring_two_fields(self):
2029 @dataclass
2030 class C:
2031 x: int
2032 y: int
2033
2034 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2035
2036 def test_docstring_three_fields(self):
2037 @dataclass
2038 class C:
2039 x: int
2040 y: int
2041 z: str
2042
2043 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2044
2045 def test_docstring_one_field_with_default(self):
2046 @dataclass
2047 class C:
2048 x: int = 3
2049
2050 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2051
2052 def test_docstring_one_field_with_default_none(self):
2053 @dataclass
2054 class C:
2055 x: Union[int, type(None)] = None
2056
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002057 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002058
2059 def test_docstring_list_field(self):
2060 @dataclass
2061 class C:
2062 x: List[int]
2063
2064 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2065
2066 def test_docstring_list_field_with_default_factory(self):
2067 @dataclass
2068 class C:
2069 x: List[int] = field(default_factory=list)
2070
2071 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2072
2073 def test_docstring_deque_field(self):
2074 @dataclass
2075 class C:
2076 x: deque
2077
2078 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2079
2080 def test_docstring_deque_field_with_default_factory(self):
2081 @dataclass
2082 class C:
2083 x: deque = field(default_factory=deque)
2084
2085 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2086
2087
Eric V. Smithea8fc522018-01-27 19:07:40 -05002088class TestInit(unittest.TestCase):
2089 def test_base_has_init(self):
2090 class B:
2091 def __init__(self):
2092 self.z = 100
2093 pass
2094
2095 # Make sure that declaring this class doesn't raise an error.
2096 # The issue is that we can't override __init__ in our class,
2097 # but it should be okay to add __init__ to us if our base has
2098 # an __init__.
2099 @dataclass
2100 class C(B):
2101 x: int = 0
2102 c = C(10)
2103 self.assertEqual(c.x, 10)
2104 self.assertNotIn('z', vars(c))
2105
2106 # Make sure that if we don't add an init, the base __init__
2107 # gets called.
2108 @dataclass(init=False)
2109 class C(B):
2110 x: int = 10
2111 c = C()
2112 self.assertEqual(c.x, 10)
2113 self.assertEqual(c.z, 100)
2114
2115 def test_no_init(self):
2116 dataclass(init=False)
2117 class C:
2118 i: int = 0
2119 self.assertEqual(C().i, 0)
2120
2121 dataclass(init=False)
2122 class C:
2123 i: int = 2
2124 def __init__(self):
2125 self.i = 3
2126 self.assertEqual(C().i, 3)
2127
2128 def test_overwriting_init(self):
2129 # If the class has __init__, use it no matter the value of
2130 # init=.
2131
2132 @dataclass
2133 class C:
2134 x: int
2135 def __init__(self, x):
2136 self.x = 2 * x
2137 self.assertEqual(C(3).x, 6)
2138
2139 @dataclass(init=True)
2140 class C:
2141 x: int
2142 def __init__(self, x):
2143 self.x = 2 * x
2144 self.assertEqual(C(4).x, 8)
2145
2146 @dataclass(init=False)
2147 class C:
2148 x: int
2149 def __init__(self, x):
2150 self.x = 2 * x
2151 self.assertEqual(C(5).x, 10)
2152
2153
2154class TestRepr(unittest.TestCase):
2155 def test_repr(self):
2156 @dataclass
2157 class B:
2158 x: int
2159
2160 @dataclass
2161 class C(B):
2162 y: int = 10
2163
2164 o = C(4)
2165 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2166
2167 @dataclass
2168 class D(C):
2169 x: int = 20
2170 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2171
2172 @dataclass
2173 class C:
2174 @dataclass
2175 class D:
2176 i: int
2177 @dataclass
2178 class E:
2179 pass
2180 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2181 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2182
2183 def test_no_repr(self):
2184 # Test a class with no __repr__ and repr=False.
2185 @dataclass(repr=False)
2186 class C:
2187 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002188 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002189 repr(C(3)))
2190
2191 # Test a class with a __repr__ and repr=False.
2192 @dataclass(repr=False)
2193 class C:
2194 x: int
2195 def __repr__(self):
2196 return 'C-class'
2197 self.assertEqual(repr(C(3)), 'C-class')
2198
2199 def test_overwriting_repr(self):
2200 # If the class has __repr__, use it no matter the value of
2201 # repr=.
2202
2203 @dataclass
2204 class C:
2205 x: int
2206 def __repr__(self):
2207 return 'x'
2208 self.assertEqual(repr(C(0)), 'x')
2209
2210 @dataclass(repr=True)
2211 class C:
2212 x: int
2213 def __repr__(self):
2214 return 'x'
2215 self.assertEqual(repr(C(0)), 'x')
2216
2217 @dataclass(repr=False)
2218 class C:
2219 x: int
2220 def __repr__(self):
2221 return 'x'
2222 self.assertEqual(repr(C(0)), 'x')
2223
2224
Eric V. Smithea8fc522018-01-27 19:07:40 -05002225class TestEq(unittest.TestCase):
2226 def test_no_eq(self):
2227 # Test a class with no __eq__ and eq=False.
2228 @dataclass(eq=False)
2229 class C:
2230 x: int
2231 self.assertNotEqual(C(0), C(0))
2232 c = C(3)
2233 self.assertEqual(c, c)
2234
2235 # Test a class with an __eq__ and eq=False.
2236 @dataclass(eq=False)
2237 class C:
2238 x: int
2239 def __eq__(self, other):
2240 return other == 10
2241 self.assertEqual(C(3), 10)
2242
2243 def test_overwriting_eq(self):
2244 # If the class has __eq__, use it no matter the value of
2245 # eq=.
2246
2247 @dataclass
2248 class C:
2249 x: int
2250 def __eq__(self, other):
2251 return other == 3
2252 self.assertEqual(C(1), 3)
2253 self.assertNotEqual(C(1), 1)
2254
2255 @dataclass(eq=True)
2256 class C:
2257 x: int
2258 def __eq__(self, other):
2259 return other == 4
2260 self.assertEqual(C(1), 4)
2261 self.assertNotEqual(C(1), 1)
2262
2263 @dataclass(eq=False)
2264 class C:
2265 x: int
2266 def __eq__(self, other):
2267 return other == 5
2268 self.assertEqual(C(1), 5)
2269 self.assertNotEqual(C(1), 1)
2270
2271
2272class TestOrdering(unittest.TestCase):
2273 def test_functools_total_ordering(self):
2274 # Test that functools.total_ordering works with this class.
2275 @total_ordering
2276 @dataclass
2277 class C:
2278 x: int
2279 def __lt__(self, other):
2280 # Perform the test "backward", just to make
2281 # sure this is being called.
2282 return self.x >= other
2283
2284 self.assertLess(C(0), -1)
2285 self.assertLessEqual(C(0), -1)
2286 self.assertGreater(C(0), 1)
2287 self.assertGreaterEqual(C(0), 1)
2288
2289 def test_no_order(self):
2290 # Test that no ordering functions are added by default.
2291 @dataclass(order=False)
2292 class C:
2293 x: int
2294 # Make sure no order methods are added.
2295 self.assertNotIn('__le__', C.__dict__)
2296 self.assertNotIn('__lt__', C.__dict__)
2297 self.assertNotIn('__ge__', C.__dict__)
2298 self.assertNotIn('__gt__', C.__dict__)
2299
2300 # Test that __lt__ is still called
2301 @dataclass(order=False)
2302 class C:
2303 x: int
2304 def __lt__(self, other):
2305 return False
2306 # Make sure other methods aren't added.
2307 self.assertNotIn('__le__', C.__dict__)
2308 self.assertNotIn('__ge__', C.__dict__)
2309 self.assertNotIn('__gt__', C.__dict__)
2310
2311 def test_overwriting_order(self):
2312 with self.assertRaisesRegex(TypeError,
2313 'Cannot overwrite attribute __lt__'
2314 '.*using functools.total_ordering'):
2315 @dataclass(order=True)
2316 class C:
2317 x: int
2318 def __lt__(self):
2319 pass
2320
2321 with self.assertRaisesRegex(TypeError,
2322 'Cannot overwrite attribute __le__'
2323 '.*using functools.total_ordering'):
2324 @dataclass(order=True)
2325 class C:
2326 x: int
2327 def __le__(self):
2328 pass
2329
2330 with self.assertRaisesRegex(TypeError,
2331 'Cannot overwrite attribute __gt__'
2332 '.*using functools.total_ordering'):
2333 @dataclass(order=True)
2334 class C:
2335 x: int
2336 def __gt__(self):
2337 pass
2338
2339 with self.assertRaisesRegex(TypeError,
2340 'Cannot overwrite attribute __ge__'
2341 '.*using functools.total_ordering'):
2342 @dataclass(order=True)
2343 class C:
2344 x: int
2345 def __ge__(self):
2346 pass
2347
2348class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002349 def test_unsafe_hash(self):
2350 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002351 class C:
2352 x: int
2353 y: str
2354 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2355
Eric V. Smithea8fc522018-01-27 19:07:40 -05002356 def test_hash_rules(self):
2357 def non_bool(value):
2358 # Map to something else that's True, but not a bool.
2359 if value is None:
2360 return None
2361 if value:
2362 return (3,)
2363 return 0
2364
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002365 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2366 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2367 frozen=frozen):
2368 if result != 'exception':
2369 if with_hash:
2370 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2371 class C:
2372 def __hash__(self):
2373 return 0
2374 else:
2375 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2376 class C:
2377 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002378
2379 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002380 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002381 # __hash__ contains the function we generated.
2382 self.assertIn('__hash__', C.__dict__)
2383 self.assertIsNotNone(C.__dict__['__hash__'])
2384
Eric V. Smithea8fc522018-01-27 19:07:40 -05002385 elif result == '':
2386 # __hash__ is not present in our class.
2387 if not with_hash:
2388 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002389
Eric V. Smithea8fc522018-01-27 19:07:40 -05002390 elif result == 'none':
2391 # __hash__ is set to None.
2392 self.assertIn('__hash__', C.__dict__)
2393 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002394
2395 elif result == 'exception':
2396 # Creating the class should cause an exception.
2397 # This only happens with with_hash==True.
2398 assert(with_hash)
2399 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2400 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2401 class C:
2402 def __hash__(self):
2403 return 0
2404
Eric V. Smithea8fc522018-01-27 19:07:40 -05002405 else:
2406 assert False, f'unknown result {result!r}'
2407
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002408 # There are 8 cases of:
2409 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002410 # eq=True/False
2411 # frozen=True/False
2412 # And for each of these, a different result if
2413 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002414 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2415 (False, False, False, '', ''),
2416 (False, False, True, '', ''),
2417 (False, True, False, 'none', ''),
2418 (False, True, True, 'fn', ''),
2419 (True, False, False, 'fn', 'exception'),
2420 (True, False, True, 'fn', 'exception'),
2421 (True, True, False, 'fn', 'exception'),
2422 (True, True, True, 'fn', 'exception'),
2423 ], 1):
2424 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2425 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002426
2427 # Test non-bool truth values, too. This is just to
2428 # make sure the data-driven table in the decorator
2429 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002430 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2431 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002432
2433
2434 def test_eq_only(self):
2435 # If a class defines __eq__, __hash__ is automatically added
2436 # and set to None. This is normal Python behavior, not
2437 # related to dataclasses. Make sure we don't interfere with
2438 # that (see bpo=32546).
2439
2440 @dataclass
2441 class C:
2442 i: int
2443 def __eq__(self, other):
2444 return self.i == other.i
2445 self.assertEqual(C(1), C(1))
2446 self.assertNotEqual(C(1), C(4))
2447
2448 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002449 # unsafe_hash=True.
2450 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002451 class C:
2452 i: int
2453 def __eq__(self, other):
2454 return self.i == other.i
2455 self.assertEqual(C(1), C(1.0))
2456 self.assertEqual(hash(C(1)), hash(C(1.0)))
2457
2458 # And check that the classes __eq__ is being used, despite
2459 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002460 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002461 class C:
2462 i: int
2463 def __eq__(self, other):
2464 return self.i == 3 and self.i == other.i
2465 self.assertEqual(C(3), C(3))
2466 self.assertNotEqual(C(1), C(1))
2467 self.assertEqual(hash(C(1)), hash(C(1.0)))
2468
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002469 def test_0_field_hash(self):
2470 @dataclass(frozen=True)
2471 class C:
2472 pass
2473 self.assertEqual(hash(C()), hash(()))
2474
2475 @dataclass(unsafe_hash=True)
2476 class C:
2477 pass
2478 self.assertEqual(hash(C()), hash(()))
2479
2480 def test_1_field_hash(self):
2481 @dataclass(frozen=True)
2482 class C:
2483 x: int
2484 self.assertEqual(hash(C(4)), hash((4,)))
2485 self.assertEqual(hash(C(42)), hash((42,)))
2486
2487 @dataclass(unsafe_hash=True)
2488 class C:
2489 x: int
2490 self.assertEqual(hash(C(4)), hash((4,)))
2491 self.assertEqual(hash(C(42)), hash((42,)))
2492
Eric V. Smith718070d2018-02-23 13:01:31 -05002493 def test_hash_no_args(self):
2494 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002495 # make sure that if the @dataclass parameter name is changed
2496 # or the non-default hashing behavior changes, the default
2497 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002498
2499 class Base:
2500 def __hash__(self):
2501 return 301
2502
2503 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002504 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002505 for frozen, eq, base, expected in [
2506 (None, None, object, 'unhashable'),
2507 (None, None, Base, 'unhashable'),
2508 (None, False, object, 'object'),
2509 (None, False, Base, 'base'),
2510 (None, True, object, 'unhashable'),
2511 (None, True, Base, 'unhashable'),
2512 (False, None, object, 'unhashable'),
2513 (False, None, Base, 'unhashable'),
2514 (False, False, object, 'object'),
2515 (False, False, Base, 'base'),
2516 (False, True, object, 'unhashable'),
2517 (False, True, Base, 'unhashable'),
2518 (True, None, object, 'tuple'),
2519 (True, None, Base, 'tuple'),
2520 (True, False, object, 'object'),
2521 (True, False, Base, 'base'),
2522 (True, True, object, 'tuple'),
2523 (True, True, Base, 'tuple'),
2524 ]:
2525
2526 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2527 # First, create the class.
2528 if frozen is None and eq is None:
2529 @dataclass
2530 class C(base):
2531 i: int
2532 elif frozen is None:
2533 @dataclass(eq=eq)
2534 class C(base):
2535 i: int
2536 elif eq is None:
2537 @dataclass(frozen=frozen)
2538 class C(base):
2539 i: int
2540 else:
2541 @dataclass(frozen=frozen, eq=eq)
2542 class C(base):
2543 i: int
2544
2545 # Now, make sure it hashes as expected.
2546 if expected == 'unhashable':
2547 c = C(10)
2548 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2549 hash(c)
2550
2551 elif expected == 'base':
2552 self.assertEqual(hash(C(10)), 301)
2553
2554 elif expected == 'object':
2555 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002556 # hash isn't based on id(), so calling hash()
2557 # won't tell us much. So, just check the
2558 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002559 self.assertIs(C.__hash__, object.__hash__)
2560
2561 elif expected == 'tuple':
2562 self.assertEqual(hash(C(42)), hash((42,)))
2563
2564 else:
2565 assert False, f'unknown value for expected={expected!r}'
2566
Eric V. Smithea8fc522018-01-27 19:07:40 -05002567
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002568class TestFrozen(unittest.TestCase):
2569 def test_frozen(self):
2570 @dataclass(frozen=True)
2571 class C:
2572 i: int
2573
2574 c = C(10)
2575 self.assertEqual(c.i, 10)
2576 with self.assertRaises(FrozenInstanceError):
2577 c.i = 5
2578 self.assertEqual(c.i, 10)
2579
2580 def test_inherit(self):
2581 @dataclass(frozen=True)
2582 class C:
2583 i: int
2584
2585 @dataclass(frozen=True)
2586 class D(C):
2587 j: int
2588
2589 d = D(0, 10)
2590 with self.assertRaises(FrozenInstanceError):
2591 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002592 with self.assertRaises(FrozenInstanceError):
2593 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002594 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002595 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002596
Eric V. Smithf199bc62018-03-18 20:40:34 -04002597 # Test both ways: with an intermediate normal (non-dataclass)
2598 # class and without an intermediate class.
2599 def test_inherit_nonfrozen_from_frozen(self):
2600 for intermediate_class in [True, False]:
2601 with self.subTest(intermediate_class=intermediate_class):
2602 @dataclass(frozen=True)
2603 class C:
2604 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002605
Eric V. Smithf199bc62018-03-18 20:40:34 -04002606 if intermediate_class:
2607 class I(C): pass
2608 else:
2609 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002610
Eric V. Smithf199bc62018-03-18 20:40:34 -04002611 with self.assertRaisesRegex(TypeError,
2612 'cannot inherit non-frozen dataclass from a frozen one'):
2613 @dataclass
2614 class D(I):
2615 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002616
Eric V. Smithf199bc62018-03-18 20:40:34 -04002617 def test_inherit_frozen_from_nonfrozen(self):
2618 for intermediate_class in [True, False]:
2619 with self.subTest(intermediate_class=intermediate_class):
2620 @dataclass
2621 class C:
2622 i: int
2623
2624 if intermediate_class:
2625 class I(C): pass
2626 else:
2627 I = C
2628
2629 with self.assertRaisesRegex(TypeError,
2630 'cannot inherit frozen dataclass from a non-frozen one'):
2631 @dataclass(frozen=True)
2632 class D(I):
2633 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002634
2635 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002636 for intermediate_class in [True, False]:
2637 with self.subTest(intermediate_class=intermediate_class):
2638 class C:
2639 pass
2640
2641 if intermediate_class:
2642 class I(C): pass
2643 else:
2644 I = C
2645
2646 @dataclass(frozen=True)
2647 class D(I):
2648 i: int
2649
2650 d = D(10)
2651 with self.assertRaises(FrozenInstanceError):
2652 d.i = 5
2653
2654 def test_non_frozen_normal_derived(self):
2655 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002656
2657 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002658 class D:
2659 x: int
2660 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002661
Eric V. Smithf199bc62018-03-18 20:40:34 -04002662 class S(D):
2663 pass
2664
2665 s = S(3)
2666 self.assertEqual(s.x, 3)
2667 self.assertEqual(s.y, 10)
2668 s.cached = True
2669
2670 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002671 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002672 s.x = 5
2673 with self.assertRaises(FrozenInstanceError):
2674 s.y = 5
2675 self.assertEqual(s.x, 3)
2676 self.assertEqual(s.y, 10)
2677 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002678
Eric V. Smith74940912018-04-05 06:50:18 -04002679 def test_overwriting_frozen(self):
2680 # frozen uses __setattr__ and __delattr__.
2681 with self.assertRaisesRegex(TypeError,
2682 'Cannot overwrite attribute __setattr__'):
2683 @dataclass(frozen=True)
2684 class C:
2685 x: int
2686 def __setattr__(self):
2687 pass
2688
2689 with self.assertRaisesRegex(TypeError,
2690 'Cannot overwrite attribute __delattr__'):
2691 @dataclass(frozen=True)
2692 class C:
2693 x: int
2694 def __delattr__(self):
2695 pass
2696
2697 @dataclass(frozen=False)
2698 class C:
2699 x: int
2700 def __setattr__(self, name, value):
2701 self.__dict__['x'] = value * 2
2702 self.assertEqual(C(10).x, 20)
2703
2704 def test_frozen_hash(self):
2705 @dataclass(frozen=True)
2706 class C:
2707 x: Any
2708
2709 # If x is immutable, we can compute the hash. No exception is
2710 # raised.
2711 hash(C(3))
2712
2713 # If x is mutable, computing the hash is an error.
2714 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2715 hash(C({}))
2716
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002717
Eric V. Smith7389fd92018-03-19 21:07:51 -04002718class TestSlots(unittest.TestCase):
2719 def test_simple(self):
2720 @dataclass
2721 class C:
2722 __slots__ = ('x',)
2723 x: Any
2724
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002725 # There was a bug where a variable in a slot was assumed to
2726 # also have a default value (of type
2727 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002728 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002729 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002730 C()
2731
2732 # We can create an instance, and assign to x.
2733 c = C(10)
2734 self.assertEqual(c.x, 10)
2735 c.x = 5
2736 self.assertEqual(c.x, 5)
2737
2738 # We can't assign to anything else.
2739 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2740 c.y = 5
2741
2742 def test_derived_added_field(self):
2743 # See bpo-33100.
2744 @dataclass
2745 class Base:
2746 __slots__ = ('x',)
2747 x: Any
2748
2749 @dataclass
2750 class Derived(Base):
2751 x: int
2752 y: int
2753
2754 d = Derived(1, 2)
2755 self.assertEqual((d.x, d.y), (1, 2))
2756
2757 # We can add a new field to the derived instance.
2758 d.z = 10
2759
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002760class TestDescriptors(unittest.TestCase):
2761 def test_set_name(self):
2762 # See bpo-33141.
2763
2764 # Create a descriptor.
2765 class D:
2766 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002767 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002768 def __get__(self, instance, owner):
2769 if instance is not None:
2770 return 1
2771 return self
2772
2773 # This is the case of just normal descriptor behavior, no
2774 # dataclass code is involved in initializing the descriptor.
2775 @dataclass
2776 class C:
2777 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002778 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002779
2780 # Now test with a default value and init=False, which is the
2781 # only time this is really meaningful. If not using
2782 # init=False, then the descriptor will be overwritten, anyway.
2783 @dataclass
2784 class C:
2785 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002786 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002787 self.assertEqual(C().c, 1)
2788
2789 def test_non_descriptor(self):
2790 # PEP 487 says __set_name__ should work on non-descriptors.
2791 # Create a descriptor.
2792
2793 class D:
2794 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002795 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002796
2797 @dataclass
2798 class C:
2799 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002800 self.assertEqual(C.c.name, 'cx')
2801
2802 def test_lookup_on_instance(self):
2803 # See bpo-33175.
2804 class D:
2805 pass
2806
2807 d = D()
2808 # Create an attribute on the instance, not type.
2809 d.__set_name__ = Mock()
2810
2811 # Make sure d.__set_name__ is not called.
2812 @dataclass
2813 class C:
2814 i: int=field(default=d, init=False)
2815
2816 self.assertEqual(d.__set_name__.call_count, 0)
2817
2818 def test_lookup_on_class(self):
2819 # See bpo-33175.
2820 class D:
2821 pass
2822 D.__set_name__ = Mock()
2823
2824 # Make sure D.__set_name__ is called.
2825 @dataclass
2826 class C:
2827 i: int=field(default=D(), init=False)
2828
2829 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002830
Eric V. Smith7389fd92018-03-19 21:07:51 -04002831
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002832class TestStringAnnotations(unittest.TestCase):
2833 def test_classvar(self):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002834 # These tests assume that both "import typing" and "from
2835 # typing import *" have been run in this file.
2836 for typestr in ('ClassVar[int]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002837 'ClassVar [int]',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002838 ' ClassVar [int]',
2839 'ClassVar',
2840 ' ClassVar ',
2841 'typing.ClassVar[int]',
2842 'typing.ClassVar[str]',
2843 ' typing.ClassVar[str]',
2844 'typing .ClassVar[str]',
2845 'typing. ClassVar[str]',
2846 'typing.ClassVar [str]',
2847 'typing.ClassVar [ str]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002848 # Double stringified
2849 '"typing.ClassVar[int]"',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002850 # Not syntactically valid, but these will
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002851 # be treated as ClassVars.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002852 'typing.ClassVar.[int]',
2853 'typing.ClassVar+',
2854 ):
2855 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002856 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002857 # x is a ClassVar, so C() takes no args.
2858 C()
2859
2860 # And it won't appear in the class's dict because it doesn't
2861 # have a default.
2862 self.assertNotIn('x', C.__dict__)
2863
2864 def test_isnt_classvar(self):
2865 for typestr in ('CV',
2866 't.ClassVar',
2867 't.ClassVar[int]',
2868 'typing..ClassVar[int]',
2869 'Classvar',
2870 'Classvar[int]',
2871 'typing.ClassVarx[int]',
2872 'typong.ClassVar[int]',
2873 'dataclasses.ClassVar[int]',
2874 'typingxClassVar[str]',
2875 ):
2876 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002877 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002878
2879 # x is not a ClassVar, so C() takes one arg.
2880 self.assertEqual(C(10).x, 10)
2881
2882 def test_initvar(self):
2883 # These tests assume that both "import dataclasses" and "from
2884 # dataclasses import *" have been run in this file.
2885 for typestr in ('InitVar[int]',
2886 'InitVar [int]'
2887 ' InitVar [int]',
2888 'InitVar',
2889 ' InitVar ',
2890 'dataclasses.InitVar[int]',
2891 'dataclasses.InitVar[str]',
2892 ' dataclasses.InitVar[str]',
2893 'dataclasses .InitVar[str]',
2894 'dataclasses. InitVar[str]',
2895 'dataclasses.InitVar [str]',
2896 'dataclasses.InitVar [ str]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002897 # Double stringified
2898 '"dataclasses.InitVar[int]"',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002899 # Not syntactically valid, but these will
2900 # be treated as InitVars.
2901 'dataclasses.InitVar.[int]',
2902 'dataclasses.InitVar+',
2903 ):
2904 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002905 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
2906
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002907
2908 # x is an InitVar, so doesn't create a member.
2909 with self.assertRaisesRegex(AttributeError,
2910 "object has no attribute 'x'"):
2911 C(1).x
2912
2913 def test_isnt_initvar(self):
2914 for typestr in ('IV',
2915 'dc.InitVar',
2916 'xdataclasses.xInitVar',
2917 'typing.xInitVar[int]',
2918 ):
2919 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002920 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002921
2922 # x is not an InitVar, so there will be a member x.
2923 self.assertEqual(C(10).x, 10)
2924
2925 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002926 from test import dataclass_module_1
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002927 from test import dataclass_module_2
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002928
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002929 for m in (dataclass_module_1,
2930 dataclass_module_2):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002931 with self.subTest(m=m):
2932 # There's a difference in how the ClassVars are
2933 # interpreted when using string annotations or
2934 # not. See the imported modules for details.
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002935 c = m.CV(10)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002936 self.assertEqual(c.cv0, 20)
2937
2938
2939 # There's a difference in how the InitVars are
2940 # interpreted when using string annotations or
2941 # not. See the imported modules for details.
2942 c = m.IV(0, 1, 2, 3, 4)
2943
2944 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2945 with self.subTest(field_name=field_name):
2946 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2947 # Since field_name is an InitVar, it's
2948 # not an instance field.
2949 getattr(c, field_name)
2950
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002951 # iv4 is interpreted as a normal field.
2952 self.assertIn('not_iv4', c.__dict__)
2953 self.assertEqual(c.not_iv4, 4)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002954
Yury Selivanovd219cc42019-12-09 09:54:20 -05002955 def test_text_annotations(self):
2956 from test import dataclass_textanno
2957
2958 self.assertEqual(
2959 get_type_hints(dataclass_textanno.Bar),
2960 {'foo': dataclass_textanno.Foo})
2961 self.assertEqual(
2962 get_type_hints(dataclass_textanno.Bar.__init__),
2963 {'foo': dataclass_textanno.Foo,
2964 'return': type(None)})
2965
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002966
Eric V. Smith4e812962018-05-16 11:31:29 -04002967class TestMakeDataclass(unittest.TestCase):
2968 def test_simple(self):
2969 C = make_dataclass('C',
2970 [('x', int),
2971 ('y', int, field(default=5))],
2972 namespace={'add_one': lambda self: self.x + 1})
2973 c = C(10)
2974 self.assertEqual((c.x, c.y), (10, 5))
2975 self.assertEqual(c.add_one(), 11)
2976
2977
2978 def test_no_mutate_namespace(self):
2979 # Make sure a provided namespace isn't mutated.
2980 ns = {}
2981 C = make_dataclass('C',
2982 [('x', int),
2983 ('y', int, field(default=5))],
2984 namespace=ns)
2985 self.assertEqual(ns, {})
2986
2987 def test_base(self):
2988 class Base1:
2989 pass
2990 class Base2:
2991 pass
2992 C = make_dataclass('C',
2993 [('x', int)],
2994 bases=(Base1, Base2))
2995 c = C(2)
2996 self.assertIsInstance(c, C)
2997 self.assertIsInstance(c, Base1)
2998 self.assertIsInstance(c, Base2)
2999
3000 def test_base_dataclass(self):
3001 @dataclass
3002 class Base1:
3003 x: int
3004 class Base2:
3005 pass
3006 C = make_dataclass('C',
3007 [('y', int)],
3008 bases=(Base1, Base2))
3009 with self.assertRaisesRegex(TypeError, 'required positional'):
3010 c = C(2)
3011 c = C(1, 2)
3012 self.assertIsInstance(c, C)
3013 self.assertIsInstance(c, Base1)
3014 self.assertIsInstance(c, Base2)
3015
3016 self.assertEqual((c.x, c.y), (1, 2))
3017
3018 def test_init_var(self):
3019 def post_init(self, y):
3020 self.x *= y
3021
3022 C = make_dataclass('C',
3023 [('x', int),
3024 ('y', InitVar[int]),
3025 ],
3026 namespace={'__post_init__': post_init},
3027 )
3028 c = C(2, 3)
3029 self.assertEqual(vars(c), {'x': 6})
3030 self.assertEqual(len(fields(c)), 1)
3031
3032 def test_class_var(self):
3033 C = make_dataclass('C',
3034 [('x', int),
3035 ('y', ClassVar[int], 10),
3036 ('z', ClassVar[int], field(default=20)),
3037 ])
3038 c = C(1)
3039 self.assertEqual(vars(c), {'x': 1})
3040 self.assertEqual(len(fields(c)), 1)
3041 self.assertEqual(C.y, 10)
3042 self.assertEqual(C.z, 20)
3043
3044 def test_other_params(self):
3045 C = make_dataclass('C',
3046 [('x', int),
3047 ('y', ClassVar[int], 10),
3048 ('z', ClassVar[int], field(default=20)),
3049 ],
3050 init=False)
3051 # Make sure we have a repr, but no init.
3052 self.assertNotIn('__init__', vars(C))
3053 self.assertIn('__repr__', vars(C))
3054
3055 # Make sure random other params don't work.
3056 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3057 C = make_dataclass('C',
3058 [],
3059 xxinit=False)
3060
3061 def test_no_types(self):
3062 C = make_dataclass('Point', ['x', 'y', 'z'])
3063 c = C(1, 2, 3)
3064 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3065 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3066 'y': 'typing.Any',
3067 'z': 'typing.Any'})
3068
3069 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3070 c = C(1, 2, 3)
3071 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3072 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3073 'y': int,
3074 'z': 'typing.Any'})
3075
3076 def test_invalid_type_specification(self):
3077 for bad_field in [(),
3078 (1, 2, 3, 4),
3079 ]:
3080 with self.subTest(bad_field=bad_field):
3081 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3082 make_dataclass('C', ['a', bad_field])
3083
3084 # And test for things with no len().
3085 for bad_field in [float,
3086 lambda x:x,
3087 ]:
3088 with self.subTest(bad_field=bad_field):
3089 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3090 make_dataclass('C', ['a', bad_field])
3091
3092 def test_duplicate_field_names(self):
3093 for field in ['a', 'ab']:
3094 with self.subTest(field=field):
3095 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3096 make_dataclass('C', [field, 'a', field])
3097
3098 def test_keyword_field_names(self):
3099 for field in ['for', 'async', 'await', 'as']:
3100 with self.subTest(field=field):
3101 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3102 make_dataclass('C', ['a', field])
3103 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3104 make_dataclass('C', [field])
3105 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3106 make_dataclass('C', [field, 'a'])
3107
3108 def test_non_identifier_field_names(self):
3109 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3110 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003111 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003112 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003113 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003114 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003115 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003116 make_dataclass('C', [field, 'a'])
3117
3118 def test_underscore_field_names(self):
3119 # Unlike namedtuple, it's okay if dataclass field names have
3120 # an underscore.
3121 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3122
3123 def test_funny_class_names_names(self):
3124 # No reason to prevent weird class names, since
3125 # types.new_class allows them.
3126 for classname in ['()', 'x,y', '*', '2@3', '']:
3127 with self.subTest(classname=classname):
3128 C = make_dataclass(classname, ['a', 'b'])
3129 self.assertEqual(C.__name__, classname)
3130
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003131class TestReplace(unittest.TestCase):
3132 def test(self):
3133 @dataclass(frozen=True)
3134 class C:
3135 x: int
3136 y: int
3137
3138 c = C(1, 2)
3139 c1 = replace(c, x=3)
3140 self.assertEqual(c1.x, 3)
3141 self.assertEqual(c1.y, 2)
3142
3143 def test_frozen(self):
3144 @dataclass(frozen=True)
3145 class C:
3146 x: int
3147 y: int
3148 z: int = field(init=False, default=10)
3149 t: int = field(init=False, default=100)
3150
3151 c = C(1, 2)
3152 c1 = replace(c, x=3)
3153 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3154 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3155
3156
3157 with self.assertRaisesRegex(ValueError, 'init=False'):
3158 replace(c, x=3, z=20, t=50)
3159 with self.assertRaisesRegex(ValueError, 'init=False'):
3160 replace(c, z=20)
3161 replace(c, x=3, z=20, t=50)
3162
3163 # Make sure the result is still frozen.
3164 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3165 c1.x = 3
3166
3167 # Make sure we can't replace an attribute that doesn't exist,
3168 # if we're also replacing one that does exist. Test this
3169 # here, because setting attributes on frozen instances is
3170 # handled slightly differently from non-frozen ones.
3171 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3172 "keyword argument 'a'"):
3173 c1 = replace(c, x=20, a=5)
3174
3175 def test_invalid_field_name(self):
3176 @dataclass(frozen=True)
3177 class C:
3178 x: int
3179 y: int
3180
3181 c = C(1, 2)
3182 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3183 "keyword argument 'z'"):
3184 c1 = replace(c, z=3)
3185
3186 def test_invalid_object(self):
3187 @dataclass(frozen=True)
3188 class C:
3189 x: int
3190 y: int
3191
3192 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3193 replace(C, x=3)
3194
3195 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3196 replace(0, x=3)
3197
3198 def test_no_init(self):
3199 @dataclass
3200 class C:
3201 x: int
3202 y: int = field(init=False, default=10)
3203
3204 c = C(1)
3205 c.y = 20
3206
3207 # Make sure y gets the default value.
3208 c1 = replace(c, x=5)
3209 self.assertEqual((c1.x, c1.y), (5, 10))
3210
3211 # Trying to replace y is an error.
3212 with self.assertRaisesRegex(ValueError, 'init=False'):
3213 replace(c, x=2, y=30)
3214
3215 with self.assertRaisesRegex(ValueError, 'init=False'):
3216 replace(c, y=30)
3217
3218 def test_classvar(self):
3219 @dataclass
3220 class C:
3221 x: int
3222 y: ClassVar[int] = 1000
3223
3224 c = C(1)
3225 d = C(2)
3226
3227 self.assertIs(c.y, d.y)
3228 self.assertEqual(c.y, 1000)
3229
3230 # Trying to replace y is an error: can't replace ClassVars.
3231 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3232 "unexpected keyword argument 'y'"):
3233 replace(c, y=30)
3234
3235 replace(c, x=5)
3236
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003237 def test_initvar_is_specified(self):
3238 @dataclass
3239 class C:
3240 x: int
3241 y: InitVar[int]
3242
3243 def __post_init__(self, y):
3244 self.x *= y
3245
3246 c = C(1, 10)
3247 self.assertEqual(c.x, 10)
3248 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3249 "specified with replace()"):
3250 replace(c, x=3)
3251 c = replace(c, x=3, y=5)
3252 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303253
3254 def test_recursive_repr(self):
3255 @dataclass
3256 class C:
3257 f: "C"
3258
3259 c = C(None)
3260 c.f = c
3261 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3262
3263 def test_recursive_repr_two_attrs(self):
3264 @dataclass
3265 class C:
3266 f: "C"
3267 g: "C"
3268
3269 c = C(None, None)
3270 c.f = c
3271 c.g = c
3272 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3273 ".<locals>.C(f=..., g=...)")
3274
3275 def test_recursive_repr_indirection(self):
3276 @dataclass
3277 class C:
3278 f: "D"
3279
3280 @dataclass
3281 class D:
3282 f: "C"
3283
3284 c = C(None)
3285 d = D(None)
3286 c.f = d
3287 d.f = c
3288 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3289 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3290 ".<locals>.D(f=...))")
3291
3292 def test_recursive_repr_indirection_two(self):
3293 @dataclass
3294 class C:
3295 f: "D"
3296
3297 @dataclass
3298 class D:
3299 f: "E"
3300
3301 @dataclass
3302 class E:
3303 f: "C"
3304
3305 c = C(None)
3306 d = D(None)
3307 e = E(None)
3308 c.f = d
3309 d.f = e
3310 e.f = c
3311 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3312 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3313 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3314 ".<locals>.E(f=...)))")
3315
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303316 def test_recursive_repr_misc_attrs(self):
3317 @dataclass
3318 class C:
3319 f: "C"
3320 g: int
3321
3322 c = C(None, 1)
3323 c.f = c
3324 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3325 ".<locals>.C(f=..., g=1)")
3326
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003327 ## def test_initvar(self):
3328 ## @dataclass
3329 ## class C:
3330 ## x: int
3331 ## y: InitVar[int]
3332
3333 ## c = C(1, 10)
3334 ## d = C(2, 20)
3335
3336 ## # In our case, replacing an InitVar is a no-op
3337 ## self.assertEqual(c, replace(c, y=5))
3338
3339 ## replace(c, x=5)
3340
Ben Avrahamibef7d292020-10-06 20:40:50 +03003341class TestAbstract(unittest.TestCase):
3342 def test_abc_implementation(self):
3343 class Ordered(abc.ABC):
3344 @abc.abstractmethod
3345 def __lt__(self, other):
3346 pass
3347
3348 @abc.abstractmethod
3349 def __le__(self, other):
3350 pass
3351
3352 @dataclass(order=True)
3353 class Date(Ordered):
3354 year: int
3355 month: 'Month'
3356 day: 'int'
3357
3358 self.assertFalse(inspect.isabstract(Date))
3359 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3360
3361 def test_maintain_abc(self):
3362 class A(abc.ABC):
3363 @abc.abstractmethod
3364 def foo(self):
3365 pass
3366
3367 @dataclass
3368 class Date(A):
3369 year: int
3370 month: 'Month'
3371 day: 'int'
3372
3373 self.assertTrue(inspect.isabstract(Date))
3374 msg = 'class Date with abstract method foo'
3375 self.assertRaisesRegex(TypeError, msg, Date)
3376
Eric V. Smith4e812962018-05-16 11:31:29 -04003377
Brandt Bucher145bf262021-02-26 14:51:55 -08003378class TestMatchArgs(unittest.TestCase):
3379 def test_match_args(self):
3380 @dataclass
3381 class C:
3382 a: int
3383 self.assertEqual(C(42).__match_args__, ('a',))
3384
3385 def test_explicit_match_args(self):
3386 ma = []
3387 @dataclass
3388 class C:
3389 a: int
3390 __match_args__ = ma
3391 self.assertIs(C(42).__match_args__, ma)
3392
3393
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003394if __name__ == '__main__':
3395 unittest.main()