blob: df53b040c0e1ee6e3fbcdd7a47cbc9c137adc131 [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
11from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
12from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
15# Just any custom exception we can catch.
16class CustomError(Exception): pass
17
18class TestCase(unittest.TestCase):
19 def test_no_fields(self):
20 @dataclass
21 class C:
22 pass
23
24 o = C()
25 self.assertEqual(len(fields(C)), 0)
26
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -070027 def test_no_fields_but_member_variable(self):
28 @dataclass
29 class C:
30 i = 0
31
32 o = C()
33 self.assertEqual(len(fields(C)), 0)
34
Eric V. Smithf0db54a2017-12-04 16:58:55 -050035 def test_one_field_no_default(self):
36 @dataclass
37 class C:
38 x: int
39
40 o = C(42)
41 self.assertEqual(o.x, 42)
42
43 def test_named_init_params(self):
44 @dataclass
45 class C:
46 x: int
47
48 o = C(x=32)
49 self.assertEqual(o.x, 32)
50
51 def test_two_fields_one_default(self):
52 @dataclass
53 class C:
54 x: int
55 y: int = 0
56
57 o = C(3)
58 self.assertEqual((o.x, o.y), (3, 0))
59
60 # Non-defaults following defaults.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class C:
66 x: int = 0
67 y: int
68
69 # A derived class adds a non-default field after a default one.
70 with self.assertRaisesRegex(TypeError,
71 "non-default argument 'y' follows "
72 "default argument"):
73 @dataclass
74 class B:
75 x: int = 0
76
77 @dataclass
78 class C(B):
79 y: int
80
81 # Override a base class field and add a default to
82 # a field which didn't use to have a default.
83 with self.assertRaisesRegex(TypeError,
84 "non-default argument 'y' follows "
85 "default argument"):
86 @dataclass
87 class B:
88 x: int
89 y: int
90
91 @dataclass
92 class C(B):
93 x: int = 0
94
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080095 def test_overwrite_hash(self):
96 # Test that declaring this class isn't an error. It should
97 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
101 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800102 return 301
103 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # Test that declaring this class isn't an error. It should
106 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500107 @dataclass(frozen=True)
108 class C:
109 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800110 def __eq__(self, other):
111 return False
112 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500113
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800114 # But this one should generate an exception, because with
115 # unsafe_hash=True, it's an error to have a __hash__ defined.
116 with self.assertRaisesRegex(TypeError,
117 'Cannot overwrite attribute __hash__'):
118 @dataclass(unsafe_hash=True)
119 class C:
120 def __hash__(self):
121 pass
122
123 # Creating this class should not generate an exception,
124 # because even though __hash__ exists before @dataclass is
125 # called, (due to __eq__ being defined), since it's None
126 # that's okay.
127 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500128 class C:
129 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800130 def __eq__(self):
131 pass
132 # The generated hash function works as we'd expect.
133 self.assertEqual(hash(C(10)), hash((10,)))
134
135 # Creating this class should generate an exception, because
136 # __hash__ exists and is not None, which it would be if it had
137 # been auto-generated do due __eq__ being defined.
138 with self.assertRaisesRegex(TypeError,
139 'Cannot overwrite attribute __hash__'):
140 @dataclass(unsafe_hash=True)
141 class C:
142 x: int
143 def __eq__(self):
144 pass
145 def __hash__(self):
146 pass
147
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500148
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500149 def test_overwrite_fields_in_derived_class(self):
150 # Note that x from C1 replaces x in Base, but the order remains
151 # the same as defined in Base.
152 @dataclass
153 class Base:
154 x: Any = 15.0
155 y: int = 0
156
157 @dataclass
158 class C1(Base):
159 z: int = 10
160 x: int = 15
161
162 o = Base()
163 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
164
165 o = C1()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
167
168 o = C1(x=5)
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
170
171 def test_field_named_self(self):
172 @dataclass
173 class C:
174 self: str
175 c=C('foo')
176 self.assertEqual(c.self, 'foo')
177
178 # Make sure the first parameter is not named 'self'.
179 sig = inspect.signature(C.__init__)
180 first = next(iter(sig.parameters))
181 self.assertNotEqual('self', first)
182
183 # But we do use 'self' if no field named self.
184 @dataclass
185 class C:
186 selfx: str
187
188 # Make sure the first parameter is named 'self'.
189 sig = inspect.signature(C.__init__)
190 first = next(iter(sig.parameters))
191 self.assertEqual('self', first)
192
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500193 def test_0_field_compare(self):
194 # Ensure that order=False is the default.
195 @dataclass
196 class C0:
197 pass
198
199 @dataclass(order=False)
200 class C1:
201 pass
202
203 for cls in [C0, C1]:
204 with self.subTest(cls=cls):
205 self.assertEqual(cls(), cls())
206 for idx, fn in enumerate([lambda a, b: a < b,
207 lambda a, b: a <= b,
208 lambda a, b: a > b,
209 lambda a, b: a >= b]):
210 with self.subTest(idx=idx):
211 with self.assertRaisesRegex(TypeError,
212 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
213 fn(cls(), cls())
214
215 @dataclass(order=True)
216 class C:
217 pass
218 self.assertLessEqual(C(), C())
219 self.assertGreaterEqual(C(), C())
220
221 def test_1_field_compare(self):
222 # Ensure that order=False is the default.
223 @dataclass
224 class C0:
225 x: int
226
227 @dataclass(order=False)
228 class C1:
229 x: int
230
231 for cls in [C0, C1]:
232 with self.subTest(cls=cls):
233 self.assertEqual(cls(1), cls(1))
234 self.assertNotEqual(cls(0), cls(1))
235 for idx, fn in enumerate([lambda a, b: a < b,
236 lambda a, b: a <= b,
237 lambda a, b: a > b,
238 lambda a, b: a >= b]):
239 with self.subTest(idx=idx):
240 with self.assertRaisesRegex(TypeError,
241 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
242 fn(cls(0), cls(0))
243
244 @dataclass(order=True)
245 class C:
246 x: int
247 self.assertLess(C(0), C(1))
248 self.assertLessEqual(C(0), C(1))
249 self.assertLessEqual(C(1), C(1))
250 self.assertGreater(C(1), C(0))
251 self.assertGreaterEqual(C(1), C(0))
252 self.assertGreaterEqual(C(1), C(1))
253
254 def test_simple_compare(self):
255 # Ensure that order=False is the default.
256 @dataclass
257 class C0:
258 x: int
259 y: int
260
261 @dataclass(order=False)
262 class C1:
263 x: int
264 y: int
265
266 for cls in [C0, C1]:
267 with self.subTest(cls=cls):
268 self.assertEqual(cls(0, 0), cls(0, 0))
269 self.assertEqual(cls(1, 2), cls(1, 2))
270 self.assertNotEqual(cls(1, 0), cls(0, 0))
271 self.assertNotEqual(cls(1, 0), cls(1, 1))
272 for idx, fn in enumerate([lambda a, b: a < b,
273 lambda a, b: a <= b,
274 lambda a, b: a > b,
275 lambda a, b: a >= b]):
276 with self.subTest(idx=idx):
277 with self.assertRaisesRegex(TypeError,
278 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
279 fn(cls(0, 0), cls(0, 0))
280
281 @dataclass(order=True)
282 class C:
283 x: int
284 y: int
285
286 for idx, fn in enumerate([lambda a, b: a == b,
287 lambda a, b: a <= b,
288 lambda a, b: a >= b]):
289 with self.subTest(idx=idx):
290 self.assertTrue(fn(C(0, 0), C(0, 0)))
291
292 for idx, fn in enumerate([lambda a, b: a < b,
293 lambda a, b: a <= b,
294 lambda a, b: a != b]):
295 with self.subTest(idx=idx):
296 self.assertTrue(fn(C(0, 0), C(0, 1)))
297 self.assertTrue(fn(C(0, 1), C(1, 0)))
298 self.assertTrue(fn(C(1, 0), C(1, 1)))
299
300 for idx, fn in enumerate([lambda a, b: a > b,
301 lambda a, b: a >= b,
302 lambda a, b: a != b]):
303 with self.subTest(idx=idx):
304 self.assertTrue(fn(C(0, 1), C(0, 0)))
305 self.assertTrue(fn(C(1, 0), C(0, 1)))
306 self.assertTrue(fn(C(1, 1), C(1, 0)))
307
308 def test_compare_subclasses(self):
309 # Comparisons fail for subclasses, even if no fields
310 # are added.
311 @dataclass
312 class B:
313 i: int
314
315 @dataclass
316 class C(B):
317 pass
318
319 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
320 (lambda a, b: a != b, True)]):
321 with self.subTest(idx=idx):
322 self.assertEqual(fn(B(0), C(0)), expected)
323
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 "not supported between instances of 'B' and 'C'"):
331 fn(B(0), C(0))
332
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500333 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500334 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500335 for (eq, order, result ) in [
336 (False, False, 'neither'),
337 (False, True, 'exception'),
338 (True, False, 'eq_only'),
339 (True, True, 'both'),
340 ]:
341 with self.subTest(eq=eq, order=order):
342 if result == 'exception':
343 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
344 @dataclass(eq=eq, order=order)
345 class C:
346 pass
347 else:
348 @dataclass(eq=eq, order=order)
349 class C:
350 pass
351
352 if result == 'neither':
353 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500354 self.assertNotIn('__lt__', C.__dict__)
355 self.assertNotIn('__le__', C.__dict__)
356 self.assertNotIn('__gt__', C.__dict__)
357 self.assertNotIn('__ge__', C.__dict__)
358 elif result == 'both':
359 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500360 self.assertIn('__lt__', C.__dict__)
361 self.assertIn('__le__', C.__dict__)
362 self.assertIn('__gt__', C.__dict__)
363 self.assertIn('__ge__', C.__dict__)
364 elif result == 'eq_only':
365 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500366 self.assertNotIn('__lt__', C.__dict__)
367 self.assertNotIn('__le__', C.__dict__)
368 self.assertNotIn('__gt__', C.__dict__)
369 self.assertNotIn('__ge__', C.__dict__)
370 else:
371 assert False, f'unknown result {result!r}'
372
373 def test_field_no_default(self):
374 @dataclass
375 class C:
376 x: int = field()
377
378 self.assertEqual(C(5).x, 5)
379
380 with self.assertRaisesRegex(TypeError,
381 r"__init__\(\) missing 1 required "
382 "positional argument: 'x'"):
383 C()
384
385 def test_field_default(self):
386 default = object()
387 @dataclass
388 class C:
389 x: object = field(default=default)
390
391 self.assertIs(C.x, default)
392 c = C(10)
393 self.assertEqual(c.x, 10)
394
395 # If we delete the instance attribute, we should then see the
396 # class attribute.
397 del c.x
398 self.assertIs(c.x, default)
399
400 self.assertIs(C().x, default)
401
402 def test_not_in_repr(self):
403 @dataclass
404 class C:
405 x: int = field(repr=False)
406 with self.assertRaises(TypeError):
407 C()
408 c = C(10)
409 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
410
411 @dataclass
412 class C:
413 x: int = field(repr=False)
414 y: int
415 c = C(10, 20)
416 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
417
418 def test_not_in_compare(self):
419 @dataclass
420 class C:
421 x: int = 0
422 y: int = field(compare=False, default=4)
423
424 self.assertEqual(C(), C(0, 20))
425 self.assertEqual(C(1, 10), C(1, 20))
426 self.assertNotEqual(C(3), C(4, 10))
427 self.assertNotEqual(C(3, 10), C(4, 10))
428
429 def test_hash_field_rules(self):
430 # Test all 6 cases of:
431 # hash=True/False/None
432 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800433 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 (True, False, 'field' ),
435 (True, True, 'field' ),
436 (False, False, 'absent'),
437 (False, True, 'absent'),
438 (None, False, 'absent'),
439 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800440 ]:
441 with self.subTest(hash=hash_, compare=compare):
442 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500443 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800444 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500445
446 if result == 'field':
447 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800448 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500449 elif result == 'absent':
450 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800451 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500452 else:
453 assert False, f'unknown result {result!r}'
454
455 def test_init_false_no_default(self):
456 # If init=False and no default value, then the field won't be
457 # present in the instance.
458 @dataclass
459 class C:
460 x: int = field(init=False)
461
462 self.assertNotIn('x', C().__dict__)
463
464 @dataclass
465 class C:
466 x: int
467 y: int = 0
468 z: int = field(init=False)
469 t: int = 10
470
471 self.assertNotIn('z', C(0).__dict__)
472 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
473
474 def test_class_marker(self):
475 @dataclass
476 class C:
477 x: int
478 y: str = field(init=False, default=None)
479 z: str = field(repr=False)
480
481 the_fields = fields(C)
482 # the_fields is a tuple of 3 items, each value
483 # is in __annotations__.
484 self.assertIsInstance(the_fields, tuple)
485 for f in the_fields:
486 self.assertIs(type(f), Field)
487 self.assertIn(f.name, C.__annotations__)
488
489 self.assertEqual(len(the_fields), 3)
490
491 self.assertEqual(the_fields[0].name, 'x')
492 self.assertEqual(the_fields[0].type, int)
493 self.assertFalse(hasattr(C, 'x'))
494 self.assertTrue (the_fields[0].init)
495 self.assertTrue (the_fields[0].repr)
496 self.assertEqual(the_fields[1].name, 'y')
497 self.assertEqual(the_fields[1].type, str)
498 self.assertIsNone(getattr(C, 'y'))
499 self.assertFalse(the_fields[1].init)
500 self.assertTrue (the_fields[1].repr)
501 self.assertEqual(the_fields[2].name, 'z')
502 self.assertEqual(the_fields[2].type, str)
503 self.assertFalse(hasattr(C, 'z'))
504 self.assertTrue (the_fields[2].init)
505 self.assertFalse(the_fields[2].repr)
506
507 def test_field_order(self):
508 @dataclass
509 class B:
510 a: str = 'B:a'
511 b: str = 'B:b'
512 c: str = 'B:c'
513
514 @dataclass
515 class C(B):
516 b: str = 'C:b'
517
518 self.assertEqual([(f.name, f.default) for f in fields(C)],
519 [('a', 'B:a'),
520 ('b', 'C:b'),
521 ('c', 'B:c')])
522
523 @dataclass
524 class D(B):
525 c: str = 'D:c'
526
527 self.assertEqual([(f.name, f.default) for f in fields(D)],
528 [('a', 'B:a'),
529 ('b', 'B:b'),
530 ('c', 'D:c')])
531
532 @dataclass
533 class E(D):
534 a: str = 'E:a'
535 d: str = 'E:d'
536
537 self.assertEqual([(f.name, f.default) for f in fields(E)],
538 [('a', 'E:a'),
539 ('b', 'B:b'),
540 ('c', 'D:c'),
541 ('d', 'E:d')])
542
543 def test_class_attrs(self):
544 # We only have a class attribute if a default value is
545 # specified, either directly or via a field with a default.
546 default = object()
547 @dataclass
548 class C:
549 x: int
550 y: int = field(repr=False)
551 z: object = default
552 t: int = field(default=100)
553
554 self.assertFalse(hasattr(C, 'x'))
555 self.assertFalse(hasattr(C, 'y'))
556 self.assertIs (C.z, default)
557 self.assertEqual(C.t, 100)
558
559 def test_disallowed_mutable_defaults(self):
560 # For the known types, don't allow mutable default values.
561 for typ, empty, non_empty in [(list, [], [1]),
562 (dict, {}, {0:1}),
563 (set, set(), set([1])),
564 ]:
565 with self.subTest(typ=typ):
566 # Can't use a zero-length value.
567 with self.assertRaisesRegex(ValueError,
568 f'mutable default {typ} for field '
569 'x is not allowed'):
570 @dataclass
571 class Point:
572 x: typ = empty
573
574
575 # Nor a non-zero-length value
576 with self.assertRaisesRegex(ValueError,
577 f'mutable default {typ} for field '
578 'y is not allowed'):
579 @dataclass
580 class Point:
581 y: typ = non_empty
582
583 # Check subtypes also fail.
584 class Subclass(typ): pass
585
586 with self.assertRaisesRegex(ValueError,
587 f"mutable default .*Subclass'>"
588 ' for field z is not allowed'
589 ):
590 @dataclass
591 class Point:
592 z: typ = Subclass()
593
594 # Because this is a ClassVar, it can be mutable.
595 @dataclass
596 class C:
597 z: ClassVar[typ] = typ()
598
599 # Because this is a ClassVar, it can be mutable.
600 @dataclass
601 class C:
602 x: ClassVar[typ] = Subclass()
603
604
605 def test_deliberately_mutable_defaults(self):
606 # If a mutable default isn't in the known list of
607 # (list, dict, set), then it's okay.
608 class Mutable:
609 def __init__(self):
610 self.l = []
611
612 @dataclass
613 class C:
614 x: Mutable
615
616 # These 2 instances will share this value of x.
617 lst = Mutable()
618 o1 = C(lst)
619 o2 = C(lst)
620 self.assertEqual(o1, o2)
621 o1.x.l.extend([1, 2])
622 self.assertEqual(o1, o2)
623 self.assertEqual(o1.x.l, [1, 2])
624 self.assertIs(o1.x, o2.x)
625
626 def test_no_options(self):
627 # call with dataclass()
628 @dataclass()
629 class C:
630 x: int
631
632 self.assertEqual(C(42).x, 42)
633
634 def test_not_tuple(self):
635 # Make sure we can't be compared to a tuple.
636 @dataclass
637 class Point:
638 x: int
639 y: int
640 self.assertNotEqual(Point(1, 2), (1, 2))
641
642 # And that we can't compare to another unrelated dataclass
643 @dataclass
644 class C:
645 x: int
646 y: int
647 self.assertNotEqual(Point(1, 3), C(1, 3))
648
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500649 def test_not_tuple(self):
650 # Test that some of the problems with namedtuple don't happen
651 # here.
652 @dataclass
653 class Point3D:
654 x: int
655 y: int
656 z: int
657
658 @dataclass
659 class Date:
660 year: int
661 month: int
662 day: int
663
664 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
665 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
666
667 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200668 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500669 x, y, z = Point3D(4, 5, 6)
670
Eric V. Smith7c99e932018-01-28 19:18:55 -0500671 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500672 # equal.
673 @dataclass
674 class Point3Dv1:
675 x: int = 0
676 y: int = 0
677 z: int = 0
678 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
679
680 def test_function_annotations(self):
681 # Some dummy class and instance to use as a default.
682 class F:
683 pass
684 f = F()
685
686 def validate_class(cls):
687 # First, check __annotations__, even though they're not
688 # function annotations.
689 self.assertEqual(cls.__annotations__['i'], int)
690 self.assertEqual(cls.__annotations__['j'], str)
691 self.assertEqual(cls.__annotations__['k'], F)
692 self.assertEqual(cls.__annotations__['l'], float)
693 self.assertEqual(cls.__annotations__['z'], complex)
694
695 # Verify __init__.
696
697 signature = inspect.signature(cls.__init__)
698 # Check the return type, should be None
699 self.assertIs(signature.return_annotation, None)
700
701 # Check each parameter.
702 params = iter(signature.parameters.values())
703 param = next(params)
704 # This is testing an internal name, and probably shouldn't be tested.
705 self.assertEqual(param.name, 'self')
706 param = next(params)
707 self.assertEqual(param.name, 'i')
708 self.assertIs (param.annotation, int)
709 self.assertEqual(param.default, inspect.Parameter.empty)
710 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
711 param = next(params)
712 self.assertEqual(param.name, 'j')
713 self.assertIs (param.annotation, str)
714 self.assertEqual(param.default, inspect.Parameter.empty)
715 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
716 param = next(params)
717 self.assertEqual(param.name, 'k')
718 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500719 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
721 param = next(params)
722 self.assertEqual(param.name, 'l')
723 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500724 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500725 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
726 self.assertRaises(StopIteration, next, params)
727
728
729 @dataclass
730 class C:
731 i: int
732 j: str
733 k: F = f
734 l: float=field(default=None)
735 z: complex=field(default=3+4j, init=False)
736
737 validate_class(C)
738
739 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800740 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500741 class C:
742 i: int
743 j: str
744 k: F = f
745 l: float=field(default=None)
746 z: complex=field(default=3+4j, init=False)
747
748 validate_class(C)
749
Eric V. Smith03220fd2017-12-29 13:59:58 -0500750 def test_missing_default(self):
751 # Test that MISSING works the same as a default not being
752 # specified.
753 @dataclass
754 class C:
755 x: int=field(default=MISSING)
756 with self.assertRaisesRegex(TypeError,
757 r'__init__\(\) missing 1 required '
758 'positional argument'):
759 C()
760 self.assertNotIn('x', C.__dict__)
761
762 @dataclass
763 class D:
764 x: int
765 with self.assertRaisesRegex(TypeError,
766 r'__init__\(\) missing 1 required '
767 'positional argument'):
768 D()
769 self.assertNotIn('x', D.__dict__)
770
771 def test_missing_default_factory(self):
772 # Test that MISSING works the same as a default factory not
773 # being specified (which is really the same as a default not
774 # being specified, too).
775 @dataclass
776 class C:
777 x: int=field(default_factory=MISSING)
778 with self.assertRaisesRegex(TypeError,
779 r'__init__\(\) missing 1 required '
780 'positional argument'):
781 C()
782 self.assertNotIn('x', C.__dict__)
783
784 @dataclass
785 class D:
786 x: int=field(default=MISSING, default_factory=MISSING)
787 with self.assertRaisesRegex(TypeError,
788 r'__init__\(\) missing 1 required '
789 'positional argument'):
790 D()
791 self.assertNotIn('x', D.__dict__)
792
793 def test_missing_repr(self):
794 self.assertIn('MISSING_TYPE object', repr(MISSING))
795
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500796 def test_dont_include_other_annotations(self):
797 @dataclass
798 class C:
799 i: int
800 def foo(self) -> int:
801 return 4
802 @property
803 def bar(self) -> int:
804 return 5
805 self.assertEqual(list(C.__annotations__), ['i'])
806 self.assertEqual(C(10).foo(), 4)
807 self.assertEqual(C(10).bar, 5)
Miss Islington (bot)5666a552018-03-25 06:27:50 -0700808 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500809
810 def test_post_init(self):
811 # Just make sure it gets called
812 @dataclass
813 class C:
814 def __post_init__(self):
815 raise CustomError()
816 with self.assertRaises(CustomError):
817 C()
818
819 @dataclass
820 class C:
821 i: int = 10
822 def __post_init__(self):
823 if self.i == 10:
824 raise CustomError()
825 with self.assertRaises(CustomError):
826 C()
827 # post-init gets called, but doesn't raise. This is just
828 # checking that self is used correctly.
829 C(5)
830
831 # If there's not an __init__, then post-init won't get called.
832 @dataclass(init=False)
833 class C:
834 def __post_init__(self):
835 raise CustomError()
836 # Creating the class won't raise
837 C()
838
839 @dataclass
840 class C:
841 x: int = 0
842 def __post_init__(self):
843 self.x *= 2
844 self.assertEqual(C().x, 0)
845 self.assertEqual(C(2).x, 4)
846
Mike53f7a7c2017-12-14 14:04:53 +0300847 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500848 # attributes.
849 @dataclass(frozen=True)
850 class C:
851 x: int = 0
852 def __post_init__(self):
853 self.x *= 2
854 with self.assertRaises(FrozenInstanceError):
855 C()
856
857 def test_post_init_super(self):
858 # Make sure super() post-init isn't called by default.
859 class B:
860 def __post_init__(self):
861 raise CustomError()
862
863 @dataclass
864 class C(B):
865 def __post_init__(self):
866 self.x = 5
867
868 self.assertEqual(C().x, 5)
869
870 # Now call super(), and it will raise
871 @dataclass
872 class C(B):
873 def __post_init__(self):
874 super().__post_init__()
875
876 with self.assertRaises(CustomError):
877 C()
878
879 # Make sure post-init is called, even if not defined in our
880 # class.
881 @dataclass
882 class C(B):
883 pass
884
885 with self.assertRaises(CustomError):
886 C()
887
888 def test_post_init_staticmethod(self):
889 flag = False
890 @dataclass
891 class C:
892 x: int
893 y: int
894 @staticmethod
895 def __post_init__():
896 nonlocal flag
897 flag = True
898
899 self.assertFalse(flag)
900 c = C(3, 4)
901 self.assertEqual((c.x, c.y), (3, 4))
902 self.assertTrue(flag)
903
904 def test_post_init_classmethod(self):
905 @dataclass
906 class C:
907 flag = False
908 x: int
909 y: int
910 @classmethod
911 def __post_init__(cls):
912 cls.flag = True
913
914 self.assertFalse(C.flag)
915 c = C(3, 4)
916 self.assertEqual((c.x, c.y), (3, 4))
917 self.assertTrue(C.flag)
918
919 def test_class_var(self):
920 # Make sure ClassVars are ignored in __init__, __repr__, etc.
921 @dataclass
922 class C:
923 x: int
924 y: int = 10
925 z: ClassVar[int] = 1000
926 w: ClassVar[int] = 2000
927 t: ClassVar[int] = 3000
928
929 c = C(5)
930 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
931 self.assertEqual(len(fields(C)), 2) # We have 2 fields
932 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
933 self.assertEqual(c.z, 1000)
934 self.assertEqual(c.w, 2000)
935 self.assertEqual(c.t, 3000)
936 C.z += 1
937 self.assertEqual(c.z, 1001)
938 c = C(20)
939 self.assertEqual((c.x, c.y), (20, 10))
940 self.assertEqual(c.z, 1001)
941 self.assertEqual(c.w, 2000)
942 self.assertEqual(c.t, 3000)
943
944 def test_class_var_no_default(self):
945 # If a ClassVar has no default value, it should not be set on the class.
946 @dataclass
947 class C:
948 x: ClassVar[int]
949
950 self.assertNotIn('x', C.__dict__)
951
952 def test_class_var_default_factory(self):
953 # It makes no sense for a ClassVar to have a default factory. When
954 # would it be called? Call it yourself, since it's class-wide.
955 with self.assertRaisesRegex(TypeError,
956 'cannot have a default factory'):
957 @dataclass
958 class C:
959 x: ClassVar[int] = field(default_factory=int)
960
961 self.assertNotIn('x', C.__dict__)
962
963 def test_class_var_with_default(self):
964 # If a ClassVar has a default value, it should be set on the class.
965 @dataclass
966 class C:
967 x: ClassVar[int] = 10
968 self.assertEqual(C.x, 10)
969
970 @dataclass
971 class C:
972 x: ClassVar[int] = field(default=10)
973 self.assertEqual(C.x, 10)
974
975 def test_class_var_frozen(self):
976 # Make sure ClassVars work even if we're frozen.
977 @dataclass(frozen=True)
978 class C:
979 x: int
980 y: int = 10
981 z: ClassVar[int] = 1000
982 w: ClassVar[int] = 2000
983 t: ClassVar[int] = 3000
984
985 c = C(5)
986 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
987 self.assertEqual(len(fields(C)), 2) # We have 2 fields
988 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
989 self.assertEqual(c.z, 1000)
990 self.assertEqual(c.w, 2000)
991 self.assertEqual(c.t, 3000)
992 # We can still modify the ClassVar, it's only instances that are
993 # frozen.
994 C.z += 1
995 self.assertEqual(c.z, 1001)
996 c = C(20)
997 self.assertEqual((c.x, c.y), (20, 10))
998 self.assertEqual(c.z, 1001)
999 self.assertEqual(c.w, 2000)
1000 self.assertEqual(c.t, 3000)
1001
1002 def test_init_var_no_default(self):
1003 # If an InitVar has no default value, it should not be set on the class.
1004 @dataclass
1005 class C:
1006 x: InitVar[int]
1007
1008 self.assertNotIn('x', C.__dict__)
1009
1010 def test_init_var_default_factory(self):
1011 # It makes no sense for an InitVar to have a default factory. When
1012 # would it be called? Call it yourself, since it's class-wide.
1013 with self.assertRaisesRegex(TypeError,
1014 'cannot have a default factory'):
1015 @dataclass
1016 class C:
1017 x: InitVar[int] = field(default_factory=int)
1018
1019 self.assertNotIn('x', C.__dict__)
1020
1021 def test_init_var_with_default(self):
1022 # If an InitVar has a default value, it should be set on the class.
1023 @dataclass
1024 class C:
1025 x: InitVar[int] = 10
1026 self.assertEqual(C.x, 10)
1027
1028 @dataclass
1029 class C:
1030 x: InitVar[int] = field(default=10)
1031 self.assertEqual(C.x, 10)
1032
1033 def test_init_var(self):
1034 @dataclass
1035 class C:
1036 x: int = None
1037 init_param: InitVar[int] = None
1038
1039 def __post_init__(self, init_param):
1040 if self.x is None:
1041 self.x = init_param*2
1042
1043 c = C(init_param=10)
1044 self.assertEqual(c.x, 20)
1045
1046 def test_init_var_inheritance(self):
1047 # Note that this deliberately tests that a dataclass need not
1048 # have a __post_init__ function if it has an InitVar field.
1049 # It could just be used in a derived class, as shown here.
1050 @dataclass
1051 class Base:
1052 x: int
1053 init_base: InitVar[int]
1054
1055 # We can instantiate by passing the InitVar, even though
1056 # it's not used.
1057 b = Base(0, 10)
1058 self.assertEqual(vars(b), {'x': 0})
1059
1060 @dataclass
1061 class C(Base):
1062 y: int
1063 init_derived: InitVar[int]
1064
1065 def __post_init__(self, init_base, init_derived):
1066 self.x = self.x + init_base
1067 self.y = self.y + init_derived
1068
1069 c = C(10, 11, 50, 51)
1070 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1071
1072 def test_default_factory(self):
1073 # Test a factory that returns a new list.
1074 @dataclass
1075 class C:
1076 x: int
1077 y: list = field(default_factory=list)
1078
1079 c0 = C(3)
1080 c1 = C(3)
1081 self.assertEqual(c0.x, 3)
1082 self.assertEqual(c0.y, [])
1083 self.assertEqual(c0, c1)
1084 self.assertIsNot(c0.y, c1.y)
1085 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1086
1087 # Test a factory that returns a shared list.
1088 l = []
1089 @dataclass
1090 class C:
1091 x: int
1092 y: list = field(default_factory=lambda: l)
1093
1094 c0 = C(3)
1095 c1 = C(3)
1096 self.assertEqual(c0.x, 3)
1097 self.assertEqual(c0.y, [])
1098 self.assertEqual(c0, c1)
1099 self.assertIs(c0.y, c1.y)
1100 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1101
1102 # Test various other field flags.
1103 # repr
1104 @dataclass
1105 class C:
1106 x: list = field(default_factory=list, repr=False)
1107 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1108 self.assertEqual(C().x, [])
1109
1110 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001111 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001112 class C:
1113 x: list = field(default_factory=list, hash=False)
1114 self.assertEqual(astuple(C()), ([],))
1115 self.assertEqual(hash(C()), hash(()))
1116
1117 # init (see also test_default_factory_with_no_init)
1118 @dataclass
1119 class C:
1120 x: list = field(default_factory=list, init=False)
1121 self.assertEqual(astuple(C()), ([],))
1122
1123 # compare
1124 @dataclass
1125 class C:
1126 x: list = field(default_factory=list, compare=False)
1127 self.assertEqual(C(), C([1]))
1128
1129 def test_default_factory_with_no_init(self):
1130 # We need a factory with a side effect.
1131 factory = Mock()
1132
1133 @dataclass
1134 class C:
1135 x: list = field(default_factory=factory, init=False)
1136
1137 # Make sure the default factory is called for each new instance.
1138 C().x
1139 self.assertEqual(factory.call_count, 1)
1140 C().x
1141 self.assertEqual(factory.call_count, 2)
1142
1143 def test_default_factory_not_called_if_value_given(self):
1144 # We need a factory that we can test if it's been called.
1145 factory = Mock()
1146
1147 @dataclass
1148 class C:
1149 x: int = field(default_factory=factory)
1150
1151 # Make sure that if a field has a default factory function,
1152 # it's not called if a value is specified.
1153 C().x
1154 self.assertEqual(factory.call_count, 1)
1155 self.assertEqual(C(10).x, 10)
1156 self.assertEqual(factory.call_count, 1)
1157 C().x
1158 self.assertEqual(factory.call_count, 2)
1159
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001160 def test_default_factory_derived(self):
1161 # See bpo-32896.
1162 @dataclass
1163 class Foo:
1164 x: dict = field(default_factory=dict)
1165
1166 @dataclass
1167 class Bar(Foo):
1168 y: int = 1
1169
1170 self.assertEqual(Foo().x, {})
1171 self.assertEqual(Bar().x, {})
1172 self.assertEqual(Bar().y, 1)
1173
1174 @dataclass
1175 class Baz(Foo):
1176 pass
1177 self.assertEqual(Baz().x, {})
1178
1179 def test_intermediate_non_dataclass(self):
1180 # Test that an intermediate class that defines
1181 # annotations does not define fields.
1182
1183 @dataclass
1184 class A:
1185 x: int
1186
1187 class B(A):
1188 y: int
1189
1190 @dataclass
1191 class C(B):
1192 z: int
1193
1194 c = C(1, 3)
1195 self.assertEqual((c.x, c.z), (1, 3))
1196
1197 # .y was not initialized.
1198 with self.assertRaisesRegex(AttributeError,
1199 'object has no attribute'):
1200 c.y
1201
1202 # And if we again derive a non-dataclass, no fields are added.
1203 class D(C):
1204 t: int
1205 d = D(4, 5)
1206 self.assertEqual((d.x, d.z), (4, 5))
1207
1208
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001209 def x_test_classvar_default_factory(self):
1210 # XXX: it's an error for a ClassVar to have a factory function
1211 @dataclass
1212 class C:
1213 x: ClassVar[int] = field(default_factory=int)
1214
1215 self.assertIs(C().x, int)
1216
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001217 def test_is_dataclass(self):
1218 class NotDataClass:
1219 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001220
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001221 self.assertFalse(is_dataclass(0))
1222 self.assertFalse(is_dataclass(int))
1223 self.assertFalse(is_dataclass(NotDataClass))
1224 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001225
1226 @dataclass
1227 class C:
1228 x: int
1229
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001230 @dataclass
1231 class D:
1232 d: C
1233 e: int
1234
1235 c = C(10)
1236 d = D(c, 4)
1237
1238 self.assertTrue(is_dataclass(C))
1239 self.assertTrue(is_dataclass(c))
1240 self.assertFalse(is_dataclass(c.x))
1241 self.assertTrue(is_dataclass(d.d))
1242 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001243
1244 def test_helper_fields_with_class_instance(self):
1245 # Check that we can call fields() on either a class or instance,
1246 # and get back the same thing.
1247 @dataclass
1248 class C:
1249 x: int
1250 y: float
1251
1252 self.assertEqual(fields(C), fields(C(0, 0.0)))
1253
1254 def test_helper_fields_exception(self):
1255 # Check that TypeError is raised if not passed a dataclass or
1256 # instance.
1257 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1258 fields(0)
1259
1260 class C: pass
1261 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1262 fields(C)
1263 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1264 fields(C())
1265
1266 def test_helper_asdict(self):
1267 # Basic tests for asdict(), it should return a new dictionary
1268 @dataclass
1269 class C:
1270 x: int
1271 y: int
1272 c = C(1, 2)
1273
1274 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1275 self.assertEqual(asdict(c), asdict(c))
1276 self.assertIsNot(asdict(c), asdict(c))
1277 c.x = 42
1278 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1279 self.assertIs(type(asdict(c)), dict)
1280
1281 def test_helper_asdict_raises_on_classes(self):
1282 # asdict() should raise on a class object
1283 @dataclass
1284 class C:
1285 x: int
1286 y: int
1287 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1288 asdict(C)
1289 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1290 asdict(int)
1291
1292 def test_helper_asdict_copy_values(self):
1293 @dataclass
1294 class C:
1295 x: int
1296 y: List[int] = field(default_factory=list)
1297 initial = []
1298 c = C(1, initial)
1299 d = asdict(c)
1300 self.assertEqual(d['y'], initial)
1301 self.assertIsNot(d['y'], initial)
1302 c = C(1)
1303 d = asdict(c)
1304 d['y'].append(1)
1305 self.assertEqual(c.y, [])
1306
1307 def test_helper_asdict_nested(self):
1308 @dataclass
1309 class UserId:
1310 token: int
1311 group: int
1312 @dataclass
1313 class User:
1314 name: str
1315 id: UserId
1316 u = User('Joe', UserId(123, 1))
1317 d = asdict(u)
1318 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1319 self.assertIsNot(asdict(u), asdict(u))
1320 u.id.group = 2
1321 self.assertEqual(asdict(u), {'name': 'Joe',
1322 'id': {'token': 123, 'group': 2}})
1323
1324 def test_helper_asdict_builtin_containers(self):
1325 @dataclass
1326 class User:
1327 name: str
1328 id: int
1329 @dataclass
1330 class GroupList:
1331 id: int
1332 users: List[User]
1333 @dataclass
1334 class GroupTuple:
1335 id: int
1336 users: Tuple[User, ...]
1337 @dataclass
1338 class GroupDict:
1339 id: int
1340 users: Dict[str, User]
1341 a = User('Alice', 1)
1342 b = User('Bob', 2)
1343 gl = GroupList(0, [a, b])
1344 gt = GroupTuple(0, (a, b))
1345 gd = GroupDict(0, {'first': a, 'second': b})
1346 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1347 {'name': 'Bob', 'id': 2}]})
1348 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1349 {'name': 'Bob', 'id': 2})})
1350 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1351 'second': {'name': 'Bob', 'id': 2}}})
1352
1353 def test_helper_asdict_builtin_containers(self):
1354 @dataclass
1355 class Child:
1356 d: object
1357
1358 @dataclass
1359 class Parent:
1360 child: Child
1361
1362 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1363 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1364
1365 def test_helper_asdict_factory(self):
1366 @dataclass
1367 class C:
1368 x: int
1369 y: int
1370 c = C(1, 2)
1371 d = asdict(c, dict_factory=OrderedDict)
1372 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1373 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1374 c.x = 42
1375 d = asdict(c, dict_factory=OrderedDict)
1376 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1377 self.assertIs(type(d), OrderedDict)
1378
1379 def test_helper_astuple(self):
1380 # Basic tests for astuple(), it should return a new tuple
1381 @dataclass
1382 class C:
1383 x: int
1384 y: int = 0
1385 c = C(1)
1386
1387 self.assertEqual(astuple(c), (1, 0))
1388 self.assertEqual(astuple(c), astuple(c))
1389 self.assertIsNot(astuple(c), astuple(c))
1390 c.y = 42
1391 self.assertEqual(astuple(c), (1, 42))
1392 self.assertIs(type(astuple(c)), tuple)
1393
1394 def test_helper_astuple_raises_on_classes(self):
1395 # astuple() should raise on a class object
1396 @dataclass
1397 class C:
1398 x: int
1399 y: int
1400 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1401 astuple(C)
1402 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1403 astuple(int)
1404
1405 def test_helper_astuple_copy_values(self):
1406 @dataclass
1407 class C:
1408 x: int
1409 y: List[int] = field(default_factory=list)
1410 initial = []
1411 c = C(1, initial)
1412 t = astuple(c)
1413 self.assertEqual(t[1], initial)
1414 self.assertIsNot(t[1], initial)
1415 c = C(1)
1416 t = astuple(c)
1417 t[1].append(1)
1418 self.assertEqual(c.y, [])
1419
1420 def test_helper_astuple_nested(self):
1421 @dataclass
1422 class UserId:
1423 token: int
1424 group: int
1425 @dataclass
1426 class User:
1427 name: str
1428 id: UserId
1429 u = User('Joe', UserId(123, 1))
1430 t = astuple(u)
1431 self.assertEqual(t, ('Joe', (123, 1)))
1432 self.assertIsNot(astuple(u), astuple(u))
1433 u.id.group = 2
1434 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1435
1436 def test_helper_astuple_builtin_containers(self):
1437 @dataclass
1438 class User:
1439 name: str
1440 id: int
1441 @dataclass
1442 class GroupList:
1443 id: int
1444 users: List[User]
1445 @dataclass
1446 class GroupTuple:
1447 id: int
1448 users: Tuple[User, ...]
1449 @dataclass
1450 class GroupDict:
1451 id: int
1452 users: Dict[str, User]
1453 a = User('Alice', 1)
1454 b = User('Bob', 2)
1455 gl = GroupList(0, [a, b])
1456 gt = GroupTuple(0, (a, b))
1457 gd = GroupDict(0, {'first': a, 'second': b})
1458 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1459 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1460 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1461
1462 def test_helper_astuple_builtin_containers(self):
1463 @dataclass
1464 class Child:
1465 d: object
1466
1467 @dataclass
1468 class Parent:
1469 child: Child
1470
1471 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1472 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1473
1474 def test_helper_astuple_factory(self):
1475 @dataclass
1476 class C:
1477 x: int
1478 y: int
1479 NT = namedtuple('NT', 'x y')
1480 def nt(lst):
1481 return NT(*lst)
1482 c = C(1, 2)
1483 t = astuple(c, tuple_factory=nt)
1484 self.assertEqual(t, NT(1, 2))
1485 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1486 c.x = 42
1487 t = astuple(c, tuple_factory=nt)
1488 self.assertEqual(t, NT(42, 2))
1489 self.assertIs(type(t), NT)
1490
1491 def test_dynamic_class_creation(self):
Miss Islington (bot)5666a552018-03-25 06:27:50 -07001492 cls_dict = {'__annotations__': {'x':int, 'y':int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001493 }
1494
1495 # Create the class.
1496 cls = type('C', (), cls_dict)
1497
1498 # Make it a dataclass.
1499 cls1 = dataclass(cls)
1500
1501 self.assertEqual(cls1, cls)
1502 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1503
1504 def test_dynamic_class_creation_using_field(self):
Miss Islington (bot)5666a552018-03-25 06:27:50 -07001505 cls_dict = {'__annotations__': {'x':int, 'y':int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001506 'y': field(default=5),
1507 }
1508
1509 # Create the class.
1510 cls = type('C', (), cls_dict)
1511
1512 # Make it a dataclass.
1513 cls1 = dataclass(cls)
1514
1515 self.assertEqual(cls1, cls)
1516 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1517
1518 def test_init_in_order(self):
1519 @dataclass
1520 class C:
1521 a: int
1522 b: int = field()
1523 c: list = field(default_factory=list, init=False)
1524 d: list = field(default_factory=list)
1525 e: int = field(default=4, init=False)
1526 f: int = 4
1527
1528 calls = []
1529 def setattr(self, name, value):
1530 calls.append((name, value))
1531
1532 C.__setattr__ = setattr
1533 c = C(0, 1)
1534 self.assertEqual(('a', 0), calls[0])
1535 self.assertEqual(('b', 1), calls[1])
1536 self.assertEqual(('c', []), calls[2])
1537 self.assertEqual(('d', []), calls[3])
1538 self.assertNotIn(('e', 4), calls)
1539 self.assertEqual(('f', 4), calls[4])
1540
1541 def test_items_in_dicts(self):
1542 @dataclass
1543 class C:
1544 a: int
1545 b: list = field(default_factory=list, init=False)
1546 c: list = field(default_factory=list)
1547 d: int = field(default=4, init=False)
1548 e: int = 0
1549
1550 c = C(0)
1551 # Class dict
1552 self.assertNotIn('a', C.__dict__)
1553 self.assertNotIn('b', C.__dict__)
1554 self.assertNotIn('c', C.__dict__)
1555 self.assertIn('d', C.__dict__)
1556 self.assertEqual(C.d, 4)
1557 self.assertIn('e', C.__dict__)
1558 self.assertEqual(C.e, 0)
1559 # Instance dict
1560 self.assertIn('a', c.__dict__)
1561 self.assertEqual(c.a, 0)
1562 self.assertIn('b', c.__dict__)
1563 self.assertEqual(c.b, [])
1564 self.assertIn('c', c.__dict__)
1565 self.assertEqual(c.c, [])
1566 self.assertNotIn('d', c.__dict__)
1567 self.assertIn('e', c.__dict__)
1568 self.assertEqual(c.e, 0)
1569
1570 def test_alternate_classmethod_constructor(self):
1571 # Since __post_init__ can't take params, use a classmethod
1572 # alternate constructor. This is mostly an example to show how
1573 # to use this technique.
1574 @dataclass
1575 class C:
1576 x: int
1577 @classmethod
1578 def from_file(cls, filename):
1579 # In a real example, create a new instance
1580 # and populate 'x' from contents of a file.
1581 value_in_file = 20
1582 return cls(value_in_file)
1583
1584 self.assertEqual(C.from_file('filename').x, 20)
1585
1586 def test_field_metadata_default(self):
1587 # Make sure the default metadata is read-only and of
1588 # zero length.
1589 @dataclass
1590 class C:
1591 i: int
1592
1593 self.assertFalse(fields(C)[0].metadata)
1594 self.assertEqual(len(fields(C)[0].metadata), 0)
1595 with self.assertRaisesRegex(TypeError,
1596 'does not support item assignment'):
1597 fields(C)[0].metadata['test'] = 3
1598
1599 def test_field_metadata_mapping(self):
1600 # Make sure only a mapping can be passed as metadata
1601 # zero length.
1602 with self.assertRaises(TypeError):
1603 @dataclass
1604 class C:
1605 i: int = field(metadata=0)
1606
1607 # Make sure an empty dict works
1608 @dataclass
1609 class C:
1610 i: int = field(metadata={})
1611 self.assertFalse(fields(C)[0].metadata)
1612 self.assertEqual(len(fields(C)[0].metadata), 0)
1613 with self.assertRaisesRegex(TypeError,
1614 'does not support item assignment'):
1615 fields(C)[0].metadata['test'] = 3
1616
1617 # Make sure a non-empty dict works.
1618 @dataclass
1619 class C:
1620 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1621 self.assertEqual(len(fields(C)[0].metadata), 3)
1622 self.assertEqual(fields(C)[0].metadata['test'], 10)
1623 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1624 self.assertEqual(fields(C)[0].metadata[3], 'three')
1625 with self.assertRaises(KeyError):
1626 # Non-existent key.
1627 fields(C)[0].metadata['baz']
1628 with self.assertRaisesRegex(TypeError,
1629 'does not support item assignment'):
1630 fields(C)[0].metadata['test'] = 3
1631
1632 def test_field_metadata_custom_mapping(self):
1633 # Try a custom mapping.
1634 class SimpleNameSpace:
1635 def __init__(self, **kw):
1636 self.__dict__.update(kw)
1637
1638 def __getitem__(self, item):
1639 if item == 'xyzzy':
1640 return 'plugh'
1641 return getattr(self, item)
1642
1643 def __len__(self):
1644 return self.__dict__.__len__()
1645
1646 @dataclass
1647 class C:
1648 i: int = field(metadata=SimpleNameSpace(a=10))
1649
1650 self.assertEqual(len(fields(C)[0].metadata), 1)
1651 self.assertEqual(fields(C)[0].metadata['a'], 10)
1652 with self.assertRaises(AttributeError):
1653 fields(C)[0].metadata['b']
1654 # Make sure we're still talking to our custom mapping.
1655 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1656
1657 def test_generic_dataclasses(self):
1658 T = TypeVar('T')
1659
1660 @dataclass
1661 class LabeledBox(Generic[T]):
1662 content: T
1663 label: str = '<unknown>'
1664
1665 box = LabeledBox(42)
1666 self.assertEqual(box.content, 42)
1667 self.assertEqual(box.label, '<unknown>')
1668
1669 # subscripting the resulting class should work, etc.
1670 Alias = List[LabeledBox[int]]
1671
1672 def test_generic_extending(self):
1673 S = TypeVar('S')
1674 T = TypeVar('T')
1675
1676 @dataclass
1677 class Base(Generic[T, S]):
1678 x: T
1679 y: S
1680
1681 @dataclass
1682 class DataDerived(Base[int, T]):
1683 new_field: str
1684 Alias = DataDerived[str]
1685 c = Alias(0, 'test1', 'test2')
1686 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1687
1688 class NonDataDerived(Base[int, T]):
1689 def new_method(self):
1690 return self.y
1691 Alias = NonDataDerived[float]
1692 c = Alias(10, 1.0)
1693 self.assertEqual(c.new_method(), 1.0)
1694
1695 def test_helper_replace(self):
1696 @dataclass(frozen=True)
1697 class C:
1698 x: int
1699 y: int
1700
1701 c = C(1, 2)
1702 c1 = replace(c, x=3)
1703 self.assertEqual(c1.x, 3)
1704 self.assertEqual(c1.y, 2)
1705
1706 def test_helper_replace_frozen(self):
1707 @dataclass(frozen=True)
1708 class C:
1709 x: int
1710 y: int
1711 z: int = field(init=False, default=10)
1712 t: int = field(init=False, default=100)
1713
1714 c = C(1, 2)
1715 c1 = replace(c, x=3)
1716 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1717 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1718
1719
1720 with self.assertRaisesRegex(ValueError, 'init=False'):
1721 replace(c, x=3, z=20, t=50)
1722 with self.assertRaisesRegex(ValueError, 'init=False'):
1723 replace(c, z=20)
1724 replace(c, x=3, z=20, t=50)
1725
1726 # Make sure the result is still frozen.
1727 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1728 c1.x = 3
1729
1730 # Make sure we can't replace an attribute that doesn't exist,
1731 # if we're also replacing one that does exist. Test this
1732 # here, because setting attributes on frozen instances is
1733 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001734 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001735 "keyword argument 'a'"):
1736 c1 = replace(c, x=20, a=5)
1737
1738 def test_helper_replace_invalid_field_name(self):
1739 @dataclass(frozen=True)
1740 class C:
1741 x: int
1742 y: int
1743
1744 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001745 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001746 "keyword argument 'z'"):
1747 c1 = replace(c, z=3)
1748
1749 def test_helper_replace_invalid_object(self):
1750 @dataclass(frozen=True)
1751 class C:
1752 x: int
1753 y: int
1754
1755 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1756 replace(C, x=3)
1757
1758 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1759 replace(0, x=3)
1760
1761 def test_helper_replace_no_init(self):
1762 @dataclass
1763 class C:
1764 x: int
1765 y: int = field(init=False, default=10)
1766
1767 c = C(1)
1768 c.y = 20
1769
1770 # Make sure y gets the default value.
1771 c1 = replace(c, x=5)
1772 self.assertEqual((c1.x, c1.y), (5, 10))
1773
1774 # Trying to replace y is an error.
1775 with self.assertRaisesRegex(ValueError, 'init=False'):
1776 replace(c, x=2, y=30)
1777 with self.assertRaisesRegex(ValueError, 'init=False'):
1778 replace(c, y=30)
1779
1780 def test_dataclassses_pickleable(self):
1781 global P, Q, R
1782 @dataclass
1783 class P:
1784 x: int
1785 y: int = 0
1786 @dataclass
1787 class Q:
1788 x: int
1789 y: int = field(default=0, init=False)
1790 @dataclass
1791 class R:
1792 x: int
1793 y: List[int] = field(default_factory=list)
1794 q = Q(1)
1795 q.y = 2
1796 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1797 for sample in samples:
1798 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1799 with self.subTest(sample=sample, proto=proto):
1800 new_sample = pickle.loads(pickle.dumps(sample, proto))
1801 self.assertEqual(sample.x, new_sample.x)
1802 self.assertEqual(sample.y, new_sample.y)
1803 self.assertIsNot(sample, new_sample)
1804 new_sample.x = 42
1805 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1806 self.assertEqual(new_sample.x, another_new_sample.x)
1807 self.assertEqual(sample.y, another_new_sample.y)
1808
1809 def test_helper_make_dataclass(self):
1810 C = make_dataclass('C',
1811 [('x', int),
1812 ('y', int, field(default=5))],
1813 namespace={'add_one': lambda self: self.x + 1})
1814 c = C(10)
1815 self.assertEqual((c.x, c.y), (10, 5))
1816 self.assertEqual(c.add_one(), 11)
1817
1818
1819 def test_helper_make_dataclass_no_mutate_namespace(self):
1820 # Make sure a provided namespace isn't mutated.
1821 ns = {}
1822 C = make_dataclass('C',
1823 [('x', int),
1824 ('y', int, field(default=5))],
1825 namespace=ns)
1826 self.assertEqual(ns, {})
1827
1828 def test_helper_make_dataclass_base(self):
1829 class Base1:
1830 pass
1831 class Base2:
1832 pass
1833 C = make_dataclass('C',
1834 [('x', int)],
1835 bases=(Base1, Base2))
1836 c = C(2)
1837 self.assertIsInstance(c, C)
1838 self.assertIsInstance(c, Base1)
1839 self.assertIsInstance(c, Base2)
1840
1841 def test_helper_make_dataclass_base_dataclass(self):
1842 @dataclass
1843 class Base1:
1844 x: int
1845 class Base2:
1846 pass
1847 C = make_dataclass('C',
1848 [('y', int)],
1849 bases=(Base1, Base2))
1850 with self.assertRaisesRegex(TypeError, 'required positional'):
1851 c = C(2)
1852 c = C(1, 2)
1853 self.assertIsInstance(c, C)
1854 self.assertIsInstance(c, Base1)
1855 self.assertIsInstance(c, Base2)
1856
1857 self.assertEqual((c.x, c.y), (1, 2))
1858
1859 def test_helper_make_dataclass_init_var(self):
1860 def post_init(self, y):
1861 self.x *= y
1862
1863 C = make_dataclass('C',
1864 [('x', int),
1865 ('y', InitVar[int]),
1866 ],
1867 namespace={'__post_init__': post_init},
1868 )
1869 c = C(2, 3)
1870 self.assertEqual(vars(c), {'x': 6})
1871 self.assertEqual(len(fields(c)), 1)
1872
1873 def test_helper_make_dataclass_class_var(self):
1874 C = make_dataclass('C',
1875 [('x', int),
1876 ('y', ClassVar[int], 10),
1877 ('z', ClassVar[int], field(default=20)),
1878 ])
1879 c = C(1)
1880 self.assertEqual(vars(c), {'x': 1})
1881 self.assertEqual(len(fields(c)), 1)
1882 self.assertEqual(C.y, 10)
1883 self.assertEqual(C.z, 20)
1884
Eric V. Smithd80b4432018-01-06 17:09:58 -05001885 def test_helper_make_dataclass_other_params(self):
1886 C = make_dataclass('C',
1887 [('x', int),
1888 ('y', ClassVar[int], 10),
1889 ('z', ClassVar[int], field(default=20)),
1890 ],
1891 init=False)
1892 # Make sure we have a repr, but no init.
1893 self.assertNotIn('__init__', vars(C))
1894 self.assertIn('__repr__', vars(C))
1895
1896 # Make sure random other params don't work.
1897 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1898 C = make_dataclass('C',
1899 [],
1900 xxinit=False)
1901
Eric V. Smithed7d4292018-01-06 16:14:03 -05001902 def test_helper_make_dataclass_no_types(self):
1903 C = make_dataclass('Point', ['x', 'y', 'z'])
1904 c = C(1, 2, 3)
1905 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1906 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1907 'y': 'typing.Any',
1908 'z': 'typing.Any'})
1909
1910 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1911 c = C(1, 2, 3)
1912 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1913 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1914 'y': int,
1915 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001916
Eric V. Smithea8fc522018-01-27 19:07:40 -05001917
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001918class TestFieldNoAnnotation(unittest.TestCase):
1919 def test_field_without_annotation(self):
1920 with self.assertRaisesRegex(TypeError,
1921 "'f' is a field but has no type annotation"):
1922 @dataclass
1923 class C:
1924 f = field()
1925
1926 def test_field_without_annotation_but_annotation_in_base(self):
1927 @dataclass
1928 class B:
1929 f: int
1930
1931 with self.assertRaisesRegex(TypeError,
1932 "'f' is a field but has no type annotation"):
1933 # This is still an error: make sure we don't pick up the
1934 # type annotation in the base class.
1935 @dataclass
1936 class C(B):
1937 f = field()
1938
1939 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1940 # Same test, but with the base class not a dataclass.
1941 class B:
1942 f: int
1943
1944 with self.assertRaisesRegex(TypeError,
1945 "'f' is a field but has no type annotation"):
1946 # This is still an error: make sure we don't pick up the
1947 # type annotation in the base class.
1948 @dataclass
1949 class C(B):
1950 f = field()
1951
1952
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001953class TestDocString(unittest.TestCase):
1954 def assertDocStrEqual(self, a, b):
1955 # Because 3.6 and 3.7 differ in how inspect.signature work
1956 # (see bpo #32108), for the time being just compare them with
1957 # whitespace stripped.
1958 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1959
1960 def test_existing_docstring_not_overridden(self):
1961 @dataclass
1962 class C:
1963 """Lorem ipsum"""
1964 x: int
1965
1966 self.assertEqual(C.__doc__, "Lorem ipsum")
1967
1968 def test_docstring_no_fields(self):
1969 @dataclass
1970 class C:
1971 pass
1972
1973 self.assertDocStrEqual(C.__doc__, "C()")
1974
1975 def test_docstring_one_field(self):
1976 @dataclass
1977 class C:
1978 x: int
1979
1980 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1981
1982 def test_docstring_two_fields(self):
1983 @dataclass
1984 class C:
1985 x: int
1986 y: int
1987
1988 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1989
1990 def test_docstring_three_fields(self):
1991 @dataclass
1992 class C:
1993 x: int
1994 y: int
1995 z: str
1996
1997 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1998
1999 def test_docstring_one_field_with_default(self):
2000 @dataclass
2001 class C:
2002 x: int = 3
2003
2004 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2005
2006 def test_docstring_one_field_with_default_none(self):
2007 @dataclass
2008 class C:
2009 x: Union[int, type(None)] = None
2010
2011 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2012
2013 def test_docstring_list_field(self):
2014 @dataclass
2015 class C:
2016 x: List[int]
2017
2018 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2019
2020 def test_docstring_list_field_with_default_factory(self):
2021 @dataclass
2022 class C:
2023 x: List[int] = field(default_factory=list)
2024
2025 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2026
2027 def test_docstring_deque_field(self):
2028 @dataclass
2029 class C:
2030 x: deque
2031
2032 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2033
2034 def test_docstring_deque_field_with_default_factory(self):
2035 @dataclass
2036 class C:
2037 x: deque = field(default_factory=deque)
2038
2039 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2040
2041
Eric V. Smithea8fc522018-01-27 19:07:40 -05002042class TestInit(unittest.TestCase):
2043 def test_base_has_init(self):
2044 class B:
2045 def __init__(self):
2046 self.z = 100
2047 pass
2048
2049 # Make sure that declaring this class doesn't raise an error.
2050 # The issue is that we can't override __init__ in our class,
2051 # but it should be okay to add __init__ to us if our base has
2052 # an __init__.
2053 @dataclass
2054 class C(B):
2055 x: int = 0
2056 c = C(10)
2057 self.assertEqual(c.x, 10)
2058 self.assertNotIn('z', vars(c))
2059
2060 # Make sure that if we don't add an init, the base __init__
2061 # gets called.
2062 @dataclass(init=False)
2063 class C(B):
2064 x: int = 10
2065 c = C()
2066 self.assertEqual(c.x, 10)
2067 self.assertEqual(c.z, 100)
2068
2069 def test_no_init(self):
2070 dataclass(init=False)
2071 class C:
2072 i: int = 0
2073 self.assertEqual(C().i, 0)
2074
2075 dataclass(init=False)
2076 class C:
2077 i: int = 2
2078 def __init__(self):
2079 self.i = 3
2080 self.assertEqual(C().i, 3)
2081
2082 def test_overwriting_init(self):
2083 # If the class has __init__, use it no matter the value of
2084 # init=.
2085
2086 @dataclass
2087 class C:
2088 x: int
2089 def __init__(self, x):
2090 self.x = 2 * x
2091 self.assertEqual(C(3).x, 6)
2092
2093 @dataclass(init=True)
2094 class C:
2095 x: int
2096 def __init__(self, x):
2097 self.x = 2 * x
2098 self.assertEqual(C(4).x, 8)
2099
2100 @dataclass(init=False)
2101 class C:
2102 x: int
2103 def __init__(self, x):
2104 self.x = 2 * x
2105 self.assertEqual(C(5).x, 10)
2106
2107
2108class TestRepr(unittest.TestCase):
2109 def test_repr(self):
2110 @dataclass
2111 class B:
2112 x: int
2113
2114 @dataclass
2115 class C(B):
2116 y: int = 10
2117
2118 o = C(4)
2119 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2120
2121 @dataclass
2122 class D(C):
2123 x: int = 20
2124 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2125
2126 @dataclass
2127 class C:
2128 @dataclass
2129 class D:
2130 i: int
2131 @dataclass
2132 class E:
2133 pass
2134 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2135 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2136
2137 def test_no_repr(self):
2138 # Test a class with no __repr__ and repr=False.
2139 @dataclass(repr=False)
2140 class C:
2141 x: int
2142 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2143 repr(C(3)))
2144
2145 # Test a class with a __repr__ and repr=False.
2146 @dataclass(repr=False)
2147 class C:
2148 x: int
2149 def __repr__(self):
2150 return 'C-class'
2151 self.assertEqual(repr(C(3)), 'C-class')
2152
2153 def test_overwriting_repr(self):
2154 # If the class has __repr__, use it no matter the value of
2155 # repr=.
2156
2157 @dataclass
2158 class C:
2159 x: int
2160 def __repr__(self):
2161 return 'x'
2162 self.assertEqual(repr(C(0)), 'x')
2163
2164 @dataclass(repr=True)
2165 class C:
2166 x: int
2167 def __repr__(self):
2168 return 'x'
2169 self.assertEqual(repr(C(0)), 'x')
2170
2171 @dataclass(repr=False)
2172 class C:
2173 x: int
2174 def __repr__(self):
2175 return 'x'
2176 self.assertEqual(repr(C(0)), 'x')
2177
2178
2179class TestFrozen(unittest.TestCase):
2180 def test_overwriting_frozen(self):
2181 # frozen uses __setattr__ and __delattr__
2182 with self.assertRaisesRegex(TypeError,
2183 'Cannot overwrite attribute __setattr__'):
2184 @dataclass(frozen=True)
2185 class C:
2186 x: int
2187 def __setattr__(self):
2188 pass
2189
2190 with self.assertRaisesRegex(TypeError,
2191 'Cannot overwrite attribute __delattr__'):
2192 @dataclass(frozen=True)
2193 class C:
2194 x: int
2195 def __delattr__(self):
2196 pass
2197
2198 @dataclass(frozen=False)
2199 class C:
2200 x: int
2201 def __setattr__(self, name, value):
2202 self.__dict__['x'] = value * 2
2203 self.assertEqual(C(10).x, 20)
2204
2205
2206class TestEq(unittest.TestCase):
2207 def test_no_eq(self):
2208 # Test a class with no __eq__ and eq=False.
2209 @dataclass(eq=False)
2210 class C:
2211 x: int
2212 self.assertNotEqual(C(0), C(0))
2213 c = C(3)
2214 self.assertEqual(c, c)
2215
2216 # Test a class with an __eq__ and eq=False.
2217 @dataclass(eq=False)
2218 class C:
2219 x: int
2220 def __eq__(self, other):
2221 return other == 10
2222 self.assertEqual(C(3), 10)
2223
2224 def test_overwriting_eq(self):
2225 # If the class has __eq__, use it no matter the value of
2226 # eq=.
2227
2228 @dataclass
2229 class C:
2230 x: int
2231 def __eq__(self, other):
2232 return other == 3
2233 self.assertEqual(C(1), 3)
2234 self.assertNotEqual(C(1), 1)
2235
2236 @dataclass(eq=True)
2237 class C:
2238 x: int
2239 def __eq__(self, other):
2240 return other == 4
2241 self.assertEqual(C(1), 4)
2242 self.assertNotEqual(C(1), 1)
2243
2244 @dataclass(eq=False)
2245 class C:
2246 x: int
2247 def __eq__(self, other):
2248 return other == 5
2249 self.assertEqual(C(1), 5)
2250 self.assertNotEqual(C(1), 1)
2251
2252
2253class TestOrdering(unittest.TestCase):
2254 def test_functools_total_ordering(self):
2255 # Test that functools.total_ordering works with this class.
2256 @total_ordering
2257 @dataclass
2258 class C:
2259 x: int
2260 def __lt__(self, other):
2261 # Perform the test "backward", just to make
2262 # sure this is being called.
2263 return self.x >= other
2264
2265 self.assertLess(C(0), -1)
2266 self.assertLessEqual(C(0), -1)
2267 self.assertGreater(C(0), 1)
2268 self.assertGreaterEqual(C(0), 1)
2269
2270 def test_no_order(self):
2271 # Test that no ordering functions are added by default.
2272 @dataclass(order=False)
2273 class C:
2274 x: int
2275 # Make sure no order methods are added.
2276 self.assertNotIn('__le__', C.__dict__)
2277 self.assertNotIn('__lt__', C.__dict__)
2278 self.assertNotIn('__ge__', C.__dict__)
2279 self.assertNotIn('__gt__', C.__dict__)
2280
2281 # Test that __lt__ is still called
2282 @dataclass(order=False)
2283 class C:
2284 x: int
2285 def __lt__(self, other):
2286 return False
2287 # Make sure other methods aren't added.
2288 self.assertNotIn('__le__', C.__dict__)
2289 self.assertNotIn('__ge__', C.__dict__)
2290 self.assertNotIn('__gt__', C.__dict__)
2291
2292 def test_overwriting_order(self):
2293 with self.assertRaisesRegex(TypeError,
2294 'Cannot overwrite attribute __lt__'
2295 '.*using functools.total_ordering'):
2296 @dataclass(order=True)
2297 class C:
2298 x: int
2299 def __lt__(self):
2300 pass
2301
2302 with self.assertRaisesRegex(TypeError,
2303 'Cannot overwrite attribute __le__'
2304 '.*using functools.total_ordering'):
2305 @dataclass(order=True)
2306 class C:
2307 x: int
2308 def __le__(self):
2309 pass
2310
2311 with self.assertRaisesRegex(TypeError,
2312 'Cannot overwrite attribute __gt__'
2313 '.*using functools.total_ordering'):
2314 @dataclass(order=True)
2315 class C:
2316 x: int
2317 def __gt__(self):
2318 pass
2319
2320 with self.assertRaisesRegex(TypeError,
2321 'Cannot overwrite attribute __ge__'
2322 '.*using functools.total_ordering'):
2323 @dataclass(order=True)
2324 class C:
2325 x: int
2326 def __ge__(self):
2327 pass
2328
2329class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002330 def test_unsafe_hash(self):
2331 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002332 class C:
2333 x: int
2334 y: str
2335 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2336
Eric V. Smithea8fc522018-01-27 19:07:40 -05002337 def test_hash_rules(self):
2338 def non_bool(value):
2339 # Map to something else that's True, but not a bool.
2340 if value is None:
2341 return None
2342 if value:
2343 return (3,)
2344 return 0
2345
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002346 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2347 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2348 frozen=frozen):
2349 if result != 'exception':
2350 if with_hash:
2351 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2352 class C:
2353 def __hash__(self):
2354 return 0
2355 else:
2356 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2357 class C:
2358 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002359
2360 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002361 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002362 # __hash__ contains the function we generated.
2363 self.assertIn('__hash__', C.__dict__)
2364 self.assertIsNotNone(C.__dict__['__hash__'])
2365
Eric V. Smithea8fc522018-01-27 19:07:40 -05002366 elif result == '':
2367 # __hash__ is not present in our class.
2368 if not with_hash:
2369 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002370
Eric V. Smithea8fc522018-01-27 19:07:40 -05002371 elif result == 'none':
2372 # __hash__ is set to None.
2373 self.assertIn('__hash__', C.__dict__)
2374 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002375
2376 elif result == 'exception':
2377 # Creating the class should cause an exception.
2378 # This only happens with with_hash==True.
2379 assert(with_hash)
2380 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2381 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2382 class C:
2383 def __hash__(self):
2384 return 0
2385
Eric V. Smithea8fc522018-01-27 19:07:40 -05002386 else:
2387 assert False, f'unknown result {result!r}'
2388
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002389 # There are 8 cases of:
2390 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002391 # eq=True/False
2392 # frozen=True/False
2393 # And for each of these, a different result if
2394 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002395 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2396 (False, False, False, '', ''),
2397 (False, False, True, '', ''),
2398 (False, True, False, 'none', ''),
2399 (False, True, True, 'fn', ''),
2400 (True, False, False, 'fn', 'exception'),
2401 (True, False, True, 'fn', 'exception'),
2402 (True, True, False, 'fn', 'exception'),
2403 (True, True, True, 'fn', 'exception'),
2404 ], 1):
2405 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2406 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002407
2408 # Test non-bool truth values, too. This is just to
2409 # make sure the data-driven table in the decorator
2410 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002411 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2412 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002413
2414
2415 def test_eq_only(self):
2416 # If a class defines __eq__, __hash__ is automatically added
2417 # and set to None. This is normal Python behavior, not
2418 # related to dataclasses. Make sure we don't interfere with
2419 # that (see bpo=32546).
2420
2421 @dataclass
2422 class C:
2423 i: int
2424 def __eq__(self, other):
2425 return self.i == other.i
2426 self.assertEqual(C(1), C(1))
2427 self.assertNotEqual(C(1), C(4))
2428
2429 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002430 # unsafe_hash=True.
2431 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002432 class C:
2433 i: int
2434 def __eq__(self, other):
2435 return self.i == other.i
2436 self.assertEqual(C(1), C(1.0))
2437 self.assertEqual(hash(C(1)), hash(C(1.0)))
2438
2439 # And check that the classes __eq__ is being used, despite
2440 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002441 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002442 class C:
2443 i: int
2444 def __eq__(self, other):
2445 return self.i == 3 and self.i == other.i
2446 self.assertEqual(C(3), C(3))
2447 self.assertNotEqual(C(1), C(1))
2448 self.assertEqual(hash(C(1)), hash(C(1.0)))
2449
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002450 def test_0_field_hash(self):
2451 @dataclass(frozen=True)
2452 class C:
2453 pass
2454 self.assertEqual(hash(C()), hash(()))
2455
2456 @dataclass(unsafe_hash=True)
2457 class C:
2458 pass
2459 self.assertEqual(hash(C()), hash(()))
2460
2461 def test_1_field_hash(self):
2462 @dataclass(frozen=True)
2463 class C:
2464 x: int
2465 self.assertEqual(hash(C(4)), hash((4,)))
2466 self.assertEqual(hash(C(42)), hash((42,)))
2467
2468 @dataclass(unsafe_hash=True)
2469 class C:
2470 x: int
2471 self.assertEqual(hash(C(4)), hash((4,)))
2472 self.assertEqual(hash(C(42)), hash((42,)))
2473
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002474 def test_hash_no_args(self):
2475 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002476 # make sure that if the @dataclass parameter name is changed
2477 # or the non-default hashing behavior changes, the default
2478 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002479
2480 class Base:
2481 def __hash__(self):
2482 return 301
2483
2484 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)1a579062018-02-25 19:09:05 -08002485 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002486 for frozen, eq, base, expected in [
2487 (None, None, object, 'unhashable'),
2488 (None, None, Base, 'unhashable'),
2489 (None, False, object, 'object'),
2490 (None, False, Base, 'base'),
2491 (None, True, object, 'unhashable'),
2492 (None, True, Base, 'unhashable'),
2493 (False, None, object, 'unhashable'),
2494 (False, None, Base, 'unhashable'),
2495 (False, False, object, 'object'),
2496 (False, False, Base, 'base'),
2497 (False, True, object, 'unhashable'),
2498 (False, True, Base, 'unhashable'),
2499 (True, None, object, 'tuple'),
2500 (True, None, Base, 'tuple'),
2501 (True, False, object, 'object'),
2502 (True, False, Base, 'base'),
2503 (True, True, object, 'tuple'),
2504 (True, True, Base, 'tuple'),
2505 ]:
2506
2507 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2508 # First, create the class.
2509 if frozen is None and eq is None:
2510 @dataclass
2511 class C(base):
2512 i: int
2513 elif frozen is None:
2514 @dataclass(eq=eq)
2515 class C(base):
2516 i: int
2517 elif eq is None:
2518 @dataclass(frozen=frozen)
2519 class C(base):
2520 i: int
2521 else:
2522 @dataclass(frozen=frozen, eq=eq)
2523 class C(base):
2524 i: int
2525
2526 # Now, make sure it hashes as expected.
2527 if expected == 'unhashable':
2528 c = C(10)
2529 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2530 hash(c)
2531
2532 elif expected == 'base':
2533 self.assertEqual(hash(C(10)), 301)
2534
2535 elif expected == 'object':
2536 # I'm not sure what test to use here. object's
2537 # hash isn't based on id(), so calling hash()
2538 # won't tell us much. So, just check the function
2539 # used is object's.
2540 self.assertIs(C.__hash__, object.__hash__)
2541
2542 elif expected == 'tuple':
2543 self.assertEqual(hash(C(42)), hash((42,)))
2544
2545 else:
2546 assert False, f'unknown value for expected={expected!r}'
2547
Eric V. Smithea8fc522018-01-27 19:07:40 -05002548
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002549class TestFrozen(unittest.TestCase):
2550 def test_frozen(self):
2551 @dataclass(frozen=True)
2552 class C:
2553 i: int
2554
2555 c = C(10)
2556 self.assertEqual(c.i, 10)
2557 with self.assertRaises(FrozenInstanceError):
2558 c.i = 5
2559 self.assertEqual(c.i, 10)
2560
2561 def test_inherit(self):
2562 @dataclass(frozen=True)
2563 class C:
2564 i: int
2565
2566 @dataclass(frozen=True)
2567 class D(C):
2568 j: int
2569
2570 d = D(0, 10)
2571 with self.assertRaises(FrozenInstanceError):
2572 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002573 with self.assertRaises(FrozenInstanceError):
2574 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002575 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002576 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002577
Miss Islington (bot)45648312018-03-18 18:03:36 -07002578 # Test both ways: with an intermediate normal (non-dataclass)
2579 # class and without an intermediate class.
2580 def test_inherit_nonfrozen_from_frozen(self):
2581 for intermediate_class in [True, False]:
2582 with self.subTest(intermediate_class=intermediate_class):
2583 @dataclass(frozen=True)
2584 class C:
2585 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002586
Miss Islington (bot)45648312018-03-18 18:03:36 -07002587 if intermediate_class:
2588 class I(C): pass
2589 else:
2590 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002591
Miss Islington (bot)45648312018-03-18 18:03:36 -07002592 with self.assertRaisesRegex(TypeError,
2593 'cannot inherit non-frozen dataclass from a frozen one'):
2594 @dataclass
2595 class D(I):
2596 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002597
Miss Islington (bot)45648312018-03-18 18:03:36 -07002598 def test_inherit_frozen_from_nonfrozen(self):
2599 for intermediate_class in [True, False]:
2600 with self.subTest(intermediate_class=intermediate_class):
2601 @dataclass
2602 class C:
2603 i: int
2604
2605 if intermediate_class:
2606 class I(C): pass
2607 else:
2608 I = C
2609
2610 with self.assertRaisesRegex(TypeError,
2611 'cannot inherit frozen dataclass from a non-frozen one'):
2612 @dataclass(frozen=True)
2613 class D(I):
2614 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002615
2616 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002617 for intermediate_class in [True, False]:
2618 with self.subTest(intermediate_class=intermediate_class):
2619 class C:
2620 pass
2621
2622 if intermediate_class:
2623 class I(C): pass
2624 else:
2625 I = C
2626
2627 @dataclass(frozen=True)
2628 class D(I):
2629 i: int
2630
2631 d = D(10)
2632 with self.assertRaises(FrozenInstanceError):
2633 d.i = 5
2634
2635 def test_non_frozen_normal_derived(self):
2636 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002637
2638 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002639 class D:
2640 x: int
2641 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002642
Miss Islington (bot)45648312018-03-18 18:03:36 -07002643 class S(D):
2644 pass
2645
2646 s = S(3)
2647 self.assertEqual(s.x, 3)
2648 self.assertEqual(s.y, 10)
2649 s.cached = True
2650
2651 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002652 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002653 s.x = 5
2654 with self.assertRaises(FrozenInstanceError):
2655 s.y = 5
2656 self.assertEqual(s.x, 3)
2657 self.assertEqual(s.y, 10)
2658 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002659
2660
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002661class TestSlots(unittest.TestCase):
2662 def test_simple(self):
2663 @dataclass
2664 class C:
2665 __slots__ = ('x',)
2666 x: Any
2667
2668 # There was a bug where a variable in a slot was assumed
2669 # to also have a default value (of type types.MemberDescriptorType).
2670 with self.assertRaisesRegex(TypeError,
Miss Islington (bot)5729b9c2018-03-24 20:23:00 -07002671 r"__init__\(\) missing 1 required positional argument: 'x'"):
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002672 C()
2673
2674 # We can create an instance, and assign to x.
2675 c = C(10)
2676 self.assertEqual(c.x, 10)
2677 c.x = 5
2678 self.assertEqual(c.x, 5)
2679
2680 # We can't assign to anything else.
2681 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2682 c.y = 5
2683
2684 def test_derived_added_field(self):
2685 # See bpo-33100.
2686 @dataclass
2687 class Base:
2688 __slots__ = ('x',)
2689 x: Any
2690
2691 @dataclass
2692 class Derived(Base):
2693 x: int
2694 y: int
2695
2696 d = Derived(1, 2)
2697 self.assertEqual((d.x, d.y), (1, 2))
2698
2699 # We can add a new field to the derived instance.
2700 d.z = 10
2701
2702
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002703if __name__ == '__main__':
2704 unittest.main()