blob: d9556c7ff9cecca5f87a08b5b7e318692f3877e7 [file] [log] [blame]
Eric V. Smith8e4560a2018-03-21 17:10:22 -04001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +010011from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050012from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
Eric V. Smith2a7bacb2018-05-15 22:44:27 -040015import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
16import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
17
Eric V. Smithf0db54a2017-12-04 16:58:55 -050018# Just any custom exception we can catch.
19class CustomError(Exception): pass
20
21class TestCase(unittest.TestCase):
22 def test_no_fields(self):
23 @dataclass
24 class C:
25 pass
26
27 o = C()
28 self.assertEqual(len(fields(C)), 0)
29
Eric V. Smith56970b82018-03-22 16:28:48 -040030 def test_no_fields_but_member_variable(self):
31 @dataclass
32 class C:
33 i = 0
34
35 o = C()
36 self.assertEqual(len(fields(C)), 0)
37
Eric V. Smithf0db54a2017-12-04 16:58:55 -050038 def test_one_field_no_default(self):
39 @dataclass
40 class C:
41 x: int
42
43 o = C(42)
44 self.assertEqual(o.x, 42)
45
46 def test_named_init_params(self):
47 @dataclass
48 class C:
49 x: int
50
51 o = C(x=32)
52 self.assertEqual(o.x, 32)
53
54 def test_two_fields_one_default(self):
55 @dataclass
56 class C:
57 x: int
58 y: int = 0
59
60 o = C(3)
61 self.assertEqual((o.x, o.y), (3, 0))
62
63 # Non-defaults following defaults.
64 with self.assertRaisesRegex(TypeError,
65 "non-default argument 'y' follows "
66 "default argument"):
67 @dataclass
68 class C:
69 x: int = 0
70 y: int
71
72 # A derived class adds a non-default field after a default one.
73 with self.assertRaisesRegex(TypeError,
74 "non-default argument 'y' follows "
75 "default argument"):
76 @dataclass
77 class B:
78 x: int = 0
79
80 @dataclass
81 class C(B):
82 y: int
83
84 # Override a base class field and add a default to
85 # a field which didn't use to have a default.
86 with self.assertRaisesRegex(TypeError,
87 "non-default argument 'y' follows "
88 "default argument"):
89 @dataclass
90 class B:
91 x: int
92 y: int
93
94 @dataclass
95 class C(B):
96 x: int = 0
97
Eric V. Smithdbf9cff2018-02-25 21:30:17 -050098 def test_overwrite_hash(self):
99 # Test that declaring this class isn't an error. It should
100 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500101 @dataclass(frozen=True)
102 class C:
103 x: int
104 def __hash__(self):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500105 return 301
106 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500107
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500108 # Test that declaring this class isn't an error. It should
109 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500110 @dataclass(frozen=True)
111 class C:
112 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500113 def __eq__(self, other):
114 return False
115 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500116
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500117 # But this one should generate an exception, because with
118 # unsafe_hash=True, it's an error to have a __hash__ defined.
119 with self.assertRaisesRegex(TypeError,
120 'Cannot overwrite attribute __hash__'):
121 @dataclass(unsafe_hash=True)
122 class C:
123 def __hash__(self):
124 pass
125
126 # Creating this class should not generate an exception,
127 # because even though __hash__ exists before @dataclass is
128 # called, (due to __eq__ being defined), since it's None
129 # that's okay.
130 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500131 class C:
132 x: int
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500133 def __eq__(self):
134 pass
135 # The generated hash function works as we'd expect.
136 self.assertEqual(hash(C(10)), hash((10,)))
137
138 # Creating this class should generate an exception, because
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400139 # __hash__ exists and is not None, which it would be if it
140 # had been auto-generated due to __eq__ being defined.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500141 with self.assertRaisesRegex(TypeError,
142 'Cannot overwrite attribute __hash__'):
143 @dataclass(unsafe_hash=True)
144 class C:
145 x: int
146 def __eq__(self):
147 pass
148 def __hash__(self):
149 pass
150
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500151 def test_overwrite_fields_in_derived_class(self):
152 # Note that x from C1 replaces x in Base, but the order remains
153 # the same as defined in Base.
154 @dataclass
155 class Base:
156 x: Any = 15.0
157 y: int = 0
158
159 @dataclass
160 class C1(Base):
161 z: int = 10
162 x: int = 15
163
164 o = Base()
165 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
166
167 o = C1()
168 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
169
170 o = C1(x=5)
171 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
172
173 def test_field_named_self(self):
174 @dataclass
175 class C:
176 self: str
177 c=C('foo')
178 self.assertEqual(c.self, 'foo')
179
180 # Make sure the first parameter is not named 'self'.
181 sig = inspect.signature(C.__init__)
182 first = next(iter(sig.parameters))
183 self.assertNotEqual('self', first)
184
185 # But we do use 'self' if no field named self.
186 @dataclass
187 class C:
188 selfx: str
189
190 # Make sure the first parameter is named 'self'.
191 sig = inspect.signature(C.__init__)
192 first = next(iter(sig.parameters))
193 self.assertEqual('self', first)
194
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500195 def test_0_field_compare(self):
196 # Ensure that order=False is the default.
197 @dataclass
198 class C0:
199 pass
200
201 @dataclass(order=False)
202 class C1:
203 pass
204
205 for cls in [C0, C1]:
206 with self.subTest(cls=cls):
207 self.assertEqual(cls(), cls())
208 for idx, fn in enumerate([lambda a, b: a < b,
209 lambda a, b: a <= b,
210 lambda a, b: a > b,
211 lambda a, b: a >= b]):
212 with self.subTest(idx=idx):
213 with self.assertRaisesRegex(TypeError,
214 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
215 fn(cls(), cls())
216
217 @dataclass(order=True)
218 class C:
219 pass
220 self.assertLessEqual(C(), C())
221 self.assertGreaterEqual(C(), C())
222
223 def test_1_field_compare(self):
224 # Ensure that order=False is the default.
225 @dataclass
226 class C0:
227 x: int
228
229 @dataclass(order=False)
230 class C1:
231 x: int
232
233 for cls in [C0, C1]:
234 with self.subTest(cls=cls):
235 self.assertEqual(cls(1), cls(1))
236 self.assertNotEqual(cls(0), cls(1))
237 for idx, fn in enumerate([lambda a, b: a < b,
238 lambda a, b: a <= b,
239 lambda a, b: a > b,
240 lambda a, b: a >= b]):
241 with self.subTest(idx=idx):
242 with self.assertRaisesRegex(TypeError,
243 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
244 fn(cls(0), cls(0))
245
246 @dataclass(order=True)
247 class C:
248 x: int
249 self.assertLess(C(0), C(1))
250 self.assertLessEqual(C(0), C(1))
251 self.assertLessEqual(C(1), C(1))
252 self.assertGreater(C(1), C(0))
253 self.assertGreaterEqual(C(1), C(0))
254 self.assertGreaterEqual(C(1), C(1))
255
256 def test_simple_compare(self):
257 # Ensure that order=False is the default.
258 @dataclass
259 class C0:
260 x: int
261 y: int
262
263 @dataclass(order=False)
264 class C1:
265 x: int
266 y: int
267
268 for cls in [C0, C1]:
269 with self.subTest(cls=cls):
270 self.assertEqual(cls(0, 0), cls(0, 0))
271 self.assertEqual(cls(1, 2), cls(1, 2))
272 self.assertNotEqual(cls(1, 0), cls(0, 0))
273 self.assertNotEqual(cls(1, 0), cls(1, 1))
274 for idx, fn in enumerate([lambda a, b: a < b,
275 lambda a, b: a <= b,
276 lambda a, b: a > b,
277 lambda a, b: a >= b]):
278 with self.subTest(idx=idx):
279 with self.assertRaisesRegex(TypeError,
280 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
281 fn(cls(0, 0), cls(0, 0))
282
283 @dataclass(order=True)
284 class C:
285 x: int
286 y: int
287
288 for idx, fn in enumerate([lambda a, b: a == b,
289 lambda a, b: a <= b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 self.assertTrue(fn(C(0, 0), C(0, 0)))
293
294 for idx, fn in enumerate([lambda a, b: a < b,
295 lambda a, b: a <= b,
296 lambda a, b: a != b]):
297 with self.subTest(idx=idx):
298 self.assertTrue(fn(C(0, 0), C(0, 1)))
299 self.assertTrue(fn(C(0, 1), C(1, 0)))
300 self.assertTrue(fn(C(1, 0), C(1, 1)))
301
302 for idx, fn in enumerate([lambda a, b: a > b,
303 lambda a, b: a >= b,
304 lambda a, b: a != b]):
305 with self.subTest(idx=idx):
306 self.assertTrue(fn(C(0, 1), C(0, 0)))
307 self.assertTrue(fn(C(1, 0), C(0, 1)))
308 self.assertTrue(fn(C(1, 1), C(1, 0)))
309
310 def test_compare_subclasses(self):
311 # Comparisons fail for subclasses, even if no fields
312 # are added.
313 @dataclass
314 class B:
315 i: int
316
317 @dataclass
318 class C(B):
319 pass
320
321 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
322 (lambda a, b: a != b, True)]):
323 with self.subTest(idx=idx):
324 self.assertEqual(fn(B(0), C(0)), expected)
325
326 for idx, fn in enumerate([lambda a, b: a < b,
327 lambda a, b: a <= b,
328 lambda a, b: a > b,
329 lambda a, b: a >= b]):
330 with self.subTest(idx=idx):
331 with self.assertRaisesRegex(TypeError,
332 "not supported between instances of 'B' and 'C'"):
333 fn(B(0), C(0))
334
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500335 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500336 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500337 for (eq, order, result ) in [
338 (False, False, 'neither'),
339 (False, True, 'exception'),
340 (True, False, 'eq_only'),
341 (True, True, 'both'),
342 ]:
343 with self.subTest(eq=eq, order=order):
344 if result == 'exception':
345 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
346 @dataclass(eq=eq, order=order)
347 class C:
348 pass
349 else:
350 @dataclass(eq=eq, order=order)
351 class C:
352 pass
353
354 if result == 'neither':
355 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500356 self.assertNotIn('__lt__', C.__dict__)
357 self.assertNotIn('__le__', C.__dict__)
358 self.assertNotIn('__gt__', C.__dict__)
359 self.assertNotIn('__ge__', C.__dict__)
360 elif result == 'both':
361 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500362 self.assertIn('__lt__', C.__dict__)
363 self.assertIn('__le__', C.__dict__)
364 self.assertIn('__gt__', C.__dict__)
365 self.assertIn('__ge__', C.__dict__)
366 elif result == 'eq_only':
367 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500368 self.assertNotIn('__lt__', C.__dict__)
369 self.assertNotIn('__le__', C.__dict__)
370 self.assertNotIn('__gt__', C.__dict__)
371 self.assertNotIn('__ge__', C.__dict__)
372 else:
373 assert False, f'unknown result {result!r}'
374
375 def test_field_no_default(self):
376 @dataclass
377 class C:
378 x: int = field()
379
380 self.assertEqual(C(5).x, 5)
381
382 with self.assertRaisesRegex(TypeError,
383 r"__init__\(\) missing 1 required "
384 "positional argument: 'x'"):
385 C()
386
387 def test_field_default(self):
388 default = object()
389 @dataclass
390 class C:
391 x: object = field(default=default)
392
393 self.assertIs(C.x, default)
394 c = C(10)
395 self.assertEqual(c.x, 10)
396
397 # If we delete the instance attribute, we should then see the
398 # class attribute.
399 del c.x
400 self.assertIs(c.x, default)
401
402 self.assertIs(C().x, default)
403
404 def test_not_in_repr(self):
405 @dataclass
406 class C:
407 x: int = field(repr=False)
408 with self.assertRaises(TypeError):
409 C()
410 c = C(10)
411 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
412
413 @dataclass
414 class C:
415 x: int = field(repr=False)
416 y: int
417 c = C(10, 20)
418 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
419
420 def test_not_in_compare(self):
421 @dataclass
422 class C:
423 x: int = 0
424 y: int = field(compare=False, default=4)
425
426 self.assertEqual(C(), C(0, 20))
427 self.assertEqual(C(1, 10), C(1, 20))
428 self.assertNotEqual(C(3), C(4, 10))
429 self.assertNotEqual(C(3, 10), C(4, 10))
430
431 def test_hash_field_rules(self):
432 # Test all 6 cases of:
433 # hash=True/False/None
434 # compare=True/False
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500435 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500436 (True, False, 'field' ),
437 (True, True, 'field' ),
438 (False, False, 'absent'),
439 (False, True, 'absent'),
440 (None, False, 'absent'),
441 (None, True, 'field' ),
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500442 ]:
443 with self.subTest(hash=hash_, compare=compare):
444 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500445 class C:
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500446 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500447
448 if result == 'field':
449 # __hash__ contains the field.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500450 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500451 elif result == 'absent':
452 # The field is not present in the hash.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500453 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500454 else:
455 assert False, f'unknown result {result!r}'
456
457 def test_init_false_no_default(self):
458 # If init=False and no default value, then the field won't be
459 # present in the instance.
460 @dataclass
461 class C:
462 x: int = field(init=False)
463
464 self.assertNotIn('x', C().__dict__)
465
466 @dataclass
467 class C:
468 x: int
469 y: int = 0
470 z: int = field(init=False)
471 t: int = 10
472
473 self.assertNotIn('z', C(0).__dict__)
474 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
475
476 def test_class_marker(self):
477 @dataclass
478 class C:
479 x: int
480 y: str = field(init=False, default=None)
481 z: str = field(repr=False)
482
483 the_fields = fields(C)
484 # the_fields is a tuple of 3 items, each value
485 # is in __annotations__.
486 self.assertIsInstance(the_fields, tuple)
487 for f in the_fields:
488 self.assertIs(type(f), Field)
489 self.assertIn(f.name, C.__annotations__)
490
491 self.assertEqual(len(the_fields), 3)
492
493 self.assertEqual(the_fields[0].name, 'x')
494 self.assertEqual(the_fields[0].type, int)
495 self.assertFalse(hasattr(C, 'x'))
496 self.assertTrue (the_fields[0].init)
497 self.assertTrue (the_fields[0].repr)
498 self.assertEqual(the_fields[1].name, 'y')
499 self.assertEqual(the_fields[1].type, str)
500 self.assertIsNone(getattr(C, 'y'))
501 self.assertFalse(the_fields[1].init)
502 self.assertTrue (the_fields[1].repr)
503 self.assertEqual(the_fields[2].name, 'z')
504 self.assertEqual(the_fields[2].type, str)
505 self.assertFalse(hasattr(C, 'z'))
506 self.assertTrue (the_fields[2].init)
507 self.assertFalse(the_fields[2].repr)
508
509 def test_field_order(self):
510 @dataclass
511 class B:
512 a: str = 'B:a'
513 b: str = 'B:b'
514 c: str = 'B:c'
515
516 @dataclass
517 class C(B):
518 b: str = 'C:b'
519
520 self.assertEqual([(f.name, f.default) for f in fields(C)],
521 [('a', 'B:a'),
522 ('b', 'C:b'),
523 ('c', 'B:c')])
524
525 @dataclass
526 class D(B):
527 c: str = 'D:c'
528
529 self.assertEqual([(f.name, f.default) for f in fields(D)],
530 [('a', 'B:a'),
531 ('b', 'B:b'),
532 ('c', 'D:c')])
533
534 @dataclass
535 class E(D):
536 a: str = 'E:a'
537 d: str = 'E:d'
538
539 self.assertEqual([(f.name, f.default) for f in fields(E)],
540 [('a', 'E:a'),
541 ('b', 'B:b'),
542 ('c', 'D:c'),
543 ('d', 'E:d')])
544
545 def test_class_attrs(self):
546 # We only have a class attribute if a default value is
547 # specified, either directly or via a field with a default.
548 default = object()
549 @dataclass
550 class C:
551 x: int
552 y: int = field(repr=False)
553 z: object = default
554 t: int = field(default=100)
555
556 self.assertFalse(hasattr(C, 'x'))
557 self.assertFalse(hasattr(C, 'y'))
558 self.assertIs (C.z, default)
559 self.assertEqual(C.t, 100)
560
561 def test_disallowed_mutable_defaults(self):
562 # For the known types, don't allow mutable default values.
563 for typ, empty, non_empty in [(list, [], [1]),
564 (dict, {}, {0:1}),
565 (set, set(), set([1])),
566 ]:
567 with self.subTest(typ=typ):
568 # Can't use a zero-length value.
569 with self.assertRaisesRegex(ValueError,
570 f'mutable default {typ} for field '
571 'x is not allowed'):
572 @dataclass
573 class Point:
574 x: typ = empty
575
576
577 # Nor a non-zero-length value
578 with self.assertRaisesRegex(ValueError,
579 f'mutable default {typ} for field '
580 'y is not allowed'):
581 @dataclass
582 class Point:
583 y: typ = non_empty
584
585 # Check subtypes also fail.
586 class Subclass(typ): pass
587
588 with self.assertRaisesRegex(ValueError,
589 f"mutable default .*Subclass'>"
590 ' for field z is not allowed'
591 ):
592 @dataclass
593 class Point:
594 z: typ = Subclass()
595
596 # Because this is a ClassVar, it can be mutable.
597 @dataclass
598 class C:
599 z: ClassVar[typ] = typ()
600
601 # Because this is a ClassVar, it can be mutable.
602 @dataclass
603 class C:
604 x: ClassVar[typ] = Subclass()
605
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500606 def test_deliberately_mutable_defaults(self):
607 # If a mutable default isn't in the known list of
608 # (list, dict, set), then it's okay.
609 class Mutable:
610 def __init__(self):
611 self.l = []
612
613 @dataclass
614 class C:
615 x: Mutable
616
617 # These 2 instances will share this value of x.
618 lst = Mutable()
619 o1 = C(lst)
620 o2 = C(lst)
621 self.assertEqual(o1, o2)
622 o1.x.l.extend([1, 2])
623 self.assertEqual(o1, o2)
624 self.assertEqual(o1.x.l, [1, 2])
625 self.assertIs(o1.x, o2.x)
626
627 def test_no_options(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400628 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500629 @dataclass()
630 class C:
631 x: int
632
633 self.assertEqual(C(42).x, 42)
634
635 def test_not_tuple(self):
636 # Make sure we can't be compared to a tuple.
637 @dataclass
638 class Point:
639 x: int
640 y: int
641 self.assertNotEqual(Point(1, 2), (1, 2))
642
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400643 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500644 @dataclass
645 class C:
646 x: int
647 y: int
648 self.assertNotEqual(Point(1, 3), C(1, 3))
649
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500650 def test_not_tuple(self):
651 # Test that some of the problems with namedtuple don't happen
652 # here.
653 @dataclass
654 class Point3D:
655 x: int
656 y: int
657 z: int
658
659 @dataclass
660 class Date:
661 year: int
662 month: int
663 day: int
664
665 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
666 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
667
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400668 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200669 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500670 x, y, z = Point3D(4, 5, 6)
671
Eric V. Smith7c99e932018-01-28 19:18:55 -0500672 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500673 # equal.
674 @dataclass
675 class Point3Dv1:
676 x: int = 0
677 y: int = 0
678 z: int = 0
679 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
680
681 def test_function_annotations(self):
682 # Some dummy class and instance to use as a default.
683 class F:
684 pass
685 f = F()
686
687 def validate_class(cls):
688 # First, check __annotations__, even though they're not
689 # function annotations.
690 self.assertEqual(cls.__annotations__['i'], int)
691 self.assertEqual(cls.__annotations__['j'], str)
692 self.assertEqual(cls.__annotations__['k'], F)
693 self.assertEqual(cls.__annotations__['l'], float)
694 self.assertEqual(cls.__annotations__['z'], complex)
695
696 # Verify __init__.
697
698 signature = inspect.signature(cls.__init__)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400699 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500700 self.assertIs(signature.return_annotation, None)
701
702 # Check each parameter.
703 params = iter(signature.parameters.values())
704 param = next(params)
705 # This is testing an internal name, and probably shouldn't be tested.
706 self.assertEqual(param.name, 'self')
707 param = next(params)
708 self.assertEqual(param.name, 'i')
709 self.assertIs (param.annotation, int)
710 self.assertEqual(param.default, inspect.Parameter.empty)
711 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
712 param = next(params)
713 self.assertEqual(param.name, 'j')
714 self.assertIs (param.annotation, str)
715 self.assertEqual(param.default, inspect.Parameter.empty)
716 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
717 param = next(params)
718 self.assertEqual(param.name, 'k')
719 self.assertIs (param.annotation, F)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400720 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500721 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
722 param = next(params)
723 self.assertEqual(param.name, 'l')
724 self.assertIs (param.annotation, float)
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400725 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500726 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
727 self.assertRaises(StopIteration, next, params)
728
729
730 @dataclass
731 class C:
732 i: int
733 j: str
734 k: F = f
735 l: float=field(default=None)
736 z: complex=field(default=3+4j, init=False)
737
738 validate_class(C)
739
740 # Now repeat with __hash__.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -0500741 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 class C:
743 i: int
744 j: str
745 k: F = f
746 l: float=field(default=None)
747 z: complex=field(default=3+4j, init=False)
748
749 validate_class(C)
750
Eric V. Smith03220fd2017-12-29 13:59:58 -0500751 def test_missing_default(self):
752 # Test that MISSING works the same as a default not being
753 # specified.
754 @dataclass
755 class C:
756 x: int=field(default=MISSING)
757 with self.assertRaisesRegex(TypeError,
758 r'__init__\(\) missing 1 required '
759 'positional argument'):
760 C()
761 self.assertNotIn('x', C.__dict__)
762
763 @dataclass
764 class D:
765 x: int
766 with self.assertRaisesRegex(TypeError,
767 r'__init__\(\) missing 1 required '
768 'positional argument'):
769 D()
770 self.assertNotIn('x', D.__dict__)
771
772 def test_missing_default_factory(self):
773 # Test that MISSING works the same as a default factory not
774 # being specified (which is really the same as a default not
775 # being specified, too).
776 @dataclass
777 class C:
778 x: int=field(default_factory=MISSING)
779 with self.assertRaisesRegex(TypeError,
780 r'__init__\(\) missing 1 required '
781 'positional argument'):
782 C()
783 self.assertNotIn('x', C.__dict__)
784
785 @dataclass
786 class D:
787 x: int=field(default=MISSING, default_factory=MISSING)
788 with self.assertRaisesRegex(TypeError,
789 r'__init__\(\) missing 1 required '
790 'positional argument'):
791 D()
792 self.assertNotIn('x', D.__dict__)
793
794 def test_missing_repr(self):
795 self.assertIn('MISSING_TYPE object', repr(MISSING))
796
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500797 def test_dont_include_other_annotations(self):
798 @dataclass
799 class C:
800 i: int
801 def foo(self) -> int:
802 return 4
803 @property
804 def bar(self) -> int:
805 return 5
806 self.assertEqual(list(C.__annotations__), ['i'])
807 self.assertEqual(C(10).foo(), 4)
808 self.assertEqual(C(10).bar, 5)
Eric V. Smith51c9ab42018-03-25 09:04:32 -0400809 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500810
811 def test_post_init(self):
812 # Just make sure it gets called
813 @dataclass
814 class C:
815 def __post_init__(self):
816 raise CustomError()
817 with self.assertRaises(CustomError):
818 C()
819
820 @dataclass
821 class C:
822 i: int = 10
823 def __post_init__(self):
824 if self.i == 10:
825 raise CustomError()
826 with self.assertRaises(CustomError):
827 C()
828 # post-init gets called, but doesn't raise. This is just
829 # checking that self is used correctly.
830 C(5)
831
832 # If there's not an __init__, then post-init won't get called.
833 @dataclass(init=False)
834 class C:
835 def __post_init__(self):
836 raise CustomError()
837 # Creating the class won't raise
838 C()
839
840 @dataclass
841 class C:
842 x: int = 0
843 def __post_init__(self):
844 self.x *= 2
845 self.assertEqual(C().x, 0)
846 self.assertEqual(C(2).x, 4)
847
Mike53f7a7c2017-12-14 14:04:53 +0300848 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500849 # attributes.
850 @dataclass(frozen=True)
851 class C:
852 x: int = 0
853 def __post_init__(self):
854 self.x *= 2
855 with self.assertRaises(FrozenInstanceError):
856 C()
857
858 def test_post_init_super(self):
859 # Make sure super() post-init isn't called by default.
860 class B:
861 def __post_init__(self):
862 raise CustomError()
863
864 @dataclass
865 class C(B):
866 def __post_init__(self):
867 self.x = 5
868
869 self.assertEqual(C().x, 5)
870
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400871 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500872 @dataclass
873 class C(B):
874 def __post_init__(self):
875 super().__post_init__()
876
877 with self.assertRaises(CustomError):
878 C()
879
880 # Make sure post-init is called, even if not defined in our
881 # class.
882 @dataclass
883 class C(B):
884 pass
885
886 with self.assertRaises(CustomError):
887 C()
888
889 def test_post_init_staticmethod(self):
890 flag = False
891 @dataclass
892 class C:
893 x: int
894 y: int
895 @staticmethod
896 def __post_init__():
897 nonlocal flag
898 flag = True
899
900 self.assertFalse(flag)
901 c = C(3, 4)
902 self.assertEqual((c.x, c.y), (3, 4))
903 self.assertTrue(flag)
904
905 def test_post_init_classmethod(self):
906 @dataclass
907 class C:
908 flag = False
909 x: int
910 y: int
911 @classmethod
912 def __post_init__(cls):
913 cls.flag = True
914
915 self.assertFalse(C.flag)
916 c = C(3, 4)
917 self.assertEqual((c.x, c.y), (3, 4))
918 self.assertTrue(C.flag)
919
920 def test_class_var(self):
921 # Make sure ClassVars are ignored in __init__, __repr__, etc.
922 @dataclass
923 class C:
924 x: int
925 y: int = 10
926 z: ClassVar[int] = 1000
927 w: ClassVar[int] = 2000
928 t: ClassVar[int] = 3000
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400929 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500930
931 c = C(5)
932 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Eric V. Smith2b75fc22018-03-25 20:37:33 -0400933 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400934 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500935 self.assertEqual(c.z, 1000)
936 self.assertEqual(c.w, 2000)
937 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400938 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500939 C.z += 1
940 self.assertEqual(c.z, 1001)
941 c = C(20)
942 self.assertEqual((c.x, c.y), (20, 10))
943 self.assertEqual(c.z, 1001)
944 self.assertEqual(c.w, 2000)
945 self.assertEqual(c.t, 3000)
Eric V. Smith2a7bacb2018-05-15 22:44:27 -0400946 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500947
948 def test_class_var_no_default(self):
949 # If a ClassVar has no default value, it should not be set on the class.
950 @dataclass
951 class C:
952 x: ClassVar[int]
953
954 self.assertNotIn('x', C.__dict__)
955
956 def test_class_var_default_factory(self):
957 # It makes no sense for a ClassVar to have a default factory. When
958 # would it be called? Call it yourself, since it's class-wide.
959 with self.assertRaisesRegex(TypeError,
960 'cannot have a default factory'):
961 @dataclass
962 class C:
963 x: ClassVar[int] = field(default_factory=int)
964
965 self.assertNotIn('x', C.__dict__)
966
967 def test_class_var_with_default(self):
968 # If a ClassVar has a default value, it should be set on the class.
969 @dataclass
970 class C:
971 x: ClassVar[int] = 10
972 self.assertEqual(C.x, 10)
973
974 @dataclass
975 class C:
976 x: ClassVar[int] = field(default=10)
977 self.assertEqual(C.x, 10)
978
979 def test_class_var_frozen(self):
980 # Make sure ClassVars work even if we're frozen.
981 @dataclass(frozen=True)
982 class C:
983 x: int
984 y: int = 10
985 z: ClassVar[int] = 1000
986 w: ClassVar[int] = 2000
987 t: ClassVar[int] = 3000
988
989 c = C(5)
990 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
991 self.assertEqual(len(fields(C)), 2) # We have 2 fields
992 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
993 self.assertEqual(c.z, 1000)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
996 # We can still modify the ClassVar, it's only instances that are
997 # frozen.
998 C.z += 1
999 self.assertEqual(c.z, 1001)
1000 c = C(20)
1001 self.assertEqual((c.x, c.y), (20, 10))
1002 self.assertEqual(c.z, 1001)
1003 self.assertEqual(c.w, 2000)
1004 self.assertEqual(c.t, 3000)
1005
1006 def test_init_var_no_default(self):
1007 # If an InitVar has no default value, it should not be set on the class.
1008 @dataclass
1009 class C:
1010 x: InitVar[int]
1011
1012 self.assertNotIn('x', C.__dict__)
1013
1014 def test_init_var_default_factory(self):
1015 # It makes no sense for an InitVar to have a default factory. When
1016 # would it be called? Call it yourself, since it's class-wide.
1017 with self.assertRaisesRegex(TypeError,
1018 'cannot have a default factory'):
1019 @dataclass
1020 class C:
1021 x: InitVar[int] = field(default_factory=int)
1022
1023 self.assertNotIn('x', C.__dict__)
1024
1025 def test_init_var_with_default(self):
1026 # If an InitVar has a default value, it should be set on the class.
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = 10
1030 self.assertEqual(C.x, 10)
1031
1032 @dataclass
1033 class C:
1034 x: InitVar[int] = field(default=10)
1035 self.assertEqual(C.x, 10)
1036
1037 def test_init_var(self):
1038 @dataclass
1039 class C:
1040 x: int = None
1041 init_param: InitVar[int] = None
1042
1043 def __post_init__(self, init_param):
1044 if self.x is None:
1045 self.x = init_param*2
1046
1047 c = C(init_param=10)
1048 self.assertEqual(c.x, 20)
1049
1050 def test_init_var_inheritance(self):
1051 # Note that this deliberately tests that a dataclass need not
1052 # have a __post_init__ function if it has an InitVar field.
1053 # It could just be used in a derived class, as shown here.
1054 @dataclass
1055 class Base:
1056 x: int
1057 init_base: InitVar[int]
1058
1059 # We can instantiate by passing the InitVar, even though
1060 # it's not used.
1061 b = Base(0, 10)
1062 self.assertEqual(vars(b), {'x': 0})
1063
1064 @dataclass
1065 class C(Base):
1066 y: int
1067 init_derived: InitVar[int]
1068
1069 def __post_init__(self, init_base, init_derived):
1070 self.x = self.x + init_base
1071 self.y = self.y + init_derived
1072
1073 c = C(10, 11, 50, 51)
1074 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1075
1076 def test_default_factory(self):
1077 # Test a factory that returns a new list.
1078 @dataclass
1079 class C:
1080 x: int
1081 y: list = field(default_factory=list)
1082
1083 c0 = C(3)
1084 c1 = C(3)
1085 self.assertEqual(c0.x, 3)
1086 self.assertEqual(c0.y, [])
1087 self.assertEqual(c0, c1)
1088 self.assertIsNot(c0.y, c1.y)
1089 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1090
1091 # Test a factory that returns a shared list.
1092 l = []
1093 @dataclass
1094 class C:
1095 x: int
1096 y: list = field(default_factory=lambda: l)
1097
1098 c0 = C(3)
1099 c1 = C(3)
1100 self.assertEqual(c0.x, 3)
1101 self.assertEqual(c0.y, [])
1102 self.assertEqual(c0, c1)
1103 self.assertIs(c0.y, c1.y)
1104 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1105
1106 # Test various other field flags.
1107 # repr
1108 @dataclass
1109 class C:
1110 x: list = field(default_factory=list, repr=False)
1111 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1112 self.assertEqual(C().x, [])
1113
1114 # hash
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05001115 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001116 class C:
1117 x: list = field(default_factory=list, hash=False)
1118 self.assertEqual(astuple(C()), ([],))
1119 self.assertEqual(hash(C()), hash(()))
1120
1121 # init (see also test_default_factory_with_no_init)
1122 @dataclass
1123 class C:
1124 x: list = field(default_factory=list, init=False)
1125 self.assertEqual(astuple(C()), ([],))
1126
1127 # compare
1128 @dataclass
1129 class C:
1130 x: list = field(default_factory=list, compare=False)
1131 self.assertEqual(C(), C([1]))
1132
1133 def test_default_factory_with_no_init(self):
1134 # We need a factory with a side effect.
1135 factory = Mock()
1136
1137 @dataclass
1138 class C:
1139 x: list = field(default_factory=factory, init=False)
1140
1141 # Make sure the default factory is called for each new instance.
1142 C().x
1143 self.assertEqual(factory.call_count, 1)
1144 C().x
1145 self.assertEqual(factory.call_count, 2)
1146
1147 def test_default_factory_not_called_if_value_given(self):
1148 # We need a factory that we can test if it's been called.
1149 factory = Mock()
1150
1151 @dataclass
1152 class C:
1153 x: int = field(default_factory=factory)
1154
1155 # Make sure that if a field has a default factory function,
1156 # it's not called if a value is specified.
1157 C().x
1158 self.assertEqual(factory.call_count, 1)
1159 self.assertEqual(C(10).x, 10)
1160 self.assertEqual(factory.call_count, 1)
1161 C().x
1162 self.assertEqual(factory.call_count, 2)
1163
Eric V. Smith8f6eccd2018-03-20 22:00:23 -04001164 def test_default_factory_derived(self):
1165 # See bpo-32896.
1166 @dataclass
1167 class Foo:
1168 x: dict = field(default_factory=dict)
1169
1170 @dataclass
1171 class Bar(Foo):
1172 y: int = 1
1173
1174 self.assertEqual(Foo().x, {})
1175 self.assertEqual(Bar().x, {})
1176 self.assertEqual(Bar().y, 1)
1177
1178 @dataclass
1179 class Baz(Foo):
1180 pass
1181 self.assertEqual(Baz().x, {})
1182
1183 def test_intermediate_non_dataclass(self):
1184 # Test that an intermediate class that defines
1185 # annotations does not define fields.
1186
1187 @dataclass
1188 class A:
1189 x: int
1190
1191 class B(A):
1192 y: int
1193
1194 @dataclass
1195 class C(B):
1196 z: int
1197
1198 c = C(1, 3)
1199 self.assertEqual((c.x, c.z), (1, 3))
1200
1201 # .y was not initialized.
1202 with self.assertRaisesRegex(AttributeError,
1203 'object has no attribute'):
1204 c.y
1205
1206 # And if we again derive a non-dataclass, no fields are added.
1207 class D(C):
1208 t: int
1209 d = D(4, 5)
1210 self.assertEqual((d.x, d.z), (4, 5))
1211
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001212 def test_classvar_default_factory(self):
1213 # It's an error for a ClassVar to have a factory function.
1214 with self.assertRaisesRegex(TypeError,
1215 'cannot have a default factory'):
1216 @dataclass
1217 class C:
1218 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001219
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001220 def test_is_dataclass(self):
1221 class NotDataClass:
1222 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001223
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001224 self.assertFalse(is_dataclass(0))
1225 self.assertFalse(is_dataclass(int))
1226 self.assertFalse(is_dataclass(NotDataClass))
1227 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001228
1229 @dataclass
1230 class C:
1231 x: int
1232
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001233 @dataclass
1234 class D:
1235 d: C
1236 e: int
1237
1238 c = C(10)
1239 d = D(c, 4)
1240
1241 self.assertTrue(is_dataclass(C))
1242 self.assertTrue(is_dataclass(c))
1243 self.assertFalse(is_dataclass(c.x))
1244 self.assertTrue(is_dataclass(d.d))
1245 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001246
1247 def test_helper_fields_with_class_instance(self):
1248 # Check that we can call fields() on either a class or instance,
1249 # and get back the same thing.
1250 @dataclass
1251 class C:
1252 x: int
1253 y: float
1254
1255 self.assertEqual(fields(C), fields(C(0, 0.0)))
1256
1257 def test_helper_fields_exception(self):
1258 # Check that TypeError is raised if not passed a dataclass or
1259 # instance.
1260 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1261 fields(0)
1262
1263 class C: pass
1264 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1265 fields(C)
1266 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1267 fields(C())
1268
1269 def test_helper_asdict(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001270 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001271 @dataclass
1272 class C:
1273 x: int
1274 y: int
1275 c = C(1, 2)
1276
1277 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1278 self.assertEqual(asdict(c), asdict(c))
1279 self.assertIsNot(asdict(c), asdict(c))
1280 c.x = 42
1281 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1282 self.assertIs(type(asdict(c)), dict)
1283
1284 def test_helper_asdict_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001285 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001286 @dataclass
1287 class C:
1288 x: int
1289 y: int
1290 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1291 asdict(C)
1292 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1293 asdict(int)
1294
1295 def test_helper_asdict_copy_values(self):
1296 @dataclass
1297 class C:
1298 x: int
1299 y: List[int] = field(default_factory=list)
1300 initial = []
1301 c = C(1, initial)
1302 d = asdict(c)
1303 self.assertEqual(d['y'], initial)
1304 self.assertIsNot(d['y'], initial)
1305 c = C(1)
1306 d = asdict(c)
1307 d['y'].append(1)
1308 self.assertEqual(c.y, [])
1309
1310 def test_helper_asdict_nested(self):
1311 @dataclass
1312 class UserId:
1313 token: int
1314 group: int
1315 @dataclass
1316 class User:
1317 name: str
1318 id: UserId
1319 u = User('Joe', UserId(123, 1))
1320 d = asdict(u)
1321 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1322 self.assertIsNot(asdict(u), asdict(u))
1323 u.id.group = 2
1324 self.assertEqual(asdict(u), {'name': 'Joe',
1325 'id': {'token': 123, 'group': 2}})
1326
1327 def test_helper_asdict_builtin_containers(self):
1328 @dataclass
1329 class User:
1330 name: str
1331 id: int
1332 @dataclass
1333 class GroupList:
1334 id: int
1335 users: List[User]
1336 @dataclass
1337 class GroupTuple:
1338 id: int
1339 users: Tuple[User, ...]
1340 @dataclass
1341 class GroupDict:
1342 id: int
1343 users: Dict[str, User]
1344 a = User('Alice', 1)
1345 b = User('Bob', 2)
1346 gl = GroupList(0, [a, b])
1347 gt = GroupTuple(0, (a, b))
1348 gd = GroupDict(0, {'first': a, 'second': b})
1349 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1350 {'name': 'Bob', 'id': 2}]})
1351 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1352 {'name': 'Bob', 'id': 2})})
1353 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1354 'second': {'name': 'Bob', 'id': 2}}})
1355
1356 def test_helper_asdict_builtin_containers(self):
1357 @dataclass
1358 class Child:
1359 d: object
1360
1361 @dataclass
1362 class Parent:
1363 child: Child
1364
1365 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1366 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1367
1368 def test_helper_asdict_factory(self):
1369 @dataclass
1370 class C:
1371 x: int
1372 y: int
1373 c = C(1, 2)
1374 d = asdict(c, dict_factory=OrderedDict)
1375 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1376 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1377 c.x = 42
1378 d = asdict(c, dict_factory=OrderedDict)
1379 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1380 self.assertIs(type(d), OrderedDict)
1381
1382 def test_helper_astuple(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001383 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001384 @dataclass
1385 class C:
1386 x: int
1387 y: int = 0
1388 c = C(1)
1389
1390 self.assertEqual(astuple(c), (1, 0))
1391 self.assertEqual(astuple(c), astuple(c))
1392 self.assertIsNot(astuple(c), astuple(c))
1393 c.y = 42
1394 self.assertEqual(astuple(c), (1, 42))
1395 self.assertIs(type(astuple(c)), tuple)
1396
1397 def test_helper_astuple_raises_on_classes(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001398 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001399 @dataclass
1400 class C:
1401 x: int
1402 y: int
1403 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1404 astuple(C)
1405 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1406 astuple(int)
1407
1408 def test_helper_astuple_copy_values(self):
1409 @dataclass
1410 class C:
1411 x: int
1412 y: List[int] = field(default_factory=list)
1413 initial = []
1414 c = C(1, initial)
1415 t = astuple(c)
1416 self.assertEqual(t[1], initial)
1417 self.assertIsNot(t[1], initial)
1418 c = C(1)
1419 t = astuple(c)
1420 t[1].append(1)
1421 self.assertEqual(c.y, [])
1422
1423 def test_helper_astuple_nested(self):
1424 @dataclass
1425 class UserId:
1426 token: int
1427 group: int
1428 @dataclass
1429 class User:
1430 name: str
1431 id: UserId
1432 u = User('Joe', UserId(123, 1))
1433 t = astuple(u)
1434 self.assertEqual(t, ('Joe', (123, 1)))
1435 self.assertIsNot(astuple(u), astuple(u))
1436 u.id.group = 2
1437 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1438
1439 def test_helper_astuple_builtin_containers(self):
1440 @dataclass
1441 class User:
1442 name: str
1443 id: int
1444 @dataclass
1445 class GroupList:
1446 id: int
1447 users: List[User]
1448 @dataclass
1449 class GroupTuple:
1450 id: int
1451 users: Tuple[User, ...]
1452 @dataclass
1453 class GroupDict:
1454 id: int
1455 users: Dict[str, User]
1456 a = User('Alice', 1)
1457 b = User('Bob', 2)
1458 gl = GroupList(0, [a, b])
1459 gt = GroupTuple(0, (a, b))
1460 gd = GroupDict(0, {'first': a, 'second': b})
1461 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1462 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1463 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1464
1465 def test_helper_astuple_builtin_containers(self):
1466 @dataclass
1467 class Child:
1468 d: object
1469
1470 @dataclass
1471 class Parent:
1472 child: Child
1473
1474 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1475 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1476
1477 def test_helper_astuple_factory(self):
1478 @dataclass
1479 class C:
1480 x: int
1481 y: int
1482 NT = namedtuple('NT', 'x y')
1483 def nt(lst):
1484 return NT(*lst)
1485 c = C(1, 2)
1486 t = astuple(c, tuple_factory=nt)
1487 self.assertEqual(t, NT(1, 2))
1488 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1489 c.x = 42
1490 t = astuple(c, tuple_factory=nt)
1491 self.assertEqual(t, NT(42, 2))
1492 self.assertIs(type(t), NT)
1493
1494 def test_dynamic_class_creation(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001495 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001496 }
1497
1498 # Create the class.
1499 cls = type('C', (), cls_dict)
1500
1501 # Make it a dataclass.
1502 cls1 = dataclass(cls)
1503
1504 self.assertEqual(cls1, cls)
1505 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1506
1507 def test_dynamic_class_creation_using_field(self):
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001508 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001509 'y': field(default=5),
1510 }
1511
1512 # Create the class.
1513 cls = type('C', (), cls_dict)
1514
1515 # Make it a dataclass.
1516 cls1 = dataclass(cls)
1517
1518 self.assertEqual(cls1, cls)
1519 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1520
1521 def test_init_in_order(self):
1522 @dataclass
1523 class C:
1524 a: int
1525 b: int = field()
1526 c: list = field(default_factory=list, init=False)
1527 d: list = field(default_factory=list)
1528 e: int = field(default=4, init=False)
1529 f: int = 4
1530
1531 calls = []
1532 def setattr(self, name, value):
1533 calls.append((name, value))
1534
1535 C.__setattr__ = setattr
1536 c = C(0, 1)
1537 self.assertEqual(('a', 0), calls[0])
1538 self.assertEqual(('b', 1), calls[1])
1539 self.assertEqual(('c', []), calls[2])
1540 self.assertEqual(('d', []), calls[3])
1541 self.assertNotIn(('e', 4), calls)
1542 self.assertEqual(('f', 4), calls[4])
1543
1544 def test_items_in_dicts(self):
1545 @dataclass
1546 class C:
1547 a: int
1548 b: list = field(default_factory=list, init=False)
1549 c: list = field(default_factory=list)
1550 d: int = field(default=4, init=False)
1551 e: int = 0
1552
1553 c = C(0)
1554 # Class dict
1555 self.assertNotIn('a', C.__dict__)
1556 self.assertNotIn('b', C.__dict__)
1557 self.assertNotIn('c', C.__dict__)
1558 self.assertIn('d', C.__dict__)
1559 self.assertEqual(C.d, 4)
1560 self.assertIn('e', C.__dict__)
1561 self.assertEqual(C.e, 0)
1562 # Instance dict
1563 self.assertIn('a', c.__dict__)
1564 self.assertEqual(c.a, 0)
1565 self.assertIn('b', c.__dict__)
1566 self.assertEqual(c.b, [])
1567 self.assertIn('c', c.__dict__)
1568 self.assertEqual(c.c, [])
1569 self.assertNotIn('d', c.__dict__)
1570 self.assertIn('e', c.__dict__)
1571 self.assertEqual(c.e, 0)
1572
1573 def test_alternate_classmethod_constructor(self):
1574 # Since __post_init__ can't take params, use a classmethod
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001575 # alternate constructor. This is mostly an example to show
1576 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001577 @dataclass
1578 class C:
1579 x: int
1580 @classmethod
1581 def from_file(cls, filename):
1582 # In a real example, create a new instance
1583 # and populate 'x' from contents of a file.
1584 value_in_file = 20
1585 return cls(value_in_file)
1586
1587 self.assertEqual(C.from_file('filename').x, 20)
1588
1589 def test_field_metadata_default(self):
1590 # Make sure the default metadata is read-only and of
1591 # zero length.
1592 @dataclass
1593 class C:
1594 i: int
1595
1596 self.assertFalse(fields(C)[0].metadata)
1597 self.assertEqual(len(fields(C)[0].metadata), 0)
1598 with self.assertRaisesRegex(TypeError,
1599 'does not support item assignment'):
1600 fields(C)[0].metadata['test'] = 3
1601
1602 def test_field_metadata_mapping(self):
1603 # Make sure only a mapping can be passed as metadata
1604 # zero length.
1605 with self.assertRaises(TypeError):
1606 @dataclass
1607 class C:
1608 i: int = field(metadata=0)
1609
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001610 # Make sure an empty dict works.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001611 @dataclass
1612 class C:
1613 i: int = field(metadata={})
1614 self.assertFalse(fields(C)[0].metadata)
1615 self.assertEqual(len(fields(C)[0].metadata), 0)
1616 with self.assertRaisesRegex(TypeError,
1617 'does not support item assignment'):
1618 fields(C)[0].metadata['test'] = 3
1619
1620 # Make sure a non-empty dict works.
1621 @dataclass
1622 class C:
1623 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1624 self.assertEqual(len(fields(C)[0].metadata), 3)
1625 self.assertEqual(fields(C)[0].metadata['test'], 10)
1626 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1627 self.assertEqual(fields(C)[0].metadata[3], 'three')
1628 with self.assertRaises(KeyError):
1629 # Non-existent key.
1630 fields(C)[0].metadata['baz']
1631 with self.assertRaisesRegex(TypeError,
1632 'does not support item assignment'):
1633 fields(C)[0].metadata['test'] = 3
1634
1635 def test_field_metadata_custom_mapping(self):
1636 # Try a custom mapping.
1637 class SimpleNameSpace:
1638 def __init__(self, **kw):
1639 self.__dict__.update(kw)
1640
1641 def __getitem__(self, item):
1642 if item == 'xyzzy':
1643 return 'plugh'
1644 return getattr(self, item)
1645
1646 def __len__(self):
1647 return self.__dict__.__len__()
1648
1649 @dataclass
1650 class C:
1651 i: int = field(metadata=SimpleNameSpace(a=10))
1652
1653 self.assertEqual(len(fields(C)[0].metadata), 1)
1654 self.assertEqual(fields(C)[0].metadata['a'], 10)
1655 with self.assertRaises(AttributeError):
1656 fields(C)[0].metadata['b']
1657 # Make sure we're still talking to our custom mapping.
1658 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1659
1660 def test_generic_dataclasses(self):
1661 T = TypeVar('T')
1662
1663 @dataclass
1664 class LabeledBox(Generic[T]):
1665 content: T
1666 label: str = '<unknown>'
1667
1668 box = LabeledBox(42)
1669 self.assertEqual(box.content, 42)
1670 self.assertEqual(box.label, '<unknown>')
1671
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001672 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001673 Alias = List[LabeledBox[int]]
1674
1675 def test_generic_extending(self):
1676 S = TypeVar('S')
1677 T = TypeVar('T')
1678
1679 @dataclass
1680 class Base(Generic[T, S]):
1681 x: T
1682 y: S
1683
1684 @dataclass
1685 class DataDerived(Base[int, T]):
1686 new_field: str
1687 Alias = DataDerived[str]
1688 c = Alias(0, 'test1', 'test2')
1689 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1690
1691 class NonDataDerived(Base[int, T]):
1692 def new_method(self):
1693 return self.y
1694 Alias = NonDataDerived[float]
1695 c = Alias(10, 1.0)
1696 self.assertEqual(c.new_method(), 1.0)
1697
Ivan Levkivskyi5a7092d2018-03-31 13:41:17 +01001698 def test_generic_dynamic(self):
1699 T = TypeVar('T')
1700
1701 @dataclass
1702 class Parent(Generic[T]):
1703 x: T
1704 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1705 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1706 self.assertIs(Child[int](1, 2).z, None)
1707 self.assertEqual(Child[int](1, 2, 3).z, 3)
1708 self.assertEqual(Child[int](1, 2, 3).other, 42)
1709 # Check that type aliases work correctly.
1710 Alias = Child[T]
1711 self.assertEqual(Alias[int](1, 2).x, 1)
1712 # Check MRO resolution.
1713 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1714
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001715 def test_dataclassses_pickleable(self):
1716 global P, Q, R
1717 @dataclass
1718 class P:
1719 x: int
1720 y: int = 0
1721 @dataclass
1722 class Q:
1723 x: int
1724 y: int = field(default=0, init=False)
1725 @dataclass
1726 class R:
1727 x: int
1728 y: List[int] = field(default_factory=list)
1729 q = Q(1)
1730 q.y = 2
1731 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1732 for sample in samples:
1733 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1734 with self.subTest(sample=sample, proto=proto):
1735 new_sample = pickle.loads(pickle.dumps(sample, proto))
1736 self.assertEqual(sample.x, new_sample.x)
1737 self.assertEqual(sample.y, new_sample.y)
1738 self.assertIsNot(sample, new_sample)
1739 new_sample.x = 42
1740 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1741 self.assertEqual(new_sample.x, another_new_sample.x)
1742 self.assertEqual(sample.y, another_new_sample.y)
1743
Eric V. Smithea8fc522018-01-27 19:07:40 -05001744
Eric V. Smith56970b82018-03-22 16:28:48 -04001745class TestFieldNoAnnotation(unittest.TestCase):
1746 def test_field_without_annotation(self):
1747 with self.assertRaisesRegex(TypeError,
1748 "'f' is a field but has no type annotation"):
1749 @dataclass
1750 class C:
1751 f = field()
1752
1753 def test_field_without_annotation_but_annotation_in_base(self):
1754 @dataclass
1755 class B:
1756 f: int
1757
1758 with self.assertRaisesRegex(TypeError,
1759 "'f' is a field but has no type annotation"):
1760 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001761 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001762 @dataclass
1763 class C(B):
1764 f = field()
1765
1766 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1767 # Same test, but with the base class not a dataclass.
1768 class B:
1769 f: int
1770
1771 with self.assertRaisesRegex(TypeError,
1772 "'f' is a field but has no type annotation"):
1773 # This is still an error: make sure we don't pick up the
Eric V. Smith2b75fc22018-03-25 20:37:33 -04001774 # type annotation in the base class.
Eric V. Smith56970b82018-03-22 16:28:48 -04001775 @dataclass
1776 class C(B):
1777 f = field()
1778
1779
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001780class TestDocString(unittest.TestCase):
1781 def assertDocStrEqual(self, a, b):
1782 # Because 3.6 and 3.7 differ in how inspect.signature work
1783 # (see bpo #32108), for the time being just compare them with
1784 # whitespace stripped.
1785 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1786
1787 def test_existing_docstring_not_overridden(self):
1788 @dataclass
1789 class C:
1790 """Lorem ipsum"""
1791 x: int
1792
1793 self.assertEqual(C.__doc__, "Lorem ipsum")
1794
1795 def test_docstring_no_fields(self):
1796 @dataclass
1797 class C:
1798 pass
1799
1800 self.assertDocStrEqual(C.__doc__, "C()")
1801
1802 def test_docstring_one_field(self):
1803 @dataclass
1804 class C:
1805 x: int
1806
1807 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1808
1809 def test_docstring_two_fields(self):
1810 @dataclass
1811 class C:
1812 x: int
1813 y: int
1814
1815 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1816
1817 def test_docstring_three_fields(self):
1818 @dataclass
1819 class C:
1820 x: int
1821 y: int
1822 z: str
1823
1824 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1825
1826 def test_docstring_one_field_with_default(self):
1827 @dataclass
1828 class C:
1829 x: int = 3
1830
1831 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1832
1833 def test_docstring_one_field_with_default_none(self):
1834 @dataclass
1835 class C:
1836 x: Union[int, type(None)] = None
1837
1838 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1839
1840 def test_docstring_list_field(self):
1841 @dataclass
1842 class C:
1843 x: List[int]
1844
1845 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1846
1847 def test_docstring_list_field_with_default_factory(self):
1848 @dataclass
1849 class C:
1850 x: List[int] = field(default_factory=list)
1851
1852 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1853
1854 def test_docstring_deque_field(self):
1855 @dataclass
1856 class C:
1857 x: deque
1858
1859 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1860
1861 def test_docstring_deque_field_with_default_factory(self):
1862 @dataclass
1863 class C:
1864 x: deque = field(default_factory=deque)
1865
1866 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1867
1868
Eric V. Smithea8fc522018-01-27 19:07:40 -05001869class TestInit(unittest.TestCase):
1870 def test_base_has_init(self):
1871 class B:
1872 def __init__(self):
1873 self.z = 100
1874 pass
1875
1876 # Make sure that declaring this class doesn't raise an error.
1877 # The issue is that we can't override __init__ in our class,
1878 # but it should be okay to add __init__ to us if our base has
1879 # an __init__.
1880 @dataclass
1881 class C(B):
1882 x: int = 0
1883 c = C(10)
1884 self.assertEqual(c.x, 10)
1885 self.assertNotIn('z', vars(c))
1886
1887 # Make sure that if we don't add an init, the base __init__
1888 # gets called.
1889 @dataclass(init=False)
1890 class C(B):
1891 x: int = 10
1892 c = C()
1893 self.assertEqual(c.x, 10)
1894 self.assertEqual(c.z, 100)
1895
1896 def test_no_init(self):
1897 dataclass(init=False)
1898 class C:
1899 i: int = 0
1900 self.assertEqual(C().i, 0)
1901
1902 dataclass(init=False)
1903 class C:
1904 i: int = 2
1905 def __init__(self):
1906 self.i = 3
1907 self.assertEqual(C().i, 3)
1908
1909 def test_overwriting_init(self):
1910 # If the class has __init__, use it no matter the value of
1911 # init=.
1912
1913 @dataclass
1914 class C:
1915 x: int
1916 def __init__(self, x):
1917 self.x = 2 * x
1918 self.assertEqual(C(3).x, 6)
1919
1920 @dataclass(init=True)
1921 class C:
1922 x: int
1923 def __init__(self, x):
1924 self.x = 2 * x
1925 self.assertEqual(C(4).x, 8)
1926
1927 @dataclass(init=False)
1928 class C:
1929 x: int
1930 def __init__(self, x):
1931 self.x = 2 * x
1932 self.assertEqual(C(5).x, 10)
1933
1934
1935class TestRepr(unittest.TestCase):
1936 def test_repr(self):
1937 @dataclass
1938 class B:
1939 x: int
1940
1941 @dataclass
1942 class C(B):
1943 y: int = 10
1944
1945 o = C(4)
1946 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
1947
1948 @dataclass
1949 class D(C):
1950 x: int = 20
1951 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
1952
1953 @dataclass
1954 class C:
1955 @dataclass
1956 class D:
1957 i: int
1958 @dataclass
1959 class E:
1960 pass
1961 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
1962 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
1963
1964 def test_no_repr(self):
1965 # Test a class with no __repr__ and repr=False.
1966 @dataclass(repr=False)
1967 class C:
1968 x: int
1969 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
1970 repr(C(3)))
1971
1972 # Test a class with a __repr__ and repr=False.
1973 @dataclass(repr=False)
1974 class C:
1975 x: int
1976 def __repr__(self):
1977 return 'C-class'
1978 self.assertEqual(repr(C(3)), 'C-class')
1979
1980 def test_overwriting_repr(self):
1981 # If the class has __repr__, use it no matter the value of
1982 # repr=.
1983
1984 @dataclass
1985 class C:
1986 x: int
1987 def __repr__(self):
1988 return 'x'
1989 self.assertEqual(repr(C(0)), 'x')
1990
1991 @dataclass(repr=True)
1992 class C:
1993 x: int
1994 def __repr__(self):
1995 return 'x'
1996 self.assertEqual(repr(C(0)), 'x')
1997
1998 @dataclass(repr=False)
1999 class C:
2000 x: int
2001 def __repr__(self):
2002 return 'x'
2003 self.assertEqual(repr(C(0)), 'x')
2004
2005
Eric V. Smithea8fc522018-01-27 19:07:40 -05002006class TestEq(unittest.TestCase):
2007 def test_no_eq(self):
2008 # Test a class with no __eq__ and eq=False.
2009 @dataclass(eq=False)
2010 class C:
2011 x: int
2012 self.assertNotEqual(C(0), C(0))
2013 c = C(3)
2014 self.assertEqual(c, c)
2015
2016 # Test a class with an __eq__ and eq=False.
2017 @dataclass(eq=False)
2018 class C:
2019 x: int
2020 def __eq__(self, other):
2021 return other == 10
2022 self.assertEqual(C(3), 10)
2023
2024 def test_overwriting_eq(self):
2025 # If the class has __eq__, use it no matter the value of
2026 # eq=.
2027
2028 @dataclass
2029 class C:
2030 x: int
2031 def __eq__(self, other):
2032 return other == 3
2033 self.assertEqual(C(1), 3)
2034 self.assertNotEqual(C(1), 1)
2035
2036 @dataclass(eq=True)
2037 class C:
2038 x: int
2039 def __eq__(self, other):
2040 return other == 4
2041 self.assertEqual(C(1), 4)
2042 self.assertNotEqual(C(1), 1)
2043
2044 @dataclass(eq=False)
2045 class C:
2046 x: int
2047 def __eq__(self, other):
2048 return other == 5
2049 self.assertEqual(C(1), 5)
2050 self.assertNotEqual(C(1), 1)
2051
2052
2053class TestOrdering(unittest.TestCase):
2054 def test_functools_total_ordering(self):
2055 # Test that functools.total_ordering works with this class.
2056 @total_ordering
2057 @dataclass
2058 class C:
2059 x: int
2060 def __lt__(self, other):
2061 # Perform the test "backward", just to make
2062 # sure this is being called.
2063 return self.x >= other
2064
2065 self.assertLess(C(0), -1)
2066 self.assertLessEqual(C(0), -1)
2067 self.assertGreater(C(0), 1)
2068 self.assertGreaterEqual(C(0), 1)
2069
2070 def test_no_order(self):
2071 # Test that no ordering functions are added by default.
2072 @dataclass(order=False)
2073 class C:
2074 x: int
2075 # Make sure no order methods are added.
2076 self.assertNotIn('__le__', C.__dict__)
2077 self.assertNotIn('__lt__', C.__dict__)
2078 self.assertNotIn('__ge__', C.__dict__)
2079 self.assertNotIn('__gt__', C.__dict__)
2080
2081 # Test that __lt__ is still called
2082 @dataclass(order=False)
2083 class C:
2084 x: int
2085 def __lt__(self, other):
2086 return False
2087 # Make sure other methods aren't added.
2088 self.assertNotIn('__le__', C.__dict__)
2089 self.assertNotIn('__ge__', C.__dict__)
2090 self.assertNotIn('__gt__', C.__dict__)
2091
2092 def test_overwriting_order(self):
2093 with self.assertRaisesRegex(TypeError,
2094 'Cannot overwrite attribute __lt__'
2095 '.*using functools.total_ordering'):
2096 @dataclass(order=True)
2097 class C:
2098 x: int
2099 def __lt__(self):
2100 pass
2101
2102 with self.assertRaisesRegex(TypeError,
2103 'Cannot overwrite attribute __le__'
2104 '.*using functools.total_ordering'):
2105 @dataclass(order=True)
2106 class C:
2107 x: int
2108 def __le__(self):
2109 pass
2110
2111 with self.assertRaisesRegex(TypeError,
2112 'Cannot overwrite attribute __gt__'
2113 '.*using functools.total_ordering'):
2114 @dataclass(order=True)
2115 class C:
2116 x: int
2117 def __gt__(self):
2118 pass
2119
2120 with self.assertRaisesRegex(TypeError,
2121 'Cannot overwrite attribute __ge__'
2122 '.*using functools.total_ordering'):
2123 @dataclass(order=True)
2124 class C:
2125 x: int
2126 def __ge__(self):
2127 pass
2128
2129class TestHash(unittest.TestCase):
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002130 def test_unsafe_hash(self):
2131 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002132 class C:
2133 x: int
2134 y: str
2135 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2136
Eric V. Smithea8fc522018-01-27 19:07:40 -05002137 def test_hash_rules(self):
2138 def non_bool(value):
2139 # Map to something else that's True, but not a bool.
2140 if value is None:
2141 return None
2142 if value:
2143 return (3,)
2144 return 0
2145
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002146 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2147 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2148 frozen=frozen):
2149 if result != 'exception':
2150 if with_hash:
2151 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2152 class C:
2153 def __hash__(self):
2154 return 0
2155 else:
2156 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2157 class C:
2158 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002159
2160 # See if the result matches what's expected.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002161 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002162 # __hash__ contains the function we generated.
2163 self.assertIn('__hash__', C.__dict__)
2164 self.assertIsNotNone(C.__dict__['__hash__'])
2165
Eric V. Smithea8fc522018-01-27 19:07:40 -05002166 elif result == '':
2167 # __hash__ is not present in our class.
2168 if not with_hash:
2169 self.assertNotIn('__hash__', C.__dict__)
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002170
Eric V. Smithea8fc522018-01-27 19:07:40 -05002171 elif result == 'none':
2172 # __hash__ is set to None.
2173 self.assertIn('__hash__', C.__dict__)
2174 self.assertIsNone(C.__dict__['__hash__'])
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002175
2176 elif result == 'exception':
2177 # Creating the class should cause an exception.
2178 # This only happens with with_hash==True.
2179 assert(with_hash)
2180 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2181 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2182 class C:
2183 def __hash__(self):
2184 return 0
2185
Eric V. Smithea8fc522018-01-27 19:07:40 -05002186 else:
2187 assert False, f'unknown result {result!r}'
2188
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002189 # There are 8 cases of:
2190 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002191 # eq=True/False
2192 # frozen=True/False
2193 # And for each of these, a different result if
2194 # __hash__ is defined or not.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002195 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2196 (False, False, False, '', ''),
2197 (False, False, True, '', ''),
2198 (False, True, False, 'none', ''),
2199 (False, True, True, 'fn', ''),
2200 (True, False, False, 'fn', 'exception'),
2201 (True, False, True, 'fn', 'exception'),
2202 (True, True, False, 'fn', 'exception'),
2203 (True, True, True, 'fn', 'exception'),
2204 ], 1):
2205 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2206 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002207
2208 # Test non-bool truth values, too. This is just to
2209 # make sure the data-driven table in the decorator
2210 # handles non-bool values.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002211 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2212 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002213
2214
2215 def test_eq_only(self):
2216 # If a class defines __eq__, __hash__ is automatically added
2217 # and set to None. This is normal Python behavior, not
2218 # related to dataclasses. Make sure we don't interfere with
2219 # that (see bpo=32546).
2220
2221 @dataclass
2222 class C:
2223 i: int
2224 def __eq__(self, other):
2225 return self.i == other.i
2226 self.assertEqual(C(1), C(1))
2227 self.assertNotEqual(C(1), C(4))
2228
2229 # And make sure things work in this case if we specify
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002230 # unsafe_hash=True.
2231 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002232 class C:
2233 i: int
2234 def __eq__(self, other):
2235 return self.i == other.i
2236 self.assertEqual(C(1), C(1.0))
2237 self.assertEqual(hash(C(1)), hash(C(1.0)))
2238
2239 # And check that the classes __eq__ is being used, despite
2240 # specifying eq=True.
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002241 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002242 class C:
2243 i: int
2244 def __eq__(self, other):
2245 return self.i == 3 and self.i == other.i
2246 self.assertEqual(C(3), C(3))
2247 self.assertNotEqual(C(1), C(1))
2248 self.assertEqual(hash(C(1)), hash(C(1.0)))
2249
Eric V. Smithdbf9cff2018-02-25 21:30:17 -05002250 def test_0_field_hash(self):
2251 @dataclass(frozen=True)
2252 class C:
2253 pass
2254 self.assertEqual(hash(C()), hash(()))
2255
2256 @dataclass(unsafe_hash=True)
2257 class C:
2258 pass
2259 self.assertEqual(hash(C()), hash(()))
2260
2261 def test_1_field_hash(self):
2262 @dataclass(frozen=True)
2263 class C:
2264 x: int
2265 self.assertEqual(hash(C(4)), hash((4,)))
2266 self.assertEqual(hash(C(42)), hash((42,)))
2267
2268 @dataclass(unsafe_hash=True)
2269 class C:
2270 x: int
2271 self.assertEqual(hash(C(4)), hash((4,)))
2272 self.assertEqual(hash(C(42)), hash((42,)))
2273
Eric V. Smith718070d2018-02-23 13:01:31 -05002274 def test_hash_no_args(self):
2275 # Test dataclasses with no hash= argument. This exists to
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002276 # make sure that if the @dataclass parameter name is changed
2277 # or the non-default hashing behavior changes, the default
2278 # hashability keeps working the same way.
Eric V. Smith718070d2018-02-23 13:01:31 -05002279
2280 class Base:
2281 def __hash__(self):
2282 return 301
2283
2284 # If frozen or eq is None, then use the default value (do not
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002285 # specify any value in the decorator).
Eric V. Smith718070d2018-02-23 13:01:31 -05002286 for frozen, eq, base, expected in [
2287 (None, None, object, 'unhashable'),
2288 (None, None, Base, 'unhashable'),
2289 (None, False, object, 'object'),
2290 (None, False, Base, 'base'),
2291 (None, True, object, 'unhashable'),
2292 (None, True, Base, 'unhashable'),
2293 (False, None, object, 'unhashable'),
2294 (False, None, Base, 'unhashable'),
2295 (False, False, object, 'object'),
2296 (False, False, Base, 'base'),
2297 (False, True, object, 'unhashable'),
2298 (False, True, Base, 'unhashable'),
2299 (True, None, object, 'tuple'),
2300 (True, None, Base, 'tuple'),
2301 (True, False, object, 'object'),
2302 (True, False, Base, 'base'),
2303 (True, True, object, 'tuple'),
2304 (True, True, Base, 'tuple'),
2305 ]:
2306
2307 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2308 # First, create the class.
2309 if frozen is None and eq is None:
2310 @dataclass
2311 class C(base):
2312 i: int
2313 elif frozen is None:
2314 @dataclass(eq=eq)
2315 class C(base):
2316 i: int
2317 elif eq is None:
2318 @dataclass(frozen=frozen)
2319 class C(base):
2320 i: int
2321 else:
2322 @dataclass(frozen=frozen, eq=eq)
2323 class C(base):
2324 i: int
2325
2326 # Now, make sure it hashes as expected.
2327 if expected == 'unhashable':
2328 c = C(10)
2329 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2330 hash(c)
2331
2332 elif expected == 'base':
2333 self.assertEqual(hash(C(10)), 301)
2334
2335 elif expected == 'object':
2336 # I'm not sure what test to use here. object's
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002337 # hash isn't based on id(), so calling hash()
2338 # won't tell us much. So, just check the
2339 # function used is object's.
Eric V. Smith718070d2018-02-23 13:01:31 -05002340 self.assertIs(C.__hash__, object.__hash__)
2341
2342 elif expected == 'tuple':
2343 self.assertEqual(hash(C(42)), hash((42,)))
2344
2345 else:
2346 assert False, f'unknown value for expected={expected!r}'
2347
Eric V. Smithea8fc522018-01-27 19:07:40 -05002348
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002349class TestFrozen(unittest.TestCase):
2350 def test_frozen(self):
2351 @dataclass(frozen=True)
2352 class C:
2353 i: int
2354
2355 c = C(10)
2356 self.assertEqual(c.i, 10)
2357 with self.assertRaises(FrozenInstanceError):
2358 c.i = 5
2359 self.assertEqual(c.i, 10)
2360
2361 def test_inherit(self):
2362 @dataclass(frozen=True)
2363 class C:
2364 i: int
2365
2366 @dataclass(frozen=True)
2367 class D(C):
2368 j: int
2369
2370 d = D(0, 10)
2371 with self.assertRaises(FrozenInstanceError):
2372 d.i = 5
Eric V. Smithf199bc62018-03-18 20:40:34 -04002373 with self.assertRaises(FrozenInstanceError):
2374 d.j = 6
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002375 self.assertEqual(d.i, 0)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002376 self.assertEqual(d.j, 10)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002377
Eric V. Smithf199bc62018-03-18 20:40:34 -04002378 # Test both ways: with an intermediate normal (non-dataclass)
2379 # class and without an intermediate class.
2380 def test_inherit_nonfrozen_from_frozen(self):
2381 for intermediate_class in [True, False]:
2382 with self.subTest(intermediate_class=intermediate_class):
2383 @dataclass(frozen=True)
2384 class C:
2385 i: int
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002386
Eric V. Smithf199bc62018-03-18 20:40:34 -04002387 if intermediate_class:
2388 class I(C): pass
2389 else:
2390 I = C
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002391
Eric V. Smithf199bc62018-03-18 20:40:34 -04002392 with self.assertRaisesRegex(TypeError,
2393 'cannot inherit non-frozen dataclass from a frozen one'):
2394 @dataclass
2395 class D(I):
2396 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002397
Eric V. Smithf199bc62018-03-18 20:40:34 -04002398 def test_inherit_frozen_from_nonfrozen(self):
2399 for intermediate_class in [True, False]:
2400 with self.subTest(intermediate_class=intermediate_class):
2401 @dataclass
2402 class C:
2403 i: int
2404
2405 if intermediate_class:
2406 class I(C): pass
2407 else:
2408 I = C
2409
2410 with self.assertRaisesRegex(TypeError,
2411 'cannot inherit frozen dataclass from a non-frozen one'):
2412 @dataclass(frozen=True)
2413 class D(I):
2414 pass
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002415
2416 def test_inherit_from_normal_class(self):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002417 for intermediate_class in [True, False]:
2418 with self.subTest(intermediate_class=intermediate_class):
2419 class C:
2420 pass
2421
2422 if intermediate_class:
2423 class I(C): pass
2424 else:
2425 I = C
2426
2427 @dataclass(frozen=True)
2428 class D(I):
2429 i: int
2430
2431 d = D(10)
2432 with self.assertRaises(FrozenInstanceError):
2433 d.i = 5
2434
2435 def test_non_frozen_normal_derived(self):
2436 # See bpo-32953.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002437
2438 @dataclass(frozen=True)
Eric V. Smithf199bc62018-03-18 20:40:34 -04002439 class D:
2440 x: int
2441 y: int = 10
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002442
Eric V. Smithf199bc62018-03-18 20:40:34 -04002443 class S(D):
2444 pass
2445
2446 s = S(3)
2447 self.assertEqual(s.x, 3)
2448 self.assertEqual(s.y, 10)
2449 s.cached = True
2450
2451 # But can't change the frozen attributes.
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002452 with self.assertRaises(FrozenInstanceError):
Eric V. Smithf199bc62018-03-18 20:40:34 -04002453 s.x = 5
2454 with self.assertRaises(FrozenInstanceError):
2455 s.y = 5
2456 self.assertEqual(s.x, 3)
2457 self.assertEqual(s.y, 10)
2458 self.assertEqual(s.cached, True)
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002459
Eric V. Smith74940912018-04-05 06:50:18 -04002460 def test_overwriting_frozen(self):
2461 # frozen uses __setattr__ and __delattr__.
2462 with self.assertRaisesRegex(TypeError,
2463 'Cannot overwrite attribute __setattr__'):
2464 @dataclass(frozen=True)
2465 class C:
2466 x: int
2467 def __setattr__(self):
2468 pass
2469
2470 with self.assertRaisesRegex(TypeError,
2471 'Cannot overwrite attribute __delattr__'):
2472 @dataclass(frozen=True)
2473 class C:
2474 x: int
2475 def __delattr__(self):
2476 pass
2477
2478 @dataclass(frozen=False)
2479 class C:
2480 x: int
2481 def __setattr__(self, name, value):
2482 self.__dict__['x'] = value * 2
2483 self.assertEqual(C(10).x, 20)
2484
2485 def test_frozen_hash(self):
2486 @dataclass(frozen=True)
2487 class C:
2488 x: Any
2489
2490 # If x is immutable, we can compute the hash. No exception is
2491 # raised.
2492 hash(C(3))
2493
2494 # If x is mutable, computing the hash is an error.
2495 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2496 hash(C({}))
2497
Eric V. Smith2fa6b9e2018-02-26 20:38:33 -05002498
Eric V. Smith7389fd92018-03-19 21:07:51 -04002499class TestSlots(unittest.TestCase):
2500 def test_simple(self):
2501 @dataclass
2502 class C:
2503 __slots__ = ('x',)
2504 x: Any
2505
Eric V. Smith2b75fc22018-03-25 20:37:33 -04002506 # There was a bug where a variable in a slot was assumed to
2507 # also have a default value (of type
2508 # types.MemberDescriptorType).
Eric V. Smith7389fd92018-03-19 21:07:51 -04002509 with self.assertRaisesRegex(TypeError,
Eric V. Smithc42e7aa2018-03-24 23:02:21 -04002510 r"__init__\(\) missing 1 required positional argument: 'x'"):
Eric V. Smith7389fd92018-03-19 21:07:51 -04002511 C()
2512
2513 # We can create an instance, and assign to x.
2514 c = C(10)
2515 self.assertEqual(c.x, 10)
2516 c.x = 5
2517 self.assertEqual(c.x, 5)
2518
2519 # We can't assign to anything else.
2520 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2521 c.y = 5
2522
2523 def test_derived_added_field(self):
2524 # See bpo-33100.
2525 @dataclass
2526 class Base:
2527 __slots__ = ('x',)
2528 x: Any
2529
2530 @dataclass
2531 class Derived(Base):
2532 x: int
2533 y: int
2534
2535 d = Derived(1, 2)
2536 self.assertEqual((d.x, d.y), (1, 2))
2537
2538 # We can add a new field to the derived instance.
2539 d.z = 10
2540
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002541class TestDescriptors(unittest.TestCase):
2542 def test_set_name(self):
2543 # See bpo-33141.
2544
2545 # Create a descriptor.
2546 class D:
2547 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002548 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002549 def __get__(self, instance, owner):
2550 if instance is not None:
2551 return 1
2552 return self
2553
2554 # This is the case of just normal descriptor behavior, no
2555 # dataclass code is involved in initializing the descriptor.
2556 @dataclass
2557 class C:
2558 c: int=D()
Eric V. Smith52199522018-03-29 11:07:48 -04002559 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002560
2561 # Now test with a default value and init=False, which is the
2562 # only time this is really meaningful. If not using
2563 # init=False, then the descriptor will be overwritten, anyway.
2564 @dataclass
2565 class C:
2566 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002567 self.assertEqual(C.c.name, 'cx')
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002568 self.assertEqual(C().c, 1)
2569
2570 def test_non_descriptor(self):
2571 # PEP 487 says __set_name__ should work on non-descriptors.
2572 # Create a descriptor.
2573
2574 class D:
2575 def __set_name__(self, owner, name):
Eric V. Smith52199522018-03-29 11:07:48 -04002576 self.name = name + 'x'
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002577
2578 @dataclass
2579 class C:
2580 c: int=field(default=D(), init=False)
Eric V. Smith52199522018-03-29 11:07:48 -04002581 self.assertEqual(C.c.name, 'cx')
2582
2583 def test_lookup_on_instance(self):
2584 # See bpo-33175.
2585 class D:
2586 pass
2587
2588 d = D()
2589 # Create an attribute on the instance, not type.
2590 d.__set_name__ = Mock()
2591
2592 # Make sure d.__set_name__ is not called.
2593 @dataclass
2594 class C:
2595 i: int=field(default=d, init=False)
2596
2597 self.assertEqual(d.__set_name__.call_count, 0)
2598
2599 def test_lookup_on_class(self):
2600 # See bpo-33175.
2601 class D:
2602 pass
2603 D.__set_name__ = Mock()
2604
2605 # Make sure D.__set_name__ is called.
2606 @dataclass
2607 class C:
2608 i: int=field(default=D(), init=False)
2609
2610 self.assertEqual(D.__set_name__.call_count, 1)
Eric V. Smithde7a2f02018-03-26 13:29:16 -04002611
Eric V. Smith7389fd92018-03-19 21:07:51 -04002612
Eric V. Smith2a7bacb2018-05-15 22:44:27 -04002613class TestStringAnnotations(unittest.TestCase):
2614 def test_classvar(self):
2615 # Some expressions recognized as ClassVar really aren't. But
2616 # if you're using string annotations, it's not an exact
2617 # science.
2618 # These tests assume that both "import typing" and "from
2619 # typing import *" have been run in this file.
2620 for typestr in ('ClassVar[int]',
2621 'ClassVar [int]'
2622 ' ClassVar [int]',
2623 'ClassVar',
2624 ' ClassVar ',
2625 'typing.ClassVar[int]',
2626 'typing.ClassVar[str]',
2627 ' typing.ClassVar[str]',
2628 'typing .ClassVar[str]',
2629 'typing. ClassVar[str]',
2630 'typing.ClassVar [str]',
2631 'typing.ClassVar [ str]',
2632
2633 # Not syntactically valid, but these will
2634 # be treated as ClassVars.
2635 'typing.ClassVar.[int]',
2636 'typing.ClassVar+',
2637 ):
2638 with self.subTest(typestr=typestr):
2639 @dataclass
2640 class C:
2641 x: typestr
2642
2643 # x is a ClassVar, so C() takes no args.
2644 C()
2645
2646 # And it won't appear in the class's dict because it doesn't
2647 # have a default.
2648 self.assertNotIn('x', C.__dict__)
2649
2650 def test_isnt_classvar(self):
2651 for typestr in ('CV',
2652 't.ClassVar',
2653 't.ClassVar[int]',
2654 'typing..ClassVar[int]',
2655 'Classvar',
2656 'Classvar[int]',
2657 'typing.ClassVarx[int]',
2658 'typong.ClassVar[int]',
2659 'dataclasses.ClassVar[int]',
2660 'typingxClassVar[str]',
2661 ):
2662 with self.subTest(typestr=typestr):
2663 @dataclass
2664 class C:
2665 x: typestr
2666
2667 # x is not a ClassVar, so C() takes one arg.
2668 self.assertEqual(C(10).x, 10)
2669
2670 def test_initvar(self):
2671 # These tests assume that both "import dataclasses" and "from
2672 # dataclasses import *" have been run in this file.
2673 for typestr in ('InitVar[int]',
2674 'InitVar [int]'
2675 ' InitVar [int]',
2676 'InitVar',
2677 ' InitVar ',
2678 'dataclasses.InitVar[int]',
2679 'dataclasses.InitVar[str]',
2680 ' dataclasses.InitVar[str]',
2681 'dataclasses .InitVar[str]',
2682 'dataclasses. InitVar[str]',
2683 'dataclasses.InitVar [str]',
2684 'dataclasses.InitVar [ str]',
2685
2686 # Not syntactically valid, but these will
2687 # be treated as InitVars.
2688 'dataclasses.InitVar.[int]',
2689 'dataclasses.InitVar+',
2690 ):
2691 with self.subTest(typestr=typestr):
2692 @dataclass
2693 class C:
2694 x: typestr
2695
2696 # x is an InitVar, so doesn't create a member.
2697 with self.assertRaisesRegex(AttributeError,
2698 "object has no attribute 'x'"):
2699 C(1).x
2700
2701 def test_isnt_initvar(self):
2702 for typestr in ('IV',
2703 'dc.InitVar',
2704 'xdataclasses.xInitVar',
2705 'typing.xInitVar[int]',
2706 ):
2707 with self.subTest(typestr=typestr):
2708 @dataclass
2709 class C:
2710 x: typestr
2711
2712 # x is not an InitVar, so there will be a member x.
2713 self.assertEqual(C(10).x, 10)
2714
2715 def test_classvar_module_level_import(self):
2716 from . import dataclass_module_1
2717 from . import dataclass_module_1_str
2718 from . import dataclass_module_2
2719 from . import dataclass_module_2_str
2720
2721 for m in (dataclass_module_1, dataclass_module_1_str,
2722 dataclass_module_2, dataclass_module_2_str,
2723 ):
2724 with self.subTest(m=m):
2725 # There's a difference in how the ClassVars are
2726 # interpreted when using string annotations or
2727 # not. See the imported modules for details.
2728 if m.USING_STRINGS:
2729 c = m.CV(10)
2730 else:
2731 c = m.CV()
2732 self.assertEqual(c.cv0, 20)
2733
2734
2735 # There's a difference in how the InitVars are
2736 # interpreted when using string annotations or
2737 # not. See the imported modules for details.
2738 c = m.IV(0, 1, 2, 3, 4)
2739
2740 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2741 with self.subTest(field_name=field_name):
2742 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2743 # Since field_name is an InitVar, it's
2744 # not an instance field.
2745 getattr(c, field_name)
2746
2747 if m.USING_STRINGS:
2748 # iv4 is interpreted as a normal field.
2749 self.assertIn('not_iv4', c.__dict__)
2750 self.assertEqual(c.not_iv4, 4)
2751 else:
2752 # iv4 is interpreted as an InitVar, so it
2753 # won't exist on the instance.
2754 self.assertNotIn('not_iv4', c.__dict__)
2755
2756
Eric V. Smith4e812962018-05-16 11:31:29 -04002757class TestMakeDataclass(unittest.TestCase):
2758 def test_simple(self):
2759 C = make_dataclass('C',
2760 [('x', int),
2761 ('y', int, field(default=5))],
2762 namespace={'add_one': lambda self: self.x + 1})
2763 c = C(10)
2764 self.assertEqual((c.x, c.y), (10, 5))
2765 self.assertEqual(c.add_one(), 11)
2766
2767
2768 def test_no_mutate_namespace(self):
2769 # Make sure a provided namespace isn't mutated.
2770 ns = {}
2771 C = make_dataclass('C',
2772 [('x', int),
2773 ('y', int, field(default=5))],
2774 namespace=ns)
2775 self.assertEqual(ns, {})
2776
2777 def test_base(self):
2778 class Base1:
2779 pass
2780 class Base2:
2781 pass
2782 C = make_dataclass('C',
2783 [('x', int)],
2784 bases=(Base1, Base2))
2785 c = C(2)
2786 self.assertIsInstance(c, C)
2787 self.assertIsInstance(c, Base1)
2788 self.assertIsInstance(c, Base2)
2789
2790 def test_base_dataclass(self):
2791 @dataclass
2792 class Base1:
2793 x: int
2794 class Base2:
2795 pass
2796 C = make_dataclass('C',
2797 [('y', int)],
2798 bases=(Base1, Base2))
2799 with self.assertRaisesRegex(TypeError, 'required positional'):
2800 c = C(2)
2801 c = C(1, 2)
2802 self.assertIsInstance(c, C)
2803 self.assertIsInstance(c, Base1)
2804 self.assertIsInstance(c, Base2)
2805
2806 self.assertEqual((c.x, c.y), (1, 2))
2807
2808 def test_init_var(self):
2809 def post_init(self, y):
2810 self.x *= y
2811
2812 C = make_dataclass('C',
2813 [('x', int),
2814 ('y', InitVar[int]),
2815 ],
2816 namespace={'__post_init__': post_init},
2817 )
2818 c = C(2, 3)
2819 self.assertEqual(vars(c), {'x': 6})
2820 self.assertEqual(len(fields(c)), 1)
2821
2822 def test_class_var(self):
2823 C = make_dataclass('C',
2824 [('x', int),
2825 ('y', ClassVar[int], 10),
2826 ('z', ClassVar[int], field(default=20)),
2827 ])
2828 c = C(1)
2829 self.assertEqual(vars(c), {'x': 1})
2830 self.assertEqual(len(fields(c)), 1)
2831 self.assertEqual(C.y, 10)
2832 self.assertEqual(C.z, 20)
2833
2834 def test_other_params(self):
2835 C = make_dataclass('C',
2836 [('x', int),
2837 ('y', ClassVar[int], 10),
2838 ('z', ClassVar[int], field(default=20)),
2839 ],
2840 init=False)
2841 # Make sure we have a repr, but no init.
2842 self.assertNotIn('__init__', vars(C))
2843 self.assertIn('__repr__', vars(C))
2844
2845 # Make sure random other params don't work.
2846 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2847 C = make_dataclass('C',
2848 [],
2849 xxinit=False)
2850
2851 def test_no_types(self):
2852 C = make_dataclass('Point', ['x', 'y', 'z'])
2853 c = C(1, 2, 3)
2854 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2855 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2856 'y': 'typing.Any',
2857 'z': 'typing.Any'})
2858
2859 C = make_dataclass('Point', ['x', ('y', int), 'z'])
2860 c = C(1, 2, 3)
2861 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2862 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2863 'y': int,
2864 'z': 'typing.Any'})
2865
2866 def test_invalid_type_specification(self):
2867 for bad_field in [(),
2868 (1, 2, 3, 4),
2869 ]:
2870 with self.subTest(bad_field=bad_field):
2871 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
2872 make_dataclass('C', ['a', bad_field])
2873
2874 # And test for things with no len().
2875 for bad_field in [float,
2876 lambda x:x,
2877 ]:
2878 with self.subTest(bad_field=bad_field):
2879 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
2880 make_dataclass('C', ['a', bad_field])
2881
2882 def test_duplicate_field_names(self):
2883 for field in ['a', 'ab']:
2884 with self.subTest(field=field):
2885 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
2886 make_dataclass('C', [field, 'a', field])
2887
2888 def test_keyword_field_names(self):
2889 for field in ['for', 'async', 'await', 'as']:
2890 with self.subTest(field=field):
2891 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2892 make_dataclass('C', ['a', field])
2893 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2894 make_dataclass('C', [field])
2895 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2896 make_dataclass('C', [field, 'a'])
2897
2898 def test_non_identifier_field_names(self):
2899 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
2900 with self.subTest(field=field):
2901 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2902 make_dataclass('C', ['a', field])
2903 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2904 make_dataclass('C', [field])
2905 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2906 make_dataclass('C', [field, 'a'])
2907
2908 def test_underscore_field_names(self):
2909 # Unlike namedtuple, it's okay if dataclass field names have
2910 # an underscore.
2911 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
2912
2913 def test_funny_class_names_names(self):
2914 # No reason to prevent weird class names, since
2915 # types.new_class allows them.
2916 for classname in ['()', 'x,y', '*', '2@3', '']:
2917 with self.subTest(classname=classname):
2918 C = make_dataclass(classname, ['a', 'b'])
2919 self.assertEqual(C.__name__, classname)
2920
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04002921class TestReplace(unittest.TestCase):
2922 def test(self):
2923 @dataclass(frozen=True)
2924 class C:
2925 x: int
2926 y: int
2927
2928 c = C(1, 2)
2929 c1 = replace(c, x=3)
2930 self.assertEqual(c1.x, 3)
2931 self.assertEqual(c1.y, 2)
2932
2933 def test_frozen(self):
2934 @dataclass(frozen=True)
2935 class C:
2936 x: int
2937 y: int
2938 z: int = field(init=False, default=10)
2939 t: int = field(init=False, default=100)
2940
2941 c = C(1, 2)
2942 c1 = replace(c, x=3)
2943 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
2944 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
2945
2946
2947 with self.assertRaisesRegex(ValueError, 'init=False'):
2948 replace(c, x=3, z=20, t=50)
2949 with self.assertRaisesRegex(ValueError, 'init=False'):
2950 replace(c, z=20)
2951 replace(c, x=3, z=20, t=50)
2952
2953 # Make sure the result is still frozen.
2954 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
2955 c1.x = 3
2956
2957 # Make sure we can't replace an attribute that doesn't exist,
2958 # if we're also replacing one that does exist. Test this
2959 # here, because setting attributes on frozen instances is
2960 # handled slightly differently from non-frozen ones.
2961 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
2962 "keyword argument 'a'"):
2963 c1 = replace(c, x=20, a=5)
2964
2965 def test_invalid_field_name(self):
2966 @dataclass(frozen=True)
2967 class C:
2968 x: int
2969 y: int
2970
2971 c = C(1, 2)
2972 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
2973 "keyword argument 'z'"):
2974 c1 = replace(c, z=3)
2975
2976 def test_invalid_object(self):
2977 @dataclass(frozen=True)
2978 class C:
2979 x: int
2980 y: int
2981
2982 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
2983 replace(C, x=3)
2984
2985 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
2986 replace(0, x=3)
2987
2988 def test_no_init(self):
2989 @dataclass
2990 class C:
2991 x: int
2992 y: int = field(init=False, default=10)
2993
2994 c = C(1)
2995 c.y = 20
2996
2997 # Make sure y gets the default value.
2998 c1 = replace(c, x=5)
2999 self.assertEqual((c1.x, c1.y), (5, 10))
3000
3001 # Trying to replace y is an error.
3002 with self.assertRaisesRegex(ValueError, 'init=False'):
3003 replace(c, x=2, y=30)
3004
3005 with self.assertRaisesRegex(ValueError, 'init=False'):
3006 replace(c, y=30)
3007
3008 def test_classvar(self):
3009 @dataclass
3010 class C:
3011 x: int
3012 y: ClassVar[int] = 1000
3013
3014 c = C(1)
3015 d = C(2)
3016
3017 self.assertIs(c.y, d.y)
3018 self.assertEqual(c.y, 1000)
3019
3020 # Trying to replace y is an error: can't replace ClassVars.
3021 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3022 "unexpected keyword argument 'y'"):
3023 replace(c, y=30)
3024
3025 replace(c, x=5)
3026
Dong-hee Na3d70f7a2018-06-23 23:46:32 +09003027 def test_initvar_is_specified(self):
3028 @dataclass
3029 class C:
3030 x: int
3031 y: InitVar[int]
3032
3033 def __post_init__(self, y):
3034 self.x *= y
3035
3036 c = C(1, 10)
3037 self.assertEqual(c.x, 10)
3038 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3039 "specified with replace()"):
3040 replace(c, x=3)
3041 c = replace(c, x=3, y=5)
3042 self.assertEqual(c.x, 15)
Eric V. Smithe7adf2b2018-06-07 14:43:59 -04003043 ## def test_initvar(self):
3044 ## @dataclass
3045 ## class C:
3046 ## x: int
3047 ## y: InitVar[int]
3048
3049 ## c = C(1, 10)
3050 ## d = C(2, 20)
3051
3052 ## # In our case, replacing an InitVar is a no-op
3053 ## self.assertEqual(c, replace(c, y=5))
3054
3055 ## replace(c, x=5)
3056
Eric V. Smith4e812962018-05-16 11:31:29 -04003057
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003058if __name__ == '__main__':
3059 unittest.main()