blob: 7c1d9c568f4ef6feb047b2ec9df9982840ed0214 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011import unittest
Batuhan Taskaya044a1042020-10-06 23:03:02 +030012from textwrap import dedent
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010014from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Yury Selivanovd219cc42019-12-09 09:54:20 -050015from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050016from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050017from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050018
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040019import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
20import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
21
Eric V. Smithf0db54a2017-12-04 16:58:55 -050022# Just any custom exception we can catch.
23class CustomError(Exception): pass
24
25class TestCase(unittest.TestCase):
26 def test_no_fields(self):
27 @dataclass
28 class C:
29 pass
30
31 o = C()
32 self.assertEqual(len(fields(C)), 0)
33
Eric V. Smith56970b82018-03-22 16:28:48 -040034 def test_no_fields_but_member_variable(self):
35 @dataclass
36 class C:
37 i = 0
38
39 o = C()
40 self.assertEqual(len(fields(C)), 0)
41
Eric V. Smithf0db54a2017-12-04 16:58:55 -050042 def test_one_field_no_default(self):
43 @dataclass
44 class C:
45 x: int
46
47 o = C(42)
48 self.assertEqual(o.x, 42)
49
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053050 def test_field_default_default_factory_error(self):
51 msg = "cannot specify both default and default_factory"
52 with self.assertRaisesRegex(ValueError, msg):
53 @dataclass
54 class C:
55 x: int = field(default=1, default_factory=int)
56
57 def test_field_repr(self):
58 int_field = field(default=1, init=True, repr=False)
59 int_field.name = "id"
60 repr_output = repr(int_field)
61 expected_output = "Field(name='id',type=None," \
62 f"default=1,default_factory={MISSING!r}," \
63 "init=True,repr=False,hash=None," \
64 "compare=True,metadata=mappingproxy({})," \
65 "_field_type=None)"
66
67 self.assertEqual(repr_output, expected_output)
68
Eric V. Smithf0db54a2017-12-04 16:58:55 -050069 def test_named_init_params(self):
70 @dataclass
71 class C:
72 x: int
73
74 o = C(x=32)
75 self.assertEqual(o.x, 32)
76
77 def test_two_fields_one_default(self):
78 @dataclass
79 class C:
80 x: int
81 y: int = 0
82
83 o = C(3)
84 self.assertEqual((o.x, o.y), (3, 0))
85
86 # Non-defaults following defaults.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class C:
92 x: int = 0
93 y: int
94
95 # A derived class adds a non-default field after a default one.
96 with self.assertRaisesRegex(TypeError,
97 "non-default argument 'y' follows "
98 "default argument"):
99 @dataclass
100 class B:
101 x: int = 0
102
103 @dataclass
104 class C(B):
105 y: int
106
107 # Override a base class field and add a default to
108 # a field which didn't use to have a default.
109 with self.assertRaisesRegex(TypeError,
110 "non-default argument 'y' follows "
111 "default argument"):
112 @dataclass
113 class B:
114 x: int
115 y: int
116
117 @dataclass
118 class C(B):
119 x: int = 0
120
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500121 def test_overwrite_hash(self):
122 # Test that declaring this class isn't an error. It should
123 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500124 @dataclass(frozen=True)
125 class C:
126 x: int
127 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500128 return 301
129 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500130
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500131 # Test that declaring this class isn't an error. It should
132 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500133 @dataclass(frozen=True)
134 class C:
135 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500136 def __eq__(self, other):
137 return False
138 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500140 # But this one should generate an exception, because with
141 # unsafe_hash=True, it's an error to have a __hash__ defined.
142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 def __hash__(self):
147 pass
148
149 # Creating this class should not generate an exception,
150 # because even though __hash__ exists before @dataclass is
151 # called, (due to __eq__ being defined), since it's None
152 # that's okay.
153 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500154 class C:
155 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500156 def __eq__(self):
157 pass
158 # The generated hash function works as we'd expect.
159 self.assertEqual(hash(C(10)), hash((10,)))
160
161 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400162 # __hash__ exists and is not None, which it would be if it
163 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500164 with self.assertRaisesRegex(TypeError,
165 'Cannot overwrite attribute __hash__'):
166 @dataclass(unsafe_hash=True)
167 class C:
168 x: int
169 def __eq__(self):
170 pass
171 def __hash__(self):
172 pass
173
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500174 def test_overwrite_fields_in_derived_class(self):
175 # Note that x from C1 replaces x in Base, but the order remains
176 # the same as defined in Base.
177 @dataclass
178 class Base:
179 x: Any = 15.0
180 y: int = 0
181
182 @dataclass
183 class C1(Base):
184 z: int = 10
185 x: int = 15
186
187 o = Base()
188 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
189
190 o = C1()
191 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
192
193 o = C1(x=5)
194 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
195
196 def test_field_named_self(self):
197 @dataclass
198 class C:
199 self: str
200 c=C('foo')
201 self.assertEqual(c.self, 'foo')
202
203 # Make sure the first parameter is not named 'self'.
204 sig = inspect.signature(C.__init__)
205 first = next(iter(sig.parameters))
206 self.assertNotEqual('self', first)
207
208 # But we do use 'self' if no field named self.
209 @dataclass
210 class C:
211 selfx: str
212
213 # Make sure the first parameter is named 'self'.
214 sig = inspect.signature(C.__init__)
215 first = next(iter(sig.parameters))
216 self.assertEqual('self', first)
217
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300218 def test_field_named_object(self):
219 @dataclass
220 class C:
221 object: str
222 c = C('foo')
223 self.assertEqual(c.object, 'foo')
224
225 def test_field_named_object_frozen(self):
226 @dataclass(frozen=True)
227 class C:
228 object: str
229 c = C('foo')
230 self.assertEqual(c.object, 'foo')
231
232 def test_field_named_like_builtin(self):
233 # Attribute names can shadow built-in names
234 # since code generation is used.
235 # Ensure that this is not happening.
236 exclusions = {'None', 'True', 'False'}
237 builtins_names = sorted(
238 b for b in builtins.__dict__.keys()
239 if not b.startswith('__') and b not in exclusions
240 )
241 attributes = [(name, str) for name in builtins_names]
242 C = make_dataclass('C', attributes)
243
244 c = C(*[name for name in builtins_names])
245
246 for name in builtins_names:
247 self.assertEqual(getattr(c, name), name)
248
249 def test_field_named_like_builtin_frozen(self):
250 # Attribute names can shadow built-in names
251 # since code generation is used.
252 # Ensure that this is not happening
253 # for frozen data classes.
254 exclusions = {'None', 'True', 'False'}
255 builtins_names = sorted(
256 b for b in builtins.__dict__.keys()
257 if not b.startswith('__') and b not in exclusions
258 )
259 attributes = [(name, str) for name in builtins_names]
260 C = make_dataclass('C', attributes, frozen=True)
261
262 c = C(*[name for name in builtins_names])
263
264 for name in builtins_names:
265 self.assertEqual(getattr(c, name), name)
266
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500267 def test_0_field_compare(self):
268 # Ensure that order=False is the default.
269 @dataclass
270 class C0:
271 pass
272
273 @dataclass(order=False)
274 class C1:
275 pass
276
277 for cls in [C0, C1]:
278 with self.subTest(cls=cls):
279 self.assertEqual(cls(), cls())
280 for idx, fn in enumerate([lambda a, b: a < b,
281 lambda a, b: a <= b,
282 lambda a, b: a > b,
283 lambda a, b: a >= b]):
284 with self.subTest(idx=idx):
285 with self.assertRaisesRegex(TypeError,
286 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
287 fn(cls(), cls())
288
289 @dataclass(order=True)
290 class C:
291 pass
292 self.assertLessEqual(C(), C())
293 self.assertGreaterEqual(C(), C())
294
295 def test_1_field_compare(self):
296 # Ensure that order=False is the default.
297 @dataclass
298 class C0:
299 x: int
300
301 @dataclass(order=False)
302 class C1:
303 x: int
304
305 for cls in [C0, C1]:
306 with self.subTest(cls=cls):
307 self.assertEqual(cls(1), cls(1))
308 self.assertNotEqual(cls(0), cls(1))
309 for idx, fn in enumerate([lambda a, b: a < b,
310 lambda a, b: a <= b,
311 lambda a, b: a > b,
312 lambda a, b: a >= b]):
313 with self.subTest(idx=idx):
314 with self.assertRaisesRegex(TypeError,
315 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
316 fn(cls(0), cls(0))
317
318 @dataclass(order=True)
319 class C:
320 x: int
321 self.assertLess(C(0), C(1))
322 self.assertLessEqual(C(0), C(1))
323 self.assertLessEqual(C(1), C(1))
324 self.assertGreater(C(1), C(0))
325 self.assertGreaterEqual(C(1), C(0))
326 self.assertGreaterEqual(C(1), C(1))
327
328 def test_simple_compare(self):
329 # Ensure that order=False is the default.
330 @dataclass
331 class C0:
332 x: int
333 y: int
334
335 @dataclass(order=False)
336 class C1:
337 x: int
338 y: int
339
340 for cls in [C0, C1]:
341 with self.subTest(cls=cls):
342 self.assertEqual(cls(0, 0), cls(0, 0))
343 self.assertEqual(cls(1, 2), cls(1, 2))
344 self.assertNotEqual(cls(1, 0), cls(0, 0))
345 self.assertNotEqual(cls(1, 0), cls(1, 1))
346 for idx, fn in enumerate([lambda a, b: a < b,
347 lambda a, b: a <= b,
348 lambda a, b: a > b,
349 lambda a, b: a >= b]):
350 with self.subTest(idx=idx):
351 with self.assertRaisesRegex(TypeError,
352 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
353 fn(cls(0, 0), cls(0, 0))
354
355 @dataclass(order=True)
356 class C:
357 x: int
358 y: int
359
360 for idx, fn in enumerate([lambda a, b: a == b,
361 lambda a, b: a <= b,
362 lambda a, b: a >= b]):
363 with self.subTest(idx=idx):
364 self.assertTrue(fn(C(0, 0), C(0, 0)))
365
366 for idx, fn in enumerate([lambda a, b: a < b,
367 lambda a, b: a <= b,
368 lambda a, b: a != b]):
369 with self.subTest(idx=idx):
370 self.assertTrue(fn(C(0, 0), C(0, 1)))
371 self.assertTrue(fn(C(0, 1), C(1, 0)))
372 self.assertTrue(fn(C(1, 0), C(1, 1)))
373
374 for idx, fn in enumerate([lambda a, b: a > b,
375 lambda a, b: a >= b,
376 lambda a, b: a != b]):
377 with self.subTest(idx=idx):
378 self.assertTrue(fn(C(0, 1), C(0, 0)))
379 self.assertTrue(fn(C(1, 0), C(0, 1)))
380 self.assertTrue(fn(C(1, 1), C(1, 0)))
381
382 def test_compare_subclasses(self):
383 # Comparisons fail for subclasses, even if no fields
384 # are added.
385 @dataclass
386 class B:
387 i: int
388
389 @dataclass
390 class C(B):
391 pass
392
393 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
394 (lambda a, b: a != b, True)]):
395 with self.subTest(idx=idx):
396 self.assertEqual(fn(B(0), C(0)), expected)
397
398 for idx, fn in enumerate([lambda a, b: a < b,
399 lambda a, b: a <= b,
400 lambda a, b: a > b,
401 lambda a, b: a >= b]):
402 with self.subTest(idx=idx):
403 with self.assertRaisesRegex(TypeError,
404 "not supported between instances of 'B' and 'C'"):
405 fn(B(0), C(0))
406
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500407 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500408 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500409 for (eq, order, result ) in [
410 (False, False, 'neither'),
411 (False, True, 'exception'),
412 (True, False, 'eq_only'),
413 (True, True, 'both'),
414 ]:
415 with self.subTest(eq=eq, order=order):
416 if result == 'exception':
417 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
418 @dataclass(eq=eq, order=order)
419 class C:
420 pass
421 else:
422 @dataclass(eq=eq, order=order)
423 class C:
424 pass
425
426 if result == 'neither':
427 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500428 self.assertNotIn('__lt__', C.__dict__)
429 self.assertNotIn('__le__', C.__dict__)
430 self.assertNotIn('__gt__', C.__dict__)
431 self.assertNotIn('__ge__', C.__dict__)
432 elif result == 'both':
433 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 self.assertIn('__lt__', C.__dict__)
435 self.assertIn('__le__', C.__dict__)
436 self.assertIn('__gt__', C.__dict__)
437 self.assertIn('__ge__', C.__dict__)
438 elif result == 'eq_only':
439 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 self.assertNotIn('__lt__', C.__dict__)
441 self.assertNotIn('__le__', C.__dict__)
442 self.assertNotIn('__gt__', C.__dict__)
443 self.assertNotIn('__ge__', C.__dict__)
444 else:
445 assert False, f'unknown result {result!r}'
446
447 def test_field_no_default(self):
448 @dataclass
449 class C:
450 x: int = field()
451
452 self.assertEqual(C(5).x, 5)
453
454 with self.assertRaisesRegex(TypeError,
455 r"__init__\(\) missing 1 required "
456 "positional argument: 'x'"):
457 C()
458
459 def test_field_default(self):
460 default = object()
461 @dataclass
462 class C:
463 x: object = field(default=default)
464
465 self.assertIs(C.x, default)
466 c = C(10)
467 self.assertEqual(c.x, 10)
468
469 # If we delete the instance attribute, we should then see the
470 # class attribute.
471 del c.x
472 self.assertIs(c.x, default)
473
474 self.assertIs(C().x, default)
475
476 def test_not_in_repr(self):
477 @dataclass
478 class C:
479 x: int = field(repr=False)
480 with self.assertRaises(TypeError):
481 C()
482 c = C(10)
483 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
484
485 @dataclass
486 class C:
487 x: int = field(repr=False)
488 y: int
489 c = C(10, 20)
490 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
491
492 def test_not_in_compare(self):
493 @dataclass
494 class C:
495 x: int = 0
496 y: int = field(compare=False, default=4)
497
498 self.assertEqual(C(), C(0, 20))
499 self.assertEqual(C(1, 10), C(1, 20))
500 self.assertNotEqual(C(3), C(4, 10))
501 self.assertNotEqual(C(3, 10), C(4, 10))
502
503 def test_hash_field_rules(self):
504 # Test all 6 cases of:
505 # hash=True/False/None
506 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500507 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500508 (True, False, 'field' ),
509 (True, True, 'field' ),
510 (False, False, 'absent'),
511 (False, True, 'absent'),
512 (None, False, 'absent'),
513 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500514 ]:
515 with self.subTest(hash=hash_, compare=compare):
516 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500517 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500518 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500519
520 if result == 'field':
521 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500522 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500523 elif result == 'absent':
524 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500525 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500526 else:
527 assert False, f'unknown result {result!r}'
528
529 def test_init_false_no_default(self):
530 # If init=False and no default value, then the field won't be
531 # present in the instance.
532 @dataclass
533 class C:
534 x: int = field(init=False)
535
536 self.assertNotIn('x', C().__dict__)
537
538 @dataclass
539 class C:
540 x: int
541 y: int = 0
542 z: int = field(init=False)
543 t: int = 10
544
545 self.assertNotIn('z', C(0).__dict__)
546 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
547
548 def test_class_marker(self):
549 @dataclass
550 class C:
551 x: int
552 y: str = field(init=False, default=None)
553 z: str = field(repr=False)
554
555 the_fields = fields(C)
556 # the_fields is a tuple of 3 items, each value
557 # is in __annotations__.
558 self.assertIsInstance(the_fields, tuple)
559 for f in the_fields:
560 self.assertIs(type(f), Field)
561 self.assertIn(f.name, C.__annotations__)
562
563 self.assertEqual(len(the_fields), 3)
564
565 self.assertEqual(the_fields[0].name, 'x')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300566 self.assertEqual(the_fields[0].type, 'int')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500567 self.assertFalse(hasattr(C, 'x'))
568 self.assertTrue (the_fields[0].init)
569 self.assertTrue (the_fields[0].repr)
570 self.assertEqual(the_fields[1].name, 'y')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300571 self.assertEqual(the_fields[1].type, 'str')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500572 self.assertIsNone(getattr(C, 'y'))
573 self.assertFalse(the_fields[1].init)
574 self.assertTrue (the_fields[1].repr)
575 self.assertEqual(the_fields[2].name, 'z')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300576 self.assertEqual(the_fields[2].type, 'str')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500577 self.assertFalse(hasattr(C, 'z'))
578 self.assertTrue (the_fields[2].init)
579 self.assertFalse(the_fields[2].repr)
580
581 def test_field_order(self):
582 @dataclass
583 class B:
584 a: str = 'B:a'
585 b: str = 'B:b'
586 c: str = 'B:c'
587
588 @dataclass
589 class C(B):
590 b: str = 'C:b'
591
592 self.assertEqual([(f.name, f.default) for f in fields(C)],
593 [('a', 'B:a'),
594 ('b', 'C:b'),
595 ('c', 'B:c')])
596
597 @dataclass
598 class D(B):
599 c: str = 'D:c'
600
601 self.assertEqual([(f.name, f.default) for f in fields(D)],
602 [('a', 'B:a'),
603 ('b', 'B:b'),
604 ('c', 'D:c')])
605
606 @dataclass
607 class E(D):
608 a: str = 'E:a'
609 d: str = 'E:d'
610
611 self.assertEqual([(f.name, f.default) for f in fields(E)],
612 [('a', 'E:a'),
613 ('b', 'B:b'),
614 ('c', 'D:c'),
615 ('d', 'E:d')])
616
617 def test_class_attrs(self):
618 # We only have a class attribute if a default value is
619 # specified, either directly or via a field with a default.
620 default = object()
621 @dataclass
622 class C:
623 x: int
624 y: int = field(repr=False)
625 z: object = default
626 t: int = field(default=100)
627
628 self.assertFalse(hasattr(C, 'x'))
629 self.assertFalse(hasattr(C, 'y'))
630 self.assertIs (C.z, default)
631 self.assertEqual(C.t, 100)
632
633 def test_disallowed_mutable_defaults(self):
634 # For the known types, don't allow mutable default values.
635 for typ, empty, non_empty in [(list, [], [1]),
636 (dict, {}, {0:1}),
637 (set, set(), set([1])),
638 ]:
639 with self.subTest(typ=typ):
640 # Can't use a zero-length value.
641 with self.assertRaisesRegex(ValueError,
642 f'mutable default {typ} for field '
643 'x is not allowed'):
644 @dataclass
645 class Point:
646 x: typ = empty
647
648
649 # Nor a non-zero-length value
650 with self.assertRaisesRegex(ValueError,
651 f'mutable default {typ} for field '
652 'y is not allowed'):
653 @dataclass
654 class Point:
655 y: typ = non_empty
656
657 # Check subtypes also fail.
658 class Subclass(typ): pass
659
660 with self.assertRaisesRegex(ValueError,
661 f"mutable default .*Subclass'>"
662 ' for field z is not allowed'
663 ):
664 @dataclass
665 class Point:
666 z: typ = Subclass()
667
668 # Because this is a ClassVar, it can be mutable.
669 @dataclass
670 class C:
671 z: ClassVar[typ] = typ()
672
673 # Because this is a ClassVar, it can be mutable.
674 @dataclass
675 class C:
676 x: ClassVar[typ] = Subclass()
677
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500678 def test_deliberately_mutable_defaults(self):
679 # If a mutable default isn't in the known list of
680 # (list, dict, set), then it's okay.
681 class Mutable:
682 def __init__(self):
683 self.l = []
684
685 @dataclass
686 class C:
687 x: Mutable
688
689 # These 2 instances will share this value of x.
690 lst = Mutable()
691 o1 = C(lst)
692 o2 = C(lst)
693 self.assertEqual(o1, o2)
694 o1.x.l.extend([1, 2])
695 self.assertEqual(o1, o2)
696 self.assertEqual(o1.x.l, [1, 2])
697 self.assertIs(o1.x, o2.x)
698
699 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400700 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500701 @dataclass()
702 class C:
703 x: int
704
705 self.assertEqual(C(42).x, 42)
706
707 def test_not_tuple(self):
708 # Make sure we can't be compared to a tuple.
709 @dataclass
710 class Point:
711 x: int
712 y: int
713 self.assertNotEqual(Point(1, 2), (1, 2))
714
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400715 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716 @dataclass
717 class C:
718 x: int
719 y: int
720 self.assertNotEqual(Point(1, 3), C(1, 3))
721
Windson yangbe372d72019-04-23 02:45:34 +0800722 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # Test that some of the problems with namedtuple don't happen
724 # here.
725 @dataclass
726 class Point3D:
727 x: int
728 y: int
729 z: int
730
731 @dataclass
732 class Date:
733 year: int
734 month: int
735 day: int
736
737 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
738 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
739
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400740 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200741 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 x, y, z = Point3D(4, 5, 6)
743
Eric V. Smith7c99e932018-01-28 19:18:55 -0500744 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500745 # equal.
746 @dataclass
747 class Point3Dv1:
748 x: int = 0
749 y: int = 0
750 z: int = 0
751 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
752
753 def test_function_annotations(self):
754 # Some dummy class and instance to use as a default.
755 class F:
756 pass
757 f = F()
758
759 def validate_class(cls):
760 # First, check __annotations__, even though they're not
761 # function annotations.
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300762 self.assertEqual(cls.__annotations__['i'], 'int')
763 self.assertEqual(cls.__annotations__['j'], 'str')
764 self.assertEqual(cls.__annotations__['k'], 'F')
765 self.assertEqual(cls.__annotations__['l'], 'float')
766 self.assertEqual(cls.__annotations__['z'], 'complex')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500767
768 # Verify __init__.
769
770 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400771 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500772 self.assertIs(signature.return_annotation, None)
773
774 # Check each parameter.
775 params = iter(signature.parameters.values())
776 param = next(params)
777 # This is testing an internal name, and probably shouldn't be tested.
778 self.assertEqual(param.name, 'self')
779 param = next(params)
780 self.assertEqual(param.name, 'i')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300781 self.assertIs (param.annotation, 'int')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500782 self.assertEqual(param.default, inspect.Parameter.empty)
783 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
784 param = next(params)
785 self.assertEqual(param.name, 'j')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300786 self.assertIs (param.annotation, 'str')
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500787 self.assertEqual(param.default, inspect.Parameter.empty)
788 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
789 param = next(params)
790 self.assertEqual(param.name, 'k')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300791 self.assertIs (param.annotation, 'F')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400792 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500793 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
794 param = next(params)
795 self.assertEqual(param.name, 'l')
Batuhan Taskaya044a1042020-10-06 23:03:02 +0300796 self.assertIs (param.annotation, 'float')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400797 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500798 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
799 self.assertRaises(StopIteration, next, params)
800
801
802 @dataclass
803 class C:
804 i: int
805 j: str
806 k: F = f
807 l: float=field(default=None)
808 z: complex=field(default=3+4j, init=False)
809
810 validate_class(C)
811
812 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500813 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500814 class C:
815 i: int
816 j: str
817 k: F = f
818 l: float=field(default=None)
819 z: complex=field(default=3+4j, init=False)
820
821 validate_class(C)
822
Eric V. Smith03220fd2017-12-29 13:59:58 -0500823 def test_missing_default(self):
824 # Test that MISSING works the same as a default not being
825 # specified.
826 @dataclass
827 class C:
828 x: int=field(default=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_default_factory(self):
845 # Test that MISSING works the same as a default factory not
846 # being specified (which is really the same as a default not
847 # being specified, too).
848 @dataclass
849 class C:
850 x: int=field(default_factory=MISSING)
851 with self.assertRaisesRegex(TypeError,
852 r'__init__\(\) missing 1 required '
853 'positional argument'):
854 C()
855 self.assertNotIn('x', C.__dict__)
856
857 @dataclass
858 class D:
859 x: int=field(default=MISSING, default_factory=MISSING)
860 with self.assertRaisesRegex(TypeError,
861 r'__init__\(\) missing 1 required '
862 'positional argument'):
863 D()
864 self.assertNotIn('x', D.__dict__)
865
866 def test_missing_repr(self):
867 self.assertIn('MISSING_TYPE object', repr(MISSING))
868
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500869 def test_dont_include_other_annotations(self):
870 @dataclass
871 class C:
872 i: int
873 def foo(self) -> int:
874 return 4
875 @property
876 def bar(self) -> int:
877 return 5
878 self.assertEqual(list(C.__annotations__), ['i'])
879 self.assertEqual(C(10).foo(), 4)
880 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400881 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500882
883 def test_post_init(self):
884 # Just make sure it gets called
885 @dataclass
886 class C:
887 def __post_init__(self):
888 raise CustomError()
889 with self.assertRaises(CustomError):
890 C()
891
892 @dataclass
893 class C:
894 i: int = 10
895 def __post_init__(self):
896 if self.i == 10:
897 raise CustomError()
898 with self.assertRaises(CustomError):
899 C()
900 # post-init gets called, but doesn't raise. This is just
901 # checking that self is used correctly.
902 C(5)
903
904 # If there's not an __init__, then post-init won't get called.
905 @dataclass(init=False)
906 class C:
907 def __post_init__(self):
908 raise CustomError()
909 # Creating the class won't raise
910 C()
911
912 @dataclass
913 class C:
914 x: int = 0
915 def __post_init__(self):
916 self.x *= 2
917 self.assertEqual(C().x, 0)
918 self.assertEqual(C(2).x, 4)
919
Mike53f7a7c2017-12-14 14:04:53 +0300920 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500921 # attributes.
922 @dataclass(frozen=True)
923 class C:
924 x: int = 0
925 def __post_init__(self):
926 self.x *= 2
927 with self.assertRaises(FrozenInstanceError):
928 C()
929
930 def test_post_init_super(self):
931 # Make sure super() post-init isn't called by default.
932 class B:
933 def __post_init__(self):
934 raise CustomError()
935
936 @dataclass
937 class C(B):
938 def __post_init__(self):
939 self.x = 5
940
941 self.assertEqual(C().x, 5)
942
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400943 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500944 @dataclass
945 class C(B):
946 def __post_init__(self):
947 super().__post_init__()
948
949 with self.assertRaises(CustomError):
950 C()
951
952 # Make sure post-init is called, even if not defined in our
953 # class.
954 @dataclass
955 class C(B):
956 pass
957
958 with self.assertRaises(CustomError):
959 C()
960
961 def test_post_init_staticmethod(self):
962 flag = False
963 @dataclass
964 class C:
965 x: int
966 y: int
967 @staticmethod
968 def __post_init__():
969 nonlocal flag
970 flag = True
971
972 self.assertFalse(flag)
973 c = C(3, 4)
974 self.assertEqual((c.x, c.y), (3, 4))
975 self.assertTrue(flag)
976
977 def test_post_init_classmethod(self):
978 @dataclass
979 class C:
980 flag = False
981 x: int
982 y: int
983 @classmethod
984 def __post_init__(cls):
985 cls.flag = True
986
987 self.assertFalse(C.flag)
988 c = C(3, 4)
989 self.assertEqual((c.x, c.y), (3, 4))
990 self.assertTrue(C.flag)
991
992 def test_class_var(self):
993 # Make sure ClassVars are ignored in __init__, __repr__, etc.
994 @dataclass
995 class C:
996 x: int
997 y: int = 10
998 z: ClassVar[int] = 1000
999 w: ClassVar[int] = 2000
1000 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001001 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001002
1003 c = C(5)
1004 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001005 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001006 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001007 self.assertEqual(c.z, 1000)
1008 self.assertEqual(c.w, 2000)
1009 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001010 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001011 C.z += 1
1012 self.assertEqual(c.z, 1001)
1013 c = C(20)
1014 self.assertEqual((c.x, c.y), (20, 10))
1015 self.assertEqual(c.z, 1001)
1016 self.assertEqual(c.w, 2000)
1017 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001018 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001019
1020 def test_class_var_no_default(self):
1021 # If a ClassVar has no default value, it should not be set on the class.
1022 @dataclass
1023 class C:
1024 x: ClassVar[int]
1025
1026 self.assertNotIn('x', C.__dict__)
1027
1028 def test_class_var_default_factory(self):
1029 # It makes no sense for a ClassVar to have a default factory. When
1030 # would it be called? Call it yourself, since it's class-wide.
1031 with self.assertRaisesRegex(TypeError,
1032 'cannot have a default factory'):
1033 @dataclass
1034 class C:
1035 x: ClassVar[int] = field(default_factory=int)
1036
1037 self.assertNotIn('x', C.__dict__)
1038
1039 def test_class_var_with_default(self):
1040 # If a ClassVar has a default value, it should be set on the class.
1041 @dataclass
1042 class C:
1043 x: ClassVar[int] = 10
1044 self.assertEqual(C.x, 10)
1045
1046 @dataclass
1047 class C:
1048 x: ClassVar[int] = field(default=10)
1049 self.assertEqual(C.x, 10)
1050
1051 def test_class_var_frozen(self):
1052 # Make sure ClassVars work even if we're frozen.
1053 @dataclass(frozen=True)
1054 class C:
1055 x: int
1056 y: int = 10
1057 z: ClassVar[int] = 1000
1058 w: ClassVar[int] = 2000
1059 t: ClassVar[int] = 3000
1060
1061 c = C(5)
1062 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1063 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1064 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1065 self.assertEqual(c.z, 1000)
1066 self.assertEqual(c.w, 2000)
1067 self.assertEqual(c.t, 3000)
1068 # We can still modify the ClassVar, it's only instances that are
1069 # frozen.
1070 C.z += 1
1071 self.assertEqual(c.z, 1001)
1072 c = C(20)
1073 self.assertEqual((c.x, c.y), (20, 10))
1074 self.assertEqual(c.z, 1001)
1075 self.assertEqual(c.w, 2000)
1076 self.assertEqual(c.t, 3000)
1077
1078 def test_init_var_no_default(self):
1079 # If an InitVar has no default value, it should not be set on the class.
1080 @dataclass
1081 class C:
1082 x: InitVar[int]
1083
1084 self.assertNotIn('x', C.__dict__)
1085
1086 def test_init_var_default_factory(self):
1087 # It makes no sense for an InitVar to have a default factory. When
1088 # would it be called? Call it yourself, since it's class-wide.
1089 with self.assertRaisesRegex(TypeError,
1090 'cannot have a default factory'):
1091 @dataclass
1092 class C:
1093 x: InitVar[int] = field(default_factory=int)
1094
1095 self.assertNotIn('x', C.__dict__)
1096
1097 def test_init_var_with_default(self):
1098 # If an InitVar has a default value, it should be set on the class.
1099 @dataclass
1100 class C:
1101 x: InitVar[int] = 10
1102 self.assertEqual(C.x, 10)
1103
1104 @dataclass
1105 class C:
1106 x: InitVar[int] = field(default=10)
1107 self.assertEqual(C.x, 10)
1108
1109 def test_init_var(self):
1110 @dataclass
1111 class C:
1112 x: int = None
1113 init_param: InitVar[int] = None
1114
1115 def __post_init__(self, init_param):
1116 if self.x is None:
1117 self.x = init_param*2
1118
1119 c = C(init_param=10)
1120 self.assertEqual(c.x, 20)
1121
Augusto Hack01ee12b2019-06-02 23:14:48 -03001122 def test_init_var_preserve_type(self):
1123 self.assertEqual(InitVar[int].type, int)
1124
1125 # Make sure the repr is correct.
1126 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001127 self.assertEqual(repr(InitVar[List[int]]),
1128 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001129
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001130 def test_init_var_inheritance(self):
1131 # Note that this deliberately tests that a dataclass need not
1132 # have a __post_init__ function if it has an InitVar field.
1133 # It could just be used in a derived class, as shown here.
1134 @dataclass
1135 class Base:
1136 x: int
1137 init_base: InitVar[int]
1138
1139 # We can instantiate by passing the InitVar, even though
1140 # it's not used.
1141 b = Base(0, 10)
1142 self.assertEqual(vars(b), {'x': 0})
1143
1144 @dataclass
1145 class C(Base):
1146 y: int
1147 init_derived: InitVar[int]
1148
1149 def __post_init__(self, init_base, init_derived):
1150 self.x = self.x + init_base
1151 self.y = self.y + init_derived
1152
1153 c = C(10, 11, 50, 51)
1154 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1155
1156 def test_default_factory(self):
1157 # Test a factory that returns a new list.
1158 @dataclass
1159 class C:
1160 x: int
1161 y: list = field(default_factory=list)
1162
1163 c0 = C(3)
1164 c1 = C(3)
1165 self.assertEqual(c0.x, 3)
1166 self.assertEqual(c0.y, [])
1167 self.assertEqual(c0, c1)
1168 self.assertIsNot(c0.y, c1.y)
1169 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1170
1171 # Test a factory that returns a shared list.
1172 l = []
1173 @dataclass
1174 class C:
1175 x: int
1176 y: list = field(default_factory=lambda: l)
1177
1178 c0 = C(3)
1179 c1 = C(3)
1180 self.assertEqual(c0.x, 3)
1181 self.assertEqual(c0.y, [])
1182 self.assertEqual(c0, c1)
1183 self.assertIs(c0.y, c1.y)
1184 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1185
1186 # Test various other field flags.
1187 # repr
1188 @dataclass
1189 class C:
1190 x: list = field(default_factory=list, repr=False)
1191 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1192 self.assertEqual(C().x, [])
1193
1194 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001195 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001196 class C:
1197 x: list = field(default_factory=list, hash=False)
1198 self.assertEqual(astuple(C()), ([],))
1199 self.assertEqual(hash(C()), hash(()))
1200
1201 # init (see also test_default_factory_with_no_init)
1202 @dataclass
1203 class C:
1204 x: list = field(default_factory=list, init=False)
1205 self.assertEqual(astuple(C()), ([],))
1206
1207 # compare
1208 @dataclass
1209 class C:
1210 x: list = field(default_factory=list, compare=False)
1211 self.assertEqual(C(), C([1]))
1212
1213 def test_default_factory_with_no_init(self):
1214 # We need a factory with a side effect.
1215 factory = Mock()
1216
1217 @dataclass
1218 class C:
1219 x: list = field(default_factory=factory, init=False)
1220
1221 # Make sure the default factory is called for each new instance.
1222 C().x
1223 self.assertEqual(factory.call_count, 1)
1224 C().x
1225 self.assertEqual(factory.call_count, 2)
1226
1227 def test_default_factory_not_called_if_value_given(self):
1228 # We need a factory that we can test if it's been called.
1229 factory = Mock()
1230
1231 @dataclass
1232 class C:
1233 x: int = field(default_factory=factory)
1234
1235 # Make sure that if a field has a default factory function,
1236 # it's not called if a value is specified.
1237 C().x
1238 self.assertEqual(factory.call_count, 1)
1239 self.assertEqual(C(10).x, 10)
1240 self.assertEqual(factory.call_count, 1)
1241 C().x
1242 self.assertEqual(factory.call_count, 2)
1243
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001244 def test_default_factory_derived(self):
1245 # See bpo-32896.
1246 @dataclass
1247 class Foo:
1248 x: dict = field(default_factory=dict)
1249
1250 @dataclass
1251 class Bar(Foo):
1252 y: int = 1
1253
1254 self.assertEqual(Foo().x, {})
1255 self.assertEqual(Bar().x, {})
1256 self.assertEqual(Bar().y, 1)
1257
1258 @dataclass
1259 class Baz(Foo):
1260 pass
1261 self.assertEqual(Baz().x, {})
1262
1263 def test_intermediate_non_dataclass(self):
1264 # Test that an intermediate class that defines
1265 # annotations does not define fields.
1266
1267 @dataclass
1268 class A:
1269 x: int
1270
1271 class B(A):
1272 y: int
1273
1274 @dataclass
1275 class C(B):
1276 z: int
1277
1278 c = C(1, 3)
1279 self.assertEqual((c.x, c.z), (1, 3))
1280
1281 # .y was not initialized.
1282 with self.assertRaisesRegex(AttributeError,
1283 'object has no attribute'):
1284 c.y
1285
1286 # And if we again derive a non-dataclass, no fields are added.
1287 class D(C):
1288 t: int
1289 d = D(4, 5)
1290 self.assertEqual((d.x, d.z), (4, 5))
1291
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001292 def test_classvar_default_factory(self):
1293 # It's an error for a ClassVar to have a factory function.
1294 with self.assertRaisesRegex(TypeError,
1295 'cannot have a default factory'):
1296 @dataclass
1297 class C:
1298 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001299
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001300 def test_is_dataclass(self):
1301 class NotDataClass:
1302 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001303
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001304 self.assertFalse(is_dataclass(0))
1305 self.assertFalse(is_dataclass(int))
1306 self.assertFalse(is_dataclass(NotDataClass))
1307 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001308
1309 @dataclass
1310 class C:
1311 x: int
1312
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001313 @dataclass
1314 class D:
1315 d: C
1316 e: int
1317
1318 c = C(10)
1319 d = D(c, 4)
1320
1321 self.assertTrue(is_dataclass(C))
1322 self.assertTrue(is_dataclass(c))
1323 self.assertFalse(is_dataclass(c.x))
1324 self.assertTrue(is_dataclass(d.d))
1325 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001326
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001327 def test_is_dataclass_when_getattr_always_returns(self):
1328 # See bpo-37868.
1329 class A:
1330 def __getattr__(self, key):
1331 return 0
1332 self.assertFalse(is_dataclass(A))
1333 a = A()
1334
1335 # Also test for an instance attribute.
1336 class B:
1337 pass
1338 b = B()
1339 b.__dataclass_fields__ = []
1340
1341 for obj in a, b:
1342 with self.subTest(obj=obj):
1343 self.assertFalse(is_dataclass(obj))
1344
1345 # Indirect tests for _is_dataclass_instance().
1346 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1347 asdict(obj)
1348 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1349 astuple(obj)
1350 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1351 replace(obj, x=0)
1352
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001353 def test_helper_fields_with_class_instance(self):
1354 # Check that we can call fields() on either a class or instance,
1355 # and get back the same thing.
1356 @dataclass
1357 class C:
1358 x: int
1359 y: float
1360
1361 self.assertEqual(fields(C), fields(C(0, 0.0)))
1362
1363 def test_helper_fields_exception(self):
1364 # Check that TypeError is raised if not passed a dataclass or
1365 # instance.
1366 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1367 fields(0)
1368
1369 class C: pass
1370 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1371 fields(C)
1372 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1373 fields(C())
1374
1375 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001376 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001377 @dataclass
1378 class C:
1379 x: int
1380 y: int
1381 c = C(1, 2)
1382
1383 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1384 self.assertEqual(asdict(c), asdict(c))
1385 self.assertIsNot(asdict(c), asdict(c))
1386 c.x = 42
1387 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1388 self.assertIs(type(asdict(c)), dict)
1389
1390 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001391 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001392 @dataclass
1393 class C:
1394 x: int
1395 y: int
1396 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1397 asdict(C)
1398 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1399 asdict(int)
1400
1401 def test_helper_asdict_copy_values(self):
1402 @dataclass
1403 class C:
1404 x: int
1405 y: List[int] = field(default_factory=list)
1406 initial = []
1407 c = C(1, initial)
1408 d = asdict(c)
1409 self.assertEqual(d['y'], initial)
1410 self.assertIsNot(d['y'], initial)
1411 c = C(1)
1412 d = asdict(c)
1413 d['y'].append(1)
1414 self.assertEqual(c.y, [])
1415
1416 def test_helper_asdict_nested(self):
1417 @dataclass
1418 class UserId:
1419 token: int
1420 group: int
1421 @dataclass
1422 class User:
1423 name: str
1424 id: UserId
1425 u = User('Joe', UserId(123, 1))
1426 d = asdict(u)
1427 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1428 self.assertIsNot(asdict(u), asdict(u))
1429 u.id.group = 2
1430 self.assertEqual(asdict(u), {'name': 'Joe',
1431 'id': {'token': 123, 'group': 2}})
1432
1433 def test_helper_asdict_builtin_containers(self):
1434 @dataclass
1435 class User:
1436 name: str
1437 id: int
1438 @dataclass
1439 class GroupList:
1440 id: int
1441 users: List[User]
1442 @dataclass
1443 class GroupTuple:
1444 id: int
1445 users: Tuple[User, ...]
1446 @dataclass
1447 class GroupDict:
1448 id: int
1449 users: Dict[str, User]
1450 a = User('Alice', 1)
1451 b = User('Bob', 2)
1452 gl = GroupList(0, [a, b])
1453 gt = GroupTuple(0, (a, b))
1454 gd = GroupDict(0, {'first': a, 'second': b})
1455 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1456 {'name': 'Bob', 'id': 2}]})
1457 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1458 {'name': 'Bob', 'id': 2})})
1459 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1460 'second': {'name': 'Bob', 'id': 2}}})
1461
Windson yangbe372d72019-04-23 02:45:34 +08001462 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001463 @dataclass
1464 class Child:
1465 d: object
1466
1467 @dataclass
1468 class Parent:
1469 child: Child
1470
1471 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1472 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1473
1474 def test_helper_asdict_factory(self):
1475 @dataclass
1476 class C:
1477 x: int
1478 y: int
1479 c = C(1, 2)
1480 d = asdict(c, dict_factory=OrderedDict)
1481 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1482 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1483 c.x = 42
1484 d = asdict(c, dict_factory=OrderedDict)
1485 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1486 self.assertIs(type(d), OrderedDict)
1487
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001488 def test_helper_asdict_namedtuple(self):
1489 T = namedtuple('T', 'a b c')
1490 @dataclass
1491 class C:
1492 x: str
1493 y: T
1494 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1495
1496 d = asdict(c)
1497 self.assertEqual(d, {'x': 'outer',
1498 'y': T(1,
1499 {'x': 'inner',
1500 'y': T(11, 12, 13)},
1501 2),
1502 }
1503 )
1504
1505 # Now with a dict_factory. OrderedDict is convenient, but
1506 # since it compares to dicts, we also need to have separate
1507 # assertIs tests.
1508 d = asdict(c, dict_factory=OrderedDict)
1509 self.assertEqual(d, {'x': 'outer',
1510 'y': T(1,
1511 {'x': 'inner',
1512 'y': T(11, 12, 13)},
1513 2),
1514 }
1515 )
1516
penguindustin96466302019-05-06 14:57:17 -04001517 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001518 self.assertIs(type(d), OrderedDict)
1519 self.assertIs(type(d['y'][1]), OrderedDict)
1520
1521 def test_helper_asdict_namedtuple_key(self):
1522 # Ensure that a field that contains a dict which has a
1523 # namedtuple as a key works with asdict().
1524
1525 @dataclass
1526 class C:
1527 f: dict
1528 T = namedtuple('T', 'a')
1529
1530 c = C({T('an a'): 0})
1531
1532 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1533
1534 def test_helper_asdict_namedtuple_derived(self):
1535 class T(namedtuple('Tbase', 'a')):
1536 def my_a(self):
1537 return self.a
1538
1539 @dataclass
1540 class C:
1541 f: T
1542
1543 t = T(6)
1544 c = C(t)
1545
1546 d = asdict(c)
1547 self.assertEqual(d, {'f': T(a=6)})
1548 # Make sure that t has been copied, not used directly.
1549 self.assertIsNot(d['f'], t)
1550 self.assertEqual(d['f'].my_a(), 6)
1551
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001552 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001553 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001554 @dataclass
1555 class C:
1556 x: int
1557 y: int = 0
1558 c = C(1)
1559
1560 self.assertEqual(astuple(c), (1, 0))
1561 self.assertEqual(astuple(c), astuple(c))
1562 self.assertIsNot(astuple(c), astuple(c))
1563 c.y = 42
1564 self.assertEqual(astuple(c), (1, 42))
1565 self.assertIs(type(astuple(c)), tuple)
1566
1567 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001568 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001569 @dataclass
1570 class C:
1571 x: int
1572 y: int
1573 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1574 astuple(C)
1575 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1576 astuple(int)
1577
1578 def test_helper_astuple_copy_values(self):
1579 @dataclass
1580 class C:
1581 x: int
1582 y: List[int] = field(default_factory=list)
1583 initial = []
1584 c = C(1, initial)
1585 t = astuple(c)
1586 self.assertEqual(t[1], initial)
1587 self.assertIsNot(t[1], initial)
1588 c = C(1)
1589 t = astuple(c)
1590 t[1].append(1)
1591 self.assertEqual(c.y, [])
1592
1593 def test_helper_astuple_nested(self):
1594 @dataclass
1595 class UserId:
1596 token: int
1597 group: int
1598 @dataclass
1599 class User:
1600 name: str
1601 id: UserId
1602 u = User('Joe', UserId(123, 1))
1603 t = astuple(u)
1604 self.assertEqual(t, ('Joe', (123, 1)))
1605 self.assertIsNot(astuple(u), astuple(u))
1606 u.id.group = 2
1607 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1608
1609 def test_helper_astuple_builtin_containers(self):
1610 @dataclass
1611 class User:
1612 name: str
1613 id: int
1614 @dataclass
1615 class GroupList:
1616 id: int
1617 users: List[User]
1618 @dataclass
1619 class GroupTuple:
1620 id: int
1621 users: Tuple[User, ...]
1622 @dataclass
1623 class GroupDict:
1624 id: int
1625 users: Dict[str, User]
1626 a = User('Alice', 1)
1627 b = User('Bob', 2)
1628 gl = GroupList(0, [a, b])
1629 gt = GroupTuple(0, (a, b))
1630 gd = GroupDict(0, {'first': a, 'second': b})
1631 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1632 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1633 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1634
Windson yangbe372d72019-04-23 02:45:34 +08001635 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001636 @dataclass
1637 class Child:
1638 d: object
1639
1640 @dataclass
1641 class Parent:
1642 child: Child
1643
1644 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1645 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1646
1647 def test_helper_astuple_factory(self):
1648 @dataclass
1649 class C:
1650 x: int
1651 y: int
1652 NT = namedtuple('NT', 'x y')
1653 def nt(lst):
1654 return NT(*lst)
1655 c = C(1, 2)
1656 t = astuple(c, tuple_factory=nt)
1657 self.assertEqual(t, NT(1, 2))
1658 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1659 c.x = 42
1660 t = astuple(c, tuple_factory=nt)
1661 self.assertEqual(t, NT(42, 2))
1662 self.assertIs(type(t), NT)
1663
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001664 def test_helper_astuple_namedtuple(self):
1665 T = namedtuple('T', 'a b c')
1666 @dataclass
1667 class C:
1668 x: str
1669 y: T
1670 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1671
1672 t = astuple(c)
1673 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1674
1675 # Now, using a tuple_factory. list is convenient here.
1676 t = astuple(c, tuple_factory=list)
1677 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1678
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001679 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001680 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001681 }
1682
1683 # Create the class.
1684 cls = type('C', (), cls_dict)
1685
1686 # Make it a dataclass.
1687 cls1 = dataclass(cls)
1688
1689 self.assertEqual(cls1, cls)
1690 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1691
1692 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001693 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001694 'y': field(default=5),
1695 }
1696
1697 # Create the class.
1698 cls = type('C', (), cls_dict)
1699
1700 # Make it a dataclass.
1701 cls1 = dataclass(cls)
1702
1703 self.assertEqual(cls1, cls)
1704 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1705
1706 def test_init_in_order(self):
1707 @dataclass
1708 class C:
1709 a: int
1710 b: int = field()
1711 c: list = field(default_factory=list, init=False)
1712 d: list = field(default_factory=list)
1713 e: int = field(default=4, init=False)
1714 f: int = 4
1715
1716 calls = []
1717 def setattr(self, name, value):
1718 calls.append((name, value))
1719
1720 C.__setattr__ = setattr
1721 c = C(0, 1)
1722 self.assertEqual(('a', 0), calls[0])
1723 self.assertEqual(('b', 1), calls[1])
1724 self.assertEqual(('c', []), calls[2])
1725 self.assertEqual(('d', []), calls[3])
1726 self.assertNotIn(('e', 4), calls)
1727 self.assertEqual(('f', 4), calls[4])
1728
1729 def test_items_in_dicts(self):
1730 @dataclass
1731 class C:
1732 a: int
1733 b: list = field(default_factory=list, init=False)
1734 c: list = field(default_factory=list)
1735 d: int = field(default=4, init=False)
1736 e: int = 0
1737
1738 c = C(0)
1739 # Class dict
1740 self.assertNotIn('a', C.__dict__)
1741 self.assertNotIn('b', C.__dict__)
1742 self.assertNotIn('c', C.__dict__)
1743 self.assertIn('d', C.__dict__)
1744 self.assertEqual(C.d, 4)
1745 self.assertIn('e', C.__dict__)
1746 self.assertEqual(C.e, 0)
1747 # Instance dict
1748 self.assertIn('a', c.__dict__)
1749 self.assertEqual(c.a, 0)
1750 self.assertIn('b', c.__dict__)
1751 self.assertEqual(c.b, [])
1752 self.assertIn('c', c.__dict__)
1753 self.assertEqual(c.c, [])
1754 self.assertNotIn('d', c.__dict__)
1755 self.assertIn('e', c.__dict__)
1756 self.assertEqual(c.e, 0)
1757
1758 def test_alternate_classmethod_constructor(self):
1759 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001760 # alternate constructor. This is mostly an example to show
1761 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001762 @dataclass
1763 class C:
1764 x: int
1765 @classmethod
1766 def from_file(cls, filename):
1767 # In a real example, create a new instance
1768 # and populate 'x' from contents of a file.
1769 value_in_file = 20
1770 return cls(value_in_file)
1771
1772 self.assertEqual(C.from_file('filename').x, 20)
1773
1774 def test_field_metadata_default(self):
1775 # Make sure the default metadata is read-only and of
1776 # zero length.
1777 @dataclass
1778 class C:
1779 i: int
1780
1781 self.assertFalse(fields(C)[0].metadata)
1782 self.assertEqual(len(fields(C)[0].metadata), 0)
1783 with self.assertRaisesRegex(TypeError,
1784 'does not support item assignment'):
1785 fields(C)[0].metadata['test'] = 3
1786
1787 def test_field_metadata_mapping(self):
1788 # Make sure only a mapping can be passed as metadata
1789 # zero length.
1790 with self.assertRaises(TypeError):
1791 @dataclass
1792 class C:
1793 i: int = field(metadata=0)
1794
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001795 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001796 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001797 @dataclass
1798 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001799 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001800 self.assertFalse(fields(C)[0].metadata)
1801 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001802 # Update should work (see bpo-35960).
1803 d['foo'] = 1
1804 self.assertEqual(len(fields(C)[0].metadata), 1)
1805 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001806 with self.assertRaisesRegex(TypeError,
1807 'does not support item assignment'):
1808 fields(C)[0].metadata['test'] = 3
1809
1810 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001811 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001812 @dataclass
1813 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001814 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001815 self.assertEqual(len(fields(C)[0].metadata), 3)
1816 self.assertEqual(fields(C)[0].metadata['test'], 10)
1817 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1818 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001819 # Update should work.
1820 d['foo'] = 1
1821 self.assertEqual(len(fields(C)[0].metadata), 4)
1822 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001823 with self.assertRaises(KeyError):
1824 # Non-existent key.
1825 fields(C)[0].metadata['baz']
1826 with self.assertRaisesRegex(TypeError,
1827 'does not support item assignment'):
1828 fields(C)[0].metadata['test'] = 3
1829
1830 def test_field_metadata_custom_mapping(self):
1831 # Try a custom mapping.
1832 class SimpleNameSpace:
1833 def __init__(self, **kw):
1834 self.__dict__.update(kw)
1835
1836 def __getitem__(self, item):
1837 if item == 'xyzzy':
1838 return 'plugh'
1839 return getattr(self, item)
1840
1841 def __len__(self):
1842 return self.__dict__.__len__()
1843
1844 @dataclass
1845 class C:
1846 i: int = field(metadata=SimpleNameSpace(a=10))
1847
1848 self.assertEqual(len(fields(C)[0].metadata), 1)
1849 self.assertEqual(fields(C)[0].metadata['a'], 10)
1850 with self.assertRaises(AttributeError):
1851 fields(C)[0].metadata['b']
1852 # Make sure we're still talking to our custom mapping.
1853 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1854
1855 def test_generic_dataclasses(self):
1856 T = TypeVar('T')
1857
1858 @dataclass
1859 class LabeledBox(Generic[T]):
1860 content: T
1861 label: str = '<unknown>'
1862
1863 box = LabeledBox(42)
1864 self.assertEqual(box.content, 42)
1865 self.assertEqual(box.label, '<unknown>')
1866
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001867 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001868 Alias = List[LabeledBox[int]]
1869
1870 def test_generic_extending(self):
1871 S = TypeVar('S')
1872 T = TypeVar('T')
1873
1874 @dataclass
1875 class Base(Generic[T, S]):
1876 x: T
1877 y: S
1878
1879 @dataclass
1880 class DataDerived(Base[int, T]):
1881 new_field: str
1882 Alias = DataDerived[str]
1883 c = Alias(0, 'test1', 'test2')
1884 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1885
1886 class NonDataDerived(Base[int, T]):
1887 def new_method(self):
1888 return self.y
1889 Alias = NonDataDerived[float]
1890 c = Alias(10, 1.0)
1891 self.assertEqual(c.new_method(), 1.0)
1892
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001893 def test_generic_dynamic(self):
1894 T = TypeVar('T')
1895
1896 @dataclass
1897 class Parent(Generic[T]):
1898 x: T
1899 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1900 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1901 self.assertIs(Child[int](1, 2).z, None)
1902 self.assertEqual(Child[int](1, 2, 3).z, 3)
1903 self.assertEqual(Child[int](1, 2, 3).other, 42)
1904 # Check that type aliases work correctly.
1905 Alias = Child[T]
1906 self.assertEqual(Alias[int](1, 2).x, 1)
1907 # Check MRO resolution.
1908 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1909
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001910 def test_dataclassses_pickleable(self):
1911 global P, Q, R
1912 @dataclass
1913 class P:
1914 x: int
1915 y: int = 0
1916 @dataclass
1917 class Q:
1918 x: int
1919 y: int = field(default=0, init=False)
1920 @dataclass
1921 class R:
1922 x: int
1923 y: List[int] = field(default_factory=list)
1924 q = Q(1)
1925 q.y = 2
1926 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1927 for sample in samples:
1928 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1929 with self.subTest(sample=sample, proto=proto):
1930 new_sample = pickle.loads(pickle.dumps(sample, proto))
1931 self.assertEqual(sample.x, new_sample.x)
1932 self.assertEqual(sample.y, new_sample.y)
1933 self.assertIsNot(sample, new_sample)
1934 new_sample.x = 42
1935 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1936 self.assertEqual(new_sample.x, another_new_sample.x)
1937 self.assertEqual(sample.y, another_new_sample.y)
1938
Eric V. Smithea8fc522018-01-27 19:07:40 -05001939
Eric V. Smith56970b82018-03-22 16:28:48 -04001940class TestFieldNoAnnotation(unittest.TestCase):
1941 def test_field_without_annotation(self):
1942 with self.assertRaisesRegex(TypeError,
1943 "'f' is a field but has no type annotation"):
1944 @dataclass
1945 class C:
1946 f = field()
1947
1948 def test_field_without_annotation_but_annotation_in_base(self):
1949 @dataclass
1950 class B:
1951 f: int
1952
1953 with self.assertRaisesRegex(TypeError,
1954 "'f' is a field but has no type annotation"):
1955 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001956 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001957 @dataclass
1958 class C(B):
1959 f = field()
1960
1961 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1962 # Same test, but with the base class not a dataclass.
1963 class B:
1964 f: int
1965
1966 with self.assertRaisesRegex(TypeError,
1967 "'f' is a field but has no type annotation"):
1968 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001969 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001970 @dataclass
1971 class C(B):
1972 f = field()
1973
1974
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001975class TestDocString(unittest.TestCase):
1976 def assertDocStrEqual(self, a, b):
1977 # Because 3.6 and 3.7 differ in how inspect.signature work
1978 # (see bpo #32108), for the time being just compare them with
1979 # whitespace stripped.
1980 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1981
1982 def test_existing_docstring_not_overridden(self):
1983 @dataclass
1984 class C:
1985 """Lorem ipsum"""
1986 x: int
1987
1988 self.assertEqual(C.__doc__, "Lorem ipsum")
1989
1990 def test_docstring_no_fields(self):
1991 @dataclass
1992 class C:
1993 pass
1994
1995 self.assertDocStrEqual(C.__doc__, "C()")
1996
1997 def test_docstring_one_field(self):
1998 @dataclass
1999 class C:
2000 x: int
2001
2002 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2003
2004 def test_docstring_two_fields(self):
2005 @dataclass
2006 class C:
2007 x: int
2008 y: int
2009
2010 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2011
2012 def test_docstring_three_fields(self):
2013 @dataclass
2014 class C:
2015 x: int
2016 y: int
2017 z: str
2018
2019 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2020
2021 def test_docstring_one_field_with_default(self):
2022 @dataclass
2023 class C:
2024 x: int = 3
2025
2026 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2027
2028 def test_docstring_one_field_with_default_none(self):
2029 @dataclass
2030 class C:
2031 x: Union[int, type(None)] = None
2032
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002033 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002034
2035 def test_docstring_list_field(self):
2036 @dataclass
2037 class C:
2038 x: List[int]
2039
2040 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2041
2042 def test_docstring_list_field_with_default_factory(self):
2043 @dataclass
2044 class C:
2045 x: List[int] = field(default_factory=list)
2046
2047 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2048
2049 def test_docstring_deque_field(self):
2050 @dataclass
2051 class C:
2052 x: deque
2053
2054 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2055
2056 def test_docstring_deque_field_with_default_factory(self):
2057 @dataclass
2058 class C:
2059 x: deque = field(default_factory=deque)
2060
2061 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2062
2063
Eric V. Smithea8fc522018-01-27 19:07:40 -05002064class TestInit(unittest.TestCase):
2065 def test_base_has_init(self):
2066 class B:
2067 def __init__(self):
2068 self.z = 100
2069 pass
2070
2071 # Make sure that declaring this class doesn't raise an error.
2072 # The issue is that we can't override __init__ in our class,
2073 # but it should be okay to add __init__ to us if our base has
2074 # an __init__.
2075 @dataclass
2076 class C(B):
2077 x: int = 0
2078 c = C(10)
2079 self.assertEqual(c.x, 10)
2080 self.assertNotIn('z', vars(c))
2081
2082 # Make sure that if we don't add an init, the base __init__
2083 # gets called.
2084 @dataclass(init=False)
2085 class C(B):
2086 x: int = 10
2087 c = C()
2088 self.assertEqual(c.x, 10)
2089 self.assertEqual(c.z, 100)
2090
2091 def test_no_init(self):
2092 dataclass(init=False)
2093 class C:
2094 i: int = 0
2095 self.assertEqual(C().i, 0)
2096
2097 dataclass(init=False)
2098 class C:
2099 i: int = 2
2100 def __init__(self):
2101 self.i = 3
2102 self.assertEqual(C().i, 3)
2103
2104 def test_overwriting_init(self):
2105 # If the class has __init__, use it no matter the value of
2106 # init=.
2107
2108 @dataclass
2109 class C:
2110 x: int
2111 def __init__(self, x):
2112 self.x = 2 * x
2113 self.assertEqual(C(3).x, 6)
2114
2115 @dataclass(init=True)
2116 class C:
2117 x: int
2118 def __init__(self, x):
2119 self.x = 2 * x
2120 self.assertEqual(C(4).x, 8)
2121
2122 @dataclass(init=False)
2123 class C:
2124 x: int
2125 def __init__(self, x):
2126 self.x = 2 * x
2127 self.assertEqual(C(5).x, 10)
2128
2129
2130class TestRepr(unittest.TestCase):
2131 def test_repr(self):
2132 @dataclass
2133 class B:
2134 x: int
2135
2136 @dataclass
2137 class C(B):
2138 y: int = 10
2139
2140 o = C(4)
2141 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2142
2143 @dataclass
2144 class D(C):
2145 x: int = 20
2146 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2147
2148 @dataclass
2149 class C:
2150 @dataclass
2151 class D:
2152 i: int
2153 @dataclass
2154 class E:
2155 pass
2156 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2157 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2158
2159 def test_no_repr(self):
2160 # Test a class with no __repr__ and repr=False.
2161 @dataclass(repr=False)
2162 class C:
2163 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002164 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002165 repr(C(3)))
2166
2167 # Test a class with a __repr__ and repr=False.
2168 @dataclass(repr=False)
2169 class C:
2170 x: int
2171 def __repr__(self):
2172 return 'C-class'
2173 self.assertEqual(repr(C(3)), 'C-class')
2174
2175 def test_overwriting_repr(self):
2176 # If the class has __repr__, use it no matter the value of
2177 # repr=.
2178
2179 @dataclass
2180 class C:
2181 x: int
2182 def __repr__(self):
2183 return 'x'
2184 self.assertEqual(repr(C(0)), 'x')
2185
2186 @dataclass(repr=True)
2187 class C:
2188 x: int
2189 def __repr__(self):
2190 return 'x'
2191 self.assertEqual(repr(C(0)), 'x')
2192
2193 @dataclass(repr=False)
2194 class C:
2195 x: int
2196 def __repr__(self):
2197 return 'x'
2198 self.assertEqual(repr(C(0)), 'x')
2199
2200
Eric V. Smithea8fc522018-01-27 19:07:40 -05002201class TestEq(unittest.TestCase):
2202 def test_no_eq(self):
2203 # Test a class with no __eq__ and eq=False.
2204 @dataclass(eq=False)
2205 class C:
2206 x: int
2207 self.assertNotEqual(C(0), C(0))
2208 c = C(3)
2209 self.assertEqual(c, c)
2210
2211 # Test a class with an __eq__ and eq=False.
2212 @dataclass(eq=False)
2213 class C:
2214 x: int
2215 def __eq__(self, other):
2216 return other == 10
2217 self.assertEqual(C(3), 10)
2218
2219 def test_overwriting_eq(self):
2220 # If the class has __eq__, use it no matter the value of
2221 # eq=.
2222
2223 @dataclass
2224 class C:
2225 x: int
2226 def __eq__(self, other):
2227 return other == 3
2228 self.assertEqual(C(1), 3)
2229 self.assertNotEqual(C(1), 1)
2230
2231 @dataclass(eq=True)
2232 class C:
2233 x: int
2234 def __eq__(self, other):
2235 return other == 4
2236 self.assertEqual(C(1), 4)
2237 self.assertNotEqual(C(1), 1)
2238
2239 @dataclass(eq=False)
2240 class C:
2241 x: int
2242 def __eq__(self, other):
2243 return other == 5
2244 self.assertEqual(C(1), 5)
2245 self.assertNotEqual(C(1), 1)
2246
2247
2248class TestOrdering(unittest.TestCase):
2249 def test_functools_total_ordering(self):
2250 # Test that functools.total_ordering works with this class.
2251 @total_ordering
2252 @dataclass
2253 class C:
2254 x: int
2255 def __lt__(self, other):
2256 # Perform the test "backward", just to make
2257 # sure this is being called.
2258 return self.x >= other
2259
2260 self.assertLess(C(0), -1)
2261 self.assertLessEqual(C(0), -1)
2262 self.assertGreater(C(0), 1)
2263 self.assertGreaterEqual(C(0), 1)
2264
2265 def test_no_order(self):
2266 # Test that no ordering functions are added by default.
2267 @dataclass(order=False)
2268 class C:
2269 x: int
2270 # Make sure no order methods are added.
2271 self.assertNotIn('__le__', C.__dict__)
2272 self.assertNotIn('__lt__', C.__dict__)
2273 self.assertNotIn('__ge__', C.__dict__)
2274 self.assertNotIn('__gt__', C.__dict__)
2275
2276 # Test that __lt__ is still called
2277 @dataclass(order=False)
2278 class C:
2279 x: int
2280 def __lt__(self, other):
2281 return False
2282 # Make sure other methods aren't added.
2283 self.assertNotIn('__le__', C.__dict__)
2284 self.assertNotIn('__ge__', C.__dict__)
2285 self.assertNotIn('__gt__', C.__dict__)
2286
2287 def test_overwriting_order(self):
2288 with self.assertRaisesRegex(TypeError,
2289 'Cannot overwrite attribute __lt__'
2290 '.*using functools.total_ordering'):
2291 @dataclass(order=True)
2292 class C:
2293 x: int
2294 def __lt__(self):
2295 pass
2296
2297 with self.assertRaisesRegex(TypeError,
2298 'Cannot overwrite attribute __le__'
2299 '.*using functools.total_ordering'):
2300 @dataclass(order=True)
2301 class C:
2302 x: int
2303 def __le__(self):
2304 pass
2305
2306 with self.assertRaisesRegex(TypeError,
2307 'Cannot overwrite attribute __gt__'
2308 '.*using functools.total_ordering'):
2309 @dataclass(order=True)
2310 class C:
2311 x: int
2312 def __gt__(self):
2313 pass
2314
2315 with self.assertRaisesRegex(TypeError,
2316 'Cannot overwrite attribute __ge__'
2317 '.*using functools.total_ordering'):
2318 @dataclass(order=True)
2319 class C:
2320 x: int
2321 def __ge__(self):
2322 pass
2323
2324class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002325 def test_unsafe_hash(self):
2326 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002327 class C:
2328 x: int
2329 y: str
2330 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2331
Eric V. Smithea8fc522018-01-27 19:07:40 -05002332 def test_hash_rules(self):
2333 def non_bool(value):
2334 # Map to something else that's True, but not a bool.
2335 if value is None:
2336 return None
2337 if value:
2338 return (3,)
2339 return 0
2340
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002341 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2342 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2343 frozen=frozen):
2344 if result != 'exception':
2345 if with_hash:
2346 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2347 class C:
2348 def __hash__(self):
2349 return 0
2350 else:
2351 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2352 class C:
2353 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002354
2355 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002356 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002357 # __hash__ contains the function we generated.
2358 self.assertIn('__hash__', C.__dict__)
2359 self.assertIsNotNone(C.__dict__['__hash__'])
2360
Eric V. Smithea8fc522018-01-27 19:07:40 -05002361 elif result == '':
2362 # __hash__ is not present in our class.
2363 if not with_hash:
2364 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002365
Eric V. Smithea8fc522018-01-27 19:07:40 -05002366 elif result == 'none':
2367 # __hash__ is set to None.
2368 self.assertIn('__hash__', C.__dict__)
2369 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002370
2371 elif result == 'exception':
2372 # Creating the class should cause an exception.
2373 # This only happens with with_hash==True.
2374 assert(with_hash)
2375 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2376 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2377 class C:
2378 def __hash__(self):
2379 return 0
2380
Eric V. Smithea8fc522018-01-27 19:07:40 -05002381 else:
2382 assert False, f'unknown result {result!r}'
2383
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002384 # There are 8 cases of:
2385 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002386 # eq=True/False
2387 # frozen=True/False
2388 # And for each of these, a different result if
2389 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002390 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2391 (False, False, False, '', ''),
2392 (False, False, True, '', ''),
2393 (False, True, False, 'none', ''),
2394 (False, True, True, 'fn', ''),
2395 (True, False, False, 'fn', 'exception'),
2396 (True, False, True, 'fn', 'exception'),
2397 (True, True, False, 'fn', 'exception'),
2398 (True, True, True, 'fn', 'exception'),
2399 ], 1):
2400 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2401 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002402
2403 # Test non-bool truth values, too. This is just to
2404 # make sure the data-driven table in the decorator
2405 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002406 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2407 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002408
2409
2410 def test_eq_only(self):
2411 # If a class defines __eq__, __hash__ is automatically added
2412 # and set to None. This is normal Python behavior, not
2413 # related to dataclasses. Make sure we don't interfere with
2414 # that (see bpo=32546).
2415
2416 @dataclass
2417 class C:
2418 i: int
2419 def __eq__(self, other):
2420 return self.i == other.i
2421 self.assertEqual(C(1), C(1))
2422 self.assertNotEqual(C(1), C(4))
2423
2424 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002425 # unsafe_hash=True.
2426 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002427 class C:
2428 i: int
2429 def __eq__(self, other):
2430 return self.i == other.i
2431 self.assertEqual(C(1), C(1.0))
2432 self.assertEqual(hash(C(1)), hash(C(1.0)))
2433
2434 # And check that the classes __eq__ is being used, despite
2435 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002436 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002437 class C:
2438 i: int
2439 def __eq__(self, other):
2440 return self.i == 3 and self.i == other.i
2441 self.assertEqual(C(3), C(3))
2442 self.assertNotEqual(C(1), C(1))
2443 self.assertEqual(hash(C(1)), hash(C(1.0)))
2444
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002445 def test_0_field_hash(self):
2446 @dataclass(frozen=True)
2447 class C:
2448 pass
2449 self.assertEqual(hash(C()), hash(()))
2450
2451 @dataclass(unsafe_hash=True)
2452 class C:
2453 pass
2454 self.assertEqual(hash(C()), hash(()))
2455
2456 def test_1_field_hash(self):
2457 @dataclass(frozen=True)
2458 class C:
2459 x: int
2460 self.assertEqual(hash(C(4)), hash((4,)))
2461 self.assertEqual(hash(C(42)), hash((42,)))
2462
2463 @dataclass(unsafe_hash=True)
2464 class C:
2465 x: int
2466 self.assertEqual(hash(C(4)), hash((4,)))
2467 self.assertEqual(hash(C(42)), hash((42,)))
2468
Eric V. Smith718070d2018-02-23 13:01:31 -05002469 def test_hash_no_args(self):
2470 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002471 # make sure that if the @dataclass parameter name is changed
2472 # or the non-default hashing behavior changes, the default
2473 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002474
2475 class Base:
2476 def __hash__(self):
2477 return 301
2478
2479 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002480 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002481 for frozen, eq, base, expected in [
2482 (None, None, object, 'unhashable'),
2483 (None, None, Base, 'unhashable'),
2484 (None, False, object, 'object'),
2485 (None, False, Base, 'base'),
2486 (None, True, object, 'unhashable'),
2487 (None, True, Base, 'unhashable'),
2488 (False, None, object, 'unhashable'),
2489 (False, None, Base, 'unhashable'),
2490 (False, False, object, 'object'),
2491 (False, False, Base, 'base'),
2492 (False, True, object, 'unhashable'),
2493 (False, True, Base, 'unhashable'),
2494 (True, None, object, 'tuple'),
2495 (True, None, Base, 'tuple'),
2496 (True, False, object, 'object'),
2497 (True, False, Base, 'base'),
2498 (True, True, object, 'tuple'),
2499 (True, True, Base, 'tuple'),
2500 ]:
2501
2502 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2503 # First, create the class.
2504 if frozen is None and eq is None:
2505 @dataclass
2506 class C(base):
2507 i: int
2508 elif frozen is None:
2509 @dataclass(eq=eq)
2510 class C(base):
2511 i: int
2512 elif eq is None:
2513 @dataclass(frozen=frozen)
2514 class C(base):
2515 i: int
2516 else:
2517 @dataclass(frozen=frozen, eq=eq)
2518 class C(base):
2519 i: int
2520
2521 # Now, make sure it hashes as expected.
2522 if expected == 'unhashable':
2523 c = C(10)
2524 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2525 hash(c)
2526
2527 elif expected == 'base':
2528 self.assertEqual(hash(C(10)), 301)
2529
2530 elif expected == 'object':
2531 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002532 # hash isn't based on id(), so calling hash()
2533 # won't tell us much. So, just check the
2534 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002535 self.assertIs(C.__hash__, object.__hash__)
2536
2537 elif expected == 'tuple':
2538 self.assertEqual(hash(C(42)), hash((42,)))
2539
2540 else:
2541 assert False, f'unknown value for expected={expected!r}'
2542
Eric V. Smithea8fc522018-01-27 19:07:40 -05002543
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002544class TestFrozen(unittest.TestCase):
2545 def test_frozen(self):
2546 @dataclass(frozen=True)
2547 class C:
2548 i: int
2549
2550 c = C(10)
2551 self.assertEqual(c.i, 10)
2552 with self.assertRaises(FrozenInstanceError):
2553 c.i = 5
2554 self.assertEqual(c.i, 10)
2555
2556 def test_inherit(self):
2557 @dataclass(frozen=True)
2558 class C:
2559 i: int
2560
2561 @dataclass(frozen=True)
2562 class D(C):
2563 j: int
2564
2565 d = D(0, 10)
2566 with self.assertRaises(FrozenInstanceError):
2567 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002568 with self.assertRaises(FrozenInstanceError):
2569 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002570 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002571 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002572
Eric V. Smithf199bc62018-03-18 20:40:34 -04002573 # Test both ways: with an intermediate normal (non-dataclass)
2574 # class and without an intermediate class.
2575 def test_inherit_nonfrozen_from_frozen(self):
2576 for intermediate_class in [True, False]:
2577 with self.subTest(intermediate_class=intermediate_class):
2578 @dataclass(frozen=True)
2579 class C:
2580 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002581
Eric V. Smithf199bc62018-03-18 20:40:34 -04002582 if intermediate_class:
2583 class I(C): pass
2584 else:
2585 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002586
Eric V. Smithf199bc62018-03-18 20:40:34 -04002587 with self.assertRaisesRegex(TypeError,
2588 'cannot inherit non-frozen dataclass from a frozen one'):
2589 @dataclass
2590 class D(I):
2591 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002592
Eric V. Smithf199bc62018-03-18 20:40:34 -04002593 def test_inherit_frozen_from_nonfrozen(self):
2594 for intermediate_class in [True, False]:
2595 with self.subTest(intermediate_class=intermediate_class):
2596 @dataclass
2597 class C:
2598 i: int
2599
2600 if intermediate_class:
2601 class I(C): pass
2602 else:
2603 I = C
2604
2605 with self.assertRaisesRegex(TypeError,
2606 'cannot inherit frozen dataclass from a non-frozen one'):
2607 @dataclass(frozen=True)
2608 class D(I):
2609 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002610
2611 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002612 for intermediate_class in [True, False]:
2613 with self.subTest(intermediate_class=intermediate_class):
2614 class C:
2615 pass
2616
2617 if intermediate_class:
2618 class I(C): pass
2619 else:
2620 I = C
2621
2622 @dataclass(frozen=True)
2623 class D(I):
2624 i: int
2625
2626 d = D(10)
2627 with self.assertRaises(FrozenInstanceError):
2628 d.i = 5
2629
2630 def test_non_frozen_normal_derived(self):
2631 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002632
2633 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002634 class D:
2635 x: int
2636 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002637
Eric V. Smithf199bc62018-03-18 20:40:34 -04002638 class S(D):
2639 pass
2640
2641 s = S(3)
2642 self.assertEqual(s.x, 3)
2643 self.assertEqual(s.y, 10)
2644 s.cached = True
2645
2646 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002647 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002648 s.x = 5
2649 with self.assertRaises(FrozenInstanceError):
2650 s.y = 5
2651 self.assertEqual(s.x, 3)
2652 self.assertEqual(s.y, 10)
2653 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002654
Eric V. Smith74940912018-04-05 06:50:18 -04002655 def test_overwriting_frozen(self):
2656 # frozen uses __setattr__ and __delattr__.
2657 with self.assertRaisesRegex(TypeError,
2658 'Cannot overwrite attribute __setattr__'):
2659 @dataclass(frozen=True)
2660 class C:
2661 x: int
2662 def __setattr__(self):
2663 pass
2664
2665 with self.assertRaisesRegex(TypeError,
2666 'Cannot overwrite attribute __delattr__'):
2667 @dataclass(frozen=True)
2668 class C:
2669 x: int
2670 def __delattr__(self):
2671 pass
2672
2673 @dataclass(frozen=False)
2674 class C:
2675 x: int
2676 def __setattr__(self, name, value):
2677 self.__dict__['x'] = value * 2
2678 self.assertEqual(C(10).x, 20)
2679
2680 def test_frozen_hash(self):
2681 @dataclass(frozen=True)
2682 class C:
2683 x: Any
2684
2685 # If x is immutable, we can compute the hash. No exception is
2686 # raised.
2687 hash(C(3))
2688
2689 # If x is mutable, computing the hash is an error.
2690 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2691 hash(C({}))
2692
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002693
Eric V. Smith7389fd92018-03-19 21:07:51 -04002694class TestSlots(unittest.TestCase):
2695 def test_simple(self):
2696 @dataclass
2697 class C:
2698 __slots__ = ('x',)
2699 x: Any
2700
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002701 # There was a bug where a variable in a slot was assumed to
2702 # also have a default value (of type
2703 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002704 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002705 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002706 C()
2707
2708 # We can create an instance, and assign to x.
2709 c = C(10)
2710 self.assertEqual(c.x, 10)
2711 c.x = 5
2712 self.assertEqual(c.x, 5)
2713
2714 # We can't assign to anything else.
2715 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2716 c.y = 5
2717
2718 def test_derived_added_field(self):
2719 # See bpo-33100.
2720 @dataclass
2721 class Base:
2722 __slots__ = ('x',)
2723 x: Any
2724
2725 @dataclass
2726 class Derived(Base):
2727 x: int
2728 y: int
2729
2730 d = Derived(1, 2)
2731 self.assertEqual((d.x, d.y), (1, 2))
2732
2733 # We can add a new field to the derived instance.
2734 d.z = 10
2735
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002736class TestDescriptors(unittest.TestCase):
2737 def test_set_name(self):
2738 # See bpo-33141.
2739
2740 # Create a descriptor.
2741 class D:
2742 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002743 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002744 def __get__(self, instance, owner):
2745 if instance is not None:
2746 return 1
2747 return self
2748
2749 # This is the case of just normal descriptor behavior, no
2750 # dataclass code is involved in initializing the descriptor.
2751 @dataclass
2752 class C:
2753 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002754 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002755
2756 # Now test with a default value and init=False, which is the
2757 # only time this is really meaningful. If not using
2758 # init=False, then the descriptor will be overwritten, anyway.
2759 @dataclass
2760 class C:
2761 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002762 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002763 self.assertEqual(C().c, 1)
2764
2765 def test_non_descriptor(self):
2766 # PEP 487 says __set_name__ should work on non-descriptors.
2767 # Create a descriptor.
2768
2769 class D:
2770 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002771 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002772
2773 @dataclass
2774 class C:
2775 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002776 self.assertEqual(C.c.name, 'cx')
2777
2778 def test_lookup_on_instance(self):
2779 # See bpo-33175.
2780 class D:
2781 pass
2782
2783 d = D()
2784 # Create an attribute on the instance, not type.
2785 d.__set_name__ = Mock()
2786
2787 # Make sure d.__set_name__ is not called.
2788 @dataclass
2789 class C:
2790 i: int=field(default=d, init=False)
2791
2792 self.assertEqual(d.__set_name__.call_count, 0)
2793
2794 def test_lookup_on_class(self):
2795 # See bpo-33175.
2796 class D:
2797 pass
2798 D.__set_name__ = Mock()
2799
2800 # Make sure D.__set_name__ is called.
2801 @dataclass
2802 class C:
2803 i: int=field(default=D(), init=False)
2804
2805 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002806
Eric V. Smith7389fd92018-03-19 21:07:51 -04002807
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002808class TestStringAnnotations(unittest.TestCase):
2809 def test_classvar(self):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002810 # These tests assume that both "import typing" and "from
2811 # typing import *" have been run in this file.
2812 for typestr in ('ClassVar[int]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002813 'ClassVar [int]',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002814 ' ClassVar [int]',
2815 'ClassVar',
2816 ' ClassVar ',
2817 'typing.ClassVar[int]',
2818 'typing.ClassVar[str]',
2819 ' typing.ClassVar[str]',
2820 'typing .ClassVar[str]',
2821 'typing. ClassVar[str]',
2822 'typing.ClassVar [str]',
2823 'typing.ClassVar [ str]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002824 # Double stringified
2825 '"typing.ClassVar[int]"',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002826 # Not syntactically valid, but these will
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002827 # be treated as ClassVars.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002828 'typing.ClassVar.[int]',
2829 'typing.ClassVar+',
2830 ):
2831 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002832 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002833 # x is a ClassVar, so C() takes no args.
2834 C()
2835
2836 # And it won't appear in the class's dict because it doesn't
2837 # have a default.
2838 self.assertNotIn('x', C.__dict__)
2839
2840 def test_isnt_classvar(self):
2841 for typestr in ('CV',
2842 't.ClassVar',
2843 't.ClassVar[int]',
2844 'typing..ClassVar[int]',
2845 'Classvar',
2846 'Classvar[int]',
2847 'typing.ClassVarx[int]',
2848 'typong.ClassVar[int]',
2849 'dataclasses.ClassVar[int]',
2850 'typingxClassVar[str]',
2851 ):
2852 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002853 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002854
2855 # x is not a ClassVar, so C() takes one arg.
2856 self.assertEqual(C(10).x, 10)
2857
2858 def test_initvar(self):
2859 # These tests assume that both "import dataclasses" and "from
2860 # dataclasses import *" have been run in this file.
2861 for typestr in ('InitVar[int]',
2862 'InitVar [int]'
2863 ' InitVar [int]',
2864 'InitVar',
2865 ' InitVar ',
2866 'dataclasses.InitVar[int]',
2867 'dataclasses.InitVar[str]',
2868 ' dataclasses.InitVar[str]',
2869 'dataclasses .InitVar[str]',
2870 'dataclasses. InitVar[str]',
2871 'dataclasses.InitVar [str]',
2872 'dataclasses.InitVar [ str]',
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002873 # Double stringified
2874 '"dataclasses.InitVar[int]"',
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002875 # Not syntactically valid, but these will
2876 # be treated as InitVars.
2877 'dataclasses.InitVar.[int]',
2878 'dataclasses.InitVar+',
2879 ):
2880 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002881 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
2882
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002883
2884 # x is an InitVar, so doesn't create a member.
2885 with self.assertRaisesRegex(AttributeError,
2886 "object has no attribute 'x'"):
2887 C(1).x
2888
2889 def test_isnt_initvar(self):
2890 for typestr in ('IV',
2891 'dc.InitVar',
2892 'xdataclasses.xInitVar',
2893 'typing.xInitVar[int]',
2894 ):
2895 with self.subTest(typestr=typestr):
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002896 C = dataclass(type("C", (), {"__annotations__": {"x": typestr}}))
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002897
2898 # x is not an InitVar, so there will be a member x.
2899 self.assertEqual(C(10).x, 10)
2900
2901 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002902 from test import dataclass_module_1
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002903 from test import dataclass_module_2
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002904
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002905 for m in (dataclass_module_1,
2906 dataclass_module_2):
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002907 with self.subTest(m=m):
2908 # There's a difference in how the ClassVars are
2909 # interpreted when using string annotations or
2910 # not. See the imported modules for details.
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002911 c = m.CV(10)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002912 self.assertEqual(c.cv0, 20)
2913
2914
2915 # There's a difference in how the InitVars are
2916 # interpreted when using string annotations or
2917 # not. See the imported modules for details.
2918 c = m.IV(0, 1, 2, 3, 4)
2919
2920 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2921 with self.subTest(field_name=field_name):
2922 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2923 # Since field_name is an InitVar, it's
2924 # not an instance field.
2925 getattr(c, field_name)
2926
Batuhan Taskaya044a1042020-10-06 23:03:02 +03002927 # iv4 is interpreted as a normal field.
2928 self.assertIn('not_iv4', c.__dict__)
2929 self.assertEqual(c.not_iv4, 4)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002930
Yury Selivanovd219cc42019-12-09 09:54:20 -05002931 def test_text_annotations(self):
2932 from test import dataclass_textanno
2933
2934 self.assertEqual(
2935 get_type_hints(dataclass_textanno.Bar),
2936 {'foo': dataclass_textanno.Foo})
2937 self.assertEqual(
2938 get_type_hints(dataclass_textanno.Bar.__init__),
2939 {'foo': dataclass_textanno.Foo,
2940 'return': type(None)})
2941
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002942
Eric V. Smith4e812962018-05-16 11:31:29 -04002943class TestMakeDataclass(unittest.TestCase):
2944 def test_simple(self):
2945 C = make_dataclass('C',
2946 [('x', int),
2947 ('y', int, field(default=5))],
2948 namespace={'add_one': lambda self: self.x + 1})
2949 c = C(10)
2950 self.assertEqual((c.x, c.y), (10, 5))
2951 self.assertEqual(c.add_one(), 11)
2952
2953
2954 def test_no_mutate_namespace(self):
2955 # Make sure a provided namespace isn't mutated.
2956 ns = {}
2957 C = make_dataclass('C',
2958 [('x', int),
2959 ('y', int, field(default=5))],
2960 namespace=ns)
2961 self.assertEqual(ns, {})
2962
2963 def test_base(self):
2964 class Base1:
2965 pass
2966 class Base2:
2967 pass
2968 C = make_dataclass('C',
2969 [('x', int)],
2970 bases=(Base1, Base2))
2971 c = C(2)
2972 self.assertIsInstance(c, C)
2973 self.assertIsInstance(c, Base1)
2974 self.assertIsInstance(c, Base2)
2975
2976 def test_base_dataclass(self):
2977 @dataclass
2978 class Base1:
2979 x: int
2980 class Base2:
2981 pass
2982 C = make_dataclass('C',
2983 [('y', int)],
2984 bases=(Base1, Base2))
2985 with self.assertRaisesRegex(TypeError, 'required positional'):
2986 c = C(2)
2987 c = C(1, 2)
2988 self.assertIsInstance(c, C)
2989 self.assertIsInstance(c, Base1)
2990 self.assertIsInstance(c, Base2)
2991
2992 self.assertEqual((c.x, c.y), (1, 2))
2993
2994 def test_init_var(self):
2995 def post_init(self, y):
2996 self.x *= y
2997
2998 C = make_dataclass('C',
2999 [('x', int),
3000 ('y', InitVar[int]),
3001 ],
3002 namespace={'__post_init__': post_init},
3003 )
3004 c = C(2, 3)
3005 self.assertEqual(vars(c), {'x': 6})
3006 self.assertEqual(len(fields(c)), 1)
3007
3008 def test_class_var(self):
3009 C = make_dataclass('C',
3010 [('x', int),
3011 ('y', ClassVar[int], 10),
3012 ('z', ClassVar[int], field(default=20)),
3013 ])
3014 c = C(1)
3015 self.assertEqual(vars(c), {'x': 1})
3016 self.assertEqual(len(fields(c)), 1)
3017 self.assertEqual(C.y, 10)
3018 self.assertEqual(C.z, 20)
3019
3020 def test_other_params(self):
3021 C = make_dataclass('C',
3022 [('x', int),
3023 ('y', ClassVar[int], 10),
3024 ('z', ClassVar[int], field(default=20)),
3025 ],
3026 init=False)
3027 # Make sure we have a repr, but no init.
3028 self.assertNotIn('__init__', vars(C))
3029 self.assertIn('__repr__', vars(C))
3030
3031 # Make sure random other params don't work.
3032 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3033 C = make_dataclass('C',
3034 [],
3035 xxinit=False)
3036
3037 def test_no_types(self):
3038 C = make_dataclass('Point', ['x', 'y', 'z'])
3039 c = C(1, 2, 3)
3040 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3041 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3042 'y': 'typing.Any',
3043 'z': 'typing.Any'})
3044
3045 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3046 c = C(1, 2, 3)
3047 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3048 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3049 'y': int,
3050 'z': 'typing.Any'})
3051
3052 def test_invalid_type_specification(self):
3053 for bad_field in [(),
3054 (1, 2, 3, 4),
3055 ]:
3056 with self.subTest(bad_field=bad_field):
3057 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3058 make_dataclass('C', ['a', bad_field])
3059
3060 # And test for things with no len().
3061 for bad_field in [float,
3062 lambda x:x,
3063 ]:
3064 with self.subTest(bad_field=bad_field):
3065 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3066 make_dataclass('C', ['a', bad_field])
3067
3068 def test_duplicate_field_names(self):
3069 for field in ['a', 'ab']:
3070 with self.subTest(field=field):
3071 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3072 make_dataclass('C', [field, 'a', field])
3073
3074 def test_keyword_field_names(self):
3075 for field in ['for', 'async', 'await', 'as']:
3076 with self.subTest(field=field):
3077 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3078 make_dataclass('C', ['a', field])
3079 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3080 make_dataclass('C', [field])
3081 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3082 make_dataclass('C', [field, 'a'])
3083
3084 def test_non_identifier_field_names(self):
3085 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3086 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003087 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003088 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003089 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003090 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003091 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003092 make_dataclass('C', [field, 'a'])
3093
3094 def test_underscore_field_names(self):
3095 # Unlike namedtuple, it's okay if dataclass field names have
3096 # an underscore.
3097 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3098
3099 def test_funny_class_names_names(self):
3100 # No reason to prevent weird class names, since
3101 # types.new_class allows them.
3102 for classname in ['()', 'x,y', '*', '2@3', '']:
3103 with self.subTest(classname=classname):
3104 C = make_dataclass(classname, ['a', 'b'])
3105 self.assertEqual(C.__name__, classname)
3106
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003107class TestReplace(unittest.TestCase):
3108 def test(self):
3109 @dataclass(frozen=True)
3110 class C:
3111 x: int
3112 y: int
3113
3114 c = C(1, 2)
3115 c1 = replace(c, x=3)
3116 self.assertEqual(c1.x, 3)
3117 self.assertEqual(c1.y, 2)
3118
3119 def test_frozen(self):
3120 @dataclass(frozen=True)
3121 class C:
3122 x: int
3123 y: int
3124 z: int = field(init=False, default=10)
3125 t: int = field(init=False, default=100)
3126
3127 c = C(1, 2)
3128 c1 = replace(c, x=3)
3129 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3130 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3131
3132
3133 with self.assertRaisesRegex(ValueError, 'init=False'):
3134 replace(c, x=3, z=20, t=50)
3135 with self.assertRaisesRegex(ValueError, 'init=False'):
3136 replace(c, z=20)
3137 replace(c, x=3, z=20, t=50)
3138
3139 # Make sure the result is still frozen.
3140 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3141 c1.x = 3
3142
3143 # Make sure we can't replace an attribute that doesn't exist,
3144 # if we're also replacing one that does exist. Test this
3145 # here, because setting attributes on frozen instances is
3146 # handled slightly differently from non-frozen ones.
3147 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3148 "keyword argument 'a'"):
3149 c1 = replace(c, x=20, a=5)
3150
3151 def test_invalid_field_name(self):
3152 @dataclass(frozen=True)
3153 class C:
3154 x: int
3155 y: int
3156
3157 c = C(1, 2)
3158 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3159 "keyword argument 'z'"):
3160 c1 = replace(c, z=3)
3161
3162 def test_invalid_object(self):
3163 @dataclass(frozen=True)
3164 class C:
3165 x: int
3166 y: int
3167
3168 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3169 replace(C, x=3)
3170
3171 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3172 replace(0, x=3)
3173
3174 def test_no_init(self):
3175 @dataclass
3176 class C:
3177 x: int
3178 y: int = field(init=False, default=10)
3179
3180 c = C(1)
3181 c.y = 20
3182
3183 # Make sure y gets the default value.
3184 c1 = replace(c, x=5)
3185 self.assertEqual((c1.x, c1.y), (5, 10))
3186
3187 # Trying to replace y is an error.
3188 with self.assertRaisesRegex(ValueError, 'init=False'):
3189 replace(c, x=2, y=30)
3190
3191 with self.assertRaisesRegex(ValueError, 'init=False'):
3192 replace(c, y=30)
3193
3194 def test_classvar(self):
3195 @dataclass
3196 class C:
3197 x: int
3198 y: ClassVar[int] = 1000
3199
3200 c = C(1)
3201 d = C(2)
3202
3203 self.assertIs(c.y, d.y)
3204 self.assertEqual(c.y, 1000)
3205
3206 # Trying to replace y is an error: can't replace ClassVars.
3207 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3208 "unexpected keyword argument 'y'"):
3209 replace(c, y=30)
3210
3211 replace(c, x=5)
3212
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003213 def test_initvar_is_specified(self):
3214 @dataclass
3215 class C:
3216 x: int
3217 y: InitVar[int]
3218
3219 def __post_init__(self, y):
3220 self.x *= y
3221
3222 c = C(1, 10)
3223 self.assertEqual(c.x, 10)
3224 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3225 "specified with replace()"):
3226 replace(c, x=3)
3227 c = replace(c, x=3, y=5)
3228 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303229
3230 def test_recursive_repr(self):
3231 @dataclass
3232 class C:
3233 f: "C"
3234
3235 c = C(None)
3236 c.f = c
3237 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3238
3239 def test_recursive_repr_two_attrs(self):
3240 @dataclass
3241 class C:
3242 f: "C"
3243 g: "C"
3244
3245 c = C(None, None)
3246 c.f = c
3247 c.g = c
3248 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3249 ".<locals>.C(f=..., g=...)")
3250
3251 def test_recursive_repr_indirection(self):
3252 @dataclass
3253 class C:
3254 f: "D"
3255
3256 @dataclass
3257 class D:
3258 f: "C"
3259
3260 c = C(None)
3261 d = D(None)
3262 c.f = d
3263 d.f = c
3264 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3265 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3266 ".<locals>.D(f=...))")
3267
3268 def test_recursive_repr_indirection_two(self):
3269 @dataclass
3270 class C:
3271 f: "D"
3272
3273 @dataclass
3274 class D:
3275 f: "E"
3276
3277 @dataclass
3278 class E:
3279 f: "C"
3280
3281 c = C(None)
3282 d = D(None)
3283 e = E(None)
3284 c.f = d
3285 d.f = e
3286 e.f = c
3287 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3288 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3289 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3290 ".<locals>.E(f=...)))")
3291
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303292 def test_recursive_repr_misc_attrs(self):
3293 @dataclass
3294 class C:
3295 f: "C"
3296 g: int
3297
3298 c = C(None, 1)
3299 c.f = c
3300 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3301 ".<locals>.C(f=..., g=1)")
3302
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003303 ## def test_initvar(self):
3304 ## @dataclass
3305 ## class C:
3306 ## x: int
3307 ## y: InitVar[int]
3308
3309 ## c = C(1, 10)
3310 ## d = C(2, 20)
3311
3312 ## # In our case, replacing an InitVar is a no-op
3313 ## self.assertEqual(c, replace(c, y=5))
3314
3315 ## replace(c, x=5)
3316
Ben Avrahamibef7d292020-10-06 20:40:50 +03003317class TestAbstract(unittest.TestCase):
3318 def test_abc_implementation(self):
3319 class Ordered(abc.ABC):
3320 @abc.abstractmethod
3321 def __lt__(self, other):
3322 pass
3323
3324 @abc.abstractmethod
3325 def __le__(self, other):
3326 pass
3327
3328 @dataclass(order=True)
3329 class Date(Ordered):
3330 year: int
3331 month: 'Month'
3332 day: 'int'
3333
3334 self.assertFalse(inspect.isabstract(Date))
3335 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3336
3337 def test_maintain_abc(self):
3338 class A(abc.ABC):
3339 @abc.abstractmethod
3340 def foo(self):
3341 pass
3342
3343 @dataclass
3344 class Date(A):
3345 year: int
3346 month: 'Month'
3347 day: 'int'
3348
3349 self.assertTrue(inspect.isabstract(Date))
3350 msg = 'class Date with abstract method foo'
3351 self.assertRaisesRegex(TypeError, msg, Date)
3352
Eric V. Smith4e812962018-05-16 11:31:29 -04003353
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003354if __name__ == '__main__':
3355 unittest.main()