blob: db03ec1925f6742a3e28dfab601b61dc31a0080d [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
Eric V. Smithe7ba0132018-01-06 12:41:53 -05003 make_dataclass, replace, InitVar, Field, MISSING, is_dataclass,
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050012from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013
14# Just any custom exception we can catch.
15class CustomError(Exception): pass
16
17class TestCase(unittest.TestCase):
18 def test_no_fields(self):
19 @dataclass
20 class C:
21 pass
22
23 o = C()
24 self.assertEqual(len(fields(C)), 0)
25
26 def test_one_field_no_default(self):
27 @dataclass
28 class C:
29 x: int
30
31 o = C(42)
32 self.assertEqual(o.x, 42)
33
34 def test_named_init_params(self):
35 @dataclass
36 class C:
37 x: int
38
39 o = C(x=32)
40 self.assertEqual(o.x, 32)
41
42 def test_two_fields_one_default(self):
43 @dataclass
44 class C:
45 x: int
46 y: int = 0
47
48 o = C(3)
49 self.assertEqual((o.x, o.y), (3, 0))
50
51 # Non-defaults following defaults.
52 with self.assertRaisesRegex(TypeError,
53 "non-default argument 'y' follows "
54 "default argument"):
55 @dataclass
56 class C:
57 x: int = 0
58 y: int
59
60 # A derived class adds a non-default field after a default one.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class B:
66 x: int = 0
67
68 @dataclass
69 class C(B):
70 y: int
71
72 # Override a base class field and add a default to
73 # a field which didn't use to have a default.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int
80 y: int
81
82 @dataclass
83 class C(B):
84 x: int = 0
85
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080086 def test_overwrite_hash(self):
87 # Test that declaring this class isn't an error. It should
88 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050089 @dataclass(frozen=True)
90 class C:
91 x: int
92 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080093 return 301
94 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -050095
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080096 # Test that declaring this class isn't an error. It should
97 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800101 def __eq__(self, other):
102 return False
103 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # But this one should generate an exception, because with
106 # unsafe_hash=True, it's an error to have a __hash__ defined.
107 with self.assertRaisesRegex(TypeError,
108 'Cannot overwrite attribute __hash__'):
109 @dataclass(unsafe_hash=True)
110 class C:
111 def __hash__(self):
112 pass
113
114 # Creating this class should not generate an exception,
115 # because even though __hash__ exists before @dataclass is
116 # called, (due to __eq__ being defined), since it's None
117 # that's okay.
118 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500119 class C:
120 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800121 def __eq__(self):
122 pass
123 # The generated hash function works as we'd expect.
124 self.assertEqual(hash(C(10)), hash((10,)))
125
126 # Creating this class should generate an exception, because
127 # __hash__ exists and is not None, which it would be if it had
128 # been auto-generated do due __eq__ being defined.
129 with self.assertRaisesRegex(TypeError,
130 'Cannot overwrite attribute __hash__'):
131 @dataclass(unsafe_hash=True)
132 class C:
133 x: int
134 def __eq__(self):
135 pass
136 def __hash__(self):
137 pass
138
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500140 def test_overwrite_fields_in_derived_class(self):
141 # Note that x from C1 replaces x in Base, but the order remains
142 # the same as defined in Base.
143 @dataclass
144 class Base:
145 x: Any = 15.0
146 y: int = 0
147
148 @dataclass
149 class C1(Base):
150 z: int = 10
151 x: int = 15
152
153 o = Base()
154 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
155
156 o = C1()
157 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
158
159 o = C1(x=5)
160 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
161
162 def test_field_named_self(self):
163 @dataclass
164 class C:
165 self: str
166 c=C('foo')
167 self.assertEqual(c.self, 'foo')
168
169 # Make sure the first parameter is not named 'self'.
170 sig = inspect.signature(C.__init__)
171 first = next(iter(sig.parameters))
172 self.assertNotEqual('self', first)
173
174 # But we do use 'self' if no field named self.
175 @dataclass
176 class C:
177 selfx: str
178
179 # Make sure the first parameter is named 'self'.
180 sig = inspect.signature(C.__init__)
181 first = next(iter(sig.parameters))
182 self.assertEqual('self', first)
183
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500184 def test_0_field_compare(self):
185 # Ensure that order=False is the default.
186 @dataclass
187 class C0:
188 pass
189
190 @dataclass(order=False)
191 class C1:
192 pass
193
194 for cls in [C0, C1]:
195 with self.subTest(cls=cls):
196 self.assertEqual(cls(), cls())
197 for idx, fn in enumerate([lambda a, b: a < b,
198 lambda a, b: a <= b,
199 lambda a, b: a > b,
200 lambda a, b: a >= b]):
201 with self.subTest(idx=idx):
202 with self.assertRaisesRegex(TypeError,
203 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
204 fn(cls(), cls())
205
206 @dataclass(order=True)
207 class C:
208 pass
209 self.assertLessEqual(C(), C())
210 self.assertGreaterEqual(C(), C())
211
212 def test_1_field_compare(self):
213 # Ensure that order=False is the default.
214 @dataclass
215 class C0:
216 x: int
217
218 @dataclass(order=False)
219 class C1:
220 x: int
221
222 for cls in [C0, C1]:
223 with self.subTest(cls=cls):
224 self.assertEqual(cls(1), cls(1))
225 self.assertNotEqual(cls(0), cls(1))
226 for idx, fn in enumerate([lambda a, b: a < b,
227 lambda a, b: a <= b,
228 lambda a, b: a > b,
229 lambda a, b: a >= b]):
230 with self.subTest(idx=idx):
231 with self.assertRaisesRegex(TypeError,
232 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
233 fn(cls(0), cls(0))
234
235 @dataclass(order=True)
236 class C:
237 x: int
238 self.assertLess(C(0), C(1))
239 self.assertLessEqual(C(0), C(1))
240 self.assertLessEqual(C(1), C(1))
241 self.assertGreater(C(1), C(0))
242 self.assertGreaterEqual(C(1), C(0))
243 self.assertGreaterEqual(C(1), C(1))
244
245 def test_simple_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 x: int
250 y: int
251
252 @dataclass(order=False)
253 class C1:
254 x: int
255 y: int
256
257 for cls in [C0, C1]:
258 with self.subTest(cls=cls):
259 self.assertEqual(cls(0, 0), cls(0, 0))
260 self.assertEqual(cls(1, 2), cls(1, 2))
261 self.assertNotEqual(cls(1, 0), cls(0, 0))
262 self.assertNotEqual(cls(1, 0), cls(1, 1))
263 for idx, fn in enumerate([lambda a, b: a < b,
264 lambda a, b: a <= b,
265 lambda a, b: a > b,
266 lambda a, b: a >= b]):
267 with self.subTest(idx=idx):
268 with self.assertRaisesRegex(TypeError,
269 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
270 fn(cls(0, 0), cls(0, 0))
271
272 @dataclass(order=True)
273 class C:
274 x: int
275 y: int
276
277 for idx, fn in enumerate([lambda a, b: a == b,
278 lambda a, b: a <= b,
279 lambda a, b: a >= b]):
280 with self.subTest(idx=idx):
281 self.assertTrue(fn(C(0, 0), C(0, 0)))
282
283 for idx, fn in enumerate([lambda a, b: a < b,
284 lambda a, b: a <= b,
285 lambda a, b: a != b]):
286 with self.subTest(idx=idx):
287 self.assertTrue(fn(C(0, 0), C(0, 1)))
288 self.assertTrue(fn(C(0, 1), C(1, 0)))
289 self.assertTrue(fn(C(1, 0), C(1, 1)))
290
291 for idx, fn in enumerate([lambda a, b: a > b,
292 lambda a, b: a >= b,
293 lambda a, b: a != b]):
294 with self.subTest(idx=idx):
295 self.assertTrue(fn(C(0, 1), C(0, 0)))
296 self.assertTrue(fn(C(1, 0), C(0, 1)))
297 self.assertTrue(fn(C(1, 1), C(1, 0)))
298
299 def test_compare_subclasses(self):
300 # Comparisons fail for subclasses, even if no fields
301 # are added.
302 @dataclass
303 class B:
304 i: int
305
306 @dataclass
307 class C(B):
308 pass
309
310 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
311 (lambda a, b: a != b, True)]):
312 with self.subTest(idx=idx):
313 self.assertEqual(fn(B(0), C(0)), expected)
314
315 for idx, fn in enumerate([lambda a, b: a < b,
316 lambda a, b: a <= b,
317 lambda a, b: a > b,
318 lambda a, b: a >= b]):
319 with self.subTest(idx=idx):
320 with self.assertRaisesRegex(TypeError,
321 "not supported between instances of 'B' and 'C'"):
322 fn(B(0), C(0))
323
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500324 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500325 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500326 for (eq, order, result ) in [
327 (False, False, 'neither'),
328 (False, True, 'exception'),
329 (True, False, 'eq_only'),
330 (True, True, 'both'),
331 ]:
332 with self.subTest(eq=eq, order=order):
333 if result == 'exception':
334 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
335 @dataclass(eq=eq, order=order)
336 class C:
337 pass
338 else:
339 @dataclass(eq=eq, order=order)
340 class C:
341 pass
342
343 if result == 'neither':
344 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500345 self.assertNotIn('__lt__', C.__dict__)
346 self.assertNotIn('__le__', C.__dict__)
347 self.assertNotIn('__gt__', C.__dict__)
348 self.assertNotIn('__ge__', C.__dict__)
349 elif result == 'both':
350 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500351 self.assertIn('__lt__', C.__dict__)
352 self.assertIn('__le__', C.__dict__)
353 self.assertIn('__gt__', C.__dict__)
354 self.assertIn('__ge__', C.__dict__)
355 elif result == 'eq_only':
356 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500357 self.assertNotIn('__lt__', C.__dict__)
358 self.assertNotIn('__le__', C.__dict__)
359 self.assertNotIn('__gt__', C.__dict__)
360 self.assertNotIn('__ge__', C.__dict__)
361 else:
362 assert False, f'unknown result {result!r}'
363
364 def test_field_no_default(self):
365 @dataclass
366 class C:
367 x: int = field()
368
369 self.assertEqual(C(5).x, 5)
370
371 with self.assertRaisesRegex(TypeError,
372 r"__init__\(\) missing 1 required "
373 "positional argument: 'x'"):
374 C()
375
376 def test_field_default(self):
377 default = object()
378 @dataclass
379 class C:
380 x: object = field(default=default)
381
382 self.assertIs(C.x, default)
383 c = C(10)
384 self.assertEqual(c.x, 10)
385
386 # If we delete the instance attribute, we should then see the
387 # class attribute.
388 del c.x
389 self.assertIs(c.x, default)
390
391 self.assertIs(C().x, default)
392
393 def test_not_in_repr(self):
394 @dataclass
395 class C:
396 x: int = field(repr=False)
397 with self.assertRaises(TypeError):
398 C()
399 c = C(10)
400 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
401
402 @dataclass
403 class C:
404 x: int = field(repr=False)
405 y: int
406 c = C(10, 20)
407 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
408
409 def test_not_in_compare(self):
410 @dataclass
411 class C:
412 x: int = 0
413 y: int = field(compare=False, default=4)
414
415 self.assertEqual(C(), C(0, 20))
416 self.assertEqual(C(1, 10), C(1, 20))
417 self.assertNotEqual(C(3), C(4, 10))
418 self.assertNotEqual(C(3, 10), C(4, 10))
419
420 def test_hash_field_rules(self):
421 # Test all 6 cases of:
422 # hash=True/False/None
423 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800424 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500425 (True, False, 'field' ),
426 (True, True, 'field' ),
427 (False, False, 'absent'),
428 (False, True, 'absent'),
429 (None, False, 'absent'),
430 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800431 ]:
432 with self.subTest(hash=hash_, compare=compare):
433 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800435 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500436
437 if result == 'field':
438 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800439 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 elif result == 'absent':
441 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800442 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500443 else:
444 assert False, f'unknown result {result!r}'
445
446 def test_init_false_no_default(self):
447 # If init=False and no default value, then the field won't be
448 # present in the instance.
449 @dataclass
450 class C:
451 x: int = field(init=False)
452
453 self.assertNotIn('x', C().__dict__)
454
455 @dataclass
456 class C:
457 x: int
458 y: int = 0
459 z: int = field(init=False)
460 t: int = 10
461
462 self.assertNotIn('z', C(0).__dict__)
463 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
464
465 def test_class_marker(self):
466 @dataclass
467 class C:
468 x: int
469 y: str = field(init=False, default=None)
470 z: str = field(repr=False)
471
472 the_fields = fields(C)
473 # the_fields is a tuple of 3 items, each value
474 # is in __annotations__.
475 self.assertIsInstance(the_fields, tuple)
476 for f in the_fields:
477 self.assertIs(type(f), Field)
478 self.assertIn(f.name, C.__annotations__)
479
480 self.assertEqual(len(the_fields), 3)
481
482 self.assertEqual(the_fields[0].name, 'x')
483 self.assertEqual(the_fields[0].type, int)
484 self.assertFalse(hasattr(C, 'x'))
485 self.assertTrue (the_fields[0].init)
486 self.assertTrue (the_fields[0].repr)
487 self.assertEqual(the_fields[1].name, 'y')
488 self.assertEqual(the_fields[1].type, str)
489 self.assertIsNone(getattr(C, 'y'))
490 self.assertFalse(the_fields[1].init)
491 self.assertTrue (the_fields[1].repr)
492 self.assertEqual(the_fields[2].name, 'z')
493 self.assertEqual(the_fields[2].type, str)
494 self.assertFalse(hasattr(C, 'z'))
495 self.assertTrue (the_fields[2].init)
496 self.assertFalse(the_fields[2].repr)
497
498 def test_field_order(self):
499 @dataclass
500 class B:
501 a: str = 'B:a'
502 b: str = 'B:b'
503 c: str = 'B:c'
504
505 @dataclass
506 class C(B):
507 b: str = 'C:b'
508
509 self.assertEqual([(f.name, f.default) for f in fields(C)],
510 [('a', 'B:a'),
511 ('b', 'C:b'),
512 ('c', 'B:c')])
513
514 @dataclass
515 class D(B):
516 c: str = 'D:c'
517
518 self.assertEqual([(f.name, f.default) for f in fields(D)],
519 [('a', 'B:a'),
520 ('b', 'B:b'),
521 ('c', 'D:c')])
522
523 @dataclass
524 class E(D):
525 a: str = 'E:a'
526 d: str = 'E:d'
527
528 self.assertEqual([(f.name, f.default) for f in fields(E)],
529 [('a', 'E:a'),
530 ('b', 'B:b'),
531 ('c', 'D:c'),
532 ('d', 'E:d')])
533
534 def test_class_attrs(self):
535 # We only have a class attribute if a default value is
536 # specified, either directly or via a field with a default.
537 default = object()
538 @dataclass
539 class C:
540 x: int
541 y: int = field(repr=False)
542 z: object = default
543 t: int = field(default=100)
544
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertFalse(hasattr(C, 'y'))
547 self.assertIs (C.z, default)
548 self.assertEqual(C.t, 100)
549
550 def test_disallowed_mutable_defaults(self):
551 # For the known types, don't allow mutable default values.
552 for typ, empty, non_empty in [(list, [], [1]),
553 (dict, {}, {0:1}),
554 (set, set(), set([1])),
555 ]:
556 with self.subTest(typ=typ):
557 # Can't use a zero-length value.
558 with self.assertRaisesRegex(ValueError,
559 f'mutable default {typ} for field '
560 'x is not allowed'):
561 @dataclass
562 class Point:
563 x: typ = empty
564
565
566 # Nor a non-zero-length value
567 with self.assertRaisesRegex(ValueError,
568 f'mutable default {typ} for field '
569 'y is not allowed'):
570 @dataclass
571 class Point:
572 y: typ = non_empty
573
574 # Check subtypes also fail.
575 class Subclass(typ): pass
576
577 with self.assertRaisesRegex(ValueError,
578 f"mutable default .*Subclass'>"
579 ' for field z is not allowed'
580 ):
581 @dataclass
582 class Point:
583 z: typ = Subclass()
584
585 # Because this is a ClassVar, it can be mutable.
586 @dataclass
587 class C:
588 z: ClassVar[typ] = typ()
589
590 # Because this is a ClassVar, it can be mutable.
591 @dataclass
592 class C:
593 x: ClassVar[typ] = Subclass()
594
595
596 def test_deliberately_mutable_defaults(self):
597 # If a mutable default isn't in the known list of
598 # (list, dict, set), then it's okay.
599 class Mutable:
600 def __init__(self):
601 self.l = []
602
603 @dataclass
604 class C:
605 x: Mutable
606
607 # These 2 instances will share this value of x.
608 lst = Mutable()
609 o1 = C(lst)
610 o2 = C(lst)
611 self.assertEqual(o1, o2)
612 o1.x.l.extend([1, 2])
613 self.assertEqual(o1, o2)
614 self.assertEqual(o1.x.l, [1, 2])
615 self.assertIs(o1.x, o2.x)
616
617 def test_no_options(self):
618 # call with dataclass()
619 @dataclass()
620 class C:
621 x: int
622
623 self.assertEqual(C(42).x, 42)
624
625 def test_not_tuple(self):
626 # Make sure we can't be compared to a tuple.
627 @dataclass
628 class Point:
629 x: int
630 y: int
631 self.assertNotEqual(Point(1, 2), (1, 2))
632
633 # And that we can't compare to another unrelated dataclass
634 @dataclass
635 class C:
636 x: int
637 y: int
638 self.assertNotEqual(Point(1, 3), C(1, 3))
639
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500640 def test_not_tuple(self):
641 # Test that some of the problems with namedtuple don't happen
642 # here.
643 @dataclass
644 class Point3D:
645 x: int
646 y: int
647 z: int
648
649 @dataclass
650 class Date:
651 year: int
652 month: int
653 day: int
654
655 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
656 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
657
658 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200659 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500660 x, y, z = Point3D(4, 5, 6)
661
Eric V. Smith7c99e932018-01-28 19:18:55 -0500662 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500663 # equal.
664 @dataclass
665 class Point3Dv1:
666 x: int = 0
667 y: int = 0
668 z: int = 0
669 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
670
671 def test_function_annotations(self):
672 # Some dummy class and instance to use as a default.
673 class F:
674 pass
675 f = F()
676
677 def validate_class(cls):
678 # First, check __annotations__, even though they're not
679 # function annotations.
680 self.assertEqual(cls.__annotations__['i'], int)
681 self.assertEqual(cls.__annotations__['j'], str)
682 self.assertEqual(cls.__annotations__['k'], F)
683 self.assertEqual(cls.__annotations__['l'], float)
684 self.assertEqual(cls.__annotations__['z'], complex)
685
686 # Verify __init__.
687
688 signature = inspect.signature(cls.__init__)
689 # Check the return type, should be None
690 self.assertIs(signature.return_annotation, None)
691
692 # Check each parameter.
693 params = iter(signature.parameters.values())
694 param = next(params)
695 # This is testing an internal name, and probably shouldn't be tested.
696 self.assertEqual(param.name, 'self')
697 param = next(params)
698 self.assertEqual(param.name, 'i')
699 self.assertIs (param.annotation, int)
700 self.assertEqual(param.default, inspect.Parameter.empty)
701 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
702 param = next(params)
703 self.assertEqual(param.name, 'j')
704 self.assertIs (param.annotation, str)
705 self.assertEqual(param.default, inspect.Parameter.empty)
706 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
707 param = next(params)
708 self.assertEqual(param.name, 'k')
709 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500710 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500711 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
712 param = next(params)
713 self.assertEqual(param.name, 'l')
714 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500715 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500716 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
717 self.assertRaises(StopIteration, next, params)
718
719
720 @dataclass
721 class C:
722 i: int
723 j: str
724 k: F = f
725 l: float=field(default=None)
726 z: complex=field(default=3+4j, init=False)
727
728 validate_class(C)
729
730 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800731 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500732 class C:
733 i: int
734 j: str
735 k: F = f
736 l: float=field(default=None)
737 z: complex=field(default=3+4j, init=False)
738
739 validate_class(C)
740
Eric V. Smith03220fd2017-12-29 13:59:58 -0500741 def test_missing_default(self):
742 # Test that MISSING works the same as a default not being
743 # specified.
744 @dataclass
745 class C:
746 x: int=field(default=MISSING)
747 with self.assertRaisesRegex(TypeError,
748 r'__init__\(\) missing 1 required '
749 'positional argument'):
750 C()
751 self.assertNotIn('x', C.__dict__)
752
753 @dataclass
754 class D:
755 x: int
756 with self.assertRaisesRegex(TypeError,
757 r'__init__\(\) missing 1 required '
758 'positional argument'):
759 D()
760 self.assertNotIn('x', D.__dict__)
761
762 def test_missing_default_factory(self):
763 # Test that MISSING works the same as a default factory not
764 # being specified (which is really the same as a default not
765 # being specified, too).
766 @dataclass
767 class C:
768 x: int=field(default_factory=MISSING)
769 with self.assertRaisesRegex(TypeError,
770 r'__init__\(\) missing 1 required '
771 'positional argument'):
772 C()
773 self.assertNotIn('x', C.__dict__)
774
775 @dataclass
776 class D:
777 x: int=field(default=MISSING, default_factory=MISSING)
778 with self.assertRaisesRegex(TypeError,
779 r'__init__\(\) missing 1 required '
780 'positional argument'):
781 D()
782 self.assertNotIn('x', D.__dict__)
783
784 def test_missing_repr(self):
785 self.assertIn('MISSING_TYPE object', repr(MISSING))
786
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500787 def test_dont_include_other_annotations(self):
788 @dataclass
789 class C:
790 i: int
791 def foo(self) -> int:
792 return 4
793 @property
794 def bar(self) -> int:
795 return 5
796 self.assertEqual(list(C.__annotations__), ['i'])
797 self.assertEqual(C(10).foo(), 4)
798 self.assertEqual(C(10).bar, 5)
799
800 def test_post_init(self):
801 # Just make sure it gets called
802 @dataclass
803 class C:
804 def __post_init__(self):
805 raise CustomError()
806 with self.assertRaises(CustomError):
807 C()
808
809 @dataclass
810 class C:
811 i: int = 10
812 def __post_init__(self):
813 if self.i == 10:
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817 # post-init gets called, but doesn't raise. This is just
818 # checking that self is used correctly.
819 C(5)
820
821 # If there's not an __init__, then post-init won't get called.
822 @dataclass(init=False)
823 class C:
824 def __post_init__(self):
825 raise CustomError()
826 # Creating the class won't raise
827 C()
828
829 @dataclass
830 class C:
831 x: int = 0
832 def __post_init__(self):
833 self.x *= 2
834 self.assertEqual(C().x, 0)
835 self.assertEqual(C(2).x, 4)
836
Mike53f7a7c2017-12-14 14:04:53 +0300837 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500838 # attributes.
839 @dataclass(frozen=True)
840 class C:
841 x: int = 0
842 def __post_init__(self):
843 self.x *= 2
844 with self.assertRaises(FrozenInstanceError):
845 C()
846
847 def test_post_init_super(self):
848 # Make sure super() post-init isn't called by default.
849 class B:
850 def __post_init__(self):
851 raise CustomError()
852
853 @dataclass
854 class C(B):
855 def __post_init__(self):
856 self.x = 5
857
858 self.assertEqual(C().x, 5)
859
860 # Now call super(), and it will raise
861 @dataclass
862 class C(B):
863 def __post_init__(self):
864 super().__post_init__()
865
866 with self.assertRaises(CustomError):
867 C()
868
869 # Make sure post-init is called, even if not defined in our
870 # class.
871 @dataclass
872 class C(B):
873 pass
874
875 with self.assertRaises(CustomError):
876 C()
877
878 def test_post_init_staticmethod(self):
879 flag = False
880 @dataclass
881 class C:
882 x: int
883 y: int
884 @staticmethod
885 def __post_init__():
886 nonlocal flag
887 flag = True
888
889 self.assertFalse(flag)
890 c = C(3, 4)
891 self.assertEqual((c.x, c.y), (3, 4))
892 self.assertTrue(flag)
893
894 def test_post_init_classmethod(self):
895 @dataclass
896 class C:
897 flag = False
898 x: int
899 y: int
900 @classmethod
901 def __post_init__(cls):
902 cls.flag = True
903
904 self.assertFalse(C.flag)
905 c = C(3, 4)
906 self.assertEqual((c.x, c.y), (3, 4))
907 self.assertTrue(C.flag)
908
909 def test_class_var(self):
910 # Make sure ClassVars are ignored in __init__, __repr__, etc.
911 @dataclass
912 class C:
913 x: int
914 y: int = 10
915 z: ClassVar[int] = 1000
916 w: ClassVar[int] = 2000
917 t: ClassVar[int] = 3000
918
919 c = C(5)
920 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
921 self.assertEqual(len(fields(C)), 2) # We have 2 fields
922 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
923 self.assertEqual(c.z, 1000)
924 self.assertEqual(c.w, 2000)
925 self.assertEqual(c.t, 3000)
926 C.z += 1
927 self.assertEqual(c.z, 1001)
928 c = C(20)
929 self.assertEqual((c.x, c.y), (20, 10))
930 self.assertEqual(c.z, 1001)
931 self.assertEqual(c.w, 2000)
932 self.assertEqual(c.t, 3000)
933
934 def test_class_var_no_default(self):
935 # If a ClassVar has no default value, it should not be set on the class.
936 @dataclass
937 class C:
938 x: ClassVar[int]
939
940 self.assertNotIn('x', C.__dict__)
941
942 def test_class_var_default_factory(self):
943 # It makes no sense for a ClassVar to have a default factory. When
944 # would it be called? Call it yourself, since it's class-wide.
945 with self.assertRaisesRegex(TypeError,
946 'cannot have a default factory'):
947 @dataclass
948 class C:
949 x: ClassVar[int] = field(default_factory=int)
950
951 self.assertNotIn('x', C.__dict__)
952
953 def test_class_var_with_default(self):
954 # If a ClassVar has a default value, it should be set on the class.
955 @dataclass
956 class C:
957 x: ClassVar[int] = 10
958 self.assertEqual(C.x, 10)
959
960 @dataclass
961 class C:
962 x: ClassVar[int] = field(default=10)
963 self.assertEqual(C.x, 10)
964
965 def test_class_var_frozen(self):
966 # Make sure ClassVars work even if we're frozen.
967 @dataclass(frozen=True)
968 class C:
969 x: int
970 y: int = 10
971 z: ClassVar[int] = 1000
972 w: ClassVar[int] = 2000
973 t: ClassVar[int] = 3000
974
975 c = C(5)
976 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
977 self.assertEqual(len(fields(C)), 2) # We have 2 fields
978 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
979 self.assertEqual(c.z, 1000)
980 self.assertEqual(c.w, 2000)
981 self.assertEqual(c.t, 3000)
982 # We can still modify the ClassVar, it's only instances that are
983 # frozen.
984 C.z += 1
985 self.assertEqual(c.z, 1001)
986 c = C(20)
987 self.assertEqual((c.x, c.y), (20, 10))
988 self.assertEqual(c.z, 1001)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991
992 def test_init_var_no_default(self):
993 # If an InitVar has no default value, it should not be set on the class.
994 @dataclass
995 class C:
996 x: InitVar[int]
997
998 self.assertNotIn('x', C.__dict__)
999
1000 def test_init_var_default_factory(self):
1001 # It makes no sense for an InitVar to have a default factory. When
1002 # would it be called? Call it yourself, since it's class-wide.
1003 with self.assertRaisesRegex(TypeError,
1004 'cannot have a default factory'):
1005 @dataclass
1006 class C:
1007 x: InitVar[int] = field(default_factory=int)
1008
1009 self.assertNotIn('x', C.__dict__)
1010
1011 def test_init_var_with_default(self):
1012 # If an InitVar has a default value, it should be set on the class.
1013 @dataclass
1014 class C:
1015 x: InitVar[int] = 10
1016 self.assertEqual(C.x, 10)
1017
1018 @dataclass
1019 class C:
1020 x: InitVar[int] = field(default=10)
1021 self.assertEqual(C.x, 10)
1022
1023 def test_init_var(self):
1024 @dataclass
1025 class C:
1026 x: int = None
1027 init_param: InitVar[int] = None
1028
1029 def __post_init__(self, init_param):
1030 if self.x is None:
1031 self.x = init_param*2
1032
1033 c = C(init_param=10)
1034 self.assertEqual(c.x, 20)
1035
1036 def test_init_var_inheritance(self):
1037 # Note that this deliberately tests that a dataclass need not
1038 # have a __post_init__ function if it has an InitVar field.
1039 # It could just be used in a derived class, as shown here.
1040 @dataclass
1041 class Base:
1042 x: int
1043 init_base: InitVar[int]
1044
1045 # We can instantiate by passing the InitVar, even though
1046 # it's not used.
1047 b = Base(0, 10)
1048 self.assertEqual(vars(b), {'x': 0})
1049
1050 @dataclass
1051 class C(Base):
1052 y: int
1053 init_derived: InitVar[int]
1054
1055 def __post_init__(self, init_base, init_derived):
1056 self.x = self.x + init_base
1057 self.y = self.y + init_derived
1058
1059 c = C(10, 11, 50, 51)
1060 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1061
1062 def test_default_factory(self):
1063 # Test a factory that returns a new list.
1064 @dataclass
1065 class C:
1066 x: int
1067 y: list = field(default_factory=list)
1068
1069 c0 = C(3)
1070 c1 = C(3)
1071 self.assertEqual(c0.x, 3)
1072 self.assertEqual(c0.y, [])
1073 self.assertEqual(c0, c1)
1074 self.assertIsNot(c0.y, c1.y)
1075 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1076
1077 # Test a factory that returns a shared list.
1078 l = []
1079 @dataclass
1080 class C:
1081 x: int
1082 y: list = field(default_factory=lambda: l)
1083
1084 c0 = C(3)
1085 c1 = C(3)
1086 self.assertEqual(c0.x, 3)
1087 self.assertEqual(c0.y, [])
1088 self.assertEqual(c0, c1)
1089 self.assertIs(c0.y, c1.y)
1090 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1091
1092 # Test various other field flags.
1093 # repr
1094 @dataclass
1095 class C:
1096 x: list = field(default_factory=list, repr=False)
1097 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1098 self.assertEqual(C().x, [])
1099
1100 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001101 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001102 class C:
1103 x: list = field(default_factory=list, hash=False)
1104 self.assertEqual(astuple(C()), ([],))
1105 self.assertEqual(hash(C()), hash(()))
1106
1107 # init (see also test_default_factory_with_no_init)
1108 @dataclass
1109 class C:
1110 x: list = field(default_factory=list, init=False)
1111 self.assertEqual(astuple(C()), ([],))
1112
1113 # compare
1114 @dataclass
1115 class C:
1116 x: list = field(default_factory=list, compare=False)
1117 self.assertEqual(C(), C([1]))
1118
1119 def test_default_factory_with_no_init(self):
1120 # We need a factory with a side effect.
1121 factory = Mock()
1122
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=factory, init=False)
1126
1127 # Make sure the default factory is called for each new instance.
1128 C().x
1129 self.assertEqual(factory.call_count, 1)
1130 C().x
1131 self.assertEqual(factory.call_count, 2)
1132
1133 def test_default_factory_not_called_if_value_given(self):
1134 # We need a factory that we can test if it's been called.
1135 factory = Mock()
1136
1137 @dataclass
1138 class C:
1139 x: int = field(default_factory=factory)
1140
1141 # Make sure that if a field has a default factory function,
1142 # it's not called if a value is specified.
1143 C().x
1144 self.assertEqual(factory.call_count, 1)
1145 self.assertEqual(C(10).x, 10)
1146 self.assertEqual(factory.call_count, 1)
1147 C().x
1148 self.assertEqual(factory.call_count, 2)
1149
1150 def x_test_classvar_default_factory(self):
1151 # XXX: it's an error for a ClassVar to have a factory function
1152 @dataclass
1153 class C:
1154 x: ClassVar[int] = field(default_factory=int)
1155
1156 self.assertIs(C().x, int)
1157
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001158 def test_is_dataclass(self):
1159 class NotDataClass:
1160 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001161
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001162 self.assertFalse(is_dataclass(0))
1163 self.assertFalse(is_dataclass(int))
1164 self.assertFalse(is_dataclass(NotDataClass))
1165 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001166
1167 @dataclass
1168 class C:
1169 x: int
1170
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001171 @dataclass
1172 class D:
1173 d: C
1174 e: int
1175
1176 c = C(10)
1177 d = D(c, 4)
1178
1179 self.assertTrue(is_dataclass(C))
1180 self.assertTrue(is_dataclass(c))
1181 self.assertFalse(is_dataclass(c.x))
1182 self.assertTrue(is_dataclass(d.d))
1183 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001184
1185 def test_helper_fields_with_class_instance(self):
1186 # Check that we can call fields() on either a class or instance,
1187 # and get back the same thing.
1188 @dataclass
1189 class C:
1190 x: int
1191 y: float
1192
1193 self.assertEqual(fields(C), fields(C(0, 0.0)))
1194
1195 def test_helper_fields_exception(self):
1196 # Check that TypeError is raised if not passed a dataclass or
1197 # instance.
1198 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1199 fields(0)
1200
1201 class C: pass
1202 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1203 fields(C)
1204 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1205 fields(C())
1206
1207 def test_helper_asdict(self):
1208 # Basic tests for asdict(), it should return a new dictionary
1209 @dataclass
1210 class C:
1211 x: int
1212 y: int
1213 c = C(1, 2)
1214
1215 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1216 self.assertEqual(asdict(c), asdict(c))
1217 self.assertIsNot(asdict(c), asdict(c))
1218 c.x = 42
1219 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1220 self.assertIs(type(asdict(c)), dict)
1221
1222 def test_helper_asdict_raises_on_classes(self):
1223 # asdict() should raise on a class object
1224 @dataclass
1225 class C:
1226 x: int
1227 y: int
1228 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1229 asdict(C)
1230 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1231 asdict(int)
1232
1233 def test_helper_asdict_copy_values(self):
1234 @dataclass
1235 class C:
1236 x: int
1237 y: List[int] = field(default_factory=list)
1238 initial = []
1239 c = C(1, initial)
1240 d = asdict(c)
1241 self.assertEqual(d['y'], initial)
1242 self.assertIsNot(d['y'], initial)
1243 c = C(1)
1244 d = asdict(c)
1245 d['y'].append(1)
1246 self.assertEqual(c.y, [])
1247
1248 def test_helper_asdict_nested(self):
1249 @dataclass
1250 class UserId:
1251 token: int
1252 group: int
1253 @dataclass
1254 class User:
1255 name: str
1256 id: UserId
1257 u = User('Joe', UserId(123, 1))
1258 d = asdict(u)
1259 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1260 self.assertIsNot(asdict(u), asdict(u))
1261 u.id.group = 2
1262 self.assertEqual(asdict(u), {'name': 'Joe',
1263 'id': {'token': 123, 'group': 2}})
1264
1265 def test_helper_asdict_builtin_containers(self):
1266 @dataclass
1267 class User:
1268 name: str
1269 id: int
1270 @dataclass
1271 class GroupList:
1272 id: int
1273 users: List[User]
1274 @dataclass
1275 class GroupTuple:
1276 id: int
1277 users: Tuple[User, ...]
1278 @dataclass
1279 class GroupDict:
1280 id: int
1281 users: Dict[str, User]
1282 a = User('Alice', 1)
1283 b = User('Bob', 2)
1284 gl = GroupList(0, [a, b])
1285 gt = GroupTuple(0, (a, b))
1286 gd = GroupDict(0, {'first': a, 'second': b})
1287 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1288 {'name': 'Bob', 'id': 2}]})
1289 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1290 {'name': 'Bob', 'id': 2})})
1291 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1292 'second': {'name': 'Bob', 'id': 2}}})
1293
1294 def test_helper_asdict_builtin_containers(self):
1295 @dataclass
1296 class Child:
1297 d: object
1298
1299 @dataclass
1300 class Parent:
1301 child: Child
1302
1303 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1304 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1305
1306 def test_helper_asdict_factory(self):
1307 @dataclass
1308 class C:
1309 x: int
1310 y: int
1311 c = C(1, 2)
1312 d = asdict(c, dict_factory=OrderedDict)
1313 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1314 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1315 c.x = 42
1316 d = asdict(c, dict_factory=OrderedDict)
1317 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1318 self.assertIs(type(d), OrderedDict)
1319
1320 def test_helper_astuple(self):
1321 # Basic tests for astuple(), it should return a new tuple
1322 @dataclass
1323 class C:
1324 x: int
1325 y: int = 0
1326 c = C(1)
1327
1328 self.assertEqual(astuple(c), (1, 0))
1329 self.assertEqual(astuple(c), astuple(c))
1330 self.assertIsNot(astuple(c), astuple(c))
1331 c.y = 42
1332 self.assertEqual(astuple(c), (1, 42))
1333 self.assertIs(type(astuple(c)), tuple)
1334
1335 def test_helper_astuple_raises_on_classes(self):
1336 # astuple() should raise on a class object
1337 @dataclass
1338 class C:
1339 x: int
1340 y: int
1341 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1342 astuple(C)
1343 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1344 astuple(int)
1345
1346 def test_helper_astuple_copy_values(self):
1347 @dataclass
1348 class C:
1349 x: int
1350 y: List[int] = field(default_factory=list)
1351 initial = []
1352 c = C(1, initial)
1353 t = astuple(c)
1354 self.assertEqual(t[1], initial)
1355 self.assertIsNot(t[1], initial)
1356 c = C(1)
1357 t = astuple(c)
1358 t[1].append(1)
1359 self.assertEqual(c.y, [])
1360
1361 def test_helper_astuple_nested(self):
1362 @dataclass
1363 class UserId:
1364 token: int
1365 group: int
1366 @dataclass
1367 class User:
1368 name: str
1369 id: UserId
1370 u = User('Joe', UserId(123, 1))
1371 t = astuple(u)
1372 self.assertEqual(t, ('Joe', (123, 1)))
1373 self.assertIsNot(astuple(u), astuple(u))
1374 u.id.group = 2
1375 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1376
1377 def test_helper_astuple_builtin_containers(self):
1378 @dataclass
1379 class User:
1380 name: str
1381 id: int
1382 @dataclass
1383 class GroupList:
1384 id: int
1385 users: List[User]
1386 @dataclass
1387 class GroupTuple:
1388 id: int
1389 users: Tuple[User, ...]
1390 @dataclass
1391 class GroupDict:
1392 id: int
1393 users: Dict[str, User]
1394 a = User('Alice', 1)
1395 b = User('Bob', 2)
1396 gl = GroupList(0, [a, b])
1397 gt = GroupTuple(0, (a, b))
1398 gd = GroupDict(0, {'first': a, 'second': b})
1399 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1400 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1401 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1402
1403 def test_helper_astuple_builtin_containers(self):
1404 @dataclass
1405 class Child:
1406 d: object
1407
1408 @dataclass
1409 class Parent:
1410 child: Child
1411
1412 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1413 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1414
1415 def test_helper_astuple_factory(self):
1416 @dataclass
1417 class C:
1418 x: int
1419 y: int
1420 NT = namedtuple('NT', 'x y')
1421 def nt(lst):
1422 return NT(*lst)
1423 c = C(1, 2)
1424 t = astuple(c, tuple_factory=nt)
1425 self.assertEqual(t, NT(1, 2))
1426 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1427 c.x = 42
1428 t = astuple(c, tuple_factory=nt)
1429 self.assertEqual(t, NT(42, 2))
1430 self.assertIs(type(t), NT)
1431
1432 def test_dynamic_class_creation(self):
1433 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1434 }
1435
1436 # Create the class.
1437 cls = type('C', (), cls_dict)
1438
1439 # Make it a dataclass.
1440 cls1 = dataclass(cls)
1441
1442 self.assertEqual(cls1, cls)
1443 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1444
1445 def test_dynamic_class_creation_using_field(self):
1446 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1447 'y': field(default=5),
1448 }
1449
1450 # Create the class.
1451 cls = type('C', (), cls_dict)
1452
1453 # Make it a dataclass.
1454 cls1 = dataclass(cls)
1455
1456 self.assertEqual(cls1, cls)
1457 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1458
1459 def test_init_in_order(self):
1460 @dataclass
1461 class C:
1462 a: int
1463 b: int = field()
1464 c: list = field(default_factory=list, init=False)
1465 d: list = field(default_factory=list)
1466 e: int = field(default=4, init=False)
1467 f: int = 4
1468
1469 calls = []
1470 def setattr(self, name, value):
1471 calls.append((name, value))
1472
1473 C.__setattr__ = setattr
1474 c = C(0, 1)
1475 self.assertEqual(('a', 0), calls[0])
1476 self.assertEqual(('b', 1), calls[1])
1477 self.assertEqual(('c', []), calls[2])
1478 self.assertEqual(('d', []), calls[3])
1479 self.assertNotIn(('e', 4), calls)
1480 self.assertEqual(('f', 4), calls[4])
1481
1482 def test_items_in_dicts(self):
1483 @dataclass
1484 class C:
1485 a: int
1486 b: list = field(default_factory=list, init=False)
1487 c: list = field(default_factory=list)
1488 d: int = field(default=4, init=False)
1489 e: int = 0
1490
1491 c = C(0)
1492 # Class dict
1493 self.assertNotIn('a', C.__dict__)
1494 self.assertNotIn('b', C.__dict__)
1495 self.assertNotIn('c', C.__dict__)
1496 self.assertIn('d', C.__dict__)
1497 self.assertEqual(C.d, 4)
1498 self.assertIn('e', C.__dict__)
1499 self.assertEqual(C.e, 0)
1500 # Instance dict
1501 self.assertIn('a', c.__dict__)
1502 self.assertEqual(c.a, 0)
1503 self.assertIn('b', c.__dict__)
1504 self.assertEqual(c.b, [])
1505 self.assertIn('c', c.__dict__)
1506 self.assertEqual(c.c, [])
1507 self.assertNotIn('d', c.__dict__)
1508 self.assertIn('e', c.__dict__)
1509 self.assertEqual(c.e, 0)
1510
1511 def test_alternate_classmethod_constructor(self):
1512 # Since __post_init__ can't take params, use a classmethod
1513 # alternate constructor. This is mostly an example to show how
1514 # to use this technique.
1515 @dataclass
1516 class C:
1517 x: int
1518 @classmethod
1519 def from_file(cls, filename):
1520 # In a real example, create a new instance
1521 # and populate 'x' from contents of a file.
1522 value_in_file = 20
1523 return cls(value_in_file)
1524
1525 self.assertEqual(C.from_file('filename').x, 20)
1526
1527 def test_field_metadata_default(self):
1528 # Make sure the default metadata is read-only and of
1529 # zero length.
1530 @dataclass
1531 class C:
1532 i: int
1533
1534 self.assertFalse(fields(C)[0].metadata)
1535 self.assertEqual(len(fields(C)[0].metadata), 0)
1536 with self.assertRaisesRegex(TypeError,
1537 'does not support item assignment'):
1538 fields(C)[0].metadata['test'] = 3
1539
1540 def test_field_metadata_mapping(self):
1541 # Make sure only a mapping can be passed as metadata
1542 # zero length.
1543 with self.assertRaises(TypeError):
1544 @dataclass
1545 class C:
1546 i: int = field(metadata=0)
1547
1548 # Make sure an empty dict works
1549 @dataclass
1550 class C:
1551 i: int = field(metadata={})
1552 self.assertFalse(fields(C)[0].metadata)
1553 self.assertEqual(len(fields(C)[0].metadata), 0)
1554 with self.assertRaisesRegex(TypeError,
1555 'does not support item assignment'):
1556 fields(C)[0].metadata['test'] = 3
1557
1558 # Make sure a non-empty dict works.
1559 @dataclass
1560 class C:
1561 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1562 self.assertEqual(len(fields(C)[0].metadata), 3)
1563 self.assertEqual(fields(C)[0].metadata['test'], 10)
1564 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1565 self.assertEqual(fields(C)[0].metadata[3], 'three')
1566 with self.assertRaises(KeyError):
1567 # Non-existent key.
1568 fields(C)[0].metadata['baz']
1569 with self.assertRaisesRegex(TypeError,
1570 'does not support item assignment'):
1571 fields(C)[0].metadata['test'] = 3
1572
1573 def test_field_metadata_custom_mapping(self):
1574 # Try a custom mapping.
1575 class SimpleNameSpace:
1576 def __init__(self, **kw):
1577 self.__dict__.update(kw)
1578
1579 def __getitem__(self, item):
1580 if item == 'xyzzy':
1581 return 'plugh'
1582 return getattr(self, item)
1583
1584 def __len__(self):
1585 return self.__dict__.__len__()
1586
1587 @dataclass
1588 class C:
1589 i: int = field(metadata=SimpleNameSpace(a=10))
1590
1591 self.assertEqual(len(fields(C)[0].metadata), 1)
1592 self.assertEqual(fields(C)[0].metadata['a'], 10)
1593 with self.assertRaises(AttributeError):
1594 fields(C)[0].metadata['b']
1595 # Make sure we're still talking to our custom mapping.
1596 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1597
1598 def test_generic_dataclasses(self):
1599 T = TypeVar('T')
1600
1601 @dataclass
1602 class LabeledBox(Generic[T]):
1603 content: T
1604 label: str = '<unknown>'
1605
1606 box = LabeledBox(42)
1607 self.assertEqual(box.content, 42)
1608 self.assertEqual(box.label, '<unknown>')
1609
1610 # subscripting the resulting class should work, etc.
1611 Alias = List[LabeledBox[int]]
1612
1613 def test_generic_extending(self):
1614 S = TypeVar('S')
1615 T = TypeVar('T')
1616
1617 @dataclass
1618 class Base(Generic[T, S]):
1619 x: T
1620 y: S
1621
1622 @dataclass
1623 class DataDerived(Base[int, T]):
1624 new_field: str
1625 Alias = DataDerived[str]
1626 c = Alias(0, 'test1', 'test2')
1627 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1628
1629 class NonDataDerived(Base[int, T]):
1630 def new_method(self):
1631 return self.y
1632 Alias = NonDataDerived[float]
1633 c = Alias(10, 1.0)
1634 self.assertEqual(c.new_method(), 1.0)
1635
1636 def test_helper_replace(self):
1637 @dataclass(frozen=True)
1638 class C:
1639 x: int
1640 y: int
1641
1642 c = C(1, 2)
1643 c1 = replace(c, x=3)
1644 self.assertEqual(c1.x, 3)
1645 self.assertEqual(c1.y, 2)
1646
1647 def test_helper_replace_frozen(self):
1648 @dataclass(frozen=True)
1649 class C:
1650 x: int
1651 y: int
1652 z: int = field(init=False, default=10)
1653 t: int = field(init=False, default=100)
1654
1655 c = C(1, 2)
1656 c1 = replace(c, x=3)
1657 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1658 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1659
1660
1661 with self.assertRaisesRegex(ValueError, 'init=False'):
1662 replace(c, x=3, z=20, t=50)
1663 with self.assertRaisesRegex(ValueError, 'init=False'):
1664 replace(c, z=20)
1665 replace(c, x=3, z=20, t=50)
1666
1667 # Make sure the result is still frozen.
1668 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1669 c1.x = 3
1670
1671 # Make sure we can't replace an attribute that doesn't exist,
1672 # if we're also replacing one that does exist. Test this
1673 # here, because setting attributes on frozen instances is
1674 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001675 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001676 "keyword argument 'a'"):
1677 c1 = replace(c, x=20, a=5)
1678
1679 def test_helper_replace_invalid_field_name(self):
1680 @dataclass(frozen=True)
1681 class C:
1682 x: int
1683 y: int
1684
1685 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001686 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001687 "keyword argument 'z'"):
1688 c1 = replace(c, z=3)
1689
1690 def test_helper_replace_invalid_object(self):
1691 @dataclass(frozen=True)
1692 class C:
1693 x: int
1694 y: int
1695
1696 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1697 replace(C, x=3)
1698
1699 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1700 replace(0, x=3)
1701
1702 def test_helper_replace_no_init(self):
1703 @dataclass
1704 class C:
1705 x: int
1706 y: int = field(init=False, default=10)
1707
1708 c = C(1)
1709 c.y = 20
1710
1711 # Make sure y gets the default value.
1712 c1 = replace(c, x=5)
1713 self.assertEqual((c1.x, c1.y), (5, 10))
1714
1715 # Trying to replace y is an error.
1716 with self.assertRaisesRegex(ValueError, 'init=False'):
1717 replace(c, x=2, y=30)
1718 with self.assertRaisesRegex(ValueError, 'init=False'):
1719 replace(c, y=30)
1720
1721 def test_dataclassses_pickleable(self):
1722 global P, Q, R
1723 @dataclass
1724 class P:
1725 x: int
1726 y: int = 0
1727 @dataclass
1728 class Q:
1729 x: int
1730 y: int = field(default=0, init=False)
1731 @dataclass
1732 class R:
1733 x: int
1734 y: List[int] = field(default_factory=list)
1735 q = Q(1)
1736 q.y = 2
1737 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1738 for sample in samples:
1739 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1740 with self.subTest(sample=sample, proto=proto):
1741 new_sample = pickle.loads(pickle.dumps(sample, proto))
1742 self.assertEqual(sample.x, new_sample.x)
1743 self.assertEqual(sample.y, new_sample.y)
1744 self.assertIsNot(sample, new_sample)
1745 new_sample.x = 42
1746 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1747 self.assertEqual(new_sample.x, another_new_sample.x)
1748 self.assertEqual(sample.y, another_new_sample.y)
1749
1750 def test_helper_make_dataclass(self):
1751 C = make_dataclass('C',
1752 [('x', int),
1753 ('y', int, field(default=5))],
1754 namespace={'add_one': lambda self: self.x + 1})
1755 c = C(10)
1756 self.assertEqual((c.x, c.y), (10, 5))
1757 self.assertEqual(c.add_one(), 11)
1758
1759
1760 def test_helper_make_dataclass_no_mutate_namespace(self):
1761 # Make sure a provided namespace isn't mutated.
1762 ns = {}
1763 C = make_dataclass('C',
1764 [('x', int),
1765 ('y', int, field(default=5))],
1766 namespace=ns)
1767 self.assertEqual(ns, {})
1768
1769 def test_helper_make_dataclass_base(self):
1770 class Base1:
1771 pass
1772 class Base2:
1773 pass
1774 C = make_dataclass('C',
1775 [('x', int)],
1776 bases=(Base1, Base2))
1777 c = C(2)
1778 self.assertIsInstance(c, C)
1779 self.assertIsInstance(c, Base1)
1780 self.assertIsInstance(c, Base2)
1781
1782 def test_helper_make_dataclass_base_dataclass(self):
1783 @dataclass
1784 class Base1:
1785 x: int
1786 class Base2:
1787 pass
1788 C = make_dataclass('C',
1789 [('y', int)],
1790 bases=(Base1, Base2))
1791 with self.assertRaisesRegex(TypeError, 'required positional'):
1792 c = C(2)
1793 c = C(1, 2)
1794 self.assertIsInstance(c, C)
1795 self.assertIsInstance(c, Base1)
1796 self.assertIsInstance(c, Base2)
1797
1798 self.assertEqual((c.x, c.y), (1, 2))
1799
1800 def test_helper_make_dataclass_init_var(self):
1801 def post_init(self, y):
1802 self.x *= y
1803
1804 C = make_dataclass('C',
1805 [('x', int),
1806 ('y', InitVar[int]),
1807 ],
1808 namespace={'__post_init__': post_init},
1809 )
1810 c = C(2, 3)
1811 self.assertEqual(vars(c), {'x': 6})
1812 self.assertEqual(len(fields(c)), 1)
1813
1814 def test_helper_make_dataclass_class_var(self):
1815 C = make_dataclass('C',
1816 [('x', int),
1817 ('y', ClassVar[int], 10),
1818 ('z', ClassVar[int], field(default=20)),
1819 ])
1820 c = C(1)
1821 self.assertEqual(vars(c), {'x': 1})
1822 self.assertEqual(len(fields(c)), 1)
1823 self.assertEqual(C.y, 10)
1824 self.assertEqual(C.z, 20)
1825
Eric V. Smithd80b4432018-01-06 17:09:58 -05001826 def test_helper_make_dataclass_other_params(self):
1827 C = make_dataclass('C',
1828 [('x', int),
1829 ('y', ClassVar[int], 10),
1830 ('z', ClassVar[int], field(default=20)),
1831 ],
1832 init=False)
1833 # Make sure we have a repr, but no init.
1834 self.assertNotIn('__init__', vars(C))
1835 self.assertIn('__repr__', vars(C))
1836
1837 # Make sure random other params don't work.
1838 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1839 C = make_dataclass('C',
1840 [],
1841 xxinit=False)
1842
Eric V. Smithed7d4292018-01-06 16:14:03 -05001843 def test_helper_make_dataclass_no_types(self):
1844 C = make_dataclass('Point', ['x', 'y', 'z'])
1845 c = C(1, 2, 3)
1846 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1847 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1848 'y': 'typing.Any',
1849 'z': 'typing.Any'})
1850
1851 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1852 c = C(1, 2, 3)
1853 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1854 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1855 'y': int,
1856 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001857
Eric V. Smithea8fc522018-01-27 19:07:40 -05001858
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001859class TestDocString(unittest.TestCase):
1860 def assertDocStrEqual(self, a, b):
1861 # Because 3.6 and 3.7 differ in how inspect.signature work
1862 # (see bpo #32108), for the time being just compare them with
1863 # whitespace stripped.
1864 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1865
1866 def test_existing_docstring_not_overridden(self):
1867 @dataclass
1868 class C:
1869 """Lorem ipsum"""
1870 x: int
1871
1872 self.assertEqual(C.__doc__, "Lorem ipsum")
1873
1874 def test_docstring_no_fields(self):
1875 @dataclass
1876 class C:
1877 pass
1878
1879 self.assertDocStrEqual(C.__doc__, "C()")
1880
1881 def test_docstring_one_field(self):
1882 @dataclass
1883 class C:
1884 x: int
1885
1886 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1887
1888 def test_docstring_two_fields(self):
1889 @dataclass
1890 class C:
1891 x: int
1892 y: int
1893
1894 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1895
1896 def test_docstring_three_fields(self):
1897 @dataclass
1898 class C:
1899 x: int
1900 y: int
1901 z: str
1902
1903 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1904
1905 def test_docstring_one_field_with_default(self):
1906 @dataclass
1907 class C:
1908 x: int = 3
1909
1910 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1911
1912 def test_docstring_one_field_with_default_none(self):
1913 @dataclass
1914 class C:
1915 x: Union[int, type(None)] = None
1916
1917 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1918
1919 def test_docstring_list_field(self):
1920 @dataclass
1921 class C:
1922 x: List[int]
1923
1924 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1925
1926 def test_docstring_list_field_with_default_factory(self):
1927 @dataclass
1928 class C:
1929 x: List[int] = field(default_factory=list)
1930
1931 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1932
1933 def test_docstring_deque_field(self):
1934 @dataclass
1935 class C:
1936 x: deque
1937
1938 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1939
1940 def test_docstring_deque_field_with_default_factory(self):
1941 @dataclass
1942 class C:
1943 x: deque = field(default_factory=deque)
1944
1945 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1946
1947
Eric V. Smithea8fc522018-01-27 19:07:40 -05001948class TestInit(unittest.TestCase):
1949 def test_base_has_init(self):
1950 class B:
1951 def __init__(self):
1952 self.z = 100
1953 pass
1954
1955 # Make sure that declaring this class doesn't raise an error.
1956 # The issue is that we can't override __init__ in our class,
1957 # but it should be okay to add __init__ to us if our base has
1958 # an __init__.
1959 @dataclass
1960 class C(B):
1961 x: int = 0
1962 c = C(10)
1963 self.assertEqual(c.x, 10)
1964 self.assertNotIn('z', vars(c))
1965
1966 # Make sure that if we don't add an init, the base __init__
1967 # gets called.
1968 @dataclass(init=False)
1969 class C(B):
1970 x: int = 10
1971 c = C()
1972 self.assertEqual(c.x, 10)
1973 self.assertEqual(c.z, 100)
1974
1975 def test_no_init(self):
1976 dataclass(init=False)
1977 class C:
1978 i: int = 0
1979 self.assertEqual(C().i, 0)
1980
1981 dataclass(init=False)
1982 class C:
1983 i: int = 2
1984 def __init__(self):
1985 self.i = 3
1986 self.assertEqual(C().i, 3)
1987
1988 def test_overwriting_init(self):
1989 # If the class has __init__, use it no matter the value of
1990 # init=.
1991
1992 @dataclass
1993 class C:
1994 x: int
1995 def __init__(self, x):
1996 self.x = 2 * x
1997 self.assertEqual(C(3).x, 6)
1998
1999 @dataclass(init=True)
2000 class C:
2001 x: int
2002 def __init__(self, x):
2003 self.x = 2 * x
2004 self.assertEqual(C(4).x, 8)
2005
2006 @dataclass(init=False)
2007 class C:
2008 x: int
2009 def __init__(self, x):
2010 self.x = 2 * x
2011 self.assertEqual(C(5).x, 10)
2012
2013
2014class TestRepr(unittest.TestCase):
2015 def test_repr(self):
2016 @dataclass
2017 class B:
2018 x: int
2019
2020 @dataclass
2021 class C(B):
2022 y: int = 10
2023
2024 o = C(4)
2025 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2026
2027 @dataclass
2028 class D(C):
2029 x: int = 20
2030 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2031
2032 @dataclass
2033 class C:
2034 @dataclass
2035 class D:
2036 i: int
2037 @dataclass
2038 class E:
2039 pass
2040 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2041 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2042
2043 def test_no_repr(self):
2044 # Test a class with no __repr__ and repr=False.
2045 @dataclass(repr=False)
2046 class C:
2047 x: int
2048 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2049 repr(C(3)))
2050
2051 # Test a class with a __repr__ and repr=False.
2052 @dataclass(repr=False)
2053 class C:
2054 x: int
2055 def __repr__(self):
2056 return 'C-class'
2057 self.assertEqual(repr(C(3)), 'C-class')
2058
2059 def test_overwriting_repr(self):
2060 # If the class has __repr__, use it no matter the value of
2061 # repr=.
2062
2063 @dataclass
2064 class C:
2065 x: int
2066 def __repr__(self):
2067 return 'x'
2068 self.assertEqual(repr(C(0)), 'x')
2069
2070 @dataclass(repr=True)
2071 class C:
2072 x: int
2073 def __repr__(self):
2074 return 'x'
2075 self.assertEqual(repr(C(0)), 'x')
2076
2077 @dataclass(repr=False)
2078 class C:
2079 x: int
2080 def __repr__(self):
2081 return 'x'
2082 self.assertEqual(repr(C(0)), 'x')
2083
2084
2085class TestFrozen(unittest.TestCase):
2086 def test_overwriting_frozen(self):
2087 # frozen uses __setattr__ and __delattr__
2088 with self.assertRaisesRegex(TypeError,
2089 'Cannot overwrite attribute __setattr__'):
2090 @dataclass(frozen=True)
2091 class C:
2092 x: int
2093 def __setattr__(self):
2094 pass
2095
2096 with self.assertRaisesRegex(TypeError,
2097 'Cannot overwrite attribute __delattr__'):
2098 @dataclass(frozen=True)
2099 class C:
2100 x: int
2101 def __delattr__(self):
2102 pass
2103
2104 @dataclass(frozen=False)
2105 class C:
2106 x: int
2107 def __setattr__(self, name, value):
2108 self.__dict__['x'] = value * 2
2109 self.assertEqual(C(10).x, 20)
2110
2111
2112class TestEq(unittest.TestCase):
2113 def test_no_eq(self):
2114 # Test a class with no __eq__ and eq=False.
2115 @dataclass(eq=False)
2116 class C:
2117 x: int
2118 self.assertNotEqual(C(0), C(0))
2119 c = C(3)
2120 self.assertEqual(c, c)
2121
2122 # Test a class with an __eq__ and eq=False.
2123 @dataclass(eq=False)
2124 class C:
2125 x: int
2126 def __eq__(self, other):
2127 return other == 10
2128 self.assertEqual(C(3), 10)
2129
2130 def test_overwriting_eq(self):
2131 # If the class has __eq__, use it no matter the value of
2132 # eq=.
2133
2134 @dataclass
2135 class C:
2136 x: int
2137 def __eq__(self, other):
2138 return other == 3
2139 self.assertEqual(C(1), 3)
2140 self.assertNotEqual(C(1), 1)
2141
2142 @dataclass(eq=True)
2143 class C:
2144 x: int
2145 def __eq__(self, other):
2146 return other == 4
2147 self.assertEqual(C(1), 4)
2148 self.assertNotEqual(C(1), 1)
2149
2150 @dataclass(eq=False)
2151 class C:
2152 x: int
2153 def __eq__(self, other):
2154 return other == 5
2155 self.assertEqual(C(1), 5)
2156 self.assertNotEqual(C(1), 1)
2157
2158
2159class TestOrdering(unittest.TestCase):
2160 def test_functools_total_ordering(self):
2161 # Test that functools.total_ordering works with this class.
2162 @total_ordering
2163 @dataclass
2164 class C:
2165 x: int
2166 def __lt__(self, other):
2167 # Perform the test "backward", just to make
2168 # sure this is being called.
2169 return self.x >= other
2170
2171 self.assertLess(C(0), -1)
2172 self.assertLessEqual(C(0), -1)
2173 self.assertGreater(C(0), 1)
2174 self.assertGreaterEqual(C(0), 1)
2175
2176 def test_no_order(self):
2177 # Test that no ordering functions are added by default.
2178 @dataclass(order=False)
2179 class C:
2180 x: int
2181 # Make sure no order methods are added.
2182 self.assertNotIn('__le__', C.__dict__)
2183 self.assertNotIn('__lt__', C.__dict__)
2184 self.assertNotIn('__ge__', C.__dict__)
2185 self.assertNotIn('__gt__', C.__dict__)
2186
2187 # Test that __lt__ is still called
2188 @dataclass(order=False)
2189 class C:
2190 x: int
2191 def __lt__(self, other):
2192 return False
2193 # Make sure other methods aren't added.
2194 self.assertNotIn('__le__', C.__dict__)
2195 self.assertNotIn('__ge__', C.__dict__)
2196 self.assertNotIn('__gt__', C.__dict__)
2197
2198 def test_overwriting_order(self):
2199 with self.assertRaisesRegex(TypeError,
2200 'Cannot overwrite attribute __lt__'
2201 '.*using functools.total_ordering'):
2202 @dataclass(order=True)
2203 class C:
2204 x: int
2205 def __lt__(self):
2206 pass
2207
2208 with self.assertRaisesRegex(TypeError,
2209 'Cannot overwrite attribute __le__'
2210 '.*using functools.total_ordering'):
2211 @dataclass(order=True)
2212 class C:
2213 x: int
2214 def __le__(self):
2215 pass
2216
2217 with self.assertRaisesRegex(TypeError,
2218 'Cannot overwrite attribute __gt__'
2219 '.*using functools.total_ordering'):
2220 @dataclass(order=True)
2221 class C:
2222 x: int
2223 def __gt__(self):
2224 pass
2225
2226 with self.assertRaisesRegex(TypeError,
2227 'Cannot overwrite attribute __ge__'
2228 '.*using functools.total_ordering'):
2229 @dataclass(order=True)
2230 class C:
2231 x: int
2232 def __ge__(self):
2233 pass
2234
2235class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002236 def test_unsafe_hash(self):
2237 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002238 class C:
2239 x: int
2240 y: str
2241 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2242
Eric V. Smithea8fc522018-01-27 19:07:40 -05002243 def test_hash_rules(self):
2244 def non_bool(value):
2245 # Map to something else that's True, but not a bool.
2246 if value is None:
2247 return None
2248 if value:
2249 return (3,)
2250 return 0
2251
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002252 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2253 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2254 frozen=frozen):
2255 if result != 'exception':
2256 if with_hash:
2257 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2258 class C:
2259 def __hash__(self):
2260 return 0
2261 else:
2262 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2263 class C:
2264 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002265
2266 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002267 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002268 # __hash__ contains the function we generated.
2269 self.assertIn('__hash__', C.__dict__)
2270 self.assertIsNotNone(C.__dict__['__hash__'])
2271
Eric V. Smithea8fc522018-01-27 19:07:40 -05002272 elif result == '':
2273 # __hash__ is not present in our class.
2274 if not with_hash:
2275 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002276
Eric V. Smithea8fc522018-01-27 19:07:40 -05002277 elif result == 'none':
2278 # __hash__ is set to None.
2279 self.assertIn('__hash__', C.__dict__)
2280 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002281
2282 elif result == 'exception':
2283 # Creating the class should cause an exception.
2284 # This only happens with with_hash==True.
2285 assert(with_hash)
2286 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2287 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2288 class C:
2289 def __hash__(self):
2290 return 0
2291
Eric V. Smithea8fc522018-01-27 19:07:40 -05002292 else:
2293 assert False, f'unknown result {result!r}'
2294
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002295 # There are 8 cases of:
2296 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002297 # eq=True/False
2298 # frozen=True/False
2299 # And for each of these, a different result if
2300 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002301 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2302 (False, False, False, '', ''),
2303 (False, False, True, '', ''),
2304 (False, True, False, 'none', ''),
2305 (False, True, True, 'fn', ''),
2306 (True, False, False, 'fn', 'exception'),
2307 (True, False, True, 'fn', 'exception'),
2308 (True, True, False, 'fn', 'exception'),
2309 (True, True, True, 'fn', 'exception'),
2310 ], 1):
2311 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2312 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002313
2314 # Test non-bool truth values, too. This is just to
2315 # make sure the data-driven table in the decorator
2316 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002317 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2318 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002319
2320
2321 def test_eq_only(self):
2322 # If a class defines __eq__, __hash__ is automatically added
2323 # and set to None. This is normal Python behavior, not
2324 # related to dataclasses. Make sure we don't interfere with
2325 # that (see bpo=32546).
2326
2327 @dataclass
2328 class C:
2329 i: int
2330 def __eq__(self, other):
2331 return self.i == other.i
2332 self.assertEqual(C(1), C(1))
2333 self.assertNotEqual(C(1), C(4))
2334
2335 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002336 # unsafe_hash=True.
2337 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002338 class C:
2339 i: int
2340 def __eq__(self, other):
2341 return self.i == other.i
2342 self.assertEqual(C(1), C(1.0))
2343 self.assertEqual(hash(C(1)), hash(C(1.0)))
2344
2345 # And check that the classes __eq__ is being used, despite
2346 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002347 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002348 class C:
2349 i: int
2350 def __eq__(self, other):
2351 return self.i == 3 and self.i == other.i
2352 self.assertEqual(C(3), C(3))
2353 self.assertNotEqual(C(1), C(1))
2354 self.assertEqual(hash(C(1)), hash(C(1.0)))
2355
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002356 def test_0_field_hash(self):
2357 @dataclass(frozen=True)
2358 class C:
2359 pass
2360 self.assertEqual(hash(C()), hash(()))
2361
2362 @dataclass(unsafe_hash=True)
2363 class C:
2364 pass
2365 self.assertEqual(hash(C()), hash(()))
2366
2367 def test_1_field_hash(self):
2368 @dataclass(frozen=True)
2369 class C:
2370 x: int
2371 self.assertEqual(hash(C(4)), hash((4,)))
2372 self.assertEqual(hash(C(42)), hash((42,)))
2373
2374 @dataclass(unsafe_hash=True)
2375 class C:
2376 x: int
2377 self.assertEqual(hash(C(4)), hash((4,)))
2378 self.assertEqual(hash(C(42)), hash((42,)))
2379
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002380 def test_hash_no_args(self):
2381 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002382 # make sure that if the @dataclass parameter name is changed
2383 # or the non-default hashing behavior changes, the default
2384 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002385
2386 class Base:
2387 def __hash__(self):
2388 return 301
2389
2390 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)1a579062018-02-25 19:09:05 -08002391 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002392 for frozen, eq, base, expected in [
2393 (None, None, object, 'unhashable'),
2394 (None, None, Base, 'unhashable'),
2395 (None, False, object, 'object'),
2396 (None, False, Base, 'base'),
2397 (None, True, object, 'unhashable'),
2398 (None, True, Base, 'unhashable'),
2399 (False, None, object, 'unhashable'),
2400 (False, None, Base, 'unhashable'),
2401 (False, False, object, 'object'),
2402 (False, False, Base, 'base'),
2403 (False, True, object, 'unhashable'),
2404 (False, True, Base, 'unhashable'),
2405 (True, None, object, 'tuple'),
2406 (True, None, Base, 'tuple'),
2407 (True, False, object, 'object'),
2408 (True, False, Base, 'base'),
2409 (True, True, object, 'tuple'),
2410 (True, True, Base, 'tuple'),
2411 ]:
2412
2413 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2414 # First, create the class.
2415 if frozen is None and eq is None:
2416 @dataclass
2417 class C(base):
2418 i: int
2419 elif frozen is None:
2420 @dataclass(eq=eq)
2421 class C(base):
2422 i: int
2423 elif eq is None:
2424 @dataclass(frozen=frozen)
2425 class C(base):
2426 i: int
2427 else:
2428 @dataclass(frozen=frozen, eq=eq)
2429 class C(base):
2430 i: int
2431
2432 # Now, make sure it hashes as expected.
2433 if expected == 'unhashable':
2434 c = C(10)
2435 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2436 hash(c)
2437
2438 elif expected == 'base':
2439 self.assertEqual(hash(C(10)), 301)
2440
2441 elif expected == 'object':
2442 # I'm not sure what test to use here. object's
2443 # hash isn't based on id(), so calling hash()
2444 # won't tell us much. So, just check the function
2445 # used is object's.
2446 self.assertIs(C.__hash__, object.__hash__)
2447
2448 elif expected == 'tuple':
2449 self.assertEqual(hash(C(42)), hash((42,)))
2450
2451 else:
2452 assert False, f'unknown value for expected={expected!r}'
2453
Eric V. Smithea8fc522018-01-27 19:07:40 -05002454
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002455class TestFrozen(unittest.TestCase):
2456 def test_frozen(self):
2457 @dataclass(frozen=True)
2458 class C:
2459 i: int
2460
2461 c = C(10)
2462 self.assertEqual(c.i, 10)
2463 with self.assertRaises(FrozenInstanceError):
2464 c.i = 5
2465 self.assertEqual(c.i, 10)
2466
2467 def test_inherit(self):
2468 @dataclass(frozen=True)
2469 class C:
2470 i: int
2471
2472 @dataclass(frozen=True)
2473 class D(C):
2474 j: int
2475
2476 d = D(0, 10)
2477 with self.assertRaises(FrozenInstanceError):
2478 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002479 with self.assertRaises(FrozenInstanceError):
2480 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002481 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002482 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002483
Miss Islington (bot)45648312018-03-18 18:03:36 -07002484 # Test both ways: with an intermediate normal (non-dataclass)
2485 # class and without an intermediate class.
2486 def test_inherit_nonfrozen_from_frozen(self):
2487 for intermediate_class in [True, False]:
2488 with self.subTest(intermediate_class=intermediate_class):
2489 @dataclass(frozen=True)
2490 class C:
2491 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002492
Miss Islington (bot)45648312018-03-18 18:03:36 -07002493 if intermediate_class:
2494 class I(C): pass
2495 else:
2496 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002497
Miss Islington (bot)45648312018-03-18 18:03:36 -07002498 with self.assertRaisesRegex(TypeError,
2499 'cannot inherit non-frozen dataclass from a frozen one'):
2500 @dataclass
2501 class D(I):
2502 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002503
Miss Islington (bot)45648312018-03-18 18:03:36 -07002504 def test_inherit_frozen_from_nonfrozen(self):
2505 for intermediate_class in [True, False]:
2506 with self.subTest(intermediate_class=intermediate_class):
2507 @dataclass
2508 class C:
2509 i: int
2510
2511 if intermediate_class:
2512 class I(C): pass
2513 else:
2514 I = C
2515
2516 with self.assertRaisesRegex(TypeError,
2517 'cannot inherit frozen dataclass from a non-frozen one'):
2518 @dataclass(frozen=True)
2519 class D(I):
2520 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002521
2522 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002523 for intermediate_class in [True, False]:
2524 with self.subTest(intermediate_class=intermediate_class):
2525 class C:
2526 pass
2527
2528 if intermediate_class:
2529 class I(C): pass
2530 else:
2531 I = C
2532
2533 @dataclass(frozen=True)
2534 class D(I):
2535 i: int
2536
2537 d = D(10)
2538 with self.assertRaises(FrozenInstanceError):
2539 d.i = 5
2540
2541 def test_non_frozen_normal_derived(self):
2542 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002543
2544 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002545 class D:
2546 x: int
2547 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002548
Miss Islington (bot)45648312018-03-18 18:03:36 -07002549 class S(D):
2550 pass
2551
2552 s = S(3)
2553 self.assertEqual(s.x, 3)
2554 self.assertEqual(s.y, 10)
2555 s.cached = True
2556
2557 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002558 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002559 s.x = 5
2560 with self.assertRaises(FrozenInstanceError):
2561 s.y = 5
2562 self.assertEqual(s.x, 3)
2563 self.assertEqual(s.y, 10)
2564 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002565
2566
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002567class TestSlots(unittest.TestCase):
2568 def test_simple(self):
2569 @dataclass
2570 class C:
2571 __slots__ = ('x',)
2572 x: Any
2573
2574 # There was a bug where a variable in a slot was assumed
2575 # to also have a default value (of type types.MemberDescriptorType).
2576 with self.assertRaisesRegex(TypeError,
2577 "__init__\(\) missing 1 required positional argument: 'x'"):
2578 C()
2579
2580 # We can create an instance, and assign to x.
2581 c = C(10)
2582 self.assertEqual(c.x, 10)
2583 c.x = 5
2584 self.assertEqual(c.x, 5)
2585
2586 # We can't assign to anything else.
2587 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2588 c.y = 5
2589
2590 def test_derived_added_field(self):
2591 # See bpo-33100.
2592 @dataclass
2593 class Base:
2594 __slots__ = ('x',)
2595 x: Any
2596
2597 @dataclass
2598 class Derived(Base):
2599 x: int
2600 y: int
2601
2602 d = Derived(1, 2)
2603 self.assertEqual((d.x, d.y), (1, 2))
2604
2605 # We can add a new field to the derived instance.
2606 d.z = 10
2607
2608
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002609if __name__ == '__main__':
2610 unittest.main()