blob: 69ace36c2c58d647d417f8340de56f1b09d482ce [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
11from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
12from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
15# Just any custom exception we can catch.
16class CustomError(Exception): pass
17
18class TestCase(unittest.TestCase):
19 def test_no_fields(self):
20 @dataclass
21 class C:
22 pass
23
24 o = C()
25 self.assertEqual(len(fields(C)), 0)
26
27 def test_one_field_no_default(self):
28 @dataclass
29 class C:
30 x: int
31
32 o = C(42)
33 self.assertEqual(o.x, 42)
34
35 def test_named_init_params(self):
36 @dataclass
37 class C:
38 x: int
39
40 o = C(x=32)
41 self.assertEqual(o.x, 32)
42
43 def test_two_fields_one_default(self):
44 @dataclass
45 class C:
46 x: int
47 y: int = 0
48
49 o = C(3)
50 self.assertEqual((o.x, o.y), (3, 0))
51
52 # Non-defaults following defaults.
53 with self.assertRaisesRegex(TypeError,
54 "non-default argument 'y' follows "
55 "default argument"):
56 @dataclass
57 class C:
58 x: int = 0
59 y: int
60
61 # A derived class adds a non-default field after a default one.
62 with self.assertRaisesRegex(TypeError,
63 "non-default argument 'y' follows "
64 "default argument"):
65 @dataclass
66 class B:
67 x: int = 0
68
69 @dataclass
70 class C(B):
71 y: int
72
73 # Override a base class field and add a default to
74 # a field which didn't use to have a default.
75 with self.assertRaisesRegex(TypeError,
76 "non-default argument 'y' follows "
77 "default argument"):
78 @dataclass
79 class B:
80 x: int
81 y: int
82
83 @dataclass
84 class C(B):
85 x: int = 0
86
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080087 def test_overwrite_hash(self):
88 # Test that declaring this class isn't an error. It should
89 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050090 @dataclass(frozen=True)
91 class C:
92 x: int
93 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080094 return 301
95 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -050096
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080097 # Test that declaring this class isn't an error. It should
98 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050099 @dataclass(frozen=True)
100 class C:
101 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800102 def __eq__(self, other):
103 return False
104 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500105
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800106 # But this one should generate an exception, because with
107 # unsafe_hash=True, it's an error to have a __hash__ defined.
108 with self.assertRaisesRegex(TypeError,
109 'Cannot overwrite attribute __hash__'):
110 @dataclass(unsafe_hash=True)
111 class C:
112 def __hash__(self):
113 pass
114
115 # Creating this class should not generate an exception,
116 # because even though __hash__ exists before @dataclass is
117 # called, (due to __eq__ being defined), since it's None
118 # that's okay.
119 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500120 class C:
121 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800122 def __eq__(self):
123 pass
124 # The generated hash function works as we'd expect.
125 self.assertEqual(hash(C(10)), hash((10,)))
126
127 # Creating this class should generate an exception, because
128 # __hash__ exists and is not None, which it would be if it had
129 # been auto-generated do due __eq__ being defined.
130 with self.assertRaisesRegex(TypeError,
131 'Cannot overwrite attribute __hash__'):
132 @dataclass(unsafe_hash=True)
133 class C:
134 x: int
135 def __eq__(self):
136 pass
137 def __hash__(self):
138 pass
139
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500140
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500141 def test_overwrite_fields_in_derived_class(self):
142 # Note that x from C1 replaces x in Base, but the order remains
143 # the same as defined in Base.
144 @dataclass
145 class Base:
146 x: Any = 15.0
147 y: int = 0
148
149 @dataclass
150 class C1(Base):
151 z: int = 10
152 x: int = 15
153
154 o = Base()
155 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
156
157 o = C1()
158 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
159
160 o = C1(x=5)
161 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
162
163 def test_field_named_self(self):
164 @dataclass
165 class C:
166 self: str
167 c=C('foo')
168 self.assertEqual(c.self, 'foo')
169
170 # Make sure the first parameter is not named 'self'.
171 sig = inspect.signature(C.__init__)
172 first = next(iter(sig.parameters))
173 self.assertNotEqual('self', first)
174
175 # But we do use 'self' if no field named self.
176 @dataclass
177 class C:
178 selfx: str
179
180 # Make sure the first parameter is named 'self'.
181 sig = inspect.signature(C.__init__)
182 first = next(iter(sig.parameters))
183 self.assertEqual('self', first)
184
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500185 def test_0_field_compare(self):
186 # Ensure that order=False is the default.
187 @dataclass
188 class C0:
189 pass
190
191 @dataclass(order=False)
192 class C1:
193 pass
194
195 for cls in [C0, C1]:
196 with self.subTest(cls=cls):
197 self.assertEqual(cls(), cls())
198 for idx, fn in enumerate([lambda a, b: a < b,
199 lambda a, b: a <= b,
200 lambda a, b: a > b,
201 lambda a, b: a >= b]):
202 with self.subTest(idx=idx):
203 with self.assertRaisesRegex(TypeError,
204 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
205 fn(cls(), cls())
206
207 @dataclass(order=True)
208 class C:
209 pass
210 self.assertLessEqual(C(), C())
211 self.assertGreaterEqual(C(), C())
212
213 def test_1_field_compare(self):
214 # Ensure that order=False is the default.
215 @dataclass
216 class C0:
217 x: int
218
219 @dataclass(order=False)
220 class C1:
221 x: int
222
223 for cls in [C0, C1]:
224 with self.subTest(cls=cls):
225 self.assertEqual(cls(1), cls(1))
226 self.assertNotEqual(cls(0), cls(1))
227 for idx, fn in enumerate([lambda a, b: a < b,
228 lambda a, b: a <= b,
229 lambda a, b: a > b,
230 lambda a, b: a >= b]):
231 with self.subTest(idx=idx):
232 with self.assertRaisesRegex(TypeError,
233 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
234 fn(cls(0), cls(0))
235
236 @dataclass(order=True)
237 class C:
238 x: int
239 self.assertLess(C(0), C(1))
240 self.assertLessEqual(C(0), C(1))
241 self.assertLessEqual(C(1), C(1))
242 self.assertGreater(C(1), C(0))
243 self.assertGreaterEqual(C(1), C(0))
244 self.assertGreaterEqual(C(1), C(1))
245
246 def test_simple_compare(self):
247 # Ensure that order=False is the default.
248 @dataclass
249 class C0:
250 x: int
251 y: int
252
253 @dataclass(order=False)
254 class C1:
255 x: int
256 y: int
257
258 for cls in [C0, C1]:
259 with self.subTest(cls=cls):
260 self.assertEqual(cls(0, 0), cls(0, 0))
261 self.assertEqual(cls(1, 2), cls(1, 2))
262 self.assertNotEqual(cls(1, 0), cls(0, 0))
263 self.assertNotEqual(cls(1, 0), cls(1, 1))
264 for idx, fn in enumerate([lambda a, b: a < b,
265 lambda a, b: a <= b,
266 lambda a, b: a > b,
267 lambda a, b: a >= b]):
268 with self.subTest(idx=idx):
269 with self.assertRaisesRegex(TypeError,
270 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
271 fn(cls(0, 0), cls(0, 0))
272
273 @dataclass(order=True)
274 class C:
275 x: int
276 y: int
277
278 for idx, fn in enumerate([lambda a, b: a == b,
279 lambda a, b: a <= b,
280 lambda a, b: a >= b]):
281 with self.subTest(idx=idx):
282 self.assertTrue(fn(C(0, 0), C(0, 0)))
283
284 for idx, fn in enumerate([lambda a, b: a < b,
285 lambda a, b: a <= b,
286 lambda a, b: a != b]):
287 with self.subTest(idx=idx):
288 self.assertTrue(fn(C(0, 0), C(0, 1)))
289 self.assertTrue(fn(C(0, 1), C(1, 0)))
290 self.assertTrue(fn(C(1, 0), C(1, 1)))
291
292 for idx, fn in enumerate([lambda a, b: a > b,
293 lambda a, b: a >= b,
294 lambda a, b: a != b]):
295 with self.subTest(idx=idx):
296 self.assertTrue(fn(C(0, 1), C(0, 0)))
297 self.assertTrue(fn(C(1, 0), C(0, 1)))
298 self.assertTrue(fn(C(1, 1), C(1, 0)))
299
300 def test_compare_subclasses(self):
301 # Comparisons fail for subclasses, even if no fields
302 # are added.
303 @dataclass
304 class B:
305 i: int
306
307 @dataclass
308 class C(B):
309 pass
310
311 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
312 (lambda a, b: a != b, True)]):
313 with self.subTest(idx=idx):
314 self.assertEqual(fn(B(0), C(0)), expected)
315
316 for idx, fn in enumerate([lambda a, b: a < b,
317 lambda a, b: a <= b,
318 lambda a, b: a > b,
319 lambda a, b: a >= b]):
320 with self.subTest(idx=idx):
321 with self.assertRaisesRegex(TypeError,
322 "not supported between instances of 'B' and 'C'"):
323 fn(B(0), C(0))
324
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500325 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500326 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500327 for (eq, order, result ) in [
328 (False, False, 'neither'),
329 (False, True, 'exception'),
330 (True, False, 'eq_only'),
331 (True, True, 'both'),
332 ]:
333 with self.subTest(eq=eq, order=order):
334 if result == 'exception':
335 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
336 @dataclass(eq=eq, order=order)
337 class C:
338 pass
339 else:
340 @dataclass(eq=eq, order=order)
341 class C:
342 pass
343
344 if result == 'neither':
345 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500346 self.assertNotIn('__lt__', C.__dict__)
347 self.assertNotIn('__le__', C.__dict__)
348 self.assertNotIn('__gt__', C.__dict__)
349 self.assertNotIn('__ge__', C.__dict__)
350 elif result == 'both':
351 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500352 self.assertIn('__lt__', C.__dict__)
353 self.assertIn('__le__', C.__dict__)
354 self.assertIn('__gt__', C.__dict__)
355 self.assertIn('__ge__', C.__dict__)
356 elif result == 'eq_only':
357 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500358 self.assertNotIn('__lt__', C.__dict__)
359 self.assertNotIn('__le__', C.__dict__)
360 self.assertNotIn('__gt__', C.__dict__)
361 self.assertNotIn('__ge__', C.__dict__)
362 else:
363 assert False, f'unknown result {result!r}'
364
365 def test_field_no_default(self):
366 @dataclass
367 class C:
368 x: int = field()
369
370 self.assertEqual(C(5).x, 5)
371
372 with self.assertRaisesRegex(TypeError,
373 r"__init__\(\) missing 1 required "
374 "positional argument: 'x'"):
375 C()
376
377 def test_field_default(self):
378 default = object()
379 @dataclass
380 class C:
381 x: object = field(default=default)
382
383 self.assertIs(C.x, default)
384 c = C(10)
385 self.assertEqual(c.x, 10)
386
387 # If we delete the instance attribute, we should then see the
388 # class attribute.
389 del c.x
390 self.assertIs(c.x, default)
391
392 self.assertIs(C().x, default)
393
394 def test_not_in_repr(self):
395 @dataclass
396 class C:
397 x: int = field(repr=False)
398 with self.assertRaises(TypeError):
399 C()
400 c = C(10)
401 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
402
403 @dataclass
404 class C:
405 x: int = field(repr=False)
406 y: int
407 c = C(10, 20)
408 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
409
410 def test_not_in_compare(self):
411 @dataclass
412 class C:
413 x: int = 0
414 y: int = field(compare=False, default=4)
415
416 self.assertEqual(C(), C(0, 20))
417 self.assertEqual(C(1, 10), C(1, 20))
418 self.assertNotEqual(C(3), C(4, 10))
419 self.assertNotEqual(C(3, 10), C(4, 10))
420
421 def test_hash_field_rules(self):
422 # Test all 6 cases of:
423 # hash=True/False/None
424 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800425 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500426 (True, False, 'field' ),
427 (True, True, 'field' ),
428 (False, False, 'absent'),
429 (False, True, 'absent'),
430 (None, False, 'absent'),
431 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800432 ]:
433 with self.subTest(hash=hash_, compare=compare):
434 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500435 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800436 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500437
438 if result == 'field':
439 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800440 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500441 elif result == 'absent':
442 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800443 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500444 else:
445 assert False, f'unknown result {result!r}'
446
447 def test_init_false_no_default(self):
448 # If init=False and no default value, then the field won't be
449 # present in the instance.
450 @dataclass
451 class C:
452 x: int = field(init=False)
453
454 self.assertNotIn('x', C().__dict__)
455
456 @dataclass
457 class C:
458 x: int
459 y: int = 0
460 z: int = field(init=False)
461 t: int = 10
462
463 self.assertNotIn('z', C(0).__dict__)
464 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
465
466 def test_class_marker(self):
467 @dataclass
468 class C:
469 x: int
470 y: str = field(init=False, default=None)
471 z: str = field(repr=False)
472
473 the_fields = fields(C)
474 # the_fields is a tuple of 3 items, each value
475 # is in __annotations__.
476 self.assertIsInstance(the_fields, tuple)
477 for f in the_fields:
478 self.assertIs(type(f), Field)
479 self.assertIn(f.name, C.__annotations__)
480
481 self.assertEqual(len(the_fields), 3)
482
483 self.assertEqual(the_fields[0].name, 'x')
484 self.assertEqual(the_fields[0].type, int)
485 self.assertFalse(hasattr(C, 'x'))
486 self.assertTrue (the_fields[0].init)
487 self.assertTrue (the_fields[0].repr)
488 self.assertEqual(the_fields[1].name, 'y')
489 self.assertEqual(the_fields[1].type, str)
490 self.assertIsNone(getattr(C, 'y'))
491 self.assertFalse(the_fields[1].init)
492 self.assertTrue (the_fields[1].repr)
493 self.assertEqual(the_fields[2].name, 'z')
494 self.assertEqual(the_fields[2].type, str)
495 self.assertFalse(hasattr(C, 'z'))
496 self.assertTrue (the_fields[2].init)
497 self.assertFalse(the_fields[2].repr)
498
499 def test_field_order(self):
500 @dataclass
501 class B:
502 a: str = 'B:a'
503 b: str = 'B:b'
504 c: str = 'B:c'
505
506 @dataclass
507 class C(B):
508 b: str = 'C:b'
509
510 self.assertEqual([(f.name, f.default) for f in fields(C)],
511 [('a', 'B:a'),
512 ('b', 'C:b'),
513 ('c', 'B:c')])
514
515 @dataclass
516 class D(B):
517 c: str = 'D:c'
518
519 self.assertEqual([(f.name, f.default) for f in fields(D)],
520 [('a', 'B:a'),
521 ('b', 'B:b'),
522 ('c', 'D:c')])
523
524 @dataclass
525 class E(D):
526 a: str = 'E:a'
527 d: str = 'E:d'
528
529 self.assertEqual([(f.name, f.default) for f in fields(E)],
530 [('a', 'E:a'),
531 ('b', 'B:b'),
532 ('c', 'D:c'),
533 ('d', 'E:d')])
534
535 def test_class_attrs(self):
536 # We only have a class attribute if a default value is
537 # specified, either directly or via a field with a default.
538 default = object()
539 @dataclass
540 class C:
541 x: int
542 y: int = field(repr=False)
543 z: object = default
544 t: int = field(default=100)
545
546 self.assertFalse(hasattr(C, 'x'))
547 self.assertFalse(hasattr(C, 'y'))
548 self.assertIs (C.z, default)
549 self.assertEqual(C.t, 100)
550
551 def test_disallowed_mutable_defaults(self):
552 # For the known types, don't allow mutable default values.
553 for typ, empty, non_empty in [(list, [], [1]),
554 (dict, {}, {0:1}),
555 (set, set(), set([1])),
556 ]:
557 with self.subTest(typ=typ):
558 # Can't use a zero-length value.
559 with self.assertRaisesRegex(ValueError,
560 f'mutable default {typ} for field '
561 'x is not allowed'):
562 @dataclass
563 class Point:
564 x: typ = empty
565
566
567 # Nor a non-zero-length value
568 with self.assertRaisesRegex(ValueError,
569 f'mutable default {typ} for field '
570 'y is not allowed'):
571 @dataclass
572 class Point:
573 y: typ = non_empty
574
575 # Check subtypes also fail.
576 class Subclass(typ): pass
577
578 with self.assertRaisesRegex(ValueError,
579 f"mutable default .*Subclass'>"
580 ' for field z is not allowed'
581 ):
582 @dataclass
583 class Point:
584 z: typ = Subclass()
585
586 # Because this is a ClassVar, it can be mutable.
587 @dataclass
588 class C:
589 z: ClassVar[typ] = typ()
590
591 # Because this is a ClassVar, it can be mutable.
592 @dataclass
593 class C:
594 x: ClassVar[typ] = Subclass()
595
596
597 def test_deliberately_mutable_defaults(self):
598 # If a mutable default isn't in the known list of
599 # (list, dict, set), then it's okay.
600 class Mutable:
601 def __init__(self):
602 self.l = []
603
604 @dataclass
605 class C:
606 x: Mutable
607
608 # These 2 instances will share this value of x.
609 lst = Mutable()
610 o1 = C(lst)
611 o2 = C(lst)
612 self.assertEqual(o1, o2)
613 o1.x.l.extend([1, 2])
614 self.assertEqual(o1, o2)
615 self.assertEqual(o1.x.l, [1, 2])
616 self.assertIs(o1.x, o2.x)
617
618 def test_no_options(self):
619 # call with dataclass()
620 @dataclass()
621 class C:
622 x: int
623
624 self.assertEqual(C(42).x, 42)
625
626 def test_not_tuple(self):
627 # Make sure we can't be compared to a tuple.
628 @dataclass
629 class Point:
630 x: int
631 y: int
632 self.assertNotEqual(Point(1, 2), (1, 2))
633
634 # And that we can't compare to another unrelated dataclass
635 @dataclass
636 class C:
637 x: int
638 y: int
639 self.assertNotEqual(Point(1, 3), C(1, 3))
640
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500641 def test_not_tuple(self):
642 # Test that some of the problems with namedtuple don't happen
643 # here.
644 @dataclass
645 class Point3D:
646 x: int
647 y: int
648 z: int
649
650 @dataclass
651 class Date:
652 year: int
653 month: int
654 day: int
655
656 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
657 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
658
659 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200660 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500661 x, y, z = Point3D(4, 5, 6)
662
Eric V. Smith7c99e932018-01-28 19:18:55 -0500663 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500664 # equal.
665 @dataclass
666 class Point3Dv1:
667 x: int = 0
668 y: int = 0
669 z: int = 0
670 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
671
672 def test_function_annotations(self):
673 # Some dummy class and instance to use as a default.
674 class F:
675 pass
676 f = F()
677
678 def validate_class(cls):
679 # First, check __annotations__, even though they're not
680 # function annotations.
681 self.assertEqual(cls.__annotations__['i'], int)
682 self.assertEqual(cls.__annotations__['j'], str)
683 self.assertEqual(cls.__annotations__['k'], F)
684 self.assertEqual(cls.__annotations__['l'], float)
685 self.assertEqual(cls.__annotations__['z'], complex)
686
687 # Verify __init__.
688
689 signature = inspect.signature(cls.__init__)
690 # Check the return type, should be None
691 self.assertIs(signature.return_annotation, None)
692
693 # Check each parameter.
694 params = iter(signature.parameters.values())
695 param = next(params)
696 # This is testing an internal name, and probably shouldn't be tested.
697 self.assertEqual(param.name, 'self')
698 param = next(params)
699 self.assertEqual(param.name, 'i')
700 self.assertIs (param.annotation, int)
701 self.assertEqual(param.default, inspect.Parameter.empty)
702 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
703 param = next(params)
704 self.assertEqual(param.name, 'j')
705 self.assertIs (param.annotation, str)
706 self.assertEqual(param.default, inspect.Parameter.empty)
707 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
708 param = next(params)
709 self.assertEqual(param.name, 'k')
710 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500711 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500712 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
713 param = next(params)
714 self.assertEqual(param.name, 'l')
715 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500716 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500717 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
718 self.assertRaises(StopIteration, next, params)
719
720
721 @dataclass
722 class C:
723 i: int
724 j: str
725 k: F = f
726 l: float=field(default=None)
727 z: complex=field(default=3+4j, init=False)
728
729 validate_class(C)
730
731 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800732 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500733 class C:
734 i: int
735 j: str
736 k: F = f
737 l: float=field(default=None)
738 z: complex=field(default=3+4j, init=False)
739
740 validate_class(C)
741
Eric V. Smith03220fd2017-12-29 13:59:58 -0500742 def test_missing_default(self):
743 # Test that MISSING works the same as a default not being
744 # specified.
745 @dataclass
746 class C:
747 x: int=field(default=MISSING)
748 with self.assertRaisesRegex(TypeError,
749 r'__init__\(\) missing 1 required '
750 'positional argument'):
751 C()
752 self.assertNotIn('x', C.__dict__)
753
754 @dataclass
755 class D:
756 x: int
757 with self.assertRaisesRegex(TypeError,
758 r'__init__\(\) missing 1 required '
759 'positional argument'):
760 D()
761 self.assertNotIn('x', D.__dict__)
762
763 def test_missing_default_factory(self):
764 # Test that MISSING works the same as a default factory not
765 # being specified (which is really the same as a default not
766 # being specified, too).
767 @dataclass
768 class C:
769 x: int=field(default_factory=MISSING)
770 with self.assertRaisesRegex(TypeError,
771 r'__init__\(\) missing 1 required '
772 'positional argument'):
773 C()
774 self.assertNotIn('x', C.__dict__)
775
776 @dataclass
777 class D:
778 x: int=field(default=MISSING, default_factory=MISSING)
779 with self.assertRaisesRegex(TypeError,
780 r'__init__\(\) missing 1 required '
781 'positional argument'):
782 D()
783 self.assertNotIn('x', D.__dict__)
784
785 def test_missing_repr(self):
786 self.assertIn('MISSING_TYPE object', repr(MISSING))
787
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500788 def test_dont_include_other_annotations(self):
789 @dataclass
790 class C:
791 i: int
792 def foo(self) -> int:
793 return 4
794 @property
795 def bar(self) -> int:
796 return 5
797 self.assertEqual(list(C.__annotations__), ['i'])
798 self.assertEqual(C(10).foo(), 4)
799 self.assertEqual(C(10).bar, 5)
800
801 def test_post_init(self):
802 # Just make sure it gets called
803 @dataclass
804 class C:
805 def __post_init__(self):
806 raise CustomError()
807 with self.assertRaises(CustomError):
808 C()
809
810 @dataclass
811 class C:
812 i: int = 10
813 def __post_init__(self):
814 if self.i == 10:
815 raise CustomError()
816 with self.assertRaises(CustomError):
817 C()
818 # post-init gets called, but doesn't raise. This is just
819 # checking that self is used correctly.
820 C(5)
821
822 # If there's not an __init__, then post-init won't get called.
823 @dataclass(init=False)
824 class C:
825 def __post_init__(self):
826 raise CustomError()
827 # Creating the class won't raise
828 C()
829
830 @dataclass
831 class C:
832 x: int = 0
833 def __post_init__(self):
834 self.x *= 2
835 self.assertEqual(C().x, 0)
836 self.assertEqual(C(2).x, 4)
837
Mike53f7a7c2017-12-14 14:04:53 +0300838 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500839 # attributes.
840 @dataclass(frozen=True)
841 class C:
842 x: int = 0
843 def __post_init__(self):
844 self.x *= 2
845 with self.assertRaises(FrozenInstanceError):
846 C()
847
848 def test_post_init_super(self):
849 # Make sure super() post-init isn't called by default.
850 class B:
851 def __post_init__(self):
852 raise CustomError()
853
854 @dataclass
855 class C(B):
856 def __post_init__(self):
857 self.x = 5
858
859 self.assertEqual(C().x, 5)
860
861 # Now call super(), and it will raise
862 @dataclass
863 class C(B):
864 def __post_init__(self):
865 super().__post_init__()
866
867 with self.assertRaises(CustomError):
868 C()
869
870 # Make sure post-init is called, even if not defined in our
871 # class.
872 @dataclass
873 class C(B):
874 pass
875
876 with self.assertRaises(CustomError):
877 C()
878
879 def test_post_init_staticmethod(self):
880 flag = False
881 @dataclass
882 class C:
883 x: int
884 y: int
885 @staticmethod
886 def __post_init__():
887 nonlocal flag
888 flag = True
889
890 self.assertFalse(flag)
891 c = C(3, 4)
892 self.assertEqual((c.x, c.y), (3, 4))
893 self.assertTrue(flag)
894
895 def test_post_init_classmethod(self):
896 @dataclass
897 class C:
898 flag = False
899 x: int
900 y: int
901 @classmethod
902 def __post_init__(cls):
903 cls.flag = True
904
905 self.assertFalse(C.flag)
906 c = C(3, 4)
907 self.assertEqual((c.x, c.y), (3, 4))
908 self.assertTrue(C.flag)
909
910 def test_class_var(self):
911 # Make sure ClassVars are ignored in __init__, __repr__, etc.
912 @dataclass
913 class C:
914 x: int
915 y: int = 10
916 z: ClassVar[int] = 1000
917 w: ClassVar[int] = 2000
918 t: ClassVar[int] = 3000
919
920 c = C(5)
921 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
922 self.assertEqual(len(fields(C)), 2) # We have 2 fields
923 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
924 self.assertEqual(c.z, 1000)
925 self.assertEqual(c.w, 2000)
926 self.assertEqual(c.t, 3000)
927 C.z += 1
928 self.assertEqual(c.z, 1001)
929 c = C(20)
930 self.assertEqual((c.x, c.y), (20, 10))
931 self.assertEqual(c.z, 1001)
932 self.assertEqual(c.w, 2000)
933 self.assertEqual(c.t, 3000)
934
935 def test_class_var_no_default(self):
936 # If a ClassVar has no default value, it should not be set on the class.
937 @dataclass
938 class C:
939 x: ClassVar[int]
940
941 self.assertNotIn('x', C.__dict__)
942
943 def test_class_var_default_factory(self):
944 # It makes no sense for a ClassVar to have a default factory. When
945 # would it be called? Call it yourself, since it's class-wide.
946 with self.assertRaisesRegex(TypeError,
947 'cannot have a default factory'):
948 @dataclass
949 class C:
950 x: ClassVar[int] = field(default_factory=int)
951
952 self.assertNotIn('x', C.__dict__)
953
954 def test_class_var_with_default(self):
955 # If a ClassVar has a default value, it should be set on the class.
956 @dataclass
957 class C:
958 x: ClassVar[int] = 10
959 self.assertEqual(C.x, 10)
960
961 @dataclass
962 class C:
963 x: ClassVar[int] = field(default=10)
964 self.assertEqual(C.x, 10)
965
966 def test_class_var_frozen(self):
967 # Make sure ClassVars work even if we're frozen.
968 @dataclass(frozen=True)
969 class C:
970 x: int
971 y: int = 10
972 z: ClassVar[int] = 1000
973 w: ClassVar[int] = 2000
974 t: ClassVar[int] = 3000
975
976 c = C(5)
977 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
978 self.assertEqual(len(fields(C)), 2) # We have 2 fields
979 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
980 self.assertEqual(c.z, 1000)
981 self.assertEqual(c.w, 2000)
982 self.assertEqual(c.t, 3000)
983 # We can still modify the ClassVar, it's only instances that are
984 # frozen.
985 C.z += 1
986 self.assertEqual(c.z, 1001)
987 c = C(20)
988 self.assertEqual((c.x, c.y), (20, 10))
989 self.assertEqual(c.z, 1001)
990 self.assertEqual(c.w, 2000)
991 self.assertEqual(c.t, 3000)
992
993 def test_init_var_no_default(self):
994 # If an InitVar has no default value, it should not be set on the class.
995 @dataclass
996 class C:
997 x: InitVar[int]
998
999 self.assertNotIn('x', C.__dict__)
1000
1001 def test_init_var_default_factory(self):
1002 # It makes no sense for an InitVar to have a default factory. When
1003 # would it be called? Call it yourself, since it's class-wide.
1004 with self.assertRaisesRegex(TypeError,
1005 'cannot have a default factory'):
1006 @dataclass
1007 class C:
1008 x: InitVar[int] = field(default_factory=int)
1009
1010 self.assertNotIn('x', C.__dict__)
1011
1012 def test_init_var_with_default(self):
1013 # If an InitVar has a default value, it should be set on the class.
1014 @dataclass
1015 class C:
1016 x: InitVar[int] = 10
1017 self.assertEqual(C.x, 10)
1018
1019 @dataclass
1020 class C:
1021 x: InitVar[int] = field(default=10)
1022 self.assertEqual(C.x, 10)
1023
1024 def test_init_var(self):
1025 @dataclass
1026 class C:
1027 x: int = None
1028 init_param: InitVar[int] = None
1029
1030 def __post_init__(self, init_param):
1031 if self.x is None:
1032 self.x = init_param*2
1033
1034 c = C(init_param=10)
1035 self.assertEqual(c.x, 20)
1036
1037 def test_init_var_inheritance(self):
1038 # Note that this deliberately tests that a dataclass need not
1039 # have a __post_init__ function if it has an InitVar field.
1040 # It could just be used in a derived class, as shown here.
1041 @dataclass
1042 class Base:
1043 x: int
1044 init_base: InitVar[int]
1045
1046 # We can instantiate by passing the InitVar, even though
1047 # it's not used.
1048 b = Base(0, 10)
1049 self.assertEqual(vars(b), {'x': 0})
1050
1051 @dataclass
1052 class C(Base):
1053 y: int
1054 init_derived: InitVar[int]
1055
1056 def __post_init__(self, init_base, init_derived):
1057 self.x = self.x + init_base
1058 self.y = self.y + init_derived
1059
1060 c = C(10, 11, 50, 51)
1061 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1062
1063 def test_default_factory(self):
1064 # Test a factory that returns a new list.
1065 @dataclass
1066 class C:
1067 x: int
1068 y: list = field(default_factory=list)
1069
1070 c0 = C(3)
1071 c1 = C(3)
1072 self.assertEqual(c0.x, 3)
1073 self.assertEqual(c0.y, [])
1074 self.assertEqual(c0, c1)
1075 self.assertIsNot(c0.y, c1.y)
1076 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1077
1078 # Test a factory that returns a shared list.
1079 l = []
1080 @dataclass
1081 class C:
1082 x: int
1083 y: list = field(default_factory=lambda: l)
1084
1085 c0 = C(3)
1086 c1 = C(3)
1087 self.assertEqual(c0.x, 3)
1088 self.assertEqual(c0.y, [])
1089 self.assertEqual(c0, c1)
1090 self.assertIs(c0.y, c1.y)
1091 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1092
1093 # Test various other field flags.
1094 # repr
1095 @dataclass
1096 class C:
1097 x: list = field(default_factory=list, repr=False)
1098 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1099 self.assertEqual(C().x, [])
1100
1101 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001102 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001103 class C:
1104 x: list = field(default_factory=list, hash=False)
1105 self.assertEqual(astuple(C()), ([],))
1106 self.assertEqual(hash(C()), hash(()))
1107
1108 # init (see also test_default_factory_with_no_init)
1109 @dataclass
1110 class C:
1111 x: list = field(default_factory=list, init=False)
1112 self.assertEqual(astuple(C()), ([],))
1113
1114 # compare
1115 @dataclass
1116 class C:
1117 x: list = field(default_factory=list, compare=False)
1118 self.assertEqual(C(), C([1]))
1119
1120 def test_default_factory_with_no_init(self):
1121 # We need a factory with a side effect.
1122 factory = Mock()
1123
1124 @dataclass
1125 class C:
1126 x: list = field(default_factory=factory, init=False)
1127
1128 # Make sure the default factory is called for each new instance.
1129 C().x
1130 self.assertEqual(factory.call_count, 1)
1131 C().x
1132 self.assertEqual(factory.call_count, 2)
1133
1134 def test_default_factory_not_called_if_value_given(self):
1135 # We need a factory that we can test if it's been called.
1136 factory = Mock()
1137
1138 @dataclass
1139 class C:
1140 x: int = field(default_factory=factory)
1141
1142 # Make sure that if a field has a default factory function,
1143 # it's not called if a value is specified.
1144 C().x
1145 self.assertEqual(factory.call_count, 1)
1146 self.assertEqual(C(10).x, 10)
1147 self.assertEqual(factory.call_count, 1)
1148 C().x
1149 self.assertEqual(factory.call_count, 2)
1150
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001151 def test_default_factory_derived(self):
1152 # See bpo-32896.
1153 @dataclass
1154 class Foo:
1155 x: dict = field(default_factory=dict)
1156
1157 @dataclass
1158 class Bar(Foo):
1159 y: int = 1
1160
1161 self.assertEqual(Foo().x, {})
1162 self.assertEqual(Bar().x, {})
1163 self.assertEqual(Bar().y, 1)
1164
1165 @dataclass
1166 class Baz(Foo):
1167 pass
1168 self.assertEqual(Baz().x, {})
1169
1170 def test_intermediate_non_dataclass(self):
1171 # Test that an intermediate class that defines
1172 # annotations does not define fields.
1173
1174 @dataclass
1175 class A:
1176 x: int
1177
1178 class B(A):
1179 y: int
1180
1181 @dataclass
1182 class C(B):
1183 z: int
1184
1185 c = C(1, 3)
1186 self.assertEqual((c.x, c.z), (1, 3))
1187
1188 # .y was not initialized.
1189 with self.assertRaisesRegex(AttributeError,
1190 'object has no attribute'):
1191 c.y
1192
1193 # And if we again derive a non-dataclass, no fields are added.
1194 class D(C):
1195 t: int
1196 d = D(4, 5)
1197 self.assertEqual((d.x, d.z), (4, 5))
1198
1199
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001200 def x_test_classvar_default_factory(self):
1201 # XXX: it's an error for a ClassVar to have a factory function
1202 @dataclass
1203 class C:
1204 x: ClassVar[int] = field(default_factory=int)
1205
1206 self.assertIs(C().x, int)
1207
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001208 def test_is_dataclass(self):
1209 class NotDataClass:
1210 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001211
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001212 self.assertFalse(is_dataclass(0))
1213 self.assertFalse(is_dataclass(int))
1214 self.assertFalse(is_dataclass(NotDataClass))
1215 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001216
1217 @dataclass
1218 class C:
1219 x: int
1220
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001221 @dataclass
1222 class D:
1223 d: C
1224 e: int
1225
1226 c = C(10)
1227 d = D(c, 4)
1228
1229 self.assertTrue(is_dataclass(C))
1230 self.assertTrue(is_dataclass(c))
1231 self.assertFalse(is_dataclass(c.x))
1232 self.assertTrue(is_dataclass(d.d))
1233 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001234
1235 def test_helper_fields_with_class_instance(self):
1236 # Check that we can call fields() on either a class or instance,
1237 # and get back the same thing.
1238 @dataclass
1239 class C:
1240 x: int
1241 y: float
1242
1243 self.assertEqual(fields(C), fields(C(0, 0.0)))
1244
1245 def test_helper_fields_exception(self):
1246 # Check that TypeError is raised if not passed a dataclass or
1247 # instance.
1248 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1249 fields(0)
1250
1251 class C: pass
1252 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1253 fields(C)
1254 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1255 fields(C())
1256
1257 def test_helper_asdict(self):
1258 # Basic tests for asdict(), it should return a new dictionary
1259 @dataclass
1260 class C:
1261 x: int
1262 y: int
1263 c = C(1, 2)
1264
1265 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1266 self.assertEqual(asdict(c), asdict(c))
1267 self.assertIsNot(asdict(c), asdict(c))
1268 c.x = 42
1269 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1270 self.assertIs(type(asdict(c)), dict)
1271
1272 def test_helper_asdict_raises_on_classes(self):
1273 # asdict() should raise on a class object
1274 @dataclass
1275 class C:
1276 x: int
1277 y: int
1278 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1279 asdict(C)
1280 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1281 asdict(int)
1282
1283 def test_helper_asdict_copy_values(self):
1284 @dataclass
1285 class C:
1286 x: int
1287 y: List[int] = field(default_factory=list)
1288 initial = []
1289 c = C(1, initial)
1290 d = asdict(c)
1291 self.assertEqual(d['y'], initial)
1292 self.assertIsNot(d['y'], initial)
1293 c = C(1)
1294 d = asdict(c)
1295 d['y'].append(1)
1296 self.assertEqual(c.y, [])
1297
1298 def test_helper_asdict_nested(self):
1299 @dataclass
1300 class UserId:
1301 token: int
1302 group: int
1303 @dataclass
1304 class User:
1305 name: str
1306 id: UserId
1307 u = User('Joe', UserId(123, 1))
1308 d = asdict(u)
1309 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1310 self.assertIsNot(asdict(u), asdict(u))
1311 u.id.group = 2
1312 self.assertEqual(asdict(u), {'name': 'Joe',
1313 'id': {'token': 123, 'group': 2}})
1314
1315 def test_helper_asdict_builtin_containers(self):
1316 @dataclass
1317 class User:
1318 name: str
1319 id: int
1320 @dataclass
1321 class GroupList:
1322 id: int
1323 users: List[User]
1324 @dataclass
1325 class GroupTuple:
1326 id: int
1327 users: Tuple[User, ...]
1328 @dataclass
1329 class GroupDict:
1330 id: int
1331 users: Dict[str, User]
1332 a = User('Alice', 1)
1333 b = User('Bob', 2)
1334 gl = GroupList(0, [a, b])
1335 gt = GroupTuple(0, (a, b))
1336 gd = GroupDict(0, {'first': a, 'second': b})
1337 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1338 {'name': 'Bob', 'id': 2}]})
1339 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1340 {'name': 'Bob', 'id': 2})})
1341 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1342 'second': {'name': 'Bob', 'id': 2}}})
1343
1344 def test_helper_asdict_builtin_containers(self):
1345 @dataclass
1346 class Child:
1347 d: object
1348
1349 @dataclass
1350 class Parent:
1351 child: Child
1352
1353 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1354 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1355
1356 def test_helper_asdict_factory(self):
1357 @dataclass
1358 class C:
1359 x: int
1360 y: int
1361 c = C(1, 2)
1362 d = asdict(c, dict_factory=OrderedDict)
1363 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1364 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1365 c.x = 42
1366 d = asdict(c, dict_factory=OrderedDict)
1367 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1368 self.assertIs(type(d), OrderedDict)
1369
1370 def test_helper_astuple(self):
1371 # Basic tests for astuple(), it should return a new tuple
1372 @dataclass
1373 class C:
1374 x: int
1375 y: int = 0
1376 c = C(1)
1377
1378 self.assertEqual(astuple(c), (1, 0))
1379 self.assertEqual(astuple(c), astuple(c))
1380 self.assertIsNot(astuple(c), astuple(c))
1381 c.y = 42
1382 self.assertEqual(astuple(c), (1, 42))
1383 self.assertIs(type(astuple(c)), tuple)
1384
1385 def test_helper_astuple_raises_on_classes(self):
1386 # astuple() should raise on a class object
1387 @dataclass
1388 class C:
1389 x: int
1390 y: int
1391 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1392 astuple(C)
1393 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1394 astuple(int)
1395
1396 def test_helper_astuple_copy_values(self):
1397 @dataclass
1398 class C:
1399 x: int
1400 y: List[int] = field(default_factory=list)
1401 initial = []
1402 c = C(1, initial)
1403 t = astuple(c)
1404 self.assertEqual(t[1], initial)
1405 self.assertIsNot(t[1], initial)
1406 c = C(1)
1407 t = astuple(c)
1408 t[1].append(1)
1409 self.assertEqual(c.y, [])
1410
1411 def test_helper_astuple_nested(self):
1412 @dataclass
1413 class UserId:
1414 token: int
1415 group: int
1416 @dataclass
1417 class User:
1418 name: str
1419 id: UserId
1420 u = User('Joe', UserId(123, 1))
1421 t = astuple(u)
1422 self.assertEqual(t, ('Joe', (123, 1)))
1423 self.assertIsNot(astuple(u), astuple(u))
1424 u.id.group = 2
1425 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1426
1427 def test_helper_astuple_builtin_containers(self):
1428 @dataclass
1429 class User:
1430 name: str
1431 id: int
1432 @dataclass
1433 class GroupList:
1434 id: int
1435 users: List[User]
1436 @dataclass
1437 class GroupTuple:
1438 id: int
1439 users: Tuple[User, ...]
1440 @dataclass
1441 class GroupDict:
1442 id: int
1443 users: Dict[str, User]
1444 a = User('Alice', 1)
1445 b = User('Bob', 2)
1446 gl = GroupList(0, [a, b])
1447 gt = GroupTuple(0, (a, b))
1448 gd = GroupDict(0, {'first': a, 'second': b})
1449 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1450 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1451 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1452
1453 def test_helper_astuple_builtin_containers(self):
1454 @dataclass
1455 class Child:
1456 d: object
1457
1458 @dataclass
1459 class Parent:
1460 child: Child
1461
1462 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1463 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1464
1465 def test_helper_astuple_factory(self):
1466 @dataclass
1467 class C:
1468 x: int
1469 y: int
1470 NT = namedtuple('NT', 'x y')
1471 def nt(lst):
1472 return NT(*lst)
1473 c = C(1, 2)
1474 t = astuple(c, tuple_factory=nt)
1475 self.assertEqual(t, NT(1, 2))
1476 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1477 c.x = 42
1478 t = astuple(c, tuple_factory=nt)
1479 self.assertEqual(t, NT(42, 2))
1480 self.assertIs(type(t), NT)
1481
1482 def test_dynamic_class_creation(self):
1483 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1484 }
1485
1486 # Create the class.
1487 cls = type('C', (), cls_dict)
1488
1489 # Make it a dataclass.
1490 cls1 = dataclass(cls)
1491
1492 self.assertEqual(cls1, cls)
1493 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1494
1495 def test_dynamic_class_creation_using_field(self):
1496 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1497 'y': field(default=5),
1498 }
1499
1500 # Create the class.
1501 cls = type('C', (), cls_dict)
1502
1503 # Make it a dataclass.
1504 cls1 = dataclass(cls)
1505
1506 self.assertEqual(cls1, cls)
1507 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1508
1509 def test_init_in_order(self):
1510 @dataclass
1511 class C:
1512 a: int
1513 b: int = field()
1514 c: list = field(default_factory=list, init=False)
1515 d: list = field(default_factory=list)
1516 e: int = field(default=4, init=False)
1517 f: int = 4
1518
1519 calls = []
1520 def setattr(self, name, value):
1521 calls.append((name, value))
1522
1523 C.__setattr__ = setattr
1524 c = C(0, 1)
1525 self.assertEqual(('a', 0), calls[0])
1526 self.assertEqual(('b', 1), calls[1])
1527 self.assertEqual(('c', []), calls[2])
1528 self.assertEqual(('d', []), calls[3])
1529 self.assertNotIn(('e', 4), calls)
1530 self.assertEqual(('f', 4), calls[4])
1531
1532 def test_items_in_dicts(self):
1533 @dataclass
1534 class C:
1535 a: int
1536 b: list = field(default_factory=list, init=False)
1537 c: list = field(default_factory=list)
1538 d: int = field(default=4, init=False)
1539 e: int = 0
1540
1541 c = C(0)
1542 # Class dict
1543 self.assertNotIn('a', C.__dict__)
1544 self.assertNotIn('b', C.__dict__)
1545 self.assertNotIn('c', C.__dict__)
1546 self.assertIn('d', C.__dict__)
1547 self.assertEqual(C.d, 4)
1548 self.assertIn('e', C.__dict__)
1549 self.assertEqual(C.e, 0)
1550 # Instance dict
1551 self.assertIn('a', c.__dict__)
1552 self.assertEqual(c.a, 0)
1553 self.assertIn('b', c.__dict__)
1554 self.assertEqual(c.b, [])
1555 self.assertIn('c', c.__dict__)
1556 self.assertEqual(c.c, [])
1557 self.assertNotIn('d', c.__dict__)
1558 self.assertIn('e', c.__dict__)
1559 self.assertEqual(c.e, 0)
1560
1561 def test_alternate_classmethod_constructor(self):
1562 # Since __post_init__ can't take params, use a classmethod
1563 # alternate constructor. This is mostly an example to show how
1564 # to use this technique.
1565 @dataclass
1566 class C:
1567 x: int
1568 @classmethod
1569 def from_file(cls, filename):
1570 # In a real example, create a new instance
1571 # and populate 'x' from contents of a file.
1572 value_in_file = 20
1573 return cls(value_in_file)
1574
1575 self.assertEqual(C.from_file('filename').x, 20)
1576
1577 def test_field_metadata_default(self):
1578 # Make sure the default metadata is read-only and of
1579 # zero length.
1580 @dataclass
1581 class C:
1582 i: int
1583
1584 self.assertFalse(fields(C)[0].metadata)
1585 self.assertEqual(len(fields(C)[0].metadata), 0)
1586 with self.assertRaisesRegex(TypeError,
1587 'does not support item assignment'):
1588 fields(C)[0].metadata['test'] = 3
1589
1590 def test_field_metadata_mapping(self):
1591 # Make sure only a mapping can be passed as metadata
1592 # zero length.
1593 with self.assertRaises(TypeError):
1594 @dataclass
1595 class C:
1596 i: int = field(metadata=0)
1597
1598 # Make sure an empty dict works
1599 @dataclass
1600 class C:
1601 i: int = field(metadata={})
1602 self.assertFalse(fields(C)[0].metadata)
1603 self.assertEqual(len(fields(C)[0].metadata), 0)
1604 with self.assertRaisesRegex(TypeError,
1605 'does not support item assignment'):
1606 fields(C)[0].metadata['test'] = 3
1607
1608 # Make sure a non-empty dict works.
1609 @dataclass
1610 class C:
1611 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1612 self.assertEqual(len(fields(C)[0].metadata), 3)
1613 self.assertEqual(fields(C)[0].metadata['test'], 10)
1614 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1615 self.assertEqual(fields(C)[0].metadata[3], 'three')
1616 with self.assertRaises(KeyError):
1617 # Non-existent key.
1618 fields(C)[0].metadata['baz']
1619 with self.assertRaisesRegex(TypeError,
1620 'does not support item assignment'):
1621 fields(C)[0].metadata['test'] = 3
1622
1623 def test_field_metadata_custom_mapping(self):
1624 # Try a custom mapping.
1625 class SimpleNameSpace:
1626 def __init__(self, **kw):
1627 self.__dict__.update(kw)
1628
1629 def __getitem__(self, item):
1630 if item == 'xyzzy':
1631 return 'plugh'
1632 return getattr(self, item)
1633
1634 def __len__(self):
1635 return self.__dict__.__len__()
1636
1637 @dataclass
1638 class C:
1639 i: int = field(metadata=SimpleNameSpace(a=10))
1640
1641 self.assertEqual(len(fields(C)[0].metadata), 1)
1642 self.assertEqual(fields(C)[0].metadata['a'], 10)
1643 with self.assertRaises(AttributeError):
1644 fields(C)[0].metadata['b']
1645 # Make sure we're still talking to our custom mapping.
1646 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1647
1648 def test_generic_dataclasses(self):
1649 T = TypeVar('T')
1650
1651 @dataclass
1652 class LabeledBox(Generic[T]):
1653 content: T
1654 label: str = '<unknown>'
1655
1656 box = LabeledBox(42)
1657 self.assertEqual(box.content, 42)
1658 self.assertEqual(box.label, '<unknown>')
1659
1660 # subscripting the resulting class should work, etc.
1661 Alias = List[LabeledBox[int]]
1662
1663 def test_generic_extending(self):
1664 S = TypeVar('S')
1665 T = TypeVar('T')
1666
1667 @dataclass
1668 class Base(Generic[T, S]):
1669 x: T
1670 y: S
1671
1672 @dataclass
1673 class DataDerived(Base[int, T]):
1674 new_field: str
1675 Alias = DataDerived[str]
1676 c = Alias(0, 'test1', 'test2')
1677 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1678
1679 class NonDataDerived(Base[int, T]):
1680 def new_method(self):
1681 return self.y
1682 Alias = NonDataDerived[float]
1683 c = Alias(10, 1.0)
1684 self.assertEqual(c.new_method(), 1.0)
1685
1686 def test_helper_replace(self):
1687 @dataclass(frozen=True)
1688 class C:
1689 x: int
1690 y: int
1691
1692 c = C(1, 2)
1693 c1 = replace(c, x=3)
1694 self.assertEqual(c1.x, 3)
1695 self.assertEqual(c1.y, 2)
1696
1697 def test_helper_replace_frozen(self):
1698 @dataclass(frozen=True)
1699 class C:
1700 x: int
1701 y: int
1702 z: int = field(init=False, default=10)
1703 t: int = field(init=False, default=100)
1704
1705 c = C(1, 2)
1706 c1 = replace(c, x=3)
1707 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1708 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1709
1710
1711 with self.assertRaisesRegex(ValueError, 'init=False'):
1712 replace(c, x=3, z=20, t=50)
1713 with self.assertRaisesRegex(ValueError, 'init=False'):
1714 replace(c, z=20)
1715 replace(c, x=3, z=20, t=50)
1716
1717 # Make sure the result is still frozen.
1718 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1719 c1.x = 3
1720
1721 # Make sure we can't replace an attribute that doesn't exist,
1722 # if we're also replacing one that does exist. Test this
1723 # here, because setting attributes on frozen instances is
1724 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001725 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001726 "keyword argument 'a'"):
1727 c1 = replace(c, x=20, a=5)
1728
1729 def test_helper_replace_invalid_field_name(self):
1730 @dataclass(frozen=True)
1731 class C:
1732 x: int
1733 y: int
1734
1735 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001736 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001737 "keyword argument 'z'"):
1738 c1 = replace(c, z=3)
1739
1740 def test_helper_replace_invalid_object(self):
1741 @dataclass(frozen=True)
1742 class C:
1743 x: int
1744 y: int
1745
1746 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1747 replace(C, x=3)
1748
1749 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1750 replace(0, x=3)
1751
1752 def test_helper_replace_no_init(self):
1753 @dataclass
1754 class C:
1755 x: int
1756 y: int = field(init=False, default=10)
1757
1758 c = C(1)
1759 c.y = 20
1760
1761 # Make sure y gets the default value.
1762 c1 = replace(c, x=5)
1763 self.assertEqual((c1.x, c1.y), (5, 10))
1764
1765 # Trying to replace y is an error.
1766 with self.assertRaisesRegex(ValueError, 'init=False'):
1767 replace(c, x=2, y=30)
1768 with self.assertRaisesRegex(ValueError, 'init=False'):
1769 replace(c, y=30)
1770
1771 def test_dataclassses_pickleable(self):
1772 global P, Q, R
1773 @dataclass
1774 class P:
1775 x: int
1776 y: int = 0
1777 @dataclass
1778 class Q:
1779 x: int
1780 y: int = field(default=0, init=False)
1781 @dataclass
1782 class R:
1783 x: int
1784 y: List[int] = field(default_factory=list)
1785 q = Q(1)
1786 q.y = 2
1787 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1788 for sample in samples:
1789 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1790 with self.subTest(sample=sample, proto=proto):
1791 new_sample = pickle.loads(pickle.dumps(sample, proto))
1792 self.assertEqual(sample.x, new_sample.x)
1793 self.assertEqual(sample.y, new_sample.y)
1794 self.assertIsNot(sample, new_sample)
1795 new_sample.x = 42
1796 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1797 self.assertEqual(new_sample.x, another_new_sample.x)
1798 self.assertEqual(sample.y, another_new_sample.y)
1799
1800 def test_helper_make_dataclass(self):
1801 C = make_dataclass('C',
1802 [('x', int),
1803 ('y', int, field(default=5))],
1804 namespace={'add_one': lambda self: self.x + 1})
1805 c = C(10)
1806 self.assertEqual((c.x, c.y), (10, 5))
1807 self.assertEqual(c.add_one(), 11)
1808
1809
1810 def test_helper_make_dataclass_no_mutate_namespace(self):
1811 # Make sure a provided namespace isn't mutated.
1812 ns = {}
1813 C = make_dataclass('C',
1814 [('x', int),
1815 ('y', int, field(default=5))],
1816 namespace=ns)
1817 self.assertEqual(ns, {})
1818
1819 def test_helper_make_dataclass_base(self):
1820 class Base1:
1821 pass
1822 class Base2:
1823 pass
1824 C = make_dataclass('C',
1825 [('x', int)],
1826 bases=(Base1, Base2))
1827 c = C(2)
1828 self.assertIsInstance(c, C)
1829 self.assertIsInstance(c, Base1)
1830 self.assertIsInstance(c, Base2)
1831
1832 def test_helper_make_dataclass_base_dataclass(self):
1833 @dataclass
1834 class Base1:
1835 x: int
1836 class Base2:
1837 pass
1838 C = make_dataclass('C',
1839 [('y', int)],
1840 bases=(Base1, Base2))
1841 with self.assertRaisesRegex(TypeError, 'required positional'):
1842 c = C(2)
1843 c = C(1, 2)
1844 self.assertIsInstance(c, C)
1845 self.assertIsInstance(c, Base1)
1846 self.assertIsInstance(c, Base2)
1847
1848 self.assertEqual((c.x, c.y), (1, 2))
1849
1850 def test_helper_make_dataclass_init_var(self):
1851 def post_init(self, y):
1852 self.x *= y
1853
1854 C = make_dataclass('C',
1855 [('x', int),
1856 ('y', InitVar[int]),
1857 ],
1858 namespace={'__post_init__': post_init},
1859 )
1860 c = C(2, 3)
1861 self.assertEqual(vars(c), {'x': 6})
1862 self.assertEqual(len(fields(c)), 1)
1863
1864 def test_helper_make_dataclass_class_var(self):
1865 C = make_dataclass('C',
1866 [('x', int),
1867 ('y', ClassVar[int], 10),
1868 ('z', ClassVar[int], field(default=20)),
1869 ])
1870 c = C(1)
1871 self.assertEqual(vars(c), {'x': 1})
1872 self.assertEqual(len(fields(c)), 1)
1873 self.assertEqual(C.y, 10)
1874 self.assertEqual(C.z, 20)
1875
Eric V. Smithd80b4432018-01-06 17:09:58 -05001876 def test_helper_make_dataclass_other_params(self):
1877 C = make_dataclass('C',
1878 [('x', int),
1879 ('y', ClassVar[int], 10),
1880 ('z', ClassVar[int], field(default=20)),
1881 ],
1882 init=False)
1883 # Make sure we have a repr, but no init.
1884 self.assertNotIn('__init__', vars(C))
1885 self.assertIn('__repr__', vars(C))
1886
1887 # Make sure random other params don't work.
1888 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1889 C = make_dataclass('C',
1890 [],
1891 xxinit=False)
1892
Eric V. Smithed7d4292018-01-06 16:14:03 -05001893 def test_helper_make_dataclass_no_types(self):
1894 C = make_dataclass('Point', ['x', 'y', 'z'])
1895 c = C(1, 2, 3)
1896 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1897 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1898 'y': 'typing.Any',
1899 'z': 'typing.Any'})
1900
1901 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1902 c = C(1, 2, 3)
1903 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1904 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1905 'y': int,
1906 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001907
Eric V. Smithea8fc522018-01-27 19:07:40 -05001908
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001909class TestDocString(unittest.TestCase):
1910 def assertDocStrEqual(self, a, b):
1911 # Because 3.6 and 3.7 differ in how inspect.signature work
1912 # (see bpo #32108), for the time being just compare them with
1913 # whitespace stripped.
1914 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1915
1916 def test_existing_docstring_not_overridden(self):
1917 @dataclass
1918 class C:
1919 """Lorem ipsum"""
1920 x: int
1921
1922 self.assertEqual(C.__doc__, "Lorem ipsum")
1923
1924 def test_docstring_no_fields(self):
1925 @dataclass
1926 class C:
1927 pass
1928
1929 self.assertDocStrEqual(C.__doc__, "C()")
1930
1931 def test_docstring_one_field(self):
1932 @dataclass
1933 class C:
1934 x: int
1935
1936 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1937
1938 def test_docstring_two_fields(self):
1939 @dataclass
1940 class C:
1941 x: int
1942 y: int
1943
1944 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1945
1946 def test_docstring_three_fields(self):
1947 @dataclass
1948 class C:
1949 x: int
1950 y: int
1951 z: str
1952
1953 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1954
1955 def test_docstring_one_field_with_default(self):
1956 @dataclass
1957 class C:
1958 x: int = 3
1959
1960 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1961
1962 def test_docstring_one_field_with_default_none(self):
1963 @dataclass
1964 class C:
1965 x: Union[int, type(None)] = None
1966
1967 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1968
1969 def test_docstring_list_field(self):
1970 @dataclass
1971 class C:
1972 x: List[int]
1973
1974 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1975
1976 def test_docstring_list_field_with_default_factory(self):
1977 @dataclass
1978 class C:
1979 x: List[int] = field(default_factory=list)
1980
1981 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1982
1983 def test_docstring_deque_field(self):
1984 @dataclass
1985 class C:
1986 x: deque
1987
1988 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1989
1990 def test_docstring_deque_field_with_default_factory(self):
1991 @dataclass
1992 class C:
1993 x: deque = field(default_factory=deque)
1994
1995 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1996
1997
Eric V. Smithea8fc522018-01-27 19:07:40 -05001998class TestInit(unittest.TestCase):
1999 def test_base_has_init(self):
2000 class B:
2001 def __init__(self):
2002 self.z = 100
2003 pass
2004
2005 # Make sure that declaring this class doesn't raise an error.
2006 # The issue is that we can't override __init__ in our class,
2007 # but it should be okay to add __init__ to us if our base has
2008 # an __init__.
2009 @dataclass
2010 class C(B):
2011 x: int = 0
2012 c = C(10)
2013 self.assertEqual(c.x, 10)
2014 self.assertNotIn('z', vars(c))
2015
2016 # Make sure that if we don't add an init, the base __init__
2017 # gets called.
2018 @dataclass(init=False)
2019 class C(B):
2020 x: int = 10
2021 c = C()
2022 self.assertEqual(c.x, 10)
2023 self.assertEqual(c.z, 100)
2024
2025 def test_no_init(self):
2026 dataclass(init=False)
2027 class C:
2028 i: int = 0
2029 self.assertEqual(C().i, 0)
2030
2031 dataclass(init=False)
2032 class C:
2033 i: int = 2
2034 def __init__(self):
2035 self.i = 3
2036 self.assertEqual(C().i, 3)
2037
2038 def test_overwriting_init(self):
2039 # If the class has __init__, use it no matter the value of
2040 # init=.
2041
2042 @dataclass
2043 class C:
2044 x: int
2045 def __init__(self, x):
2046 self.x = 2 * x
2047 self.assertEqual(C(3).x, 6)
2048
2049 @dataclass(init=True)
2050 class C:
2051 x: int
2052 def __init__(self, x):
2053 self.x = 2 * x
2054 self.assertEqual(C(4).x, 8)
2055
2056 @dataclass(init=False)
2057 class C:
2058 x: int
2059 def __init__(self, x):
2060 self.x = 2 * x
2061 self.assertEqual(C(5).x, 10)
2062
2063
2064class TestRepr(unittest.TestCase):
2065 def test_repr(self):
2066 @dataclass
2067 class B:
2068 x: int
2069
2070 @dataclass
2071 class C(B):
2072 y: int = 10
2073
2074 o = C(4)
2075 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2076
2077 @dataclass
2078 class D(C):
2079 x: int = 20
2080 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2081
2082 @dataclass
2083 class C:
2084 @dataclass
2085 class D:
2086 i: int
2087 @dataclass
2088 class E:
2089 pass
2090 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2091 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2092
2093 def test_no_repr(self):
2094 # Test a class with no __repr__ and repr=False.
2095 @dataclass(repr=False)
2096 class C:
2097 x: int
2098 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2099 repr(C(3)))
2100
2101 # Test a class with a __repr__ and repr=False.
2102 @dataclass(repr=False)
2103 class C:
2104 x: int
2105 def __repr__(self):
2106 return 'C-class'
2107 self.assertEqual(repr(C(3)), 'C-class')
2108
2109 def test_overwriting_repr(self):
2110 # If the class has __repr__, use it no matter the value of
2111 # repr=.
2112
2113 @dataclass
2114 class C:
2115 x: int
2116 def __repr__(self):
2117 return 'x'
2118 self.assertEqual(repr(C(0)), 'x')
2119
2120 @dataclass(repr=True)
2121 class C:
2122 x: int
2123 def __repr__(self):
2124 return 'x'
2125 self.assertEqual(repr(C(0)), 'x')
2126
2127 @dataclass(repr=False)
2128 class C:
2129 x: int
2130 def __repr__(self):
2131 return 'x'
2132 self.assertEqual(repr(C(0)), 'x')
2133
2134
2135class TestFrozen(unittest.TestCase):
2136 def test_overwriting_frozen(self):
2137 # frozen uses __setattr__ and __delattr__
2138 with self.assertRaisesRegex(TypeError,
2139 'Cannot overwrite attribute __setattr__'):
2140 @dataclass(frozen=True)
2141 class C:
2142 x: int
2143 def __setattr__(self):
2144 pass
2145
2146 with self.assertRaisesRegex(TypeError,
2147 'Cannot overwrite attribute __delattr__'):
2148 @dataclass(frozen=True)
2149 class C:
2150 x: int
2151 def __delattr__(self):
2152 pass
2153
2154 @dataclass(frozen=False)
2155 class C:
2156 x: int
2157 def __setattr__(self, name, value):
2158 self.__dict__['x'] = value * 2
2159 self.assertEqual(C(10).x, 20)
2160
2161
2162class TestEq(unittest.TestCase):
2163 def test_no_eq(self):
2164 # Test a class with no __eq__ and eq=False.
2165 @dataclass(eq=False)
2166 class C:
2167 x: int
2168 self.assertNotEqual(C(0), C(0))
2169 c = C(3)
2170 self.assertEqual(c, c)
2171
2172 # Test a class with an __eq__ and eq=False.
2173 @dataclass(eq=False)
2174 class C:
2175 x: int
2176 def __eq__(self, other):
2177 return other == 10
2178 self.assertEqual(C(3), 10)
2179
2180 def test_overwriting_eq(self):
2181 # If the class has __eq__, use it no matter the value of
2182 # eq=.
2183
2184 @dataclass
2185 class C:
2186 x: int
2187 def __eq__(self, other):
2188 return other == 3
2189 self.assertEqual(C(1), 3)
2190 self.assertNotEqual(C(1), 1)
2191
2192 @dataclass(eq=True)
2193 class C:
2194 x: int
2195 def __eq__(self, other):
2196 return other == 4
2197 self.assertEqual(C(1), 4)
2198 self.assertNotEqual(C(1), 1)
2199
2200 @dataclass(eq=False)
2201 class C:
2202 x: int
2203 def __eq__(self, other):
2204 return other == 5
2205 self.assertEqual(C(1), 5)
2206 self.assertNotEqual(C(1), 1)
2207
2208
2209class TestOrdering(unittest.TestCase):
2210 def test_functools_total_ordering(self):
2211 # Test that functools.total_ordering works with this class.
2212 @total_ordering
2213 @dataclass
2214 class C:
2215 x: int
2216 def __lt__(self, other):
2217 # Perform the test "backward", just to make
2218 # sure this is being called.
2219 return self.x >= other
2220
2221 self.assertLess(C(0), -1)
2222 self.assertLessEqual(C(0), -1)
2223 self.assertGreater(C(0), 1)
2224 self.assertGreaterEqual(C(0), 1)
2225
2226 def test_no_order(self):
2227 # Test that no ordering functions are added by default.
2228 @dataclass(order=False)
2229 class C:
2230 x: int
2231 # Make sure no order methods are added.
2232 self.assertNotIn('__le__', C.__dict__)
2233 self.assertNotIn('__lt__', C.__dict__)
2234 self.assertNotIn('__ge__', C.__dict__)
2235 self.assertNotIn('__gt__', C.__dict__)
2236
2237 # Test that __lt__ is still called
2238 @dataclass(order=False)
2239 class C:
2240 x: int
2241 def __lt__(self, other):
2242 return False
2243 # Make sure other methods aren't added.
2244 self.assertNotIn('__le__', C.__dict__)
2245 self.assertNotIn('__ge__', C.__dict__)
2246 self.assertNotIn('__gt__', C.__dict__)
2247
2248 def test_overwriting_order(self):
2249 with self.assertRaisesRegex(TypeError,
2250 'Cannot overwrite attribute __lt__'
2251 '.*using functools.total_ordering'):
2252 @dataclass(order=True)
2253 class C:
2254 x: int
2255 def __lt__(self):
2256 pass
2257
2258 with self.assertRaisesRegex(TypeError,
2259 'Cannot overwrite attribute __le__'
2260 '.*using functools.total_ordering'):
2261 @dataclass(order=True)
2262 class C:
2263 x: int
2264 def __le__(self):
2265 pass
2266
2267 with self.assertRaisesRegex(TypeError,
2268 'Cannot overwrite attribute __gt__'
2269 '.*using functools.total_ordering'):
2270 @dataclass(order=True)
2271 class C:
2272 x: int
2273 def __gt__(self):
2274 pass
2275
2276 with self.assertRaisesRegex(TypeError,
2277 'Cannot overwrite attribute __ge__'
2278 '.*using functools.total_ordering'):
2279 @dataclass(order=True)
2280 class C:
2281 x: int
2282 def __ge__(self):
2283 pass
2284
2285class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002286 def test_unsafe_hash(self):
2287 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002288 class C:
2289 x: int
2290 y: str
2291 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2292
Eric V. Smithea8fc522018-01-27 19:07:40 -05002293 def test_hash_rules(self):
2294 def non_bool(value):
2295 # Map to something else that's True, but not a bool.
2296 if value is None:
2297 return None
2298 if value:
2299 return (3,)
2300 return 0
2301
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002302 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2303 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2304 frozen=frozen):
2305 if result != 'exception':
2306 if with_hash:
2307 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2308 class C:
2309 def __hash__(self):
2310 return 0
2311 else:
2312 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2313 class C:
2314 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002315
2316 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002317 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002318 # __hash__ contains the function we generated.
2319 self.assertIn('__hash__', C.__dict__)
2320 self.assertIsNotNone(C.__dict__['__hash__'])
2321
Eric V. Smithea8fc522018-01-27 19:07:40 -05002322 elif result == '':
2323 # __hash__ is not present in our class.
2324 if not with_hash:
2325 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002326
Eric V. Smithea8fc522018-01-27 19:07:40 -05002327 elif result == 'none':
2328 # __hash__ is set to None.
2329 self.assertIn('__hash__', C.__dict__)
2330 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002331
2332 elif result == 'exception':
2333 # Creating the class should cause an exception.
2334 # This only happens with with_hash==True.
2335 assert(with_hash)
2336 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2337 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2338 class C:
2339 def __hash__(self):
2340 return 0
2341
Eric V. Smithea8fc522018-01-27 19:07:40 -05002342 else:
2343 assert False, f'unknown result {result!r}'
2344
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002345 # There are 8 cases of:
2346 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002347 # eq=True/False
2348 # frozen=True/False
2349 # And for each of these, a different result if
2350 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002351 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2352 (False, False, False, '', ''),
2353 (False, False, True, '', ''),
2354 (False, True, False, 'none', ''),
2355 (False, True, True, 'fn', ''),
2356 (True, False, False, 'fn', 'exception'),
2357 (True, False, True, 'fn', 'exception'),
2358 (True, True, False, 'fn', 'exception'),
2359 (True, True, True, 'fn', 'exception'),
2360 ], 1):
2361 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2362 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002363
2364 # Test non-bool truth values, too. This is just to
2365 # make sure the data-driven table in the decorator
2366 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002367 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2368 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002369
2370
2371 def test_eq_only(self):
2372 # If a class defines __eq__, __hash__ is automatically added
2373 # and set to None. This is normal Python behavior, not
2374 # related to dataclasses. Make sure we don't interfere with
2375 # that (see bpo=32546).
2376
2377 @dataclass
2378 class C:
2379 i: int
2380 def __eq__(self, other):
2381 return self.i == other.i
2382 self.assertEqual(C(1), C(1))
2383 self.assertNotEqual(C(1), C(4))
2384
2385 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002386 # unsafe_hash=True.
2387 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002388 class C:
2389 i: int
2390 def __eq__(self, other):
2391 return self.i == other.i
2392 self.assertEqual(C(1), C(1.0))
2393 self.assertEqual(hash(C(1)), hash(C(1.0)))
2394
2395 # And check that the classes __eq__ is being used, despite
2396 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002397 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002398 class C:
2399 i: int
2400 def __eq__(self, other):
2401 return self.i == 3 and self.i == other.i
2402 self.assertEqual(C(3), C(3))
2403 self.assertNotEqual(C(1), C(1))
2404 self.assertEqual(hash(C(1)), hash(C(1.0)))
2405
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002406 def test_0_field_hash(self):
2407 @dataclass(frozen=True)
2408 class C:
2409 pass
2410 self.assertEqual(hash(C()), hash(()))
2411
2412 @dataclass(unsafe_hash=True)
2413 class C:
2414 pass
2415 self.assertEqual(hash(C()), hash(()))
2416
2417 def test_1_field_hash(self):
2418 @dataclass(frozen=True)
2419 class C:
2420 x: int
2421 self.assertEqual(hash(C(4)), hash((4,)))
2422 self.assertEqual(hash(C(42)), hash((42,)))
2423
2424 @dataclass(unsafe_hash=True)
2425 class C:
2426 x: int
2427 self.assertEqual(hash(C(4)), hash((4,)))
2428 self.assertEqual(hash(C(42)), hash((42,)))
2429
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002430 def test_hash_no_args(self):
2431 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002432 # make sure that if the @dataclass parameter name is changed
2433 # or the non-default hashing behavior changes, the default
2434 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002435
2436 class Base:
2437 def __hash__(self):
2438 return 301
2439
2440 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)1a579062018-02-25 19:09:05 -08002441 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002442 for frozen, eq, base, expected in [
2443 (None, None, object, 'unhashable'),
2444 (None, None, Base, 'unhashable'),
2445 (None, False, object, 'object'),
2446 (None, False, Base, 'base'),
2447 (None, True, object, 'unhashable'),
2448 (None, True, Base, 'unhashable'),
2449 (False, None, object, 'unhashable'),
2450 (False, None, Base, 'unhashable'),
2451 (False, False, object, 'object'),
2452 (False, False, Base, 'base'),
2453 (False, True, object, 'unhashable'),
2454 (False, True, Base, 'unhashable'),
2455 (True, None, object, 'tuple'),
2456 (True, None, Base, 'tuple'),
2457 (True, False, object, 'object'),
2458 (True, False, Base, 'base'),
2459 (True, True, object, 'tuple'),
2460 (True, True, Base, 'tuple'),
2461 ]:
2462
2463 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2464 # First, create the class.
2465 if frozen is None and eq is None:
2466 @dataclass
2467 class C(base):
2468 i: int
2469 elif frozen is None:
2470 @dataclass(eq=eq)
2471 class C(base):
2472 i: int
2473 elif eq is None:
2474 @dataclass(frozen=frozen)
2475 class C(base):
2476 i: int
2477 else:
2478 @dataclass(frozen=frozen, eq=eq)
2479 class C(base):
2480 i: int
2481
2482 # Now, make sure it hashes as expected.
2483 if expected == 'unhashable':
2484 c = C(10)
2485 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2486 hash(c)
2487
2488 elif expected == 'base':
2489 self.assertEqual(hash(C(10)), 301)
2490
2491 elif expected == 'object':
2492 # I'm not sure what test to use here. object's
2493 # hash isn't based on id(), so calling hash()
2494 # won't tell us much. So, just check the function
2495 # used is object's.
2496 self.assertIs(C.__hash__, object.__hash__)
2497
2498 elif expected == 'tuple':
2499 self.assertEqual(hash(C(42)), hash((42,)))
2500
2501 else:
2502 assert False, f'unknown value for expected={expected!r}'
2503
Eric V. Smithea8fc522018-01-27 19:07:40 -05002504
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002505class TestFrozen(unittest.TestCase):
2506 def test_frozen(self):
2507 @dataclass(frozen=True)
2508 class C:
2509 i: int
2510
2511 c = C(10)
2512 self.assertEqual(c.i, 10)
2513 with self.assertRaises(FrozenInstanceError):
2514 c.i = 5
2515 self.assertEqual(c.i, 10)
2516
2517 def test_inherit(self):
2518 @dataclass(frozen=True)
2519 class C:
2520 i: int
2521
2522 @dataclass(frozen=True)
2523 class D(C):
2524 j: int
2525
2526 d = D(0, 10)
2527 with self.assertRaises(FrozenInstanceError):
2528 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002529 with self.assertRaises(FrozenInstanceError):
2530 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002531 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002532 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002533
Miss Islington (bot)45648312018-03-18 18:03:36 -07002534 # Test both ways: with an intermediate normal (non-dataclass)
2535 # class and without an intermediate class.
2536 def test_inherit_nonfrozen_from_frozen(self):
2537 for intermediate_class in [True, False]:
2538 with self.subTest(intermediate_class=intermediate_class):
2539 @dataclass(frozen=True)
2540 class C:
2541 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002542
Miss Islington (bot)45648312018-03-18 18:03:36 -07002543 if intermediate_class:
2544 class I(C): pass
2545 else:
2546 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002547
Miss Islington (bot)45648312018-03-18 18:03:36 -07002548 with self.assertRaisesRegex(TypeError,
2549 'cannot inherit non-frozen dataclass from a frozen one'):
2550 @dataclass
2551 class D(I):
2552 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002553
Miss Islington (bot)45648312018-03-18 18:03:36 -07002554 def test_inherit_frozen_from_nonfrozen(self):
2555 for intermediate_class in [True, False]:
2556 with self.subTest(intermediate_class=intermediate_class):
2557 @dataclass
2558 class C:
2559 i: int
2560
2561 if intermediate_class:
2562 class I(C): pass
2563 else:
2564 I = C
2565
2566 with self.assertRaisesRegex(TypeError,
2567 'cannot inherit frozen dataclass from a non-frozen one'):
2568 @dataclass(frozen=True)
2569 class D(I):
2570 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002571
2572 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002573 for intermediate_class in [True, False]:
2574 with self.subTest(intermediate_class=intermediate_class):
2575 class C:
2576 pass
2577
2578 if intermediate_class:
2579 class I(C): pass
2580 else:
2581 I = C
2582
2583 @dataclass(frozen=True)
2584 class D(I):
2585 i: int
2586
2587 d = D(10)
2588 with self.assertRaises(FrozenInstanceError):
2589 d.i = 5
2590
2591 def test_non_frozen_normal_derived(self):
2592 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002593
2594 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002595 class D:
2596 x: int
2597 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002598
Miss Islington (bot)45648312018-03-18 18:03:36 -07002599 class S(D):
2600 pass
2601
2602 s = S(3)
2603 self.assertEqual(s.x, 3)
2604 self.assertEqual(s.y, 10)
2605 s.cached = True
2606
2607 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002608 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002609 s.x = 5
2610 with self.assertRaises(FrozenInstanceError):
2611 s.y = 5
2612 self.assertEqual(s.x, 3)
2613 self.assertEqual(s.y, 10)
2614 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002615
2616
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002617class TestSlots(unittest.TestCase):
2618 def test_simple(self):
2619 @dataclass
2620 class C:
2621 __slots__ = ('x',)
2622 x: Any
2623
2624 # There was a bug where a variable in a slot was assumed
2625 # to also have a default value (of type types.MemberDescriptorType).
2626 with self.assertRaisesRegex(TypeError,
2627 "__init__\(\) missing 1 required positional argument: 'x'"):
2628 C()
2629
2630 # We can create an instance, and assign to x.
2631 c = C(10)
2632 self.assertEqual(c.x, 10)
2633 c.x = 5
2634 self.assertEqual(c.x, 5)
2635
2636 # We can't assign to anything else.
2637 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2638 c.y = 5
2639
2640 def test_derived_added_field(self):
2641 # See bpo-33100.
2642 @dataclass
2643 class Base:
2644 __slots__ = ('x',)
2645 x: Any
2646
2647 @dataclass
2648 class Derived(Base):
2649 x: int
2650 y: int
2651
2652 d = Derived(1, 2)
2653 self.assertEqual((d.x, d.y), (1, 2))
2654
2655 # We can add a new field to the derived instance.
2656 d.z = 10
2657
2658
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002659if __name__ == '__main__':
2660 unittest.main()