blob: b31a469ec79227b4f2f2595f8c0afe81fee03174 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
Ben Avrahamibef7d292020-10-06 20:40:50 +03007import abc
Eric V. Smithf0db54a2017-12-04 16:58:55 -05008import pickle
9import inspect
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +030010import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050011import unittest
12from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010013from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Yury Selivanovd219cc42019-12-09 09:54:20 -050014from typing import get_type_hints
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050016from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050017
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040018import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
19import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
20
Eric V. Smithf0db54a2017-12-04 16:58:55 -050021# Just any custom exception we can catch.
22class CustomError(Exception): pass
23
24class TestCase(unittest.TestCase):
25 def test_no_fields(self):
26 @dataclass
27 class C:
28 pass
29
30 o = C()
31 self.assertEqual(len(fields(C)), 0)
32
Eric V. Smith56970b82018-03-22 16:28:48 -040033 def test_no_fields_but_member_variable(self):
34 @dataclass
35 class C:
36 i = 0
37
38 o = C()
39 self.assertEqual(len(fields(C)), 0)
40
Eric V. Smithf0db54a2017-12-04 16:58:55 -050041 def test_one_field_no_default(self):
42 @dataclass
43 class C:
44 x: int
45
46 o = C(42)
47 self.assertEqual(o.x, 42)
48
Karthikeyan Singaravelaneef1b022020-01-09 19:11:46 +053049 def test_field_default_default_factory_error(self):
50 msg = "cannot specify both default and default_factory"
51 with self.assertRaisesRegex(ValueError, msg):
52 @dataclass
53 class C:
54 x: int = field(default=1, default_factory=int)
55
56 def test_field_repr(self):
57 int_field = field(default=1, init=True, repr=False)
58 int_field.name = "id"
59 repr_output = repr(int_field)
60 expected_output = "Field(name='id',type=None," \
61 f"default=1,default_factory={MISSING!r}," \
62 "init=True,repr=False,hash=None," \
63 "compare=True,metadata=mappingproxy({})," \
64 "_field_type=None)"
65
66 self.assertEqual(repr_output, expected_output)
67
Eric V. Smithf0db54a2017-12-04 16:58:55 -050068 def test_named_init_params(self):
69 @dataclass
70 class C:
71 x: int
72
73 o = C(x=32)
74 self.assertEqual(o.x, 32)
75
76 def test_two_fields_one_default(self):
77 @dataclass
78 class C:
79 x: int
80 y: int = 0
81
82 o = C(3)
83 self.assertEqual((o.x, o.y), (3, 0))
84
85 # Non-defaults following defaults.
86 with self.assertRaisesRegex(TypeError,
87 "non-default argument 'y' follows "
88 "default argument"):
89 @dataclass
90 class C:
91 x: int = 0
92 y: int
93
94 # A derived class adds a non-default field after a default one.
95 with self.assertRaisesRegex(TypeError,
96 "non-default argument 'y' follows "
97 "default argument"):
98 @dataclass
99 class B:
100 x: int = 0
101
102 @dataclass
103 class C(B):
104 y: int
105
106 # Override a base class field and add a default to
107 # a field which didn't use to have a default.
108 with self.assertRaisesRegex(TypeError,
109 "non-default argument 'y' follows "
110 "default argument"):
111 @dataclass
112 class B:
113 x: int
114 y: int
115
116 @dataclass
117 class C(B):
118 x: int = 0
119
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500120 def test_overwrite_hash(self):
121 # Test that declaring this class isn't an error. It should
122 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500123 @dataclass(frozen=True)
124 class C:
125 x: int
126 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500127 return 301
128 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500129
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500130 # Test that declaring this class isn't an error. It should
131 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500132 @dataclass(frozen=True)
133 class C:
134 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500135 def __eq__(self, other):
136 return False
137 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500138
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500139 # But this one should generate an exception, because with
140 # unsafe_hash=True, it's an error to have a __hash__ defined.
141 with self.assertRaisesRegex(TypeError,
142 'Cannot overwrite attribute __hash__'):
143 @dataclass(unsafe_hash=True)
144 class C:
145 def __hash__(self):
146 pass
147
148 # Creating this class should not generate an exception,
149 # because even though __hash__ exists before @dataclass is
150 # called, (due to __eq__ being defined), since it's None
151 # that's okay.
152 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500153 class C:
154 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500155 def __eq__(self):
156 pass
157 # The generated hash function works as we'd expect.
158 self.assertEqual(hash(C(10)), hash((10,)))
159
160 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400161 # __hash__ exists and is not None, which it would be if it
162 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500163 with self.assertRaisesRegex(TypeError,
164 'Cannot overwrite attribute __hash__'):
165 @dataclass(unsafe_hash=True)
166 class C:
167 x: int
168 def __eq__(self):
169 pass
170 def __hash__(self):
171 pass
172
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500173 def test_overwrite_fields_in_derived_class(self):
174 # Note that x from C1 replaces x in Base, but the order remains
175 # the same as defined in Base.
176 @dataclass
177 class Base:
178 x: Any = 15.0
179 y: int = 0
180
181 @dataclass
182 class C1(Base):
183 z: int = 10
184 x: int = 15
185
186 o = Base()
187 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
188
189 o = C1()
190 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
191
192 o = C1(x=5)
193 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
194
195 def test_field_named_self(self):
196 @dataclass
197 class C:
198 self: str
199 c=C('foo')
200 self.assertEqual(c.self, 'foo')
201
202 # Make sure the first parameter is not named 'self'.
203 sig = inspect.signature(C.__init__)
204 first = next(iter(sig.parameters))
205 self.assertNotEqual('self', first)
206
207 # But we do use 'self' if no field named self.
208 @dataclass
209 class C:
210 selfx: str
211
212 # Make sure the first parameter is named 'self'.
213 sig = inspect.signature(C.__init__)
214 first = next(iter(sig.parameters))
215 self.assertEqual('self', first)
216
Vadim Pushtaev4d12e4d2018-08-12 14:46:05 +0300217 def test_field_named_object(self):
218 @dataclass
219 class C:
220 object: str
221 c = C('foo')
222 self.assertEqual(c.object, 'foo')
223
224 def test_field_named_object_frozen(self):
225 @dataclass(frozen=True)
226 class C:
227 object: str
228 c = C('foo')
229 self.assertEqual(c.object, 'foo')
230
231 def test_field_named_like_builtin(self):
232 # Attribute names can shadow built-in names
233 # since code generation is used.
234 # Ensure that this is not happening.
235 exclusions = {'None', 'True', 'False'}
236 builtins_names = sorted(
237 b for b in builtins.__dict__.keys()
238 if not b.startswith('__') and b not in exclusions
239 )
240 attributes = [(name, str) for name in builtins_names]
241 C = make_dataclass('C', attributes)
242
243 c = C(*[name for name in builtins_names])
244
245 for name in builtins_names:
246 self.assertEqual(getattr(c, name), name)
247
248 def test_field_named_like_builtin_frozen(self):
249 # Attribute names can shadow built-in names
250 # since code generation is used.
251 # Ensure that this is not happening
252 # for frozen data classes.
253 exclusions = {'None', 'True', 'False'}
254 builtins_names = sorted(
255 b for b in builtins.__dict__.keys()
256 if not b.startswith('__') and b not in exclusions
257 )
258 attributes = [(name, str) for name in builtins_names]
259 C = make_dataclass('C', attributes, frozen=True)
260
261 c = C(*[name for name in builtins_names])
262
263 for name in builtins_names:
264 self.assertEqual(getattr(c, name), name)
265
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500266 def test_0_field_compare(self):
267 # Ensure that order=False is the default.
268 @dataclass
269 class C0:
270 pass
271
272 @dataclass(order=False)
273 class C1:
274 pass
275
276 for cls in [C0, C1]:
277 with self.subTest(cls=cls):
278 self.assertEqual(cls(), cls())
279 for idx, fn in enumerate([lambda a, b: a < b,
280 lambda a, b: a <= b,
281 lambda a, b: a > b,
282 lambda a, b: a >= b]):
283 with self.subTest(idx=idx):
284 with self.assertRaisesRegex(TypeError,
285 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
286 fn(cls(), cls())
287
288 @dataclass(order=True)
289 class C:
290 pass
291 self.assertLessEqual(C(), C())
292 self.assertGreaterEqual(C(), C())
293
294 def test_1_field_compare(self):
295 # Ensure that order=False is the default.
296 @dataclass
297 class C0:
298 x: int
299
300 @dataclass(order=False)
301 class C1:
302 x: int
303
304 for cls in [C0, C1]:
305 with self.subTest(cls=cls):
306 self.assertEqual(cls(1), cls(1))
307 self.assertNotEqual(cls(0), cls(1))
308 for idx, fn in enumerate([lambda a, b: a < b,
309 lambda a, b: a <= b,
310 lambda a, b: a > b,
311 lambda a, b: a >= b]):
312 with self.subTest(idx=idx):
313 with self.assertRaisesRegex(TypeError,
314 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
315 fn(cls(0), cls(0))
316
317 @dataclass(order=True)
318 class C:
319 x: int
320 self.assertLess(C(0), C(1))
321 self.assertLessEqual(C(0), C(1))
322 self.assertLessEqual(C(1), C(1))
323 self.assertGreater(C(1), C(0))
324 self.assertGreaterEqual(C(1), C(0))
325 self.assertGreaterEqual(C(1), C(1))
326
327 def test_simple_compare(self):
328 # Ensure that order=False is the default.
329 @dataclass
330 class C0:
331 x: int
332 y: int
333
334 @dataclass(order=False)
335 class C1:
336 x: int
337 y: int
338
339 for cls in [C0, C1]:
340 with self.subTest(cls=cls):
341 self.assertEqual(cls(0, 0), cls(0, 0))
342 self.assertEqual(cls(1, 2), cls(1, 2))
343 self.assertNotEqual(cls(1, 0), cls(0, 0))
344 self.assertNotEqual(cls(1, 0), cls(1, 1))
345 for idx, fn in enumerate([lambda a, b: a < b,
346 lambda a, b: a <= b,
347 lambda a, b: a > b,
348 lambda a, b: a >= b]):
349 with self.subTest(idx=idx):
350 with self.assertRaisesRegex(TypeError,
351 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
352 fn(cls(0, 0), cls(0, 0))
353
354 @dataclass(order=True)
355 class C:
356 x: int
357 y: int
358
359 for idx, fn in enumerate([lambda a, b: a == b,
360 lambda a, b: a <= b,
361 lambda a, b: a >= b]):
362 with self.subTest(idx=idx):
363 self.assertTrue(fn(C(0, 0), C(0, 0)))
364
365 for idx, fn in enumerate([lambda a, b: a < b,
366 lambda a, b: a <= b,
367 lambda a, b: a != b]):
368 with self.subTest(idx=idx):
369 self.assertTrue(fn(C(0, 0), C(0, 1)))
370 self.assertTrue(fn(C(0, 1), C(1, 0)))
371 self.assertTrue(fn(C(1, 0), C(1, 1)))
372
373 for idx, fn in enumerate([lambda a, b: a > b,
374 lambda a, b: a >= b,
375 lambda a, b: a != b]):
376 with self.subTest(idx=idx):
377 self.assertTrue(fn(C(0, 1), C(0, 0)))
378 self.assertTrue(fn(C(1, 0), C(0, 1)))
379 self.assertTrue(fn(C(1, 1), C(1, 0)))
380
381 def test_compare_subclasses(self):
382 # Comparisons fail for subclasses, even if no fields
383 # are added.
384 @dataclass
385 class B:
386 i: int
387
388 @dataclass
389 class C(B):
390 pass
391
392 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
393 (lambda a, b: a != b, True)]):
394 with self.subTest(idx=idx):
395 self.assertEqual(fn(B(0), C(0)), expected)
396
397 for idx, fn in enumerate([lambda a, b: a < b,
398 lambda a, b: a <= b,
399 lambda a, b: a > b,
400 lambda a, b: a >= b]):
401 with self.subTest(idx=idx):
402 with self.assertRaisesRegex(TypeError,
403 "not supported between instances of 'B' and 'C'"):
404 fn(B(0), C(0))
405
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500407 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500408 for (eq, order, result ) in [
409 (False, False, 'neither'),
410 (False, True, 'exception'),
411 (True, False, 'eq_only'),
412 (True, True, 'both'),
413 ]:
414 with self.subTest(eq=eq, order=order):
415 if result == 'exception':
416 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
417 @dataclass(eq=eq, order=order)
418 class C:
419 pass
420 else:
421 @dataclass(eq=eq, order=order)
422 class C:
423 pass
424
425 if result == 'neither':
426 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500427 self.assertNotIn('__lt__', C.__dict__)
428 self.assertNotIn('__le__', C.__dict__)
429 self.assertNotIn('__gt__', C.__dict__)
430 self.assertNotIn('__ge__', C.__dict__)
431 elif result == 'both':
432 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500433 self.assertIn('__lt__', C.__dict__)
434 self.assertIn('__le__', C.__dict__)
435 self.assertIn('__gt__', C.__dict__)
436 self.assertIn('__ge__', C.__dict__)
437 elif result == 'eq_only':
438 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500439 self.assertNotIn('__lt__', C.__dict__)
440 self.assertNotIn('__le__', C.__dict__)
441 self.assertNotIn('__gt__', C.__dict__)
442 self.assertNotIn('__ge__', C.__dict__)
443 else:
444 assert False, f'unknown result {result!r}'
445
446 def test_field_no_default(self):
447 @dataclass
448 class C:
449 x: int = field()
450
451 self.assertEqual(C(5).x, 5)
452
453 with self.assertRaisesRegex(TypeError,
454 r"__init__\(\) missing 1 required "
455 "positional argument: 'x'"):
456 C()
457
458 def test_field_default(self):
459 default = object()
460 @dataclass
461 class C:
462 x: object = field(default=default)
463
464 self.assertIs(C.x, default)
465 c = C(10)
466 self.assertEqual(c.x, 10)
467
468 # If we delete the instance attribute, we should then see the
469 # class attribute.
470 del c.x
471 self.assertIs(c.x, default)
472
473 self.assertIs(C().x, default)
474
475 def test_not_in_repr(self):
476 @dataclass
477 class C:
478 x: int = field(repr=False)
479 with self.assertRaises(TypeError):
480 C()
481 c = C(10)
482 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
483
484 @dataclass
485 class C:
486 x: int = field(repr=False)
487 y: int
488 c = C(10, 20)
489 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
490
491 def test_not_in_compare(self):
492 @dataclass
493 class C:
494 x: int = 0
495 y: int = field(compare=False, default=4)
496
497 self.assertEqual(C(), C(0, 20))
498 self.assertEqual(C(1, 10), C(1, 20))
499 self.assertNotEqual(C(3), C(4, 10))
500 self.assertNotEqual(C(3, 10), C(4, 10))
501
502 def test_hash_field_rules(self):
503 # Test all 6 cases of:
504 # hash=True/False/None
505 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500506 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500507 (True, False, 'field' ),
508 (True, True, 'field' ),
509 (False, False, 'absent'),
510 (False, True, 'absent'),
511 (None, False, 'absent'),
512 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500513 ]:
514 with self.subTest(hash=hash_, compare=compare):
515 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500516 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500517 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500518
519 if result == 'field':
520 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500521 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500522 elif result == 'absent':
523 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500524 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500525 else:
526 assert False, f'unknown result {result!r}'
527
528 def test_init_false_no_default(self):
529 # If init=False and no default value, then the field won't be
530 # present in the instance.
531 @dataclass
532 class C:
533 x: int = field(init=False)
534
535 self.assertNotIn('x', C().__dict__)
536
537 @dataclass
538 class C:
539 x: int
540 y: int = 0
541 z: int = field(init=False)
542 t: int = 10
543
544 self.assertNotIn('z', C(0).__dict__)
545 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
546
547 def test_class_marker(self):
548 @dataclass
549 class C:
550 x: int
551 y: str = field(init=False, default=None)
552 z: str = field(repr=False)
553
554 the_fields = fields(C)
555 # the_fields is a tuple of 3 items, each value
556 # is in __annotations__.
557 self.assertIsInstance(the_fields, tuple)
558 for f in the_fields:
559 self.assertIs(type(f), Field)
560 self.assertIn(f.name, C.__annotations__)
561
562 self.assertEqual(len(the_fields), 3)
563
564 self.assertEqual(the_fields[0].name, 'x')
565 self.assertEqual(the_fields[0].type, int)
566 self.assertFalse(hasattr(C, 'x'))
567 self.assertTrue (the_fields[0].init)
568 self.assertTrue (the_fields[0].repr)
569 self.assertEqual(the_fields[1].name, 'y')
570 self.assertEqual(the_fields[1].type, str)
571 self.assertIsNone(getattr(C, 'y'))
572 self.assertFalse(the_fields[1].init)
573 self.assertTrue (the_fields[1].repr)
574 self.assertEqual(the_fields[2].name, 'z')
575 self.assertEqual(the_fields[2].type, str)
576 self.assertFalse(hasattr(C, 'z'))
577 self.assertTrue (the_fields[2].init)
578 self.assertFalse(the_fields[2].repr)
579
580 def test_field_order(self):
581 @dataclass
582 class B:
583 a: str = 'B:a'
584 b: str = 'B:b'
585 c: str = 'B:c'
586
587 @dataclass
588 class C(B):
589 b: str = 'C:b'
590
591 self.assertEqual([(f.name, f.default) for f in fields(C)],
592 [('a', 'B:a'),
593 ('b', 'C:b'),
594 ('c', 'B:c')])
595
596 @dataclass
597 class D(B):
598 c: str = 'D:c'
599
600 self.assertEqual([(f.name, f.default) for f in fields(D)],
601 [('a', 'B:a'),
602 ('b', 'B:b'),
603 ('c', 'D:c')])
604
605 @dataclass
606 class E(D):
607 a: str = 'E:a'
608 d: str = 'E:d'
609
610 self.assertEqual([(f.name, f.default) for f in fields(E)],
611 [('a', 'E:a'),
612 ('b', 'B:b'),
613 ('c', 'D:c'),
614 ('d', 'E:d')])
615
616 def test_class_attrs(self):
617 # We only have a class attribute if a default value is
618 # specified, either directly or via a field with a default.
619 default = object()
620 @dataclass
621 class C:
622 x: int
623 y: int = field(repr=False)
624 z: object = default
625 t: int = field(default=100)
626
627 self.assertFalse(hasattr(C, 'x'))
628 self.assertFalse(hasattr(C, 'y'))
629 self.assertIs (C.z, default)
630 self.assertEqual(C.t, 100)
631
632 def test_disallowed_mutable_defaults(self):
633 # For the known types, don't allow mutable default values.
634 for typ, empty, non_empty in [(list, [], [1]),
635 (dict, {}, {0:1}),
636 (set, set(), set([1])),
637 ]:
638 with self.subTest(typ=typ):
639 # Can't use a zero-length value.
640 with self.assertRaisesRegex(ValueError,
641 f'mutable default {typ} for field '
642 'x is not allowed'):
643 @dataclass
644 class Point:
645 x: typ = empty
646
647
648 # Nor a non-zero-length value
649 with self.assertRaisesRegex(ValueError,
650 f'mutable default {typ} for field '
651 'y is not allowed'):
652 @dataclass
653 class Point:
654 y: typ = non_empty
655
656 # Check subtypes also fail.
657 class Subclass(typ): pass
658
659 with self.assertRaisesRegex(ValueError,
660 f"mutable default .*Subclass'>"
661 ' for field z is not allowed'
662 ):
663 @dataclass
664 class Point:
665 z: typ = Subclass()
666
667 # Because this is a ClassVar, it can be mutable.
668 @dataclass
669 class C:
670 z: ClassVar[typ] = typ()
671
672 # Because this is a ClassVar, it can be mutable.
673 @dataclass
674 class C:
675 x: ClassVar[typ] = Subclass()
676
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500677 def test_deliberately_mutable_defaults(self):
678 # If a mutable default isn't in the known list of
679 # (list, dict, set), then it's okay.
680 class Mutable:
681 def __init__(self):
682 self.l = []
683
684 @dataclass
685 class C:
686 x: Mutable
687
688 # These 2 instances will share this value of x.
689 lst = Mutable()
690 o1 = C(lst)
691 o2 = C(lst)
692 self.assertEqual(o1, o2)
693 o1.x.l.extend([1, 2])
694 self.assertEqual(o1, o2)
695 self.assertEqual(o1.x.l, [1, 2])
696 self.assertIs(o1.x, o2.x)
697
698 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400699 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500700 @dataclass()
701 class C:
702 x: int
703
704 self.assertEqual(C(42).x, 42)
705
706 def test_not_tuple(self):
707 # Make sure we can't be compared to a tuple.
708 @dataclass
709 class Point:
710 x: int
711 y: int
712 self.assertNotEqual(Point(1, 2), (1, 2))
713
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400714 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500715 @dataclass
716 class C:
717 x: int
718 y: int
719 self.assertNotEqual(Point(1, 3), C(1, 3))
720
Windson yangbe372d72019-04-23 02:45:34 +0800721 def test_not_other_dataclass(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500722 # Test that some of the problems with namedtuple don't happen
723 # here.
724 @dataclass
725 class Point3D:
726 x: int
727 y: int
728 z: int
729
730 @dataclass
731 class Date:
732 year: int
733 month: int
734 day: int
735
736 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
737 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
738
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400739 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200740 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500741 x, y, z = Point3D(4, 5, 6)
742
Eric V. Smith7c99e932018-01-28 19:18:55 -0500743 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500744 # equal.
745 @dataclass
746 class Point3Dv1:
747 x: int = 0
748 y: int = 0
749 z: int = 0
750 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
751
752 def test_function_annotations(self):
753 # Some dummy class and instance to use as a default.
754 class F:
755 pass
756 f = F()
757
758 def validate_class(cls):
759 # First, check __annotations__, even though they're not
760 # function annotations.
761 self.assertEqual(cls.__annotations__['i'], int)
762 self.assertEqual(cls.__annotations__['j'], str)
763 self.assertEqual(cls.__annotations__['k'], F)
764 self.assertEqual(cls.__annotations__['l'], float)
765 self.assertEqual(cls.__annotations__['z'], complex)
766
767 # Verify __init__.
768
769 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400770 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertIs(signature.return_annotation, None)
772
773 # Check each parameter.
774 params = iter(signature.parameters.values())
775 param = next(params)
776 # This is testing an internal name, and probably shouldn't be tested.
777 self.assertEqual(param.name, 'self')
778 param = next(params)
779 self.assertEqual(param.name, 'i')
780 self.assertIs (param.annotation, int)
781 self.assertEqual(param.default, inspect.Parameter.empty)
782 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
783 param = next(params)
784 self.assertEqual(param.name, 'j')
785 self.assertIs (param.annotation, str)
786 self.assertEqual(param.default, inspect.Parameter.empty)
787 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
788 param = next(params)
789 self.assertEqual(param.name, 'k')
790 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400791 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
793 param = next(params)
794 self.assertEqual(param.name, 'l')
795 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400796 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500797 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
798 self.assertRaises(StopIteration, next, params)
799
800
801 @dataclass
802 class C:
803 i: int
804 j: str
805 k: F = f
806 l: float=field(default=None)
807 z: complex=field(default=3+4j, init=False)
808
809 validate_class(C)
810
811 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500812 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500813 class C:
814 i: int
815 j: str
816 k: F = f
817 l: float=field(default=None)
818 z: complex=field(default=3+4j, init=False)
819
820 validate_class(C)
821
Eric V. Smith03220fd2017-12-29 13:59:58 -0500822 def test_missing_default(self):
823 # Test that MISSING works the same as a default not being
824 # specified.
825 @dataclass
826 class C:
827 x: int=field(default=MISSING)
828 with self.assertRaisesRegex(TypeError,
829 r'__init__\(\) missing 1 required '
830 'positional argument'):
831 C()
832 self.assertNotIn('x', C.__dict__)
833
834 @dataclass
835 class D:
836 x: int
837 with self.assertRaisesRegex(TypeError,
838 r'__init__\(\) missing 1 required '
839 'positional argument'):
840 D()
841 self.assertNotIn('x', D.__dict__)
842
843 def test_missing_default_factory(self):
844 # Test that MISSING works the same as a default factory not
845 # being specified (which is really the same as a default not
846 # being specified, too).
847 @dataclass
848 class C:
849 x: int=field(default_factory=MISSING)
850 with self.assertRaisesRegex(TypeError,
851 r'__init__\(\) missing 1 required '
852 'positional argument'):
853 C()
854 self.assertNotIn('x', C.__dict__)
855
856 @dataclass
857 class D:
858 x: int=field(default=MISSING, default_factory=MISSING)
859 with self.assertRaisesRegex(TypeError,
860 r'__init__\(\) missing 1 required '
861 'positional argument'):
862 D()
863 self.assertNotIn('x', D.__dict__)
864
865 def test_missing_repr(self):
866 self.assertIn('MISSING_TYPE object', repr(MISSING))
867
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500868 def test_dont_include_other_annotations(self):
869 @dataclass
870 class C:
871 i: int
872 def foo(self) -> int:
873 return 4
874 @property
875 def bar(self) -> int:
876 return 5
877 self.assertEqual(list(C.__annotations__), ['i'])
878 self.assertEqual(C(10).foo(), 4)
879 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400880 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500881
882 def test_post_init(self):
883 # Just make sure it gets called
884 @dataclass
885 class C:
886 def __post_init__(self):
887 raise CustomError()
888 with self.assertRaises(CustomError):
889 C()
890
891 @dataclass
892 class C:
893 i: int = 10
894 def __post_init__(self):
895 if self.i == 10:
896 raise CustomError()
897 with self.assertRaises(CustomError):
898 C()
899 # post-init gets called, but doesn't raise. This is just
900 # checking that self is used correctly.
901 C(5)
902
903 # If there's not an __init__, then post-init won't get called.
904 @dataclass(init=False)
905 class C:
906 def __post_init__(self):
907 raise CustomError()
908 # Creating the class won't raise
909 C()
910
911 @dataclass
912 class C:
913 x: int = 0
914 def __post_init__(self):
915 self.x *= 2
916 self.assertEqual(C().x, 0)
917 self.assertEqual(C(2).x, 4)
918
Mike53f7a7c2017-12-14 14:04:53 +0300919 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500920 # attributes.
921 @dataclass(frozen=True)
922 class C:
923 x: int = 0
924 def __post_init__(self):
925 self.x *= 2
926 with self.assertRaises(FrozenInstanceError):
927 C()
928
929 def test_post_init_super(self):
930 # Make sure super() post-init isn't called by default.
931 class B:
932 def __post_init__(self):
933 raise CustomError()
934
935 @dataclass
936 class C(B):
937 def __post_init__(self):
938 self.x = 5
939
940 self.assertEqual(C().x, 5)
941
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400942 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500943 @dataclass
944 class C(B):
945 def __post_init__(self):
946 super().__post_init__()
947
948 with self.assertRaises(CustomError):
949 C()
950
951 # Make sure post-init is called, even if not defined in our
952 # class.
953 @dataclass
954 class C(B):
955 pass
956
957 with self.assertRaises(CustomError):
958 C()
959
960 def test_post_init_staticmethod(self):
961 flag = False
962 @dataclass
963 class C:
964 x: int
965 y: int
966 @staticmethod
967 def __post_init__():
968 nonlocal flag
969 flag = True
970
971 self.assertFalse(flag)
972 c = C(3, 4)
973 self.assertEqual((c.x, c.y), (3, 4))
974 self.assertTrue(flag)
975
976 def test_post_init_classmethod(self):
977 @dataclass
978 class C:
979 flag = False
980 x: int
981 y: int
982 @classmethod
983 def __post_init__(cls):
984 cls.flag = True
985
986 self.assertFalse(C.flag)
987 c = C(3, 4)
988 self.assertEqual((c.x, c.y), (3, 4))
989 self.assertTrue(C.flag)
990
991 def test_class_var(self):
992 # Make sure ClassVars are ignored in __init__, __repr__, etc.
993 @dataclass
994 class C:
995 x: int
996 y: int = 10
997 z: ClassVar[int] = 1000
998 w: ClassVar[int] = 2000
999 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001000 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001001
1002 c = C(5)
1003 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001004 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001005 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001006 self.assertEqual(c.z, 1000)
1007 self.assertEqual(c.w, 2000)
1008 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001009 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001010 C.z += 1
1011 self.assertEqual(c.z, 1001)
1012 c = C(20)
1013 self.assertEqual((c.x, c.y), (20, 10))
1014 self.assertEqual(c.z, 1001)
1015 self.assertEqual(c.w, 2000)
1016 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04001017 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001018
1019 def test_class_var_no_default(self):
1020 # If a ClassVar has no default value, it should not be set on the class.
1021 @dataclass
1022 class C:
1023 x: ClassVar[int]
1024
1025 self.assertNotIn('x', C.__dict__)
1026
1027 def test_class_var_default_factory(self):
1028 # It makes no sense for a ClassVar to have a default factory. When
1029 # would it be called? Call it yourself, since it's class-wide.
1030 with self.assertRaisesRegex(TypeError,
1031 'cannot have a default factory'):
1032 @dataclass
1033 class C:
1034 x: ClassVar[int] = field(default_factory=int)
1035
1036 self.assertNotIn('x', C.__dict__)
1037
1038 def test_class_var_with_default(self):
1039 # If a ClassVar has a default value, it should be set on the class.
1040 @dataclass
1041 class C:
1042 x: ClassVar[int] = 10
1043 self.assertEqual(C.x, 10)
1044
1045 @dataclass
1046 class C:
1047 x: ClassVar[int] = field(default=10)
1048 self.assertEqual(C.x, 10)
1049
1050 def test_class_var_frozen(self):
1051 # Make sure ClassVars work even if we're frozen.
1052 @dataclass(frozen=True)
1053 class C:
1054 x: int
1055 y: int = 10
1056 z: ClassVar[int] = 1000
1057 w: ClassVar[int] = 2000
1058 t: ClassVar[int] = 3000
1059
1060 c = C(5)
1061 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1062 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1063 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1064 self.assertEqual(c.z, 1000)
1065 self.assertEqual(c.w, 2000)
1066 self.assertEqual(c.t, 3000)
1067 # We can still modify the ClassVar, it's only instances that are
1068 # frozen.
1069 C.z += 1
1070 self.assertEqual(c.z, 1001)
1071 c = C(20)
1072 self.assertEqual((c.x, c.y), (20, 10))
1073 self.assertEqual(c.z, 1001)
1074 self.assertEqual(c.w, 2000)
1075 self.assertEqual(c.t, 3000)
1076
1077 def test_init_var_no_default(self):
1078 # If an InitVar has no default value, it should not be set on the class.
1079 @dataclass
1080 class C:
1081 x: InitVar[int]
1082
1083 self.assertNotIn('x', C.__dict__)
1084
1085 def test_init_var_default_factory(self):
1086 # It makes no sense for an InitVar to have a default factory. When
1087 # would it be called? Call it yourself, since it's class-wide.
1088 with self.assertRaisesRegex(TypeError,
1089 'cannot have a default factory'):
1090 @dataclass
1091 class C:
1092 x: InitVar[int] = field(default_factory=int)
1093
1094 self.assertNotIn('x', C.__dict__)
1095
1096 def test_init_var_with_default(self):
1097 # If an InitVar has a default value, it should be set on the class.
1098 @dataclass
1099 class C:
1100 x: InitVar[int] = 10
1101 self.assertEqual(C.x, 10)
1102
1103 @dataclass
1104 class C:
1105 x: InitVar[int] = field(default=10)
1106 self.assertEqual(C.x, 10)
1107
1108 def test_init_var(self):
1109 @dataclass
1110 class C:
1111 x: int = None
1112 init_param: InitVar[int] = None
1113
1114 def __post_init__(self, init_param):
1115 if self.x is None:
1116 self.x = init_param*2
1117
1118 c = C(init_param=10)
1119 self.assertEqual(c.x, 20)
1120
Augusto Hack01ee12b2019-06-02 23:14:48 -03001121 def test_init_var_preserve_type(self):
1122 self.assertEqual(InitVar[int].type, int)
1123
1124 # Make sure the repr is correct.
1125 self.assertEqual(repr(InitVar[int]), 'dataclasses.InitVar[int]')
Samuel Colvin793cb852019-10-13 12:45:36 +01001126 self.assertEqual(repr(InitVar[List[int]]),
1127 'dataclasses.InitVar[typing.List[int]]')
Augusto Hack01ee12b2019-06-02 23:14:48 -03001128
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001129 def test_init_var_inheritance(self):
1130 # Note that this deliberately tests that a dataclass need not
1131 # have a __post_init__ function if it has an InitVar field.
1132 # It could just be used in a derived class, as shown here.
1133 @dataclass
1134 class Base:
1135 x: int
1136 init_base: InitVar[int]
1137
1138 # We can instantiate by passing the InitVar, even though
1139 # it's not used.
1140 b = Base(0, 10)
1141 self.assertEqual(vars(b), {'x': 0})
1142
1143 @dataclass
1144 class C(Base):
1145 y: int
1146 init_derived: InitVar[int]
1147
1148 def __post_init__(self, init_base, init_derived):
1149 self.x = self.x + init_base
1150 self.y = self.y + init_derived
1151
1152 c = C(10, 11, 50, 51)
1153 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1154
1155 def test_default_factory(self):
1156 # Test a factory that returns a new list.
1157 @dataclass
1158 class C:
1159 x: int
1160 y: list = field(default_factory=list)
1161
1162 c0 = C(3)
1163 c1 = C(3)
1164 self.assertEqual(c0.x, 3)
1165 self.assertEqual(c0.y, [])
1166 self.assertEqual(c0, c1)
1167 self.assertIsNot(c0.y, c1.y)
1168 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1169
1170 # Test a factory that returns a shared list.
1171 l = []
1172 @dataclass
1173 class C:
1174 x: int
1175 y: list = field(default_factory=lambda: l)
1176
1177 c0 = C(3)
1178 c1 = C(3)
1179 self.assertEqual(c0.x, 3)
1180 self.assertEqual(c0.y, [])
1181 self.assertEqual(c0, c1)
1182 self.assertIs(c0.y, c1.y)
1183 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1184
1185 # Test various other field flags.
1186 # repr
1187 @dataclass
1188 class C:
1189 x: list = field(default_factory=list, repr=False)
1190 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1191 self.assertEqual(C().x, [])
1192
1193 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001194 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001195 class C:
1196 x: list = field(default_factory=list, hash=False)
1197 self.assertEqual(astuple(C()), ([],))
1198 self.assertEqual(hash(C()), hash(()))
1199
1200 # init (see also test_default_factory_with_no_init)
1201 @dataclass
1202 class C:
1203 x: list = field(default_factory=list, init=False)
1204 self.assertEqual(astuple(C()), ([],))
1205
1206 # compare
1207 @dataclass
1208 class C:
1209 x: list = field(default_factory=list, compare=False)
1210 self.assertEqual(C(), C([1]))
1211
1212 def test_default_factory_with_no_init(self):
1213 # We need a factory with a side effect.
1214 factory = Mock()
1215
1216 @dataclass
1217 class C:
1218 x: list = field(default_factory=factory, init=False)
1219
1220 # Make sure the default factory is called for each new instance.
1221 C().x
1222 self.assertEqual(factory.call_count, 1)
1223 C().x
1224 self.assertEqual(factory.call_count, 2)
1225
1226 def test_default_factory_not_called_if_value_given(self):
1227 # We need a factory that we can test if it's been called.
1228 factory = Mock()
1229
1230 @dataclass
1231 class C:
1232 x: int = field(default_factory=factory)
1233
1234 # Make sure that if a field has a default factory function,
1235 # it's not called if a value is specified.
1236 C().x
1237 self.assertEqual(factory.call_count, 1)
1238 self.assertEqual(C(10).x, 10)
1239 self.assertEqual(factory.call_count, 1)
1240 C().x
1241 self.assertEqual(factory.call_count, 2)
1242
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001243 def test_default_factory_derived(self):
1244 # See bpo-32896.
1245 @dataclass
1246 class Foo:
1247 x: dict = field(default_factory=dict)
1248
1249 @dataclass
1250 class Bar(Foo):
1251 y: int = 1
1252
1253 self.assertEqual(Foo().x, {})
1254 self.assertEqual(Bar().x, {})
1255 self.assertEqual(Bar().y, 1)
1256
1257 @dataclass
1258 class Baz(Foo):
1259 pass
1260 self.assertEqual(Baz().x, {})
1261
1262 def test_intermediate_non_dataclass(self):
1263 # Test that an intermediate class that defines
1264 # annotations does not define fields.
1265
1266 @dataclass
1267 class A:
1268 x: int
1269
1270 class B(A):
1271 y: int
1272
1273 @dataclass
1274 class C(B):
1275 z: int
1276
1277 c = C(1, 3)
1278 self.assertEqual((c.x, c.z), (1, 3))
1279
1280 # .y was not initialized.
1281 with self.assertRaisesRegex(AttributeError,
1282 'object has no attribute'):
1283 c.y
1284
1285 # And if we again derive a non-dataclass, no fields are added.
1286 class D(C):
1287 t: int
1288 d = D(4, 5)
1289 self.assertEqual((d.x, d.z), (4, 5))
1290
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001291 def test_classvar_default_factory(self):
1292 # It's an error for a ClassVar to have a factory function.
1293 with self.assertRaisesRegex(TypeError,
1294 'cannot have a default factory'):
1295 @dataclass
1296 class C:
1297 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001298
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001299 def test_is_dataclass(self):
1300 class NotDataClass:
1301 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001302
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001303 self.assertFalse(is_dataclass(0))
1304 self.assertFalse(is_dataclass(int))
1305 self.assertFalse(is_dataclass(NotDataClass))
1306 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001307
1308 @dataclass
1309 class C:
1310 x: int
1311
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001312 @dataclass
1313 class D:
1314 d: C
1315 e: int
1316
1317 c = C(10)
1318 d = D(c, 4)
1319
1320 self.assertTrue(is_dataclass(C))
1321 self.assertTrue(is_dataclass(c))
1322 self.assertFalse(is_dataclass(c.x))
1323 self.assertTrue(is_dataclass(d.d))
1324 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001325
Eric V. Smithb0f4dab2019-08-20 01:40:28 -04001326 def test_is_dataclass_when_getattr_always_returns(self):
1327 # See bpo-37868.
1328 class A:
1329 def __getattr__(self, key):
1330 return 0
1331 self.assertFalse(is_dataclass(A))
1332 a = A()
1333
1334 # Also test for an instance attribute.
1335 class B:
1336 pass
1337 b = B()
1338 b.__dataclass_fields__ = []
1339
1340 for obj in a, b:
1341 with self.subTest(obj=obj):
1342 self.assertFalse(is_dataclass(obj))
1343
1344 # Indirect tests for _is_dataclass_instance().
1345 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1346 asdict(obj)
1347 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1348 astuple(obj)
1349 with self.assertRaisesRegex(TypeError, 'should be called on dataclass instances'):
1350 replace(obj, x=0)
1351
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001352 def test_helper_fields_with_class_instance(self):
1353 # Check that we can call fields() on either a class or instance,
1354 # and get back the same thing.
1355 @dataclass
1356 class C:
1357 x: int
1358 y: float
1359
1360 self.assertEqual(fields(C), fields(C(0, 0.0)))
1361
1362 def test_helper_fields_exception(self):
1363 # Check that TypeError is raised if not passed a dataclass or
1364 # instance.
1365 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1366 fields(0)
1367
1368 class C: pass
1369 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1370 fields(C)
1371 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1372 fields(C())
1373
1374 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001375 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001376 @dataclass
1377 class C:
1378 x: int
1379 y: int
1380 c = C(1, 2)
1381
1382 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1383 self.assertEqual(asdict(c), asdict(c))
1384 self.assertIsNot(asdict(c), asdict(c))
1385 c.x = 42
1386 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1387 self.assertIs(type(asdict(c)), dict)
1388
1389 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001390 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001391 @dataclass
1392 class C:
1393 x: int
1394 y: int
1395 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1396 asdict(C)
1397 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1398 asdict(int)
1399
1400 def test_helper_asdict_copy_values(self):
1401 @dataclass
1402 class C:
1403 x: int
1404 y: List[int] = field(default_factory=list)
1405 initial = []
1406 c = C(1, initial)
1407 d = asdict(c)
1408 self.assertEqual(d['y'], initial)
1409 self.assertIsNot(d['y'], initial)
1410 c = C(1)
1411 d = asdict(c)
1412 d['y'].append(1)
1413 self.assertEqual(c.y, [])
1414
1415 def test_helper_asdict_nested(self):
1416 @dataclass
1417 class UserId:
1418 token: int
1419 group: int
1420 @dataclass
1421 class User:
1422 name: str
1423 id: UserId
1424 u = User('Joe', UserId(123, 1))
1425 d = asdict(u)
1426 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1427 self.assertIsNot(asdict(u), asdict(u))
1428 u.id.group = 2
1429 self.assertEqual(asdict(u), {'name': 'Joe',
1430 'id': {'token': 123, 'group': 2}})
1431
1432 def test_helper_asdict_builtin_containers(self):
1433 @dataclass
1434 class User:
1435 name: str
1436 id: int
1437 @dataclass
1438 class GroupList:
1439 id: int
1440 users: List[User]
1441 @dataclass
1442 class GroupTuple:
1443 id: int
1444 users: Tuple[User, ...]
1445 @dataclass
1446 class GroupDict:
1447 id: int
1448 users: Dict[str, User]
1449 a = User('Alice', 1)
1450 b = User('Bob', 2)
1451 gl = GroupList(0, [a, b])
1452 gt = GroupTuple(0, (a, b))
1453 gd = GroupDict(0, {'first': a, 'second': b})
1454 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1455 {'name': 'Bob', 'id': 2}]})
1456 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1457 {'name': 'Bob', 'id': 2})})
1458 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1459 'second': {'name': 'Bob', 'id': 2}}})
1460
Windson yangbe372d72019-04-23 02:45:34 +08001461 def test_helper_asdict_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001462 @dataclass
1463 class Child:
1464 d: object
1465
1466 @dataclass
1467 class Parent:
1468 child: Child
1469
1470 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1471 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1472
1473 def test_helper_asdict_factory(self):
1474 @dataclass
1475 class C:
1476 x: int
1477 y: int
1478 c = C(1, 2)
1479 d = asdict(c, dict_factory=OrderedDict)
1480 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1481 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1482 c.x = 42
1483 d = asdict(c, dict_factory=OrderedDict)
1484 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1485 self.assertIs(type(d), OrderedDict)
1486
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001487 def test_helper_asdict_namedtuple(self):
1488 T = namedtuple('T', 'a b c')
1489 @dataclass
1490 class C:
1491 x: str
1492 y: T
1493 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1494
1495 d = asdict(c)
1496 self.assertEqual(d, {'x': 'outer',
1497 'y': T(1,
1498 {'x': 'inner',
1499 'y': T(11, 12, 13)},
1500 2),
1501 }
1502 )
1503
1504 # Now with a dict_factory. OrderedDict is convenient, but
1505 # since it compares to dicts, we also need to have separate
1506 # assertIs tests.
1507 d = asdict(c, dict_factory=OrderedDict)
1508 self.assertEqual(d, {'x': 'outer',
1509 'y': T(1,
1510 {'x': 'inner',
1511 'y': T(11, 12, 13)},
1512 2),
1513 }
1514 )
1515
penguindustin96466302019-05-06 14:57:17 -04001516 # Make sure that the returned dicts are actually OrderedDicts.
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001517 self.assertIs(type(d), OrderedDict)
1518 self.assertIs(type(d['y'][1]), OrderedDict)
1519
1520 def test_helper_asdict_namedtuple_key(self):
1521 # Ensure that a field that contains a dict which has a
1522 # namedtuple as a key works with asdict().
1523
1524 @dataclass
1525 class C:
1526 f: dict
1527 T = namedtuple('T', 'a')
1528
1529 c = C({T('an a'): 0})
1530
1531 self.assertEqual(asdict(c), {'f': {T(a='an a'): 0}})
1532
1533 def test_helper_asdict_namedtuple_derived(self):
1534 class T(namedtuple('Tbase', 'a')):
1535 def my_a(self):
1536 return self.a
1537
1538 @dataclass
1539 class C:
1540 f: T
1541
1542 t = T(6)
1543 c = C(t)
1544
1545 d = asdict(c)
1546 self.assertEqual(d, {'f': T(a=6)})
1547 # Make sure that t has been copied, not used directly.
1548 self.assertIsNot(d['f'], t)
1549 self.assertEqual(d['f'].my_a(), 6)
1550
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001551 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001552 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001553 @dataclass
1554 class C:
1555 x: int
1556 y: int = 0
1557 c = C(1)
1558
1559 self.assertEqual(astuple(c), (1, 0))
1560 self.assertEqual(astuple(c), astuple(c))
1561 self.assertIsNot(astuple(c), astuple(c))
1562 c.y = 42
1563 self.assertEqual(astuple(c), (1, 42))
1564 self.assertIs(type(astuple(c)), tuple)
1565
1566 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001567 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001568 @dataclass
1569 class C:
1570 x: int
1571 y: int
1572 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1573 astuple(C)
1574 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1575 astuple(int)
1576
1577 def test_helper_astuple_copy_values(self):
1578 @dataclass
1579 class C:
1580 x: int
1581 y: List[int] = field(default_factory=list)
1582 initial = []
1583 c = C(1, initial)
1584 t = astuple(c)
1585 self.assertEqual(t[1], initial)
1586 self.assertIsNot(t[1], initial)
1587 c = C(1)
1588 t = astuple(c)
1589 t[1].append(1)
1590 self.assertEqual(c.y, [])
1591
1592 def test_helper_astuple_nested(self):
1593 @dataclass
1594 class UserId:
1595 token: int
1596 group: int
1597 @dataclass
1598 class User:
1599 name: str
1600 id: UserId
1601 u = User('Joe', UserId(123, 1))
1602 t = astuple(u)
1603 self.assertEqual(t, ('Joe', (123, 1)))
1604 self.assertIsNot(astuple(u), astuple(u))
1605 u.id.group = 2
1606 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1607
1608 def test_helper_astuple_builtin_containers(self):
1609 @dataclass
1610 class User:
1611 name: str
1612 id: int
1613 @dataclass
1614 class GroupList:
1615 id: int
1616 users: List[User]
1617 @dataclass
1618 class GroupTuple:
1619 id: int
1620 users: Tuple[User, ...]
1621 @dataclass
1622 class GroupDict:
1623 id: int
1624 users: Dict[str, User]
1625 a = User('Alice', 1)
1626 b = User('Bob', 2)
1627 gl = GroupList(0, [a, b])
1628 gt = GroupTuple(0, (a, b))
1629 gd = GroupDict(0, {'first': a, 'second': b})
1630 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1631 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1632 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1633
Windson yangbe372d72019-04-23 02:45:34 +08001634 def test_helper_astuple_builtin_object_containers(self):
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001635 @dataclass
1636 class Child:
1637 d: object
1638
1639 @dataclass
1640 class Parent:
1641 child: Child
1642
1643 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1644 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1645
1646 def test_helper_astuple_factory(self):
1647 @dataclass
1648 class C:
1649 x: int
1650 y: int
1651 NT = namedtuple('NT', 'x y')
1652 def nt(lst):
1653 return NT(*lst)
1654 c = C(1, 2)
1655 t = astuple(c, tuple_factory=nt)
1656 self.assertEqual(t, NT(1, 2))
1657 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1658 c.x = 42
1659 t = astuple(c, tuple_factory=nt)
1660 self.assertEqual(t, NT(42, 2))
1661 self.assertIs(type(t), NT)
1662
Eric V. Smith9b9d97d2018-09-14 11:32:16 -04001663 def test_helper_astuple_namedtuple(self):
1664 T = namedtuple('T', 'a b c')
1665 @dataclass
1666 class C:
1667 x: str
1668 y: T
1669 c = C('outer', T(1, C('inner', T(11, 12, 13)), 2))
1670
1671 t = astuple(c)
1672 self.assertEqual(t, ('outer', T(1, ('inner', (11, 12, 13)), 2)))
1673
1674 # Now, using a tuple_factory. list is convenient here.
1675 t = astuple(c, tuple_factory=list)
1676 self.assertEqual(t, ['outer', T(1, ['inner', T(11, 12, 13)], 2)])
1677
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001678 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001679 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001680 }
1681
1682 # Create the class.
1683 cls = type('C', (), cls_dict)
1684
1685 # Make it a dataclass.
1686 cls1 = dataclass(cls)
1687
1688 self.assertEqual(cls1, cls)
1689 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1690
1691 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001692 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001693 'y': field(default=5),
1694 }
1695
1696 # Create the class.
1697 cls = type('C', (), cls_dict)
1698
1699 # Make it a dataclass.
1700 cls1 = dataclass(cls)
1701
1702 self.assertEqual(cls1, cls)
1703 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1704
1705 def test_init_in_order(self):
1706 @dataclass
1707 class C:
1708 a: int
1709 b: int = field()
1710 c: list = field(default_factory=list, init=False)
1711 d: list = field(default_factory=list)
1712 e: int = field(default=4, init=False)
1713 f: int = 4
1714
1715 calls = []
1716 def setattr(self, name, value):
1717 calls.append((name, value))
1718
1719 C.__setattr__ = setattr
1720 c = C(0, 1)
1721 self.assertEqual(('a', 0), calls[0])
1722 self.assertEqual(('b', 1), calls[1])
1723 self.assertEqual(('c', []), calls[2])
1724 self.assertEqual(('d', []), calls[3])
1725 self.assertNotIn(('e', 4), calls)
1726 self.assertEqual(('f', 4), calls[4])
1727
1728 def test_items_in_dicts(self):
1729 @dataclass
1730 class C:
1731 a: int
1732 b: list = field(default_factory=list, init=False)
1733 c: list = field(default_factory=list)
1734 d: int = field(default=4, init=False)
1735 e: int = 0
1736
1737 c = C(0)
1738 # Class dict
1739 self.assertNotIn('a', C.__dict__)
1740 self.assertNotIn('b', C.__dict__)
1741 self.assertNotIn('c', C.__dict__)
1742 self.assertIn('d', C.__dict__)
1743 self.assertEqual(C.d, 4)
1744 self.assertIn('e', C.__dict__)
1745 self.assertEqual(C.e, 0)
1746 # Instance dict
1747 self.assertIn('a', c.__dict__)
1748 self.assertEqual(c.a, 0)
1749 self.assertIn('b', c.__dict__)
1750 self.assertEqual(c.b, [])
1751 self.assertIn('c', c.__dict__)
1752 self.assertEqual(c.c, [])
1753 self.assertNotIn('d', c.__dict__)
1754 self.assertIn('e', c.__dict__)
1755 self.assertEqual(c.e, 0)
1756
1757 def test_alternate_classmethod_constructor(self):
1758 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001759 # alternate constructor. This is mostly an example to show
1760 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001761 @dataclass
1762 class C:
1763 x: int
1764 @classmethod
1765 def from_file(cls, filename):
1766 # In a real example, create a new instance
1767 # and populate 'x' from contents of a file.
1768 value_in_file = 20
1769 return cls(value_in_file)
1770
1771 self.assertEqual(C.from_file('filename').x, 20)
1772
1773 def test_field_metadata_default(self):
1774 # Make sure the default metadata is read-only and of
1775 # zero length.
1776 @dataclass
1777 class C:
1778 i: int
1779
1780 self.assertFalse(fields(C)[0].metadata)
1781 self.assertEqual(len(fields(C)[0].metadata), 0)
1782 with self.assertRaisesRegex(TypeError,
1783 'does not support item assignment'):
1784 fields(C)[0].metadata['test'] = 3
1785
1786 def test_field_metadata_mapping(self):
1787 # Make sure only a mapping can be passed as metadata
1788 # zero length.
1789 with self.assertRaises(TypeError):
1790 @dataclass
1791 class C:
1792 i: int = field(metadata=0)
1793
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001794 # Make sure an empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001795 d = {}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001796 @dataclass
1797 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001798 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001799 self.assertFalse(fields(C)[0].metadata)
1800 self.assertEqual(len(fields(C)[0].metadata), 0)
Christopher Huntb01786c2019-02-12 06:50:49 -05001801 # Update should work (see bpo-35960).
1802 d['foo'] = 1
1803 self.assertEqual(len(fields(C)[0].metadata), 1)
1804 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001805 with self.assertRaisesRegex(TypeError,
1806 'does not support item assignment'):
1807 fields(C)[0].metadata['test'] = 3
1808
1809 # Make sure a non-empty dict works.
Christopher Huntb01786c2019-02-12 06:50:49 -05001810 d = {'test': 10, 'bar': '42', 3: 'three'}
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001811 @dataclass
1812 class C:
Christopher Huntb01786c2019-02-12 06:50:49 -05001813 i: int = field(metadata=d)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001814 self.assertEqual(len(fields(C)[0].metadata), 3)
1815 self.assertEqual(fields(C)[0].metadata['test'], 10)
1816 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1817 self.assertEqual(fields(C)[0].metadata[3], 'three')
Christopher Huntb01786c2019-02-12 06:50:49 -05001818 # Update should work.
1819 d['foo'] = 1
1820 self.assertEqual(len(fields(C)[0].metadata), 4)
1821 self.assertEqual(fields(C)[0].metadata['foo'], 1)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001822 with self.assertRaises(KeyError):
1823 # Non-existent key.
1824 fields(C)[0].metadata['baz']
1825 with self.assertRaisesRegex(TypeError,
1826 'does not support item assignment'):
1827 fields(C)[0].metadata['test'] = 3
1828
1829 def test_field_metadata_custom_mapping(self):
1830 # Try a custom mapping.
1831 class SimpleNameSpace:
1832 def __init__(self, **kw):
1833 self.__dict__.update(kw)
1834
1835 def __getitem__(self, item):
1836 if item == 'xyzzy':
1837 return 'plugh'
1838 return getattr(self, item)
1839
1840 def __len__(self):
1841 return self.__dict__.__len__()
1842
1843 @dataclass
1844 class C:
1845 i: int = field(metadata=SimpleNameSpace(a=10))
1846
1847 self.assertEqual(len(fields(C)[0].metadata), 1)
1848 self.assertEqual(fields(C)[0].metadata['a'], 10)
1849 with self.assertRaises(AttributeError):
1850 fields(C)[0].metadata['b']
1851 # Make sure we're still talking to our custom mapping.
1852 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1853
1854 def test_generic_dataclasses(self):
1855 T = TypeVar('T')
1856
1857 @dataclass
1858 class LabeledBox(Generic[T]):
1859 content: T
1860 label: str = '<unknown>'
1861
1862 box = LabeledBox(42)
1863 self.assertEqual(box.content, 42)
1864 self.assertEqual(box.label, '<unknown>')
1865
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001866 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001867 Alias = List[LabeledBox[int]]
1868
1869 def test_generic_extending(self):
1870 S = TypeVar('S')
1871 T = TypeVar('T')
1872
1873 @dataclass
1874 class Base(Generic[T, S]):
1875 x: T
1876 y: S
1877
1878 @dataclass
1879 class DataDerived(Base[int, T]):
1880 new_field: str
1881 Alias = DataDerived[str]
1882 c = Alias(0, 'test1', 'test2')
1883 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1884
1885 class NonDataDerived(Base[int, T]):
1886 def new_method(self):
1887 return self.y
1888 Alias = NonDataDerived[float]
1889 c = Alias(10, 1.0)
1890 self.assertEqual(c.new_method(), 1.0)
1891
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001892 def test_generic_dynamic(self):
1893 T = TypeVar('T')
1894
1895 @dataclass
1896 class Parent(Generic[T]):
1897 x: T
1898 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1899 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1900 self.assertIs(Child[int](1, 2).z, None)
1901 self.assertEqual(Child[int](1, 2, 3).z, 3)
1902 self.assertEqual(Child[int](1, 2, 3).other, 42)
1903 # Check that type aliases work correctly.
1904 Alias = Child[T]
1905 self.assertEqual(Alias[int](1, 2).x, 1)
1906 # Check MRO resolution.
1907 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1908
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001909 def test_dataclassses_pickleable(self):
1910 global P, Q, R
1911 @dataclass
1912 class P:
1913 x: int
1914 y: int = 0
1915 @dataclass
1916 class Q:
1917 x: int
1918 y: int = field(default=0, init=False)
1919 @dataclass
1920 class R:
1921 x: int
1922 y: List[int] = field(default_factory=list)
1923 q = Q(1)
1924 q.y = 2
1925 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1926 for sample in samples:
1927 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1928 with self.subTest(sample=sample, proto=proto):
1929 new_sample = pickle.loads(pickle.dumps(sample, proto))
1930 self.assertEqual(sample.x, new_sample.x)
1931 self.assertEqual(sample.y, new_sample.y)
1932 self.assertIsNot(sample, new_sample)
1933 new_sample.x = 42
1934 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1935 self.assertEqual(new_sample.x, another_new_sample.x)
1936 self.assertEqual(sample.y, another_new_sample.y)
1937
Eric V. Smithea8fc522018-01-27 19:07:40 -05001938
Eric V. Smith56970b82018-03-22 16:28:48 -04001939class TestFieldNoAnnotation(unittest.TestCase):
1940 def test_field_without_annotation(self):
1941 with self.assertRaisesRegex(TypeError,
1942 "'f' is a field but has no type annotation"):
1943 @dataclass
1944 class C:
1945 f = field()
1946
1947 def test_field_without_annotation_but_annotation_in_base(self):
1948 @dataclass
1949 class B:
1950 f: int
1951
1952 with self.assertRaisesRegex(TypeError,
1953 "'f' is a field but has no type annotation"):
1954 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001955 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001956 @dataclass
1957 class C(B):
1958 f = field()
1959
1960 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1961 # Same test, but with the base class not a dataclass.
1962 class B:
1963 f: int
1964
1965 with self.assertRaisesRegex(TypeError,
1966 "'f' is a field but has no type annotation"):
1967 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001968 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001969 @dataclass
1970 class C(B):
1971 f = field()
1972
1973
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001974class TestDocString(unittest.TestCase):
1975 def assertDocStrEqual(self, a, b):
1976 # Because 3.6 and 3.7 differ in how inspect.signature work
1977 # (see bpo #32108), for the time being just compare them with
1978 # whitespace stripped.
1979 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1980
1981 def test_existing_docstring_not_overridden(self):
1982 @dataclass
1983 class C:
1984 """Lorem ipsum"""
1985 x: int
1986
1987 self.assertEqual(C.__doc__, "Lorem ipsum")
1988
1989 def test_docstring_no_fields(self):
1990 @dataclass
1991 class C:
1992 pass
1993
1994 self.assertDocStrEqual(C.__doc__, "C()")
1995
1996 def test_docstring_one_field(self):
1997 @dataclass
1998 class C:
1999 x: int
2000
2001 self.assertDocStrEqual(C.__doc__, "C(x:int)")
2002
2003 def test_docstring_two_fields(self):
2004 @dataclass
2005 class C:
2006 x: int
2007 y: int
2008
2009 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2010
2011 def test_docstring_three_fields(self):
2012 @dataclass
2013 class C:
2014 x: int
2015 y: int
2016 z: str
2017
2018 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2019
2020 def test_docstring_one_field_with_default(self):
2021 @dataclass
2022 class C:
2023 x: int = 3
2024
2025 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2026
2027 def test_docstring_one_field_with_default_none(self):
2028 @dataclass
2029 class C:
2030 x: Union[int, type(None)] = None
2031
Vlad Serebrennikov138a9b92020-04-30 04:06:39 +03002032 self.assertDocStrEqual(C.__doc__, "C(x:Optional[int]=None)")
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002033
2034 def test_docstring_list_field(self):
2035 @dataclass
2036 class C:
2037 x: List[int]
2038
2039 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2040
2041 def test_docstring_list_field_with_default_factory(self):
2042 @dataclass
2043 class C:
2044 x: List[int] = field(default_factory=list)
2045
2046 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2047
2048 def test_docstring_deque_field(self):
2049 @dataclass
2050 class C:
2051 x: deque
2052
2053 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2054
2055 def test_docstring_deque_field_with_default_factory(self):
2056 @dataclass
2057 class C:
2058 x: deque = field(default_factory=deque)
2059
2060 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2061
2062
Eric V. Smithea8fc522018-01-27 19:07:40 -05002063class TestInit(unittest.TestCase):
2064 def test_base_has_init(self):
2065 class B:
2066 def __init__(self):
2067 self.z = 100
2068 pass
2069
2070 # Make sure that declaring this class doesn't raise an error.
2071 # The issue is that we can't override __init__ in our class,
2072 # but it should be okay to add __init__ to us if our base has
2073 # an __init__.
2074 @dataclass
2075 class C(B):
2076 x: int = 0
2077 c = C(10)
2078 self.assertEqual(c.x, 10)
2079 self.assertNotIn('z', vars(c))
2080
2081 # Make sure that if we don't add an init, the base __init__
2082 # gets called.
2083 @dataclass(init=False)
2084 class C(B):
2085 x: int = 10
2086 c = C()
2087 self.assertEqual(c.x, 10)
2088 self.assertEqual(c.z, 100)
2089
2090 def test_no_init(self):
2091 dataclass(init=False)
2092 class C:
2093 i: int = 0
2094 self.assertEqual(C().i, 0)
2095
2096 dataclass(init=False)
2097 class C:
2098 i: int = 2
2099 def __init__(self):
2100 self.i = 3
2101 self.assertEqual(C().i, 3)
2102
2103 def test_overwriting_init(self):
2104 # If the class has __init__, use it no matter the value of
2105 # init=.
2106
2107 @dataclass
2108 class C:
2109 x: int
2110 def __init__(self, x):
2111 self.x = 2 * x
2112 self.assertEqual(C(3).x, 6)
2113
2114 @dataclass(init=True)
2115 class C:
2116 x: int
2117 def __init__(self, x):
2118 self.x = 2 * x
2119 self.assertEqual(C(4).x, 8)
2120
2121 @dataclass(init=False)
2122 class C:
2123 x: int
2124 def __init__(self, x):
2125 self.x = 2 * x
2126 self.assertEqual(C(5).x, 10)
2127
2128
2129class TestRepr(unittest.TestCase):
2130 def test_repr(self):
2131 @dataclass
2132 class B:
2133 x: int
2134
2135 @dataclass
2136 class C(B):
2137 y: int = 10
2138
2139 o = C(4)
2140 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2141
2142 @dataclass
2143 class D(C):
2144 x: int = 20
2145 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2146
2147 @dataclass
2148 class C:
2149 @dataclass
2150 class D:
2151 i: int
2152 @dataclass
2153 class E:
2154 pass
2155 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2156 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2157
2158 def test_no_repr(self):
2159 # Test a class with no __repr__ and repr=False.
2160 @dataclass(repr=False)
2161 class C:
2162 x: int
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002163 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002164 repr(C(3)))
2165
2166 # Test a class with a __repr__ and repr=False.
2167 @dataclass(repr=False)
2168 class C:
2169 x: int
2170 def __repr__(self):
2171 return 'C-class'
2172 self.assertEqual(repr(C(3)), 'C-class')
2173
2174 def test_overwriting_repr(self):
2175 # If the class has __repr__, use it no matter the value of
2176 # repr=.
2177
2178 @dataclass
2179 class C:
2180 x: int
2181 def __repr__(self):
2182 return 'x'
2183 self.assertEqual(repr(C(0)), 'x')
2184
2185 @dataclass(repr=True)
2186 class C:
2187 x: int
2188 def __repr__(self):
2189 return 'x'
2190 self.assertEqual(repr(C(0)), 'x')
2191
2192 @dataclass(repr=False)
2193 class C:
2194 x: int
2195 def __repr__(self):
2196 return 'x'
2197 self.assertEqual(repr(C(0)), 'x')
2198
2199
Eric V. Smithea8fc522018-01-27 19:07:40 -05002200class TestEq(unittest.TestCase):
2201 def test_no_eq(self):
2202 # Test a class with no __eq__ and eq=False.
2203 @dataclass(eq=False)
2204 class C:
2205 x: int
2206 self.assertNotEqual(C(0), C(0))
2207 c = C(3)
2208 self.assertEqual(c, c)
2209
2210 # Test a class with an __eq__ and eq=False.
2211 @dataclass(eq=False)
2212 class C:
2213 x: int
2214 def __eq__(self, other):
2215 return other == 10
2216 self.assertEqual(C(3), 10)
2217
2218 def test_overwriting_eq(self):
2219 # If the class has __eq__, use it no matter the value of
2220 # eq=.
2221
2222 @dataclass
2223 class C:
2224 x: int
2225 def __eq__(self, other):
2226 return other == 3
2227 self.assertEqual(C(1), 3)
2228 self.assertNotEqual(C(1), 1)
2229
2230 @dataclass(eq=True)
2231 class C:
2232 x: int
2233 def __eq__(self, other):
2234 return other == 4
2235 self.assertEqual(C(1), 4)
2236 self.assertNotEqual(C(1), 1)
2237
2238 @dataclass(eq=False)
2239 class C:
2240 x: int
2241 def __eq__(self, other):
2242 return other == 5
2243 self.assertEqual(C(1), 5)
2244 self.assertNotEqual(C(1), 1)
2245
2246
2247class TestOrdering(unittest.TestCase):
2248 def test_functools_total_ordering(self):
2249 # Test that functools.total_ordering works with this class.
2250 @total_ordering
2251 @dataclass
2252 class C:
2253 x: int
2254 def __lt__(self, other):
2255 # Perform the test "backward", just to make
2256 # sure this is being called.
2257 return self.x >= other
2258
2259 self.assertLess(C(0), -1)
2260 self.assertLessEqual(C(0), -1)
2261 self.assertGreater(C(0), 1)
2262 self.assertGreaterEqual(C(0), 1)
2263
2264 def test_no_order(self):
2265 # Test that no ordering functions are added by default.
2266 @dataclass(order=False)
2267 class C:
2268 x: int
2269 # Make sure no order methods are added.
2270 self.assertNotIn('__le__', C.__dict__)
2271 self.assertNotIn('__lt__', C.__dict__)
2272 self.assertNotIn('__ge__', C.__dict__)
2273 self.assertNotIn('__gt__', C.__dict__)
2274
2275 # Test that __lt__ is still called
2276 @dataclass(order=False)
2277 class C:
2278 x: int
2279 def __lt__(self, other):
2280 return False
2281 # Make sure other methods aren't added.
2282 self.assertNotIn('__le__', C.__dict__)
2283 self.assertNotIn('__ge__', C.__dict__)
2284 self.assertNotIn('__gt__', C.__dict__)
2285
2286 def test_overwriting_order(self):
2287 with self.assertRaisesRegex(TypeError,
2288 'Cannot overwrite attribute __lt__'
2289 '.*using functools.total_ordering'):
2290 @dataclass(order=True)
2291 class C:
2292 x: int
2293 def __lt__(self):
2294 pass
2295
2296 with self.assertRaisesRegex(TypeError,
2297 'Cannot overwrite attribute __le__'
2298 '.*using functools.total_ordering'):
2299 @dataclass(order=True)
2300 class C:
2301 x: int
2302 def __le__(self):
2303 pass
2304
2305 with self.assertRaisesRegex(TypeError,
2306 'Cannot overwrite attribute __gt__'
2307 '.*using functools.total_ordering'):
2308 @dataclass(order=True)
2309 class C:
2310 x: int
2311 def __gt__(self):
2312 pass
2313
2314 with self.assertRaisesRegex(TypeError,
2315 'Cannot overwrite attribute __ge__'
2316 '.*using functools.total_ordering'):
2317 @dataclass(order=True)
2318 class C:
2319 x: int
2320 def __ge__(self):
2321 pass
2322
2323class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002324 def test_unsafe_hash(self):
2325 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002326 class C:
2327 x: int
2328 y: str
2329 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2330
Eric V. Smithea8fc522018-01-27 19:07:40 -05002331 def test_hash_rules(self):
2332 def non_bool(value):
2333 # Map to something else that's True, but not a bool.
2334 if value is None:
2335 return None
2336 if value:
2337 return (3,)
2338 return 0
2339
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002340 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2341 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2342 frozen=frozen):
2343 if result != 'exception':
2344 if with_hash:
2345 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2346 class C:
2347 def __hash__(self):
2348 return 0
2349 else:
2350 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2351 class C:
2352 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002353
2354 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002355 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002356 # __hash__ contains the function we generated.
2357 self.assertIn('__hash__', C.__dict__)
2358 self.assertIsNotNone(C.__dict__['__hash__'])
2359
Eric V. Smithea8fc522018-01-27 19:07:40 -05002360 elif result == '':
2361 # __hash__ is not present in our class.
2362 if not with_hash:
2363 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002364
Eric V. Smithea8fc522018-01-27 19:07:40 -05002365 elif result == 'none':
2366 # __hash__ is set to None.
2367 self.assertIn('__hash__', C.__dict__)
2368 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002369
2370 elif result == 'exception':
2371 # Creating the class should cause an exception.
2372 # This only happens with with_hash==True.
2373 assert(with_hash)
2374 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2375 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2376 class C:
2377 def __hash__(self):
2378 return 0
2379
Eric V. Smithea8fc522018-01-27 19:07:40 -05002380 else:
2381 assert False, f'unknown result {result!r}'
2382
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002383 # There are 8 cases of:
2384 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002385 # eq=True/False
2386 # frozen=True/False
2387 # And for each of these, a different result if
2388 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002389 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2390 (False, False, False, '', ''),
2391 (False, False, True, '', ''),
2392 (False, True, False, 'none', ''),
2393 (False, True, True, 'fn', ''),
2394 (True, False, False, 'fn', 'exception'),
2395 (True, False, True, 'fn', 'exception'),
2396 (True, True, False, 'fn', 'exception'),
2397 (True, True, True, 'fn', 'exception'),
2398 ], 1):
2399 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2400 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002401
2402 # Test non-bool truth values, too. This is just to
2403 # make sure the data-driven table in the decorator
2404 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002405 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2406 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002407
2408
2409 def test_eq_only(self):
2410 # If a class defines __eq__, __hash__ is automatically added
2411 # and set to None. This is normal Python behavior, not
2412 # related to dataclasses. Make sure we don't interfere with
2413 # that (see bpo=32546).
2414
2415 @dataclass
2416 class C:
2417 i: int
2418 def __eq__(self, other):
2419 return self.i == other.i
2420 self.assertEqual(C(1), C(1))
2421 self.assertNotEqual(C(1), C(4))
2422
2423 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002424 # unsafe_hash=True.
2425 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002426 class C:
2427 i: int
2428 def __eq__(self, other):
2429 return self.i == other.i
2430 self.assertEqual(C(1), C(1.0))
2431 self.assertEqual(hash(C(1)), hash(C(1.0)))
2432
2433 # And check that the classes __eq__ is being used, despite
2434 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002435 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002436 class C:
2437 i: int
2438 def __eq__(self, other):
2439 return self.i == 3 and self.i == other.i
2440 self.assertEqual(C(3), C(3))
2441 self.assertNotEqual(C(1), C(1))
2442 self.assertEqual(hash(C(1)), hash(C(1.0)))
2443
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002444 def test_0_field_hash(self):
2445 @dataclass(frozen=True)
2446 class C:
2447 pass
2448 self.assertEqual(hash(C()), hash(()))
2449
2450 @dataclass(unsafe_hash=True)
2451 class C:
2452 pass
2453 self.assertEqual(hash(C()), hash(()))
2454
2455 def test_1_field_hash(self):
2456 @dataclass(frozen=True)
2457 class C:
2458 x: int
2459 self.assertEqual(hash(C(4)), hash((4,)))
2460 self.assertEqual(hash(C(42)), hash((42,)))
2461
2462 @dataclass(unsafe_hash=True)
2463 class C:
2464 x: int
2465 self.assertEqual(hash(C(4)), hash((4,)))
2466 self.assertEqual(hash(C(42)), hash((42,)))
2467
Eric V. Smith718070d2018-02-23 13:01:31 -05002468 def test_hash_no_args(self):
2469 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002470 # make sure that if the @dataclass parameter name is changed
2471 # or the non-default hashing behavior changes, the default
2472 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002473
2474 class Base:
2475 def __hash__(self):
2476 return 301
2477
2478 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002479 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002480 for frozen, eq, base, expected in [
2481 (None, None, object, 'unhashable'),
2482 (None, None, Base, 'unhashable'),
2483 (None, False, object, 'object'),
2484 (None, False, Base, 'base'),
2485 (None, True, object, 'unhashable'),
2486 (None, True, Base, 'unhashable'),
2487 (False, None, object, 'unhashable'),
2488 (False, None, Base, 'unhashable'),
2489 (False, False, object, 'object'),
2490 (False, False, Base, 'base'),
2491 (False, True, object, 'unhashable'),
2492 (False, True, Base, 'unhashable'),
2493 (True, None, object, 'tuple'),
2494 (True, None, Base, 'tuple'),
2495 (True, False, object, 'object'),
2496 (True, False, Base, 'base'),
2497 (True, True, object, 'tuple'),
2498 (True, True, Base, 'tuple'),
2499 ]:
2500
2501 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2502 # First, create the class.
2503 if frozen is None and eq is None:
2504 @dataclass
2505 class C(base):
2506 i: int
2507 elif frozen is None:
2508 @dataclass(eq=eq)
2509 class C(base):
2510 i: int
2511 elif eq is None:
2512 @dataclass(frozen=frozen)
2513 class C(base):
2514 i: int
2515 else:
2516 @dataclass(frozen=frozen, eq=eq)
2517 class C(base):
2518 i: int
2519
2520 # Now, make sure it hashes as expected.
2521 if expected == 'unhashable':
2522 c = C(10)
2523 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2524 hash(c)
2525
2526 elif expected == 'base':
2527 self.assertEqual(hash(C(10)), 301)
2528
2529 elif expected == 'object':
2530 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002531 # hash isn't based on id(), so calling hash()
2532 # won't tell us much. So, just check the
2533 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002534 self.assertIs(C.__hash__, object.__hash__)
2535
2536 elif expected == 'tuple':
2537 self.assertEqual(hash(C(42)), hash((42,)))
2538
2539 else:
2540 assert False, f'unknown value for expected={expected!r}'
2541
Eric V. Smithea8fc522018-01-27 19:07:40 -05002542
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002543class TestFrozen(unittest.TestCase):
2544 def test_frozen(self):
2545 @dataclass(frozen=True)
2546 class C:
2547 i: int
2548
2549 c = C(10)
2550 self.assertEqual(c.i, 10)
2551 with self.assertRaises(FrozenInstanceError):
2552 c.i = 5
2553 self.assertEqual(c.i, 10)
2554
2555 def test_inherit(self):
2556 @dataclass(frozen=True)
2557 class C:
2558 i: int
2559
2560 @dataclass(frozen=True)
2561 class D(C):
2562 j: int
2563
2564 d = D(0, 10)
2565 with self.assertRaises(FrozenInstanceError):
2566 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002567 with self.assertRaises(FrozenInstanceError):
2568 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002569 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002570 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002571
Eric V. Smithf199bc62018-03-18 20:40:34 -04002572 # Test both ways: with an intermediate normal (non-dataclass)
2573 # class and without an intermediate class.
2574 def test_inherit_nonfrozen_from_frozen(self):
2575 for intermediate_class in [True, False]:
2576 with self.subTest(intermediate_class=intermediate_class):
2577 @dataclass(frozen=True)
2578 class C:
2579 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002580
Eric V. Smithf199bc62018-03-18 20:40:34 -04002581 if intermediate_class:
2582 class I(C): pass
2583 else:
2584 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002585
Eric V. Smithf199bc62018-03-18 20:40:34 -04002586 with self.assertRaisesRegex(TypeError,
2587 'cannot inherit non-frozen dataclass from a frozen one'):
2588 @dataclass
2589 class D(I):
2590 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002591
Eric V. Smithf199bc62018-03-18 20:40:34 -04002592 def test_inherit_frozen_from_nonfrozen(self):
2593 for intermediate_class in [True, False]:
2594 with self.subTest(intermediate_class=intermediate_class):
2595 @dataclass
2596 class C:
2597 i: int
2598
2599 if intermediate_class:
2600 class I(C): pass
2601 else:
2602 I = C
2603
2604 with self.assertRaisesRegex(TypeError,
2605 'cannot inherit frozen dataclass from a non-frozen one'):
2606 @dataclass(frozen=True)
2607 class D(I):
2608 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002609
2610 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002611 for intermediate_class in [True, False]:
2612 with self.subTest(intermediate_class=intermediate_class):
2613 class C:
2614 pass
2615
2616 if intermediate_class:
2617 class I(C): pass
2618 else:
2619 I = C
2620
2621 @dataclass(frozen=True)
2622 class D(I):
2623 i: int
2624
2625 d = D(10)
2626 with self.assertRaises(FrozenInstanceError):
2627 d.i = 5
2628
2629 def test_non_frozen_normal_derived(self):
2630 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002631
2632 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002633 class D:
2634 x: int
2635 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002636
Eric V. Smithf199bc62018-03-18 20:40:34 -04002637 class S(D):
2638 pass
2639
2640 s = S(3)
2641 self.assertEqual(s.x, 3)
2642 self.assertEqual(s.y, 10)
2643 s.cached = True
2644
2645 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002646 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002647 s.x = 5
2648 with self.assertRaises(FrozenInstanceError):
2649 s.y = 5
2650 self.assertEqual(s.x, 3)
2651 self.assertEqual(s.y, 10)
2652 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002653
Eric V. Smith74940912018-04-05 06:50:18 -04002654 def test_overwriting_frozen(self):
2655 # frozen uses __setattr__ and __delattr__.
2656 with self.assertRaisesRegex(TypeError,
2657 'Cannot overwrite attribute __setattr__'):
2658 @dataclass(frozen=True)
2659 class C:
2660 x: int
2661 def __setattr__(self):
2662 pass
2663
2664 with self.assertRaisesRegex(TypeError,
2665 'Cannot overwrite attribute __delattr__'):
2666 @dataclass(frozen=True)
2667 class C:
2668 x: int
2669 def __delattr__(self):
2670 pass
2671
2672 @dataclass(frozen=False)
2673 class C:
2674 x: int
2675 def __setattr__(self, name, value):
2676 self.__dict__['x'] = value * 2
2677 self.assertEqual(C(10).x, 20)
2678
2679 def test_frozen_hash(self):
2680 @dataclass(frozen=True)
2681 class C:
2682 x: Any
2683
2684 # If x is immutable, we can compute the hash. No exception is
2685 # raised.
2686 hash(C(3))
2687
2688 # If x is mutable, computing the hash is an error.
2689 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2690 hash(C({}))
2691
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002692
Eric V. Smith7389fd92018-03-19 21:07:51 -04002693class TestSlots(unittest.TestCase):
2694 def test_simple(self):
2695 @dataclass
2696 class C:
2697 __slots__ = ('x',)
2698 x: Any
2699
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002700 # There was a bug where a variable in a slot was assumed to
2701 # also have a default value (of type
2702 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002703 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002704 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002705 C()
2706
2707 # We can create an instance, and assign to x.
2708 c = C(10)
2709 self.assertEqual(c.x, 10)
2710 c.x = 5
2711 self.assertEqual(c.x, 5)
2712
2713 # We can't assign to anything else.
2714 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2715 c.y = 5
2716
2717 def test_derived_added_field(self):
2718 # See bpo-33100.
2719 @dataclass
2720 class Base:
2721 __slots__ = ('x',)
2722 x: Any
2723
2724 @dataclass
2725 class Derived(Base):
2726 x: int
2727 y: int
2728
2729 d = Derived(1, 2)
2730 self.assertEqual((d.x, d.y), (1, 2))
2731
2732 # We can add a new field to the derived instance.
2733 d.z = 10
2734
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002735class TestDescriptors(unittest.TestCase):
2736 def test_set_name(self):
2737 # See bpo-33141.
2738
2739 # Create a descriptor.
2740 class D:
2741 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002742 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002743 def __get__(self, instance, owner):
2744 if instance is not None:
2745 return 1
2746 return self
2747
2748 # This is the case of just normal descriptor behavior, no
2749 # dataclass code is involved in initializing the descriptor.
2750 @dataclass
2751 class C:
2752 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002753 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002754
2755 # Now test with a default value and init=False, which is the
2756 # only time this is really meaningful. If not using
2757 # init=False, then the descriptor will be overwritten, anyway.
2758 @dataclass
2759 class C:
2760 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002761 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002762 self.assertEqual(C().c, 1)
2763
2764 def test_non_descriptor(self):
2765 # PEP 487 says __set_name__ should work on non-descriptors.
2766 # Create a descriptor.
2767
2768 class D:
2769 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002770 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002771
2772 @dataclass
2773 class C:
2774 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002775 self.assertEqual(C.c.name, 'cx')
2776
2777 def test_lookup_on_instance(self):
2778 # See bpo-33175.
2779 class D:
2780 pass
2781
2782 d = D()
2783 # Create an attribute on the instance, not type.
2784 d.__set_name__ = Mock()
2785
2786 # Make sure d.__set_name__ is not called.
2787 @dataclass
2788 class C:
2789 i: int=field(default=d, init=False)
2790
2791 self.assertEqual(d.__set_name__.call_count, 0)
2792
2793 def test_lookup_on_class(self):
2794 # See bpo-33175.
2795 class D:
2796 pass
2797 D.__set_name__ = Mock()
2798
2799 # Make sure D.__set_name__ is called.
2800 @dataclass
2801 class C:
2802 i: int=field(default=D(), init=False)
2803
2804 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002805
Eric V. Smith7389fd92018-03-19 21:07:51 -04002806
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002807class TestStringAnnotations(unittest.TestCase):
2808 def test_classvar(self):
2809 # Some expressions recognized as ClassVar really aren't. But
2810 # if you're using string annotations, it's not an exact
2811 # science.
2812 # These tests assume that both "import typing" and "from
2813 # typing import *" have been run in this file.
2814 for typestr in ('ClassVar[int]',
2815 'ClassVar [int]'
2816 ' ClassVar [int]',
2817 'ClassVar',
2818 ' ClassVar ',
2819 'typing.ClassVar[int]',
2820 'typing.ClassVar[str]',
2821 ' typing.ClassVar[str]',
2822 'typing .ClassVar[str]',
2823 'typing. ClassVar[str]',
2824 'typing.ClassVar [str]',
2825 'typing.ClassVar [ str]',
2826
2827 # Not syntactically valid, but these will
2828 # be treated as ClassVars.
2829 'typing.ClassVar.[int]',
2830 'typing.ClassVar+',
2831 ):
2832 with self.subTest(typestr=typestr):
2833 @dataclass
2834 class C:
2835 x: typestr
2836
2837 # x is a ClassVar, so C() takes no args.
2838 C()
2839
2840 # And it won't appear in the class's dict because it doesn't
2841 # have a default.
2842 self.assertNotIn('x', C.__dict__)
2843
2844 def test_isnt_classvar(self):
2845 for typestr in ('CV',
2846 't.ClassVar',
2847 't.ClassVar[int]',
2848 'typing..ClassVar[int]',
2849 'Classvar',
2850 'Classvar[int]',
2851 'typing.ClassVarx[int]',
2852 'typong.ClassVar[int]',
2853 'dataclasses.ClassVar[int]',
2854 'typingxClassVar[str]',
2855 ):
2856 with self.subTest(typestr=typestr):
2857 @dataclass
2858 class C:
2859 x: typestr
2860
2861 # x is not a ClassVar, so C() takes one arg.
2862 self.assertEqual(C(10).x, 10)
2863
2864 def test_initvar(self):
2865 # These tests assume that both "import dataclasses" and "from
2866 # dataclasses import *" have been run in this file.
2867 for typestr in ('InitVar[int]',
2868 'InitVar [int]'
2869 ' InitVar [int]',
2870 'InitVar',
2871 ' InitVar ',
2872 'dataclasses.InitVar[int]',
2873 'dataclasses.InitVar[str]',
2874 ' dataclasses.InitVar[str]',
2875 'dataclasses .InitVar[str]',
2876 'dataclasses. InitVar[str]',
2877 'dataclasses.InitVar [str]',
2878 'dataclasses.InitVar [ str]',
2879
2880 # Not syntactically valid, but these will
2881 # be treated as InitVars.
2882 'dataclasses.InitVar.[int]',
2883 'dataclasses.InitVar+',
2884 ):
2885 with self.subTest(typestr=typestr):
2886 @dataclass
2887 class C:
2888 x: typestr
2889
2890 # x is an InitVar, so doesn't create a member.
2891 with self.assertRaisesRegex(AttributeError,
2892 "object has no attribute 'x'"):
2893 C(1).x
2894
2895 def test_isnt_initvar(self):
2896 for typestr in ('IV',
2897 'dc.InitVar',
2898 'xdataclasses.xInitVar',
2899 'typing.xInitVar[int]',
2900 ):
2901 with self.subTest(typestr=typestr):
2902 @dataclass
2903 class C:
2904 x: typestr
2905
2906 # x is not an InitVar, so there will be a member x.
2907 self.assertEqual(C(10).x, 10)
2908
2909 def test_classvar_module_level_import(self):
Serhiy Storchaka3fe5ccc2018-07-23 23:37:55 +03002910 from test import dataclass_module_1
2911 from test import dataclass_module_1_str
2912 from test import dataclass_module_2
2913 from test import dataclass_module_2_str
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002914
2915 for m in (dataclass_module_1, dataclass_module_1_str,
2916 dataclass_module_2, dataclass_module_2_str,
2917 ):
2918 with self.subTest(m=m):
2919 # There's a difference in how the ClassVars are
2920 # interpreted when using string annotations or
2921 # not. See the imported modules for details.
2922 if m.USING_STRINGS:
2923 c = m.CV(10)
2924 else:
2925 c = m.CV()
2926 self.assertEqual(c.cv0, 20)
2927
2928
2929 # There's a difference in how the InitVars are
2930 # interpreted when using string annotations or
2931 # not. See the imported modules for details.
2932 c = m.IV(0, 1, 2, 3, 4)
2933
2934 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2935 with self.subTest(field_name=field_name):
2936 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2937 # Since field_name is an InitVar, it's
2938 # not an instance field.
2939 getattr(c, field_name)
2940
2941 if m.USING_STRINGS:
2942 # iv4 is interpreted as a normal field.
2943 self.assertIn('not_iv4', c.__dict__)
2944 self.assertEqual(c.not_iv4, 4)
2945 else:
2946 # iv4 is interpreted as an InitVar, so it
2947 # won't exist on the instance.
2948 self.assertNotIn('not_iv4', c.__dict__)
2949
Yury Selivanovd219cc42019-12-09 09:54:20 -05002950 def test_text_annotations(self):
2951 from test import dataclass_textanno
2952
2953 self.assertEqual(
2954 get_type_hints(dataclass_textanno.Bar),
2955 {'foo': dataclass_textanno.Foo})
2956 self.assertEqual(
2957 get_type_hints(dataclass_textanno.Bar.__init__),
2958 {'foo': dataclass_textanno.Foo,
2959 'return': type(None)})
2960
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002961
Eric V. Smith4e812962018-05-16 11:31:29 -04002962class TestMakeDataclass(unittest.TestCase):
2963 def test_simple(self):
2964 C = make_dataclass('C',
2965 [('x', int),
2966 ('y', int, field(default=5))],
2967 namespace={'add_one': lambda self: self.x + 1})
2968 c = C(10)
2969 self.assertEqual((c.x, c.y), (10, 5))
2970 self.assertEqual(c.add_one(), 11)
2971
2972
2973 def test_no_mutate_namespace(self):
2974 # Make sure a provided namespace isn't mutated.
2975 ns = {}
2976 C = make_dataclass('C',
2977 [('x', int),
2978 ('y', int, field(default=5))],
2979 namespace=ns)
2980 self.assertEqual(ns, {})
2981
2982 def test_base(self):
2983 class Base1:
2984 pass
2985 class Base2:
2986 pass
2987 C = make_dataclass('C',
2988 [('x', int)],
2989 bases=(Base1, Base2))
2990 c = C(2)
2991 self.assertIsInstance(c, C)
2992 self.assertIsInstance(c, Base1)
2993 self.assertIsInstance(c, Base2)
2994
2995 def test_base_dataclass(self):
2996 @dataclass
2997 class Base1:
2998 x: int
2999 class Base2:
3000 pass
3001 C = make_dataclass('C',
3002 [('y', int)],
3003 bases=(Base1, Base2))
3004 with self.assertRaisesRegex(TypeError, 'required positional'):
3005 c = C(2)
3006 c = C(1, 2)
3007 self.assertIsInstance(c, C)
3008 self.assertIsInstance(c, Base1)
3009 self.assertIsInstance(c, Base2)
3010
3011 self.assertEqual((c.x, c.y), (1, 2))
3012
3013 def test_init_var(self):
3014 def post_init(self, y):
3015 self.x *= y
3016
3017 C = make_dataclass('C',
3018 [('x', int),
3019 ('y', InitVar[int]),
3020 ],
3021 namespace={'__post_init__': post_init},
3022 )
3023 c = C(2, 3)
3024 self.assertEqual(vars(c), {'x': 6})
3025 self.assertEqual(len(fields(c)), 1)
3026
3027 def test_class_var(self):
3028 C = make_dataclass('C',
3029 [('x', int),
3030 ('y', ClassVar[int], 10),
3031 ('z', ClassVar[int], field(default=20)),
3032 ])
3033 c = C(1)
3034 self.assertEqual(vars(c), {'x': 1})
3035 self.assertEqual(len(fields(c)), 1)
3036 self.assertEqual(C.y, 10)
3037 self.assertEqual(C.z, 20)
3038
3039 def test_other_params(self):
3040 C = make_dataclass('C',
3041 [('x', int),
3042 ('y', ClassVar[int], 10),
3043 ('z', ClassVar[int], field(default=20)),
3044 ],
3045 init=False)
3046 # Make sure we have a repr, but no init.
3047 self.assertNotIn('__init__', vars(C))
3048 self.assertIn('__repr__', vars(C))
3049
3050 # Make sure random other params don't work.
3051 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
3052 C = make_dataclass('C',
3053 [],
3054 xxinit=False)
3055
3056 def test_no_types(self):
3057 C = make_dataclass('Point', ['x', 'y', 'z'])
3058 c = C(1, 2, 3)
3059 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3060 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3061 'y': 'typing.Any',
3062 'z': 'typing.Any'})
3063
3064 C = make_dataclass('Point', ['x', ('y', int), 'z'])
3065 c = C(1, 2, 3)
3066 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
3067 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
3068 'y': int,
3069 'z': 'typing.Any'})
3070
3071 def test_invalid_type_specification(self):
3072 for bad_field in [(),
3073 (1, 2, 3, 4),
3074 ]:
3075 with self.subTest(bad_field=bad_field):
3076 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
3077 make_dataclass('C', ['a', bad_field])
3078
3079 # And test for things with no len().
3080 for bad_field in [float,
3081 lambda x:x,
3082 ]:
3083 with self.subTest(bad_field=bad_field):
3084 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
3085 make_dataclass('C', ['a', bad_field])
3086
3087 def test_duplicate_field_names(self):
3088 for field in ['a', 'ab']:
3089 with self.subTest(field=field):
3090 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
3091 make_dataclass('C', [field, 'a', field])
3092
3093 def test_keyword_field_names(self):
3094 for field in ['for', 'async', 'await', 'as']:
3095 with self.subTest(field=field):
3096 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3097 make_dataclass('C', ['a', field])
3098 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3099 make_dataclass('C', [field])
3100 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
3101 make_dataclass('C', [field, 'a'])
3102
3103 def test_non_identifier_field_names(self):
3104 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
3105 with self.subTest(field=field):
Min ho Kim96e12d52019-07-22 06:12:33 +10003106 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003107 make_dataclass('C', ['a', field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003108 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003109 make_dataclass('C', [field])
Min ho Kim96e12d52019-07-22 06:12:33 +10003110 with self.assertRaisesRegex(TypeError, 'must be valid identifiers'):
Eric V. Smith4e812962018-05-16 11:31:29 -04003111 make_dataclass('C', [field, 'a'])
3112
3113 def test_underscore_field_names(self):
3114 # Unlike namedtuple, it's okay if dataclass field names have
3115 # an underscore.
3116 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
3117
3118 def test_funny_class_names_names(self):
3119 # No reason to prevent weird class names, since
3120 # types.new_class allows them.
3121 for classname in ['()', 'x,y', '*', '2@3', '']:
3122 with self.subTest(classname=classname):
3123 C = make_dataclass(classname, ['a', 'b'])
3124 self.assertEqual(C.__name__, classname)
3125
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003126class TestReplace(unittest.TestCase):
3127 def test(self):
3128 @dataclass(frozen=True)
3129 class C:
3130 x: int
3131 y: int
3132
3133 c = C(1, 2)
3134 c1 = replace(c, x=3)
3135 self.assertEqual(c1.x, 3)
3136 self.assertEqual(c1.y, 2)
3137
3138 def test_frozen(self):
3139 @dataclass(frozen=True)
3140 class C:
3141 x: int
3142 y: int
3143 z: int = field(init=False, default=10)
3144 t: int = field(init=False, default=100)
3145
3146 c = C(1, 2)
3147 c1 = replace(c, x=3)
3148 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
3149 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
3150
3151
3152 with self.assertRaisesRegex(ValueError, 'init=False'):
3153 replace(c, x=3, z=20, t=50)
3154 with self.assertRaisesRegex(ValueError, 'init=False'):
3155 replace(c, z=20)
3156 replace(c, x=3, z=20, t=50)
3157
3158 # Make sure the result is still frozen.
3159 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3160 c1.x = 3
3161
3162 # Make sure we can't replace an attribute that doesn't exist,
3163 # if we're also replacing one that does exist. Test this
3164 # here, because setting attributes on frozen instances is
3165 # handled slightly differently from non-frozen ones.
3166 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3167 "keyword argument 'a'"):
3168 c1 = replace(c, x=20, a=5)
3169
3170 def test_invalid_field_name(self):
3171 @dataclass(frozen=True)
3172 class C:
3173 x: int
3174 y: int
3175
3176 c = C(1, 2)
3177 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3178 "keyword argument 'z'"):
3179 c1 = replace(c, z=3)
3180
3181 def test_invalid_object(self):
3182 @dataclass(frozen=True)
3183 class C:
3184 x: int
3185 y: int
3186
3187 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3188 replace(C, x=3)
3189
3190 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3191 replace(0, x=3)
3192
3193 def test_no_init(self):
3194 @dataclass
3195 class C:
3196 x: int
3197 y: int = field(init=False, default=10)
3198
3199 c = C(1)
3200 c.y = 20
3201
3202 # Make sure y gets the default value.
3203 c1 = replace(c, x=5)
3204 self.assertEqual((c1.x, c1.y), (5, 10))
3205
3206 # Trying to replace y is an error.
3207 with self.assertRaisesRegex(ValueError, 'init=False'):
3208 replace(c, x=2, y=30)
3209
3210 with self.assertRaisesRegex(ValueError, 'init=False'):
3211 replace(c, y=30)
3212
3213 def test_classvar(self):
3214 @dataclass
3215 class C:
3216 x: int
3217 y: ClassVar[int] = 1000
3218
3219 c = C(1)
3220 d = C(2)
3221
3222 self.assertIs(c.y, d.y)
3223 self.assertEqual(c.y, 1000)
3224
3225 # Trying to replace y is an error: can't replace ClassVars.
3226 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3227 "unexpected keyword argument 'y'"):
3228 replace(c, y=30)
3229
3230 replace(c, x=5)
3231
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003232 def test_initvar_is_specified(self):
3233 @dataclass
3234 class C:
3235 x: int
3236 y: InitVar[int]
3237
3238 def __post_init__(self, y):
3239 self.x *= y
3240
3241 c = C(1, 10)
3242 self.assertEqual(c.x, 10)
3243 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3244 "specified with replace()"):
3245 replace(c, x=3)
3246 c = replace(c, x=3, y=5)
3247 self.assertEqual(c.x, 15)
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303248
3249 def test_recursive_repr(self):
3250 @dataclass
3251 class C:
3252 f: "C"
3253
3254 c = C(None)
3255 c.f = c
3256 self.assertEqual(repr(c), "TestReplace.test_recursive_repr.<locals>.C(f=...)")
3257
3258 def test_recursive_repr_two_attrs(self):
3259 @dataclass
3260 class C:
3261 f: "C"
3262 g: "C"
3263
3264 c = C(None, None)
3265 c.f = c
3266 c.g = c
3267 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_two_attrs"
3268 ".<locals>.C(f=..., g=...)")
3269
3270 def test_recursive_repr_indirection(self):
3271 @dataclass
3272 class C:
3273 f: "D"
3274
3275 @dataclass
3276 class D:
3277 f: "C"
3278
3279 c = C(None)
3280 d = D(None)
3281 c.f = d
3282 d.f = c
3283 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection"
3284 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection"
3285 ".<locals>.D(f=...))")
3286
3287 def test_recursive_repr_indirection_two(self):
3288 @dataclass
3289 class C:
3290 f: "D"
3291
3292 @dataclass
3293 class D:
3294 f: "E"
3295
3296 @dataclass
3297 class E:
3298 f: "C"
3299
3300 c = C(None)
3301 d = D(None)
3302 e = E(None)
3303 c.f = d
3304 d.f = e
3305 e.f = c
3306 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_indirection_two"
3307 ".<locals>.C(f=TestReplace.test_recursive_repr_indirection_two"
3308 ".<locals>.D(f=TestReplace.test_recursive_repr_indirection_two"
3309 ".<locals>.E(f=...)))")
3310
Srinivas Thatiparthy (శ్రీనివాస్ తాటిపర్తి)dd13c882018-10-19 22:24:50 +05303311 def test_recursive_repr_misc_attrs(self):
3312 @dataclass
3313 class C:
3314 f: "C"
3315 g: int
3316
3317 c = C(None, 1)
3318 c.f = c
3319 self.assertEqual(repr(c), "TestReplace.test_recursive_repr_misc_attrs"
3320 ".<locals>.C(f=..., g=1)")
3321
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003322 ## def test_initvar(self):
3323 ## @dataclass
3324 ## class C:
3325 ## x: int
3326 ## y: InitVar[int]
3327
3328 ## c = C(1, 10)
3329 ## d = C(2, 20)
3330
3331 ## # In our case, replacing an InitVar is a no-op
3332 ## self.assertEqual(c, replace(c, y=5))
3333
3334 ## replace(c, x=5)
3335
Ben Avrahamibef7d292020-10-06 20:40:50 +03003336class TestAbstract(unittest.TestCase):
3337 def test_abc_implementation(self):
3338 class Ordered(abc.ABC):
3339 @abc.abstractmethod
3340 def __lt__(self, other):
3341 pass
3342
3343 @abc.abstractmethod
3344 def __le__(self, other):
3345 pass
3346
3347 @dataclass(order=True)
3348 class Date(Ordered):
3349 year: int
3350 month: 'Month'
3351 day: 'int'
3352
3353 self.assertFalse(inspect.isabstract(Date))
3354 self.assertGreater(Date(2020,12,25), Date(2020,8,31))
3355
3356 def test_maintain_abc(self):
3357 class A(abc.ABC):
3358 @abc.abstractmethod
3359 def foo(self):
3360 pass
3361
3362 @dataclass
3363 class Date(A):
3364 year: int
3365 month: 'Month'
3366 day: 'int'
3367
3368 self.assertTrue(inspect.isabstract(Date))
3369 msg = 'class Date with abstract method foo'
3370 self.assertRaisesRegex(TypeError, msg, Date)
3371
Eric V. Smith4e812962018-05-16 11:31:29 -04003372
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003373if __name__ == '__main__':
3374 unittest.main()