blob: 8aff8ae140a5cd1c20f630b06dff51504fe7197e [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
11from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
12from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
15# Just any custom exception we can catch.
16class CustomError(Exception): pass
17
18class TestCase(unittest.TestCase):
19 def test_no_fields(self):
20 @dataclass
21 class C:
22 pass
23
24 o = C()
25 self.assertEqual(len(fields(C)), 0)
26
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -070027 def test_no_fields_but_member_variable(self):
28 @dataclass
29 class C:
30 i = 0
31
32 o = C()
33 self.assertEqual(len(fields(C)), 0)
34
Eric V. Smithf0db54a2017-12-04 16:58:55 -050035 def test_one_field_no_default(self):
36 @dataclass
37 class C:
38 x: int
39
40 o = C(42)
41 self.assertEqual(o.x, 42)
42
43 def test_named_init_params(self):
44 @dataclass
45 class C:
46 x: int
47
48 o = C(x=32)
49 self.assertEqual(o.x, 32)
50
51 def test_two_fields_one_default(self):
52 @dataclass
53 class C:
54 x: int
55 y: int = 0
56
57 o = C(3)
58 self.assertEqual((o.x, o.y), (3, 0))
59
60 # Non-defaults following defaults.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class C:
66 x: int = 0
67 y: int
68
69 # A derived class adds a non-default field after a default one.
70 with self.assertRaisesRegex(TypeError,
71 "non-default argument 'y' follows "
72 "default argument"):
73 @dataclass
74 class B:
75 x: int = 0
76
77 @dataclass
78 class C(B):
79 y: int
80
81 # Override a base class field and add a default to
82 # a field which didn't use to have a default.
83 with self.assertRaisesRegex(TypeError,
84 "non-default argument 'y' follows "
85 "default argument"):
86 @dataclass
87 class B:
88 x: int
89 y: int
90
91 @dataclass
92 class C(B):
93 x: int = 0
94
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080095 def test_overwrite_hash(self):
96 # Test that declaring this class isn't an error. It should
97 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
101 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800102 return 301
103 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # Test that declaring this class isn't an error. It should
106 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500107 @dataclass(frozen=True)
108 class C:
109 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800110 def __eq__(self, other):
111 return False
112 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500113
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800114 # But this one should generate an exception, because with
115 # unsafe_hash=True, it's an error to have a __hash__ defined.
116 with self.assertRaisesRegex(TypeError,
117 'Cannot overwrite attribute __hash__'):
118 @dataclass(unsafe_hash=True)
119 class C:
120 def __hash__(self):
121 pass
122
123 # Creating this class should not generate an exception,
124 # because even though __hash__ exists before @dataclass is
125 # called, (due to __eq__ being defined), since it's None
126 # that's okay.
127 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500128 class C:
129 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800130 def __eq__(self):
131 pass
132 # The generated hash function works as we'd expect.
133 self.assertEqual(hash(C(10)), hash((10,)))
134
135 # Creating this class should generate an exception, because
136 # __hash__ exists and is not None, which it would be if it had
137 # been auto-generated do due __eq__ being defined.
138 with self.assertRaisesRegex(TypeError,
139 'Cannot overwrite attribute __hash__'):
140 @dataclass(unsafe_hash=True)
141 class C:
142 x: int
143 def __eq__(self):
144 pass
145 def __hash__(self):
146 pass
147
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500148
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500149 def test_overwrite_fields_in_derived_class(self):
150 # Note that x from C1 replaces x in Base, but the order remains
151 # the same as defined in Base.
152 @dataclass
153 class Base:
154 x: Any = 15.0
155 y: int = 0
156
157 @dataclass
158 class C1(Base):
159 z: int = 10
160 x: int = 15
161
162 o = Base()
163 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
164
165 o = C1()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
167
168 o = C1(x=5)
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
170
171 def test_field_named_self(self):
172 @dataclass
173 class C:
174 self: str
175 c=C('foo')
176 self.assertEqual(c.self, 'foo')
177
178 # Make sure the first parameter is not named 'self'.
179 sig = inspect.signature(C.__init__)
180 first = next(iter(sig.parameters))
181 self.assertNotEqual('self', first)
182
183 # But we do use 'self' if no field named self.
184 @dataclass
185 class C:
186 selfx: str
187
188 # Make sure the first parameter is named 'self'.
189 sig = inspect.signature(C.__init__)
190 first = next(iter(sig.parameters))
191 self.assertEqual('self', first)
192
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500193 def test_0_field_compare(self):
194 # Ensure that order=False is the default.
195 @dataclass
196 class C0:
197 pass
198
199 @dataclass(order=False)
200 class C1:
201 pass
202
203 for cls in [C0, C1]:
204 with self.subTest(cls=cls):
205 self.assertEqual(cls(), cls())
206 for idx, fn in enumerate([lambda a, b: a < b,
207 lambda a, b: a <= b,
208 lambda a, b: a > b,
209 lambda a, b: a >= b]):
210 with self.subTest(idx=idx):
211 with self.assertRaisesRegex(TypeError,
212 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
213 fn(cls(), cls())
214
215 @dataclass(order=True)
216 class C:
217 pass
218 self.assertLessEqual(C(), C())
219 self.assertGreaterEqual(C(), C())
220
221 def test_1_field_compare(self):
222 # Ensure that order=False is the default.
223 @dataclass
224 class C0:
225 x: int
226
227 @dataclass(order=False)
228 class C1:
229 x: int
230
231 for cls in [C0, C1]:
232 with self.subTest(cls=cls):
233 self.assertEqual(cls(1), cls(1))
234 self.assertNotEqual(cls(0), cls(1))
235 for idx, fn in enumerate([lambda a, b: a < b,
236 lambda a, b: a <= b,
237 lambda a, b: a > b,
238 lambda a, b: a >= b]):
239 with self.subTest(idx=idx):
240 with self.assertRaisesRegex(TypeError,
241 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
242 fn(cls(0), cls(0))
243
244 @dataclass(order=True)
245 class C:
246 x: int
247 self.assertLess(C(0), C(1))
248 self.assertLessEqual(C(0), C(1))
249 self.assertLessEqual(C(1), C(1))
250 self.assertGreater(C(1), C(0))
251 self.assertGreaterEqual(C(1), C(0))
252 self.assertGreaterEqual(C(1), C(1))
253
254 def test_simple_compare(self):
255 # Ensure that order=False is the default.
256 @dataclass
257 class C0:
258 x: int
259 y: int
260
261 @dataclass(order=False)
262 class C1:
263 x: int
264 y: int
265
266 for cls in [C0, C1]:
267 with self.subTest(cls=cls):
268 self.assertEqual(cls(0, 0), cls(0, 0))
269 self.assertEqual(cls(1, 2), cls(1, 2))
270 self.assertNotEqual(cls(1, 0), cls(0, 0))
271 self.assertNotEqual(cls(1, 0), cls(1, 1))
272 for idx, fn in enumerate([lambda a, b: a < b,
273 lambda a, b: a <= b,
274 lambda a, b: a > b,
275 lambda a, b: a >= b]):
276 with self.subTest(idx=idx):
277 with self.assertRaisesRegex(TypeError,
278 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
279 fn(cls(0, 0), cls(0, 0))
280
281 @dataclass(order=True)
282 class C:
283 x: int
284 y: int
285
286 for idx, fn in enumerate([lambda a, b: a == b,
287 lambda a, b: a <= b,
288 lambda a, b: a >= b]):
289 with self.subTest(idx=idx):
290 self.assertTrue(fn(C(0, 0), C(0, 0)))
291
292 for idx, fn in enumerate([lambda a, b: a < b,
293 lambda a, b: a <= b,
294 lambda a, b: a != b]):
295 with self.subTest(idx=idx):
296 self.assertTrue(fn(C(0, 0), C(0, 1)))
297 self.assertTrue(fn(C(0, 1), C(1, 0)))
298 self.assertTrue(fn(C(1, 0), C(1, 1)))
299
300 for idx, fn in enumerate([lambda a, b: a > b,
301 lambda a, b: a >= b,
302 lambda a, b: a != b]):
303 with self.subTest(idx=idx):
304 self.assertTrue(fn(C(0, 1), C(0, 0)))
305 self.assertTrue(fn(C(1, 0), C(0, 1)))
306 self.assertTrue(fn(C(1, 1), C(1, 0)))
307
308 def test_compare_subclasses(self):
309 # Comparisons fail for subclasses, even if no fields
310 # are added.
311 @dataclass
312 class B:
313 i: int
314
315 @dataclass
316 class C(B):
317 pass
318
319 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
320 (lambda a, b: a != b, True)]):
321 with self.subTest(idx=idx):
322 self.assertEqual(fn(B(0), C(0)), expected)
323
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 "not supported between instances of 'B' and 'C'"):
331 fn(B(0), C(0))
332
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500333 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500334 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500335 for (eq, order, result ) in [
336 (False, False, 'neither'),
337 (False, True, 'exception'),
338 (True, False, 'eq_only'),
339 (True, True, 'both'),
340 ]:
341 with self.subTest(eq=eq, order=order):
342 if result == 'exception':
343 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
344 @dataclass(eq=eq, order=order)
345 class C:
346 pass
347 else:
348 @dataclass(eq=eq, order=order)
349 class C:
350 pass
351
352 if result == 'neither':
353 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500354 self.assertNotIn('__lt__', C.__dict__)
355 self.assertNotIn('__le__', C.__dict__)
356 self.assertNotIn('__gt__', C.__dict__)
357 self.assertNotIn('__ge__', C.__dict__)
358 elif result == 'both':
359 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500360 self.assertIn('__lt__', C.__dict__)
361 self.assertIn('__le__', C.__dict__)
362 self.assertIn('__gt__', C.__dict__)
363 self.assertIn('__ge__', C.__dict__)
364 elif result == 'eq_only':
365 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500366 self.assertNotIn('__lt__', C.__dict__)
367 self.assertNotIn('__le__', C.__dict__)
368 self.assertNotIn('__gt__', C.__dict__)
369 self.assertNotIn('__ge__', C.__dict__)
370 else:
371 assert False, f'unknown result {result!r}'
372
373 def test_field_no_default(self):
374 @dataclass
375 class C:
376 x: int = field()
377
378 self.assertEqual(C(5).x, 5)
379
380 with self.assertRaisesRegex(TypeError,
381 r"__init__\(\) missing 1 required "
382 "positional argument: 'x'"):
383 C()
384
385 def test_field_default(self):
386 default = object()
387 @dataclass
388 class C:
389 x: object = field(default=default)
390
391 self.assertIs(C.x, default)
392 c = C(10)
393 self.assertEqual(c.x, 10)
394
395 # If we delete the instance attribute, we should then see the
396 # class attribute.
397 del c.x
398 self.assertIs(c.x, default)
399
400 self.assertIs(C().x, default)
401
402 def test_not_in_repr(self):
403 @dataclass
404 class C:
405 x: int = field(repr=False)
406 with self.assertRaises(TypeError):
407 C()
408 c = C(10)
409 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
410
411 @dataclass
412 class C:
413 x: int = field(repr=False)
414 y: int
415 c = C(10, 20)
416 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
417
418 def test_not_in_compare(self):
419 @dataclass
420 class C:
421 x: int = 0
422 y: int = field(compare=False, default=4)
423
424 self.assertEqual(C(), C(0, 20))
425 self.assertEqual(C(1, 10), C(1, 20))
426 self.assertNotEqual(C(3), C(4, 10))
427 self.assertNotEqual(C(3, 10), C(4, 10))
428
429 def test_hash_field_rules(self):
430 # Test all 6 cases of:
431 # hash=True/False/None
432 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800433 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 (True, False, 'field' ),
435 (True, True, 'field' ),
436 (False, False, 'absent'),
437 (False, True, 'absent'),
438 (None, False, 'absent'),
439 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800440 ]:
441 with self.subTest(hash=hash_, compare=compare):
442 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500443 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800444 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500445
446 if result == 'field':
447 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800448 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500449 elif result == 'absent':
450 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800451 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500452 else:
453 assert False, f'unknown result {result!r}'
454
455 def test_init_false_no_default(self):
456 # If init=False and no default value, then the field won't be
457 # present in the instance.
458 @dataclass
459 class C:
460 x: int = field(init=False)
461
462 self.assertNotIn('x', C().__dict__)
463
464 @dataclass
465 class C:
466 x: int
467 y: int = 0
468 z: int = field(init=False)
469 t: int = 10
470
471 self.assertNotIn('z', C(0).__dict__)
472 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
473
474 def test_class_marker(self):
475 @dataclass
476 class C:
477 x: int
478 y: str = field(init=False, default=None)
479 z: str = field(repr=False)
480
481 the_fields = fields(C)
482 # the_fields is a tuple of 3 items, each value
483 # is in __annotations__.
484 self.assertIsInstance(the_fields, tuple)
485 for f in the_fields:
486 self.assertIs(type(f), Field)
487 self.assertIn(f.name, C.__annotations__)
488
489 self.assertEqual(len(the_fields), 3)
490
491 self.assertEqual(the_fields[0].name, 'x')
492 self.assertEqual(the_fields[0].type, int)
493 self.assertFalse(hasattr(C, 'x'))
494 self.assertTrue (the_fields[0].init)
495 self.assertTrue (the_fields[0].repr)
496 self.assertEqual(the_fields[1].name, 'y')
497 self.assertEqual(the_fields[1].type, str)
498 self.assertIsNone(getattr(C, 'y'))
499 self.assertFalse(the_fields[1].init)
500 self.assertTrue (the_fields[1].repr)
501 self.assertEqual(the_fields[2].name, 'z')
502 self.assertEqual(the_fields[2].type, str)
503 self.assertFalse(hasattr(C, 'z'))
504 self.assertTrue (the_fields[2].init)
505 self.assertFalse(the_fields[2].repr)
506
507 def test_field_order(self):
508 @dataclass
509 class B:
510 a: str = 'B:a'
511 b: str = 'B:b'
512 c: str = 'B:c'
513
514 @dataclass
515 class C(B):
516 b: str = 'C:b'
517
518 self.assertEqual([(f.name, f.default) for f in fields(C)],
519 [('a', 'B:a'),
520 ('b', 'C:b'),
521 ('c', 'B:c')])
522
523 @dataclass
524 class D(B):
525 c: str = 'D:c'
526
527 self.assertEqual([(f.name, f.default) for f in fields(D)],
528 [('a', 'B:a'),
529 ('b', 'B:b'),
530 ('c', 'D:c')])
531
532 @dataclass
533 class E(D):
534 a: str = 'E:a'
535 d: str = 'E:d'
536
537 self.assertEqual([(f.name, f.default) for f in fields(E)],
538 [('a', 'E:a'),
539 ('b', 'B:b'),
540 ('c', 'D:c'),
541 ('d', 'E:d')])
542
543 def test_class_attrs(self):
544 # We only have a class attribute if a default value is
545 # specified, either directly or via a field with a default.
546 default = object()
547 @dataclass
548 class C:
549 x: int
550 y: int = field(repr=False)
551 z: object = default
552 t: int = field(default=100)
553
554 self.assertFalse(hasattr(C, 'x'))
555 self.assertFalse(hasattr(C, 'y'))
556 self.assertIs (C.z, default)
557 self.assertEqual(C.t, 100)
558
559 def test_disallowed_mutable_defaults(self):
560 # For the known types, don't allow mutable default values.
561 for typ, empty, non_empty in [(list, [], [1]),
562 (dict, {}, {0:1}),
563 (set, set(), set([1])),
564 ]:
565 with self.subTest(typ=typ):
566 # Can't use a zero-length value.
567 with self.assertRaisesRegex(ValueError,
568 f'mutable default {typ} for field '
569 'x is not allowed'):
570 @dataclass
571 class Point:
572 x: typ = empty
573
574
575 # Nor a non-zero-length value
576 with self.assertRaisesRegex(ValueError,
577 f'mutable default {typ} for field '
578 'y is not allowed'):
579 @dataclass
580 class Point:
581 y: typ = non_empty
582
583 # Check subtypes also fail.
584 class Subclass(typ): pass
585
586 with self.assertRaisesRegex(ValueError,
587 f"mutable default .*Subclass'>"
588 ' for field z is not allowed'
589 ):
590 @dataclass
591 class Point:
592 z: typ = Subclass()
593
594 # Because this is a ClassVar, it can be mutable.
595 @dataclass
596 class C:
597 z: ClassVar[typ] = typ()
598
599 # Because this is a ClassVar, it can be mutable.
600 @dataclass
601 class C:
602 x: ClassVar[typ] = Subclass()
603
604
605 def test_deliberately_mutable_defaults(self):
606 # If a mutable default isn't in the known list of
607 # (list, dict, set), then it's okay.
608 class Mutable:
609 def __init__(self):
610 self.l = []
611
612 @dataclass
613 class C:
614 x: Mutable
615
616 # These 2 instances will share this value of x.
617 lst = Mutable()
618 o1 = C(lst)
619 o2 = C(lst)
620 self.assertEqual(o1, o2)
621 o1.x.l.extend([1, 2])
622 self.assertEqual(o1, o2)
623 self.assertEqual(o1.x.l, [1, 2])
624 self.assertIs(o1.x, o2.x)
625
626 def test_no_options(self):
627 # call with dataclass()
628 @dataclass()
629 class C:
630 x: int
631
632 self.assertEqual(C(42).x, 42)
633
634 def test_not_tuple(self):
635 # Make sure we can't be compared to a tuple.
636 @dataclass
637 class Point:
638 x: int
639 y: int
640 self.assertNotEqual(Point(1, 2), (1, 2))
641
642 # And that we can't compare to another unrelated dataclass
643 @dataclass
644 class C:
645 x: int
646 y: int
647 self.assertNotEqual(Point(1, 3), C(1, 3))
648
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500649 def test_not_tuple(self):
650 # Test that some of the problems with namedtuple don't happen
651 # here.
652 @dataclass
653 class Point3D:
654 x: int
655 y: int
656 z: int
657
658 @dataclass
659 class Date:
660 year: int
661 month: int
662 day: int
663
664 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
665 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
666
667 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200668 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500669 x, y, z = Point3D(4, 5, 6)
670
Eric V. Smith7c99e932018-01-28 19:18:55 -0500671 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500672 # equal.
673 @dataclass
674 class Point3Dv1:
675 x: int = 0
676 y: int = 0
677 z: int = 0
678 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
679
680 def test_function_annotations(self):
681 # Some dummy class and instance to use as a default.
682 class F:
683 pass
684 f = F()
685
686 def validate_class(cls):
687 # First, check __annotations__, even though they're not
688 # function annotations.
689 self.assertEqual(cls.__annotations__['i'], int)
690 self.assertEqual(cls.__annotations__['j'], str)
691 self.assertEqual(cls.__annotations__['k'], F)
692 self.assertEqual(cls.__annotations__['l'], float)
693 self.assertEqual(cls.__annotations__['z'], complex)
694
695 # Verify __init__.
696
697 signature = inspect.signature(cls.__init__)
698 # Check the return type, should be None
699 self.assertIs(signature.return_annotation, None)
700
701 # Check each parameter.
702 params = iter(signature.parameters.values())
703 param = next(params)
704 # This is testing an internal name, and probably shouldn't be tested.
705 self.assertEqual(param.name, 'self')
706 param = next(params)
707 self.assertEqual(param.name, 'i')
708 self.assertIs (param.annotation, int)
709 self.assertEqual(param.default, inspect.Parameter.empty)
710 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
711 param = next(params)
712 self.assertEqual(param.name, 'j')
713 self.assertIs (param.annotation, str)
714 self.assertEqual(param.default, inspect.Parameter.empty)
715 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
716 param = next(params)
717 self.assertEqual(param.name, 'k')
718 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500719 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
721 param = next(params)
722 self.assertEqual(param.name, 'l')
723 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500724 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500725 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
726 self.assertRaises(StopIteration, next, params)
727
728
729 @dataclass
730 class C:
731 i: int
732 j: str
733 k: F = f
734 l: float=field(default=None)
735 z: complex=field(default=3+4j, init=False)
736
737 validate_class(C)
738
739 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800740 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500741 class C:
742 i: int
743 j: str
744 k: F = f
745 l: float=field(default=None)
746 z: complex=field(default=3+4j, init=False)
747
748 validate_class(C)
749
Eric V. Smith03220fd2017-12-29 13:59:58 -0500750 def test_missing_default(self):
751 # Test that MISSING works the same as a default not being
752 # specified.
753 @dataclass
754 class C:
755 x: int=field(default=MISSING)
756 with self.assertRaisesRegex(TypeError,
757 r'__init__\(\) missing 1 required '
758 'positional argument'):
759 C()
760 self.assertNotIn('x', C.__dict__)
761
762 @dataclass
763 class D:
764 x: int
765 with self.assertRaisesRegex(TypeError,
766 r'__init__\(\) missing 1 required '
767 'positional argument'):
768 D()
769 self.assertNotIn('x', D.__dict__)
770
771 def test_missing_default_factory(self):
772 # Test that MISSING works the same as a default factory not
773 # being specified (which is really the same as a default not
774 # being specified, too).
775 @dataclass
776 class C:
777 x: int=field(default_factory=MISSING)
778 with self.assertRaisesRegex(TypeError,
779 r'__init__\(\) missing 1 required '
780 'positional argument'):
781 C()
782 self.assertNotIn('x', C.__dict__)
783
784 @dataclass
785 class D:
786 x: int=field(default=MISSING, default_factory=MISSING)
787 with self.assertRaisesRegex(TypeError,
788 r'__init__\(\) missing 1 required '
789 'positional argument'):
790 D()
791 self.assertNotIn('x', D.__dict__)
792
793 def test_missing_repr(self):
794 self.assertIn('MISSING_TYPE object', repr(MISSING))
795
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500796 def test_dont_include_other_annotations(self):
797 @dataclass
798 class C:
799 i: int
800 def foo(self) -> int:
801 return 4
802 @property
803 def bar(self) -> int:
804 return 5
805 self.assertEqual(list(C.__annotations__), ['i'])
806 self.assertEqual(C(10).foo(), 4)
807 self.assertEqual(C(10).bar, 5)
808
809 def test_post_init(self):
810 # Just make sure it gets called
811 @dataclass
812 class C:
813 def __post_init__(self):
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817
818 @dataclass
819 class C:
820 i: int = 10
821 def __post_init__(self):
822 if self.i == 10:
823 raise CustomError()
824 with self.assertRaises(CustomError):
825 C()
826 # post-init gets called, but doesn't raise. This is just
827 # checking that self is used correctly.
828 C(5)
829
830 # If there's not an __init__, then post-init won't get called.
831 @dataclass(init=False)
832 class C:
833 def __post_init__(self):
834 raise CustomError()
835 # Creating the class won't raise
836 C()
837
838 @dataclass
839 class C:
840 x: int = 0
841 def __post_init__(self):
842 self.x *= 2
843 self.assertEqual(C().x, 0)
844 self.assertEqual(C(2).x, 4)
845
Mike53f7a7c2017-12-14 14:04:53 +0300846 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 # attributes.
848 @dataclass(frozen=True)
849 class C:
850 x: int = 0
851 def __post_init__(self):
852 self.x *= 2
853 with self.assertRaises(FrozenInstanceError):
854 C()
855
856 def test_post_init_super(self):
857 # Make sure super() post-init isn't called by default.
858 class B:
859 def __post_init__(self):
860 raise CustomError()
861
862 @dataclass
863 class C(B):
864 def __post_init__(self):
865 self.x = 5
866
867 self.assertEqual(C().x, 5)
868
869 # Now call super(), and it will raise
870 @dataclass
871 class C(B):
872 def __post_init__(self):
873 super().__post_init__()
874
875 with self.assertRaises(CustomError):
876 C()
877
878 # Make sure post-init is called, even if not defined in our
879 # class.
880 @dataclass
881 class C(B):
882 pass
883
884 with self.assertRaises(CustomError):
885 C()
886
887 def test_post_init_staticmethod(self):
888 flag = False
889 @dataclass
890 class C:
891 x: int
892 y: int
893 @staticmethod
894 def __post_init__():
895 nonlocal flag
896 flag = True
897
898 self.assertFalse(flag)
899 c = C(3, 4)
900 self.assertEqual((c.x, c.y), (3, 4))
901 self.assertTrue(flag)
902
903 def test_post_init_classmethod(self):
904 @dataclass
905 class C:
906 flag = False
907 x: int
908 y: int
909 @classmethod
910 def __post_init__(cls):
911 cls.flag = True
912
913 self.assertFalse(C.flag)
914 c = C(3, 4)
915 self.assertEqual((c.x, c.y), (3, 4))
916 self.assertTrue(C.flag)
917
918 def test_class_var(self):
919 # Make sure ClassVars are ignored in __init__, __repr__, etc.
920 @dataclass
921 class C:
922 x: int
923 y: int = 10
924 z: ClassVar[int] = 1000
925 w: ClassVar[int] = 2000
926 t: ClassVar[int] = 3000
927
928 c = C(5)
929 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
930 self.assertEqual(len(fields(C)), 2) # We have 2 fields
931 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
932 self.assertEqual(c.z, 1000)
933 self.assertEqual(c.w, 2000)
934 self.assertEqual(c.t, 3000)
935 C.z += 1
936 self.assertEqual(c.z, 1001)
937 c = C(20)
938 self.assertEqual((c.x, c.y), (20, 10))
939 self.assertEqual(c.z, 1001)
940 self.assertEqual(c.w, 2000)
941 self.assertEqual(c.t, 3000)
942
943 def test_class_var_no_default(self):
944 # If a ClassVar has no default value, it should not be set on the class.
945 @dataclass
946 class C:
947 x: ClassVar[int]
948
949 self.assertNotIn('x', C.__dict__)
950
951 def test_class_var_default_factory(self):
952 # It makes no sense for a ClassVar to have a default factory. When
953 # would it be called? Call it yourself, since it's class-wide.
954 with self.assertRaisesRegex(TypeError,
955 'cannot have a default factory'):
956 @dataclass
957 class C:
958 x: ClassVar[int] = field(default_factory=int)
959
960 self.assertNotIn('x', C.__dict__)
961
962 def test_class_var_with_default(self):
963 # If a ClassVar has a default value, it should be set on the class.
964 @dataclass
965 class C:
966 x: ClassVar[int] = 10
967 self.assertEqual(C.x, 10)
968
969 @dataclass
970 class C:
971 x: ClassVar[int] = field(default=10)
972 self.assertEqual(C.x, 10)
973
974 def test_class_var_frozen(self):
975 # Make sure ClassVars work even if we're frozen.
976 @dataclass(frozen=True)
977 class C:
978 x: int
979 y: int = 10
980 z: ClassVar[int] = 1000
981 w: ClassVar[int] = 2000
982 t: ClassVar[int] = 3000
983
984 c = C(5)
985 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
986 self.assertEqual(len(fields(C)), 2) # We have 2 fields
987 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
988 self.assertEqual(c.z, 1000)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991 # We can still modify the ClassVar, it's only instances that are
992 # frozen.
993 C.z += 1
994 self.assertEqual(c.z, 1001)
995 c = C(20)
996 self.assertEqual((c.x, c.y), (20, 10))
997 self.assertEqual(c.z, 1001)
998 self.assertEqual(c.w, 2000)
999 self.assertEqual(c.t, 3000)
1000
1001 def test_init_var_no_default(self):
1002 # If an InitVar has no default value, it should not be set on the class.
1003 @dataclass
1004 class C:
1005 x: InitVar[int]
1006
1007 self.assertNotIn('x', C.__dict__)
1008
1009 def test_init_var_default_factory(self):
1010 # It makes no sense for an InitVar to have a default factory. When
1011 # would it be called? Call it yourself, since it's class-wide.
1012 with self.assertRaisesRegex(TypeError,
1013 'cannot have a default factory'):
1014 @dataclass
1015 class C:
1016 x: InitVar[int] = field(default_factory=int)
1017
1018 self.assertNotIn('x', C.__dict__)
1019
1020 def test_init_var_with_default(self):
1021 # If an InitVar has a default value, it should be set on the class.
1022 @dataclass
1023 class C:
1024 x: InitVar[int] = 10
1025 self.assertEqual(C.x, 10)
1026
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = field(default=10)
1030 self.assertEqual(C.x, 10)
1031
1032 def test_init_var(self):
1033 @dataclass
1034 class C:
1035 x: int = None
1036 init_param: InitVar[int] = None
1037
1038 def __post_init__(self, init_param):
1039 if self.x is None:
1040 self.x = init_param*2
1041
1042 c = C(init_param=10)
1043 self.assertEqual(c.x, 20)
1044
1045 def test_init_var_inheritance(self):
1046 # Note that this deliberately tests that a dataclass need not
1047 # have a __post_init__ function if it has an InitVar field.
1048 # It could just be used in a derived class, as shown here.
1049 @dataclass
1050 class Base:
1051 x: int
1052 init_base: InitVar[int]
1053
1054 # We can instantiate by passing the InitVar, even though
1055 # it's not used.
1056 b = Base(0, 10)
1057 self.assertEqual(vars(b), {'x': 0})
1058
1059 @dataclass
1060 class C(Base):
1061 y: int
1062 init_derived: InitVar[int]
1063
1064 def __post_init__(self, init_base, init_derived):
1065 self.x = self.x + init_base
1066 self.y = self.y + init_derived
1067
1068 c = C(10, 11, 50, 51)
1069 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1070
1071 def test_default_factory(self):
1072 # Test a factory that returns a new list.
1073 @dataclass
1074 class C:
1075 x: int
1076 y: list = field(default_factory=list)
1077
1078 c0 = C(3)
1079 c1 = C(3)
1080 self.assertEqual(c0.x, 3)
1081 self.assertEqual(c0.y, [])
1082 self.assertEqual(c0, c1)
1083 self.assertIsNot(c0.y, c1.y)
1084 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1085
1086 # Test a factory that returns a shared list.
1087 l = []
1088 @dataclass
1089 class C:
1090 x: int
1091 y: list = field(default_factory=lambda: l)
1092
1093 c0 = C(3)
1094 c1 = C(3)
1095 self.assertEqual(c0.x, 3)
1096 self.assertEqual(c0.y, [])
1097 self.assertEqual(c0, c1)
1098 self.assertIs(c0.y, c1.y)
1099 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1100
1101 # Test various other field flags.
1102 # repr
1103 @dataclass
1104 class C:
1105 x: list = field(default_factory=list, repr=False)
1106 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1107 self.assertEqual(C().x, [])
1108
1109 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001110 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001111 class C:
1112 x: list = field(default_factory=list, hash=False)
1113 self.assertEqual(astuple(C()), ([],))
1114 self.assertEqual(hash(C()), hash(()))
1115
1116 # init (see also test_default_factory_with_no_init)
1117 @dataclass
1118 class C:
1119 x: list = field(default_factory=list, init=False)
1120 self.assertEqual(astuple(C()), ([],))
1121
1122 # compare
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=list, compare=False)
1126 self.assertEqual(C(), C([1]))
1127
1128 def test_default_factory_with_no_init(self):
1129 # We need a factory with a side effect.
1130 factory = Mock()
1131
1132 @dataclass
1133 class C:
1134 x: list = field(default_factory=factory, init=False)
1135
1136 # Make sure the default factory is called for each new instance.
1137 C().x
1138 self.assertEqual(factory.call_count, 1)
1139 C().x
1140 self.assertEqual(factory.call_count, 2)
1141
1142 def test_default_factory_not_called_if_value_given(self):
1143 # We need a factory that we can test if it's been called.
1144 factory = Mock()
1145
1146 @dataclass
1147 class C:
1148 x: int = field(default_factory=factory)
1149
1150 # Make sure that if a field has a default factory function,
1151 # it's not called if a value is specified.
1152 C().x
1153 self.assertEqual(factory.call_count, 1)
1154 self.assertEqual(C(10).x, 10)
1155 self.assertEqual(factory.call_count, 1)
1156 C().x
1157 self.assertEqual(factory.call_count, 2)
1158
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001159 def test_default_factory_derived(self):
1160 # See bpo-32896.
1161 @dataclass
1162 class Foo:
1163 x: dict = field(default_factory=dict)
1164
1165 @dataclass
1166 class Bar(Foo):
1167 y: int = 1
1168
1169 self.assertEqual(Foo().x, {})
1170 self.assertEqual(Bar().x, {})
1171 self.assertEqual(Bar().y, 1)
1172
1173 @dataclass
1174 class Baz(Foo):
1175 pass
1176 self.assertEqual(Baz().x, {})
1177
1178 def test_intermediate_non_dataclass(self):
1179 # Test that an intermediate class that defines
1180 # annotations does not define fields.
1181
1182 @dataclass
1183 class A:
1184 x: int
1185
1186 class B(A):
1187 y: int
1188
1189 @dataclass
1190 class C(B):
1191 z: int
1192
1193 c = C(1, 3)
1194 self.assertEqual((c.x, c.z), (1, 3))
1195
1196 # .y was not initialized.
1197 with self.assertRaisesRegex(AttributeError,
1198 'object has no attribute'):
1199 c.y
1200
1201 # And if we again derive a non-dataclass, no fields are added.
1202 class D(C):
1203 t: int
1204 d = D(4, 5)
1205 self.assertEqual((d.x, d.z), (4, 5))
1206
1207
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001208 def x_test_classvar_default_factory(self):
1209 # XXX: it's an error for a ClassVar to have a factory function
1210 @dataclass
1211 class C:
1212 x: ClassVar[int] = field(default_factory=int)
1213
1214 self.assertIs(C().x, int)
1215
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001216 def test_is_dataclass(self):
1217 class NotDataClass:
1218 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001219
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001220 self.assertFalse(is_dataclass(0))
1221 self.assertFalse(is_dataclass(int))
1222 self.assertFalse(is_dataclass(NotDataClass))
1223 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001224
1225 @dataclass
1226 class C:
1227 x: int
1228
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001229 @dataclass
1230 class D:
1231 d: C
1232 e: int
1233
1234 c = C(10)
1235 d = D(c, 4)
1236
1237 self.assertTrue(is_dataclass(C))
1238 self.assertTrue(is_dataclass(c))
1239 self.assertFalse(is_dataclass(c.x))
1240 self.assertTrue(is_dataclass(d.d))
1241 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001242
1243 def test_helper_fields_with_class_instance(self):
1244 # Check that we can call fields() on either a class or instance,
1245 # and get back the same thing.
1246 @dataclass
1247 class C:
1248 x: int
1249 y: float
1250
1251 self.assertEqual(fields(C), fields(C(0, 0.0)))
1252
1253 def test_helper_fields_exception(self):
1254 # Check that TypeError is raised if not passed a dataclass or
1255 # instance.
1256 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1257 fields(0)
1258
1259 class C: pass
1260 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1261 fields(C)
1262 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1263 fields(C())
1264
1265 def test_helper_asdict(self):
1266 # Basic tests for asdict(), it should return a new dictionary
1267 @dataclass
1268 class C:
1269 x: int
1270 y: int
1271 c = C(1, 2)
1272
1273 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1274 self.assertEqual(asdict(c), asdict(c))
1275 self.assertIsNot(asdict(c), asdict(c))
1276 c.x = 42
1277 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1278 self.assertIs(type(asdict(c)), dict)
1279
1280 def test_helper_asdict_raises_on_classes(self):
1281 # asdict() should raise on a class object
1282 @dataclass
1283 class C:
1284 x: int
1285 y: int
1286 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1287 asdict(C)
1288 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1289 asdict(int)
1290
1291 def test_helper_asdict_copy_values(self):
1292 @dataclass
1293 class C:
1294 x: int
1295 y: List[int] = field(default_factory=list)
1296 initial = []
1297 c = C(1, initial)
1298 d = asdict(c)
1299 self.assertEqual(d['y'], initial)
1300 self.assertIsNot(d['y'], initial)
1301 c = C(1)
1302 d = asdict(c)
1303 d['y'].append(1)
1304 self.assertEqual(c.y, [])
1305
1306 def test_helper_asdict_nested(self):
1307 @dataclass
1308 class UserId:
1309 token: int
1310 group: int
1311 @dataclass
1312 class User:
1313 name: str
1314 id: UserId
1315 u = User('Joe', UserId(123, 1))
1316 d = asdict(u)
1317 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1318 self.assertIsNot(asdict(u), asdict(u))
1319 u.id.group = 2
1320 self.assertEqual(asdict(u), {'name': 'Joe',
1321 'id': {'token': 123, 'group': 2}})
1322
1323 def test_helper_asdict_builtin_containers(self):
1324 @dataclass
1325 class User:
1326 name: str
1327 id: int
1328 @dataclass
1329 class GroupList:
1330 id: int
1331 users: List[User]
1332 @dataclass
1333 class GroupTuple:
1334 id: int
1335 users: Tuple[User, ...]
1336 @dataclass
1337 class GroupDict:
1338 id: int
1339 users: Dict[str, User]
1340 a = User('Alice', 1)
1341 b = User('Bob', 2)
1342 gl = GroupList(0, [a, b])
1343 gt = GroupTuple(0, (a, b))
1344 gd = GroupDict(0, {'first': a, 'second': b})
1345 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1346 {'name': 'Bob', 'id': 2}]})
1347 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1348 {'name': 'Bob', 'id': 2})})
1349 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1350 'second': {'name': 'Bob', 'id': 2}}})
1351
1352 def test_helper_asdict_builtin_containers(self):
1353 @dataclass
1354 class Child:
1355 d: object
1356
1357 @dataclass
1358 class Parent:
1359 child: Child
1360
1361 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1362 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1363
1364 def test_helper_asdict_factory(self):
1365 @dataclass
1366 class C:
1367 x: int
1368 y: int
1369 c = C(1, 2)
1370 d = asdict(c, dict_factory=OrderedDict)
1371 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1372 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1373 c.x = 42
1374 d = asdict(c, dict_factory=OrderedDict)
1375 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1376 self.assertIs(type(d), OrderedDict)
1377
1378 def test_helper_astuple(self):
1379 # Basic tests for astuple(), it should return a new tuple
1380 @dataclass
1381 class C:
1382 x: int
1383 y: int = 0
1384 c = C(1)
1385
1386 self.assertEqual(astuple(c), (1, 0))
1387 self.assertEqual(astuple(c), astuple(c))
1388 self.assertIsNot(astuple(c), astuple(c))
1389 c.y = 42
1390 self.assertEqual(astuple(c), (1, 42))
1391 self.assertIs(type(astuple(c)), tuple)
1392
1393 def test_helper_astuple_raises_on_classes(self):
1394 # astuple() should raise on a class object
1395 @dataclass
1396 class C:
1397 x: int
1398 y: int
1399 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1400 astuple(C)
1401 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1402 astuple(int)
1403
1404 def test_helper_astuple_copy_values(self):
1405 @dataclass
1406 class C:
1407 x: int
1408 y: List[int] = field(default_factory=list)
1409 initial = []
1410 c = C(1, initial)
1411 t = astuple(c)
1412 self.assertEqual(t[1], initial)
1413 self.assertIsNot(t[1], initial)
1414 c = C(1)
1415 t = astuple(c)
1416 t[1].append(1)
1417 self.assertEqual(c.y, [])
1418
1419 def test_helper_astuple_nested(self):
1420 @dataclass
1421 class UserId:
1422 token: int
1423 group: int
1424 @dataclass
1425 class User:
1426 name: str
1427 id: UserId
1428 u = User('Joe', UserId(123, 1))
1429 t = astuple(u)
1430 self.assertEqual(t, ('Joe', (123, 1)))
1431 self.assertIsNot(astuple(u), astuple(u))
1432 u.id.group = 2
1433 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1434
1435 def test_helper_astuple_builtin_containers(self):
1436 @dataclass
1437 class User:
1438 name: str
1439 id: int
1440 @dataclass
1441 class GroupList:
1442 id: int
1443 users: List[User]
1444 @dataclass
1445 class GroupTuple:
1446 id: int
1447 users: Tuple[User, ...]
1448 @dataclass
1449 class GroupDict:
1450 id: int
1451 users: Dict[str, User]
1452 a = User('Alice', 1)
1453 b = User('Bob', 2)
1454 gl = GroupList(0, [a, b])
1455 gt = GroupTuple(0, (a, b))
1456 gd = GroupDict(0, {'first': a, 'second': b})
1457 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1458 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1459 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1460
1461 def test_helper_astuple_builtin_containers(self):
1462 @dataclass
1463 class Child:
1464 d: object
1465
1466 @dataclass
1467 class Parent:
1468 child: Child
1469
1470 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1471 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1472
1473 def test_helper_astuple_factory(self):
1474 @dataclass
1475 class C:
1476 x: int
1477 y: int
1478 NT = namedtuple('NT', 'x y')
1479 def nt(lst):
1480 return NT(*lst)
1481 c = C(1, 2)
1482 t = astuple(c, tuple_factory=nt)
1483 self.assertEqual(t, NT(1, 2))
1484 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1485 c.x = 42
1486 t = astuple(c, tuple_factory=nt)
1487 self.assertEqual(t, NT(42, 2))
1488 self.assertIs(type(t), NT)
1489
1490 def test_dynamic_class_creation(self):
1491 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1492 }
1493
1494 # Create the class.
1495 cls = type('C', (), cls_dict)
1496
1497 # Make it a dataclass.
1498 cls1 = dataclass(cls)
1499
1500 self.assertEqual(cls1, cls)
1501 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1502
1503 def test_dynamic_class_creation_using_field(self):
1504 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1505 'y': field(default=5),
1506 }
1507
1508 # Create the class.
1509 cls = type('C', (), cls_dict)
1510
1511 # Make it a dataclass.
1512 cls1 = dataclass(cls)
1513
1514 self.assertEqual(cls1, cls)
1515 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1516
1517 def test_init_in_order(self):
1518 @dataclass
1519 class C:
1520 a: int
1521 b: int = field()
1522 c: list = field(default_factory=list, init=False)
1523 d: list = field(default_factory=list)
1524 e: int = field(default=4, init=False)
1525 f: int = 4
1526
1527 calls = []
1528 def setattr(self, name, value):
1529 calls.append((name, value))
1530
1531 C.__setattr__ = setattr
1532 c = C(0, 1)
1533 self.assertEqual(('a', 0), calls[0])
1534 self.assertEqual(('b', 1), calls[1])
1535 self.assertEqual(('c', []), calls[2])
1536 self.assertEqual(('d', []), calls[3])
1537 self.assertNotIn(('e', 4), calls)
1538 self.assertEqual(('f', 4), calls[4])
1539
1540 def test_items_in_dicts(self):
1541 @dataclass
1542 class C:
1543 a: int
1544 b: list = field(default_factory=list, init=False)
1545 c: list = field(default_factory=list)
1546 d: int = field(default=4, init=False)
1547 e: int = 0
1548
1549 c = C(0)
1550 # Class dict
1551 self.assertNotIn('a', C.__dict__)
1552 self.assertNotIn('b', C.__dict__)
1553 self.assertNotIn('c', C.__dict__)
1554 self.assertIn('d', C.__dict__)
1555 self.assertEqual(C.d, 4)
1556 self.assertIn('e', C.__dict__)
1557 self.assertEqual(C.e, 0)
1558 # Instance dict
1559 self.assertIn('a', c.__dict__)
1560 self.assertEqual(c.a, 0)
1561 self.assertIn('b', c.__dict__)
1562 self.assertEqual(c.b, [])
1563 self.assertIn('c', c.__dict__)
1564 self.assertEqual(c.c, [])
1565 self.assertNotIn('d', c.__dict__)
1566 self.assertIn('e', c.__dict__)
1567 self.assertEqual(c.e, 0)
1568
1569 def test_alternate_classmethod_constructor(self):
1570 # Since __post_init__ can't take params, use a classmethod
1571 # alternate constructor. This is mostly an example to show how
1572 # to use this technique.
1573 @dataclass
1574 class C:
1575 x: int
1576 @classmethod
1577 def from_file(cls, filename):
1578 # In a real example, create a new instance
1579 # and populate 'x' from contents of a file.
1580 value_in_file = 20
1581 return cls(value_in_file)
1582
1583 self.assertEqual(C.from_file('filename').x, 20)
1584
1585 def test_field_metadata_default(self):
1586 # Make sure the default metadata is read-only and of
1587 # zero length.
1588 @dataclass
1589 class C:
1590 i: int
1591
1592 self.assertFalse(fields(C)[0].metadata)
1593 self.assertEqual(len(fields(C)[0].metadata), 0)
1594 with self.assertRaisesRegex(TypeError,
1595 'does not support item assignment'):
1596 fields(C)[0].metadata['test'] = 3
1597
1598 def test_field_metadata_mapping(self):
1599 # Make sure only a mapping can be passed as metadata
1600 # zero length.
1601 with self.assertRaises(TypeError):
1602 @dataclass
1603 class C:
1604 i: int = field(metadata=0)
1605
1606 # Make sure an empty dict works
1607 @dataclass
1608 class C:
1609 i: int = field(metadata={})
1610 self.assertFalse(fields(C)[0].metadata)
1611 self.assertEqual(len(fields(C)[0].metadata), 0)
1612 with self.assertRaisesRegex(TypeError,
1613 'does not support item assignment'):
1614 fields(C)[0].metadata['test'] = 3
1615
1616 # Make sure a non-empty dict works.
1617 @dataclass
1618 class C:
1619 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1620 self.assertEqual(len(fields(C)[0].metadata), 3)
1621 self.assertEqual(fields(C)[0].metadata['test'], 10)
1622 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1623 self.assertEqual(fields(C)[0].metadata[3], 'three')
1624 with self.assertRaises(KeyError):
1625 # Non-existent key.
1626 fields(C)[0].metadata['baz']
1627 with self.assertRaisesRegex(TypeError,
1628 'does not support item assignment'):
1629 fields(C)[0].metadata['test'] = 3
1630
1631 def test_field_metadata_custom_mapping(self):
1632 # Try a custom mapping.
1633 class SimpleNameSpace:
1634 def __init__(self, **kw):
1635 self.__dict__.update(kw)
1636
1637 def __getitem__(self, item):
1638 if item == 'xyzzy':
1639 return 'plugh'
1640 return getattr(self, item)
1641
1642 def __len__(self):
1643 return self.__dict__.__len__()
1644
1645 @dataclass
1646 class C:
1647 i: int = field(metadata=SimpleNameSpace(a=10))
1648
1649 self.assertEqual(len(fields(C)[0].metadata), 1)
1650 self.assertEqual(fields(C)[0].metadata['a'], 10)
1651 with self.assertRaises(AttributeError):
1652 fields(C)[0].metadata['b']
1653 # Make sure we're still talking to our custom mapping.
1654 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1655
1656 def test_generic_dataclasses(self):
1657 T = TypeVar('T')
1658
1659 @dataclass
1660 class LabeledBox(Generic[T]):
1661 content: T
1662 label: str = '<unknown>'
1663
1664 box = LabeledBox(42)
1665 self.assertEqual(box.content, 42)
1666 self.assertEqual(box.label, '<unknown>')
1667
1668 # subscripting the resulting class should work, etc.
1669 Alias = List[LabeledBox[int]]
1670
1671 def test_generic_extending(self):
1672 S = TypeVar('S')
1673 T = TypeVar('T')
1674
1675 @dataclass
1676 class Base(Generic[T, S]):
1677 x: T
1678 y: S
1679
1680 @dataclass
1681 class DataDerived(Base[int, T]):
1682 new_field: str
1683 Alias = DataDerived[str]
1684 c = Alias(0, 'test1', 'test2')
1685 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1686
1687 class NonDataDerived(Base[int, T]):
1688 def new_method(self):
1689 return self.y
1690 Alias = NonDataDerived[float]
1691 c = Alias(10, 1.0)
1692 self.assertEqual(c.new_method(), 1.0)
1693
1694 def test_helper_replace(self):
1695 @dataclass(frozen=True)
1696 class C:
1697 x: int
1698 y: int
1699
1700 c = C(1, 2)
1701 c1 = replace(c, x=3)
1702 self.assertEqual(c1.x, 3)
1703 self.assertEqual(c1.y, 2)
1704
1705 def test_helper_replace_frozen(self):
1706 @dataclass(frozen=True)
1707 class C:
1708 x: int
1709 y: int
1710 z: int = field(init=False, default=10)
1711 t: int = field(init=False, default=100)
1712
1713 c = C(1, 2)
1714 c1 = replace(c, x=3)
1715 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1716 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1717
1718
1719 with self.assertRaisesRegex(ValueError, 'init=False'):
1720 replace(c, x=3, z=20, t=50)
1721 with self.assertRaisesRegex(ValueError, 'init=False'):
1722 replace(c, z=20)
1723 replace(c, x=3, z=20, t=50)
1724
1725 # Make sure the result is still frozen.
1726 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1727 c1.x = 3
1728
1729 # Make sure we can't replace an attribute that doesn't exist,
1730 # if we're also replacing one that does exist. Test this
1731 # here, because setting attributes on frozen instances is
1732 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001733 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001734 "keyword argument 'a'"):
1735 c1 = replace(c, x=20, a=5)
1736
1737 def test_helper_replace_invalid_field_name(self):
1738 @dataclass(frozen=True)
1739 class C:
1740 x: int
1741 y: int
1742
1743 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001744 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001745 "keyword argument 'z'"):
1746 c1 = replace(c, z=3)
1747
1748 def test_helper_replace_invalid_object(self):
1749 @dataclass(frozen=True)
1750 class C:
1751 x: int
1752 y: int
1753
1754 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1755 replace(C, x=3)
1756
1757 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1758 replace(0, x=3)
1759
1760 def test_helper_replace_no_init(self):
1761 @dataclass
1762 class C:
1763 x: int
1764 y: int = field(init=False, default=10)
1765
1766 c = C(1)
1767 c.y = 20
1768
1769 # Make sure y gets the default value.
1770 c1 = replace(c, x=5)
1771 self.assertEqual((c1.x, c1.y), (5, 10))
1772
1773 # Trying to replace y is an error.
1774 with self.assertRaisesRegex(ValueError, 'init=False'):
1775 replace(c, x=2, y=30)
1776 with self.assertRaisesRegex(ValueError, 'init=False'):
1777 replace(c, y=30)
1778
1779 def test_dataclassses_pickleable(self):
1780 global P, Q, R
1781 @dataclass
1782 class P:
1783 x: int
1784 y: int = 0
1785 @dataclass
1786 class Q:
1787 x: int
1788 y: int = field(default=0, init=False)
1789 @dataclass
1790 class R:
1791 x: int
1792 y: List[int] = field(default_factory=list)
1793 q = Q(1)
1794 q.y = 2
1795 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1796 for sample in samples:
1797 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1798 with self.subTest(sample=sample, proto=proto):
1799 new_sample = pickle.loads(pickle.dumps(sample, proto))
1800 self.assertEqual(sample.x, new_sample.x)
1801 self.assertEqual(sample.y, new_sample.y)
1802 self.assertIsNot(sample, new_sample)
1803 new_sample.x = 42
1804 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1805 self.assertEqual(new_sample.x, another_new_sample.x)
1806 self.assertEqual(sample.y, another_new_sample.y)
1807
1808 def test_helper_make_dataclass(self):
1809 C = make_dataclass('C',
1810 [('x', int),
1811 ('y', int, field(default=5))],
1812 namespace={'add_one': lambda self: self.x + 1})
1813 c = C(10)
1814 self.assertEqual((c.x, c.y), (10, 5))
1815 self.assertEqual(c.add_one(), 11)
1816
1817
1818 def test_helper_make_dataclass_no_mutate_namespace(self):
1819 # Make sure a provided namespace isn't mutated.
1820 ns = {}
1821 C = make_dataclass('C',
1822 [('x', int),
1823 ('y', int, field(default=5))],
1824 namespace=ns)
1825 self.assertEqual(ns, {})
1826
1827 def test_helper_make_dataclass_base(self):
1828 class Base1:
1829 pass
1830 class Base2:
1831 pass
1832 C = make_dataclass('C',
1833 [('x', int)],
1834 bases=(Base1, Base2))
1835 c = C(2)
1836 self.assertIsInstance(c, C)
1837 self.assertIsInstance(c, Base1)
1838 self.assertIsInstance(c, Base2)
1839
1840 def test_helper_make_dataclass_base_dataclass(self):
1841 @dataclass
1842 class Base1:
1843 x: int
1844 class Base2:
1845 pass
1846 C = make_dataclass('C',
1847 [('y', int)],
1848 bases=(Base1, Base2))
1849 with self.assertRaisesRegex(TypeError, 'required positional'):
1850 c = C(2)
1851 c = C(1, 2)
1852 self.assertIsInstance(c, C)
1853 self.assertIsInstance(c, Base1)
1854 self.assertIsInstance(c, Base2)
1855
1856 self.assertEqual((c.x, c.y), (1, 2))
1857
1858 def test_helper_make_dataclass_init_var(self):
1859 def post_init(self, y):
1860 self.x *= y
1861
1862 C = make_dataclass('C',
1863 [('x', int),
1864 ('y', InitVar[int]),
1865 ],
1866 namespace={'__post_init__': post_init},
1867 )
1868 c = C(2, 3)
1869 self.assertEqual(vars(c), {'x': 6})
1870 self.assertEqual(len(fields(c)), 1)
1871
1872 def test_helper_make_dataclass_class_var(self):
1873 C = make_dataclass('C',
1874 [('x', int),
1875 ('y', ClassVar[int], 10),
1876 ('z', ClassVar[int], field(default=20)),
1877 ])
1878 c = C(1)
1879 self.assertEqual(vars(c), {'x': 1})
1880 self.assertEqual(len(fields(c)), 1)
1881 self.assertEqual(C.y, 10)
1882 self.assertEqual(C.z, 20)
1883
Eric V. Smithd80b4432018-01-06 17:09:58 -05001884 def test_helper_make_dataclass_other_params(self):
1885 C = make_dataclass('C',
1886 [('x', int),
1887 ('y', ClassVar[int], 10),
1888 ('z', ClassVar[int], field(default=20)),
1889 ],
1890 init=False)
1891 # Make sure we have a repr, but no init.
1892 self.assertNotIn('__init__', vars(C))
1893 self.assertIn('__repr__', vars(C))
1894
1895 # Make sure random other params don't work.
1896 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1897 C = make_dataclass('C',
1898 [],
1899 xxinit=False)
1900
Eric V. Smithed7d4292018-01-06 16:14:03 -05001901 def test_helper_make_dataclass_no_types(self):
1902 C = make_dataclass('Point', ['x', 'y', 'z'])
1903 c = C(1, 2, 3)
1904 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1905 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1906 'y': 'typing.Any',
1907 'z': 'typing.Any'})
1908
1909 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1910 c = C(1, 2, 3)
1911 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1912 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1913 'y': int,
1914 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001915
Eric V. Smithea8fc522018-01-27 19:07:40 -05001916
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001917class TestFieldNoAnnotation(unittest.TestCase):
1918 def test_field_without_annotation(self):
1919 with self.assertRaisesRegex(TypeError,
1920 "'f' is a field but has no type annotation"):
1921 @dataclass
1922 class C:
1923 f = field()
1924
1925 def test_field_without_annotation_but_annotation_in_base(self):
1926 @dataclass
1927 class B:
1928 f: int
1929
1930 with self.assertRaisesRegex(TypeError,
1931 "'f' is a field but has no type annotation"):
1932 # This is still an error: make sure we don't pick up the
1933 # type annotation in the base class.
1934 @dataclass
1935 class C(B):
1936 f = field()
1937
1938 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1939 # Same test, but with the base class not a dataclass.
1940 class B:
1941 f: int
1942
1943 with self.assertRaisesRegex(TypeError,
1944 "'f' is a field but has no type annotation"):
1945 # This is still an error: make sure we don't pick up the
1946 # type annotation in the base class.
1947 @dataclass
1948 class C(B):
1949 f = field()
1950
1951
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001952class TestDocString(unittest.TestCase):
1953 def assertDocStrEqual(self, a, b):
1954 # Because 3.6 and 3.7 differ in how inspect.signature work
1955 # (see bpo #32108), for the time being just compare them with
1956 # whitespace stripped.
1957 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1958
1959 def test_existing_docstring_not_overridden(self):
1960 @dataclass
1961 class C:
1962 """Lorem ipsum"""
1963 x: int
1964
1965 self.assertEqual(C.__doc__, "Lorem ipsum")
1966
1967 def test_docstring_no_fields(self):
1968 @dataclass
1969 class C:
1970 pass
1971
1972 self.assertDocStrEqual(C.__doc__, "C()")
1973
1974 def test_docstring_one_field(self):
1975 @dataclass
1976 class C:
1977 x: int
1978
1979 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1980
1981 def test_docstring_two_fields(self):
1982 @dataclass
1983 class C:
1984 x: int
1985 y: int
1986
1987 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1988
1989 def test_docstring_three_fields(self):
1990 @dataclass
1991 class C:
1992 x: int
1993 y: int
1994 z: str
1995
1996 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1997
1998 def test_docstring_one_field_with_default(self):
1999 @dataclass
2000 class C:
2001 x: int = 3
2002
2003 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2004
2005 def test_docstring_one_field_with_default_none(self):
2006 @dataclass
2007 class C:
2008 x: Union[int, type(None)] = None
2009
2010 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2011
2012 def test_docstring_list_field(self):
2013 @dataclass
2014 class C:
2015 x: List[int]
2016
2017 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2018
2019 def test_docstring_list_field_with_default_factory(self):
2020 @dataclass
2021 class C:
2022 x: List[int] = field(default_factory=list)
2023
2024 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2025
2026 def test_docstring_deque_field(self):
2027 @dataclass
2028 class C:
2029 x: deque
2030
2031 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2032
2033 def test_docstring_deque_field_with_default_factory(self):
2034 @dataclass
2035 class C:
2036 x: deque = field(default_factory=deque)
2037
2038 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2039
2040
Eric V. Smithea8fc522018-01-27 19:07:40 -05002041class TestInit(unittest.TestCase):
2042 def test_base_has_init(self):
2043 class B:
2044 def __init__(self):
2045 self.z = 100
2046 pass
2047
2048 # Make sure that declaring this class doesn't raise an error.
2049 # The issue is that we can't override __init__ in our class,
2050 # but it should be okay to add __init__ to us if our base has
2051 # an __init__.
2052 @dataclass
2053 class C(B):
2054 x: int = 0
2055 c = C(10)
2056 self.assertEqual(c.x, 10)
2057 self.assertNotIn('z', vars(c))
2058
2059 # Make sure that if we don't add an init, the base __init__
2060 # gets called.
2061 @dataclass(init=False)
2062 class C(B):
2063 x: int = 10
2064 c = C()
2065 self.assertEqual(c.x, 10)
2066 self.assertEqual(c.z, 100)
2067
2068 def test_no_init(self):
2069 dataclass(init=False)
2070 class C:
2071 i: int = 0
2072 self.assertEqual(C().i, 0)
2073
2074 dataclass(init=False)
2075 class C:
2076 i: int = 2
2077 def __init__(self):
2078 self.i = 3
2079 self.assertEqual(C().i, 3)
2080
2081 def test_overwriting_init(self):
2082 # If the class has __init__, use it no matter the value of
2083 # init=.
2084
2085 @dataclass
2086 class C:
2087 x: int
2088 def __init__(self, x):
2089 self.x = 2 * x
2090 self.assertEqual(C(3).x, 6)
2091
2092 @dataclass(init=True)
2093 class C:
2094 x: int
2095 def __init__(self, x):
2096 self.x = 2 * x
2097 self.assertEqual(C(4).x, 8)
2098
2099 @dataclass(init=False)
2100 class C:
2101 x: int
2102 def __init__(self, x):
2103 self.x = 2 * x
2104 self.assertEqual(C(5).x, 10)
2105
2106
2107class TestRepr(unittest.TestCase):
2108 def test_repr(self):
2109 @dataclass
2110 class B:
2111 x: int
2112
2113 @dataclass
2114 class C(B):
2115 y: int = 10
2116
2117 o = C(4)
2118 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2119
2120 @dataclass
2121 class D(C):
2122 x: int = 20
2123 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2124
2125 @dataclass
2126 class C:
2127 @dataclass
2128 class D:
2129 i: int
2130 @dataclass
2131 class E:
2132 pass
2133 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2134 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2135
2136 def test_no_repr(self):
2137 # Test a class with no __repr__ and repr=False.
2138 @dataclass(repr=False)
2139 class C:
2140 x: int
2141 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2142 repr(C(3)))
2143
2144 # Test a class with a __repr__ and repr=False.
2145 @dataclass(repr=False)
2146 class C:
2147 x: int
2148 def __repr__(self):
2149 return 'C-class'
2150 self.assertEqual(repr(C(3)), 'C-class')
2151
2152 def test_overwriting_repr(self):
2153 # If the class has __repr__, use it no matter the value of
2154 # repr=.
2155
2156 @dataclass
2157 class C:
2158 x: int
2159 def __repr__(self):
2160 return 'x'
2161 self.assertEqual(repr(C(0)), 'x')
2162
2163 @dataclass(repr=True)
2164 class C:
2165 x: int
2166 def __repr__(self):
2167 return 'x'
2168 self.assertEqual(repr(C(0)), 'x')
2169
2170 @dataclass(repr=False)
2171 class C:
2172 x: int
2173 def __repr__(self):
2174 return 'x'
2175 self.assertEqual(repr(C(0)), 'x')
2176
2177
2178class TestFrozen(unittest.TestCase):
2179 def test_overwriting_frozen(self):
2180 # frozen uses __setattr__ and __delattr__
2181 with self.assertRaisesRegex(TypeError,
2182 'Cannot overwrite attribute __setattr__'):
2183 @dataclass(frozen=True)
2184 class C:
2185 x: int
2186 def __setattr__(self):
2187 pass
2188
2189 with self.assertRaisesRegex(TypeError,
2190 'Cannot overwrite attribute __delattr__'):
2191 @dataclass(frozen=True)
2192 class C:
2193 x: int
2194 def __delattr__(self):
2195 pass
2196
2197 @dataclass(frozen=False)
2198 class C:
2199 x: int
2200 def __setattr__(self, name, value):
2201 self.__dict__['x'] = value * 2
2202 self.assertEqual(C(10).x, 20)
2203
2204
2205class TestEq(unittest.TestCase):
2206 def test_no_eq(self):
2207 # Test a class with no __eq__ and eq=False.
2208 @dataclass(eq=False)
2209 class C:
2210 x: int
2211 self.assertNotEqual(C(0), C(0))
2212 c = C(3)
2213 self.assertEqual(c, c)
2214
2215 # Test a class with an __eq__ and eq=False.
2216 @dataclass(eq=False)
2217 class C:
2218 x: int
2219 def __eq__(self, other):
2220 return other == 10
2221 self.assertEqual(C(3), 10)
2222
2223 def test_overwriting_eq(self):
2224 # If the class has __eq__, use it no matter the value of
2225 # eq=.
2226
2227 @dataclass
2228 class C:
2229 x: int
2230 def __eq__(self, other):
2231 return other == 3
2232 self.assertEqual(C(1), 3)
2233 self.assertNotEqual(C(1), 1)
2234
2235 @dataclass(eq=True)
2236 class C:
2237 x: int
2238 def __eq__(self, other):
2239 return other == 4
2240 self.assertEqual(C(1), 4)
2241 self.assertNotEqual(C(1), 1)
2242
2243 @dataclass(eq=False)
2244 class C:
2245 x: int
2246 def __eq__(self, other):
2247 return other == 5
2248 self.assertEqual(C(1), 5)
2249 self.assertNotEqual(C(1), 1)
2250
2251
2252class TestOrdering(unittest.TestCase):
2253 def test_functools_total_ordering(self):
2254 # Test that functools.total_ordering works with this class.
2255 @total_ordering
2256 @dataclass
2257 class C:
2258 x: int
2259 def __lt__(self, other):
2260 # Perform the test "backward", just to make
2261 # sure this is being called.
2262 return self.x >= other
2263
2264 self.assertLess(C(0), -1)
2265 self.assertLessEqual(C(0), -1)
2266 self.assertGreater(C(0), 1)
2267 self.assertGreaterEqual(C(0), 1)
2268
2269 def test_no_order(self):
2270 # Test that no ordering functions are added by default.
2271 @dataclass(order=False)
2272 class C:
2273 x: int
2274 # Make sure no order methods are added.
2275 self.assertNotIn('__le__', C.__dict__)
2276 self.assertNotIn('__lt__', C.__dict__)
2277 self.assertNotIn('__ge__', C.__dict__)
2278 self.assertNotIn('__gt__', C.__dict__)
2279
2280 # Test that __lt__ is still called
2281 @dataclass(order=False)
2282 class C:
2283 x: int
2284 def __lt__(self, other):
2285 return False
2286 # Make sure other methods aren't added.
2287 self.assertNotIn('__le__', C.__dict__)
2288 self.assertNotIn('__ge__', C.__dict__)
2289 self.assertNotIn('__gt__', C.__dict__)
2290
2291 def test_overwriting_order(self):
2292 with self.assertRaisesRegex(TypeError,
2293 'Cannot overwrite attribute __lt__'
2294 '.*using functools.total_ordering'):
2295 @dataclass(order=True)
2296 class C:
2297 x: int
2298 def __lt__(self):
2299 pass
2300
2301 with self.assertRaisesRegex(TypeError,
2302 'Cannot overwrite attribute __le__'
2303 '.*using functools.total_ordering'):
2304 @dataclass(order=True)
2305 class C:
2306 x: int
2307 def __le__(self):
2308 pass
2309
2310 with self.assertRaisesRegex(TypeError,
2311 'Cannot overwrite attribute __gt__'
2312 '.*using functools.total_ordering'):
2313 @dataclass(order=True)
2314 class C:
2315 x: int
2316 def __gt__(self):
2317 pass
2318
2319 with self.assertRaisesRegex(TypeError,
2320 'Cannot overwrite attribute __ge__'
2321 '.*using functools.total_ordering'):
2322 @dataclass(order=True)
2323 class C:
2324 x: int
2325 def __ge__(self):
2326 pass
2327
2328class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002329 def test_unsafe_hash(self):
2330 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002331 class C:
2332 x: int
2333 y: str
2334 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2335
Eric V. Smithea8fc522018-01-27 19:07:40 -05002336 def test_hash_rules(self):
2337 def non_bool(value):
2338 # Map to something else that's True, but not a bool.
2339 if value is None:
2340 return None
2341 if value:
2342 return (3,)
2343 return 0
2344
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002345 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2346 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2347 frozen=frozen):
2348 if result != 'exception':
2349 if with_hash:
2350 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2351 class C:
2352 def __hash__(self):
2353 return 0
2354 else:
2355 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2356 class C:
2357 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002358
2359 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002360 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002361 # __hash__ contains the function we generated.
2362 self.assertIn('__hash__', C.__dict__)
2363 self.assertIsNotNone(C.__dict__['__hash__'])
2364
Eric V. Smithea8fc522018-01-27 19:07:40 -05002365 elif result == '':
2366 # __hash__ is not present in our class.
2367 if not with_hash:
2368 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002369
Eric V. Smithea8fc522018-01-27 19:07:40 -05002370 elif result == 'none':
2371 # __hash__ is set to None.
2372 self.assertIn('__hash__', C.__dict__)
2373 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002374
2375 elif result == 'exception':
2376 # Creating the class should cause an exception.
2377 # This only happens with with_hash==True.
2378 assert(with_hash)
2379 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2380 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2381 class C:
2382 def __hash__(self):
2383 return 0
2384
Eric V. Smithea8fc522018-01-27 19:07:40 -05002385 else:
2386 assert False, f'unknown result {result!r}'
2387
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002388 # There are 8 cases of:
2389 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002390 # eq=True/False
2391 # frozen=True/False
2392 # And for each of these, a different result if
2393 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002394 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2395 (False, False, False, '', ''),
2396 (False, False, True, '', ''),
2397 (False, True, False, 'none', ''),
2398 (False, True, True, 'fn', ''),
2399 (True, False, False, 'fn', 'exception'),
2400 (True, False, True, 'fn', 'exception'),
2401 (True, True, False, 'fn', 'exception'),
2402 (True, True, True, 'fn', 'exception'),
2403 ], 1):
2404 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2405 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002406
2407 # Test non-bool truth values, too. This is just to
2408 # make sure the data-driven table in the decorator
2409 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002410 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2411 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002412
2413
2414 def test_eq_only(self):
2415 # If a class defines __eq__, __hash__ is automatically added
2416 # and set to None. This is normal Python behavior, not
2417 # related to dataclasses. Make sure we don't interfere with
2418 # that (see bpo=32546).
2419
2420 @dataclass
2421 class C:
2422 i: int
2423 def __eq__(self, other):
2424 return self.i == other.i
2425 self.assertEqual(C(1), C(1))
2426 self.assertNotEqual(C(1), C(4))
2427
2428 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002429 # unsafe_hash=True.
2430 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002431 class C:
2432 i: int
2433 def __eq__(self, other):
2434 return self.i == other.i
2435 self.assertEqual(C(1), C(1.0))
2436 self.assertEqual(hash(C(1)), hash(C(1.0)))
2437
2438 # And check that the classes __eq__ is being used, despite
2439 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002440 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002441 class C:
2442 i: int
2443 def __eq__(self, other):
2444 return self.i == 3 and self.i == other.i
2445 self.assertEqual(C(3), C(3))
2446 self.assertNotEqual(C(1), C(1))
2447 self.assertEqual(hash(C(1)), hash(C(1.0)))
2448
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002449 def test_0_field_hash(self):
2450 @dataclass(frozen=True)
2451 class C:
2452 pass
2453 self.assertEqual(hash(C()), hash(()))
2454
2455 @dataclass(unsafe_hash=True)
2456 class C:
2457 pass
2458 self.assertEqual(hash(C()), hash(()))
2459
2460 def test_1_field_hash(self):
2461 @dataclass(frozen=True)
2462 class C:
2463 x: int
2464 self.assertEqual(hash(C(4)), hash((4,)))
2465 self.assertEqual(hash(C(42)), hash((42,)))
2466
2467 @dataclass(unsafe_hash=True)
2468 class C:
2469 x: int
2470 self.assertEqual(hash(C(4)), hash((4,)))
2471 self.assertEqual(hash(C(42)), hash((42,)))
2472
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002473 def test_hash_no_args(self):
2474 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002475 # make sure that if the @dataclass parameter name is changed
2476 # or the non-default hashing behavior changes, the default
2477 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002478
2479 class Base:
2480 def __hash__(self):
2481 return 301
2482
2483 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)1a579062018-02-25 19:09:05 -08002484 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002485 for frozen, eq, base, expected in [
2486 (None, None, object, 'unhashable'),
2487 (None, None, Base, 'unhashable'),
2488 (None, False, object, 'object'),
2489 (None, False, Base, 'base'),
2490 (None, True, object, 'unhashable'),
2491 (None, True, Base, 'unhashable'),
2492 (False, None, object, 'unhashable'),
2493 (False, None, Base, 'unhashable'),
2494 (False, False, object, 'object'),
2495 (False, False, Base, 'base'),
2496 (False, True, object, 'unhashable'),
2497 (False, True, Base, 'unhashable'),
2498 (True, None, object, 'tuple'),
2499 (True, None, Base, 'tuple'),
2500 (True, False, object, 'object'),
2501 (True, False, Base, 'base'),
2502 (True, True, object, 'tuple'),
2503 (True, True, Base, 'tuple'),
2504 ]:
2505
2506 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2507 # First, create the class.
2508 if frozen is None and eq is None:
2509 @dataclass
2510 class C(base):
2511 i: int
2512 elif frozen is None:
2513 @dataclass(eq=eq)
2514 class C(base):
2515 i: int
2516 elif eq is None:
2517 @dataclass(frozen=frozen)
2518 class C(base):
2519 i: int
2520 else:
2521 @dataclass(frozen=frozen, eq=eq)
2522 class C(base):
2523 i: int
2524
2525 # Now, make sure it hashes as expected.
2526 if expected == 'unhashable':
2527 c = C(10)
2528 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2529 hash(c)
2530
2531 elif expected == 'base':
2532 self.assertEqual(hash(C(10)), 301)
2533
2534 elif expected == 'object':
2535 # I'm not sure what test to use here. object's
2536 # hash isn't based on id(), so calling hash()
2537 # won't tell us much. So, just check the function
2538 # used is object's.
2539 self.assertIs(C.__hash__, object.__hash__)
2540
2541 elif expected == 'tuple':
2542 self.assertEqual(hash(C(42)), hash((42,)))
2543
2544 else:
2545 assert False, f'unknown value for expected={expected!r}'
2546
Eric V. Smithea8fc522018-01-27 19:07:40 -05002547
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002548class TestFrozen(unittest.TestCase):
2549 def test_frozen(self):
2550 @dataclass(frozen=True)
2551 class C:
2552 i: int
2553
2554 c = C(10)
2555 self.assertEqual(c.i, 10)
2556 with self.assertRaises(FrozenInstanceError):
2557 c.i = 5
2558 self.assertEqual(c.i, 10)
2559
2560 def test_inherit(self):
2561 @dataclass(frozen=True)
2562 class C:
2563 i: int
2564
2565 @dataclass(frozen=True)
2566 class D(C):
2567 j: int
2568
2569 d = D(0, 10)
2570 with self.assertRaises(FrozenInstanceError):
2571 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002572 with self.assertRaises(FrozenInstanceError):
2573 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002574 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002575 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002576
Miss Islington (bot)45648312018-03-18 18:03:36 -07002577 # Test both ways: with an intermediate normal (non-dataclass)
2578 # class and without an intermediate class.
2579 def test_inherit_nonfrozen_from_frozen(self):
2580 for intermediate_class in [True, False]:
2581 with self.subTest(intermediate_class=intermediate_class):
2582 @dataclass(frozen=True)
2583 class C:
2584 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002585
Miss Islington (bot)45648312018-03-18 18:03:36 -07002586 if intermediate_class:
2587 class I(C): pass
2588 else:
2589 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002590
Miss Islington (bot)45648312018-03-18 18:03:36 -07002591 with self.assertRaisesRegex(TypeError,
2592 'cannot inherit non-frozen dataclass from a frozen one'):
2593 @dataclass
2594 class D(I):
2595 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002596
Miss Islington (bot)45648312018-03-18 18:03:36 -07002597 def test_inherit_frozen_from_nonfrozen(self):
2598 for intermediate_class in [True, False]:
2599 with self.subTest(intermediate_class=intermediate_class):
2600 @dataclass
2601 class C:
2602 i: int
2603
2604 if intermediate_class:
2605 class I(C): pass
2606 else:
2607 I = C
2608
2609 with self.assertRaisesRegex(TypeError,
2610 'cannot inherit frozen dataclass from a non-frozen one'):
2611 @dataclass(frozen=True)
2612 class D(I):
2613 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002614
2615 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002616 for intermediate_class in [True, False]:
2617 with self.subTest(intermediate_class=intermediate_class):
2618 class C:
2619 pass
2620
2621 if intermediate_class:
2622 class I(C): pass
2623 else:
2624 I = C
2625
2626 @dataclass(frozen=True)
2627 class D(I):
2628 i: int
2629
2630 d = D(10)
2631 with self.assertRaises(FrozenInstanceError):
2632 d.i = 5
2633
2634 def test_non_frozen_normal_derived(self):
2635 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002636
2637 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002638 class D:
2639 x: int
2640 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002641
Miss Islington (bot)45648312018-03-18 18:03:36 -07002642 class S(D):
2643 pass
2644
2645 s = S(3)
2646 self.assertEqual(s.x, 3)
2647 self.assertEqual(s.y, 10)
2648 s.cached = True
2649
2650 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002651 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002652 s.x = 5
2653 with self.assertRaises(FrozenInstanceError):
2654 s.y = 5
2655 self.assertEqual(s.x, 3)
2656 self.assertEqual(s.y, 10)
2657 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002658
2659
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002660class TestSlots(unittest.TestCase):
2661 def test_simple(self):
2662 @dataclass
2663 class C:
2664 __slots__ = ('x',)
2665 x: Any
2666
2667 # There was a bug where a variable in a slot was assumed
2668 # to also have a default value (of type types.MemberDescriptorType).
2669 with self.assertRaisesRegex(TypeError,
2670 "__init__\(\) missing 1 required positional argument: 'x'"):
2671 C()
2672
2673 # We can create an instance, and assign to x.
2674 c = C(10)
2675 self.assertEqual(c.x, 10)
2676 c.x = 5
2677 self.assertEqual(c.x, 5)
2678
2679 # We can't assign to anything else.
2680 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2681 c.y = 5
2682
2683 def test_derived_added_field(self):
2684 # See bpo-33100.
2685 @dataclass
2686 class Base:
2687 __slots__ = ('x',)
2688 x: Any
2689
2690 @dataclass
2691 class Derived(Base):
2692 x: int
2693 y: int
2694
2695 d = Derived(1, 2)
2696 self.assertEqual((d.x, d.y), (1, 2))
2697
2698 # We can add a new field to the derived instance.
2699 d.z = 10
2700
2701
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002702if __name__ == '__main__':
2703 unittest.main()