blob: 582cb3459f5ddf319bddd8c49578ac54fdad75e3 [file] [log] [blame]
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001from dataclasses import (
2 dataclass, field, FrozenInstanceError, fields, asdict, astuple,
Eric V. Smithe7ba0132018-01-06 12:41:53 -05003 make_dataclass, replace, InitVar, Field, MISSING, is_dataclass,
Eric V. Smithf0db54a2017-12-04 16:58:55 -05004)
5
6import pickle
7import inspect
8import unittest
9from unittest.mock import Mock
10from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
11from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050012from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013
14# Just any custom exception we can catch.
15class CustomError(Exception): pass
16
17class TestCase(unittest.TestCase):
18 def test_no_fields(self):
19 @dataclass
20 class C:
21 pass
22
23 o = C()
24 self.assertEqual(len(fields(C)), 0)
25
26 def test_one_field_no_default(self):
27 @dataclass
28 class C:
29 x: int
30
31 o = C(42)
32 self.assertEqual(o.x, 42)
33
34 def test_named_init_params(self):
35 @dataclass
36 class C:
37 x: int
38
39 o = C(x=32)
40 self.assertEqual(o.x, 32)
41
42 def test_two_fields_one_default(self):
43 @dataclass
44 class C:
45 x: int
46 y: int = 0
47
48 o = C(3)
49 self.assertEqual((o.x, o.y), (3, 0))
50
51 # Non-defaults following defaults.
52 with self.assertRaisesRegex(TypeError,
53 "non-default argument 'y' follows "
54 "default argument"):
55 @dataclass
56 class C:
57 x: int = 0
58 y: int
59
60 # A derived class adds a non-default field after a default one.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class B:
66 x: int = 0
67
68 @dataclass
69 class C(B):
70 y: int
71
72 # Override a base class field and add a default to
73 # a field which didn't use to have a default.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int
80 y: int
81
82 @dataclass
83 class C(B):
84 x: int = 0
85
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080086 def test_overwrite_hash(self):
87 # Test that declaring this class isn't an error. It should
88 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050089 @dataclass(frozen=True)
90 class C:
91 x: int
92 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080093 return 301
94 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -050095
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080096 # Test that declaring this class isn't an error. It should
97 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800101 def __eq__(self, other):
102 return False
103 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # But this one should generate an exception, because with
106 # unsafe_hash=True, it's an error to have a __hash__ defined.
107 with self.assertRaisesRegex(TypeError,
108 'Cannot overwrite attribute __hash__'):
109 @dataclass(unsafe_hash=True)
110 class C:
111 def __hash__(self):
112 pass
113
114 # Creating this class should not generate an exception,
115 # because even though __hash__ exists before @dataclass is
116 # called, (due to __eq__ being defined), since it's None
117 # that's okay.
118 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500119 class C:
120 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800121 def __eq__(self):
122 pass
123 # The generated hash function works as we'd expect.
124 self.assertEqual(hash(C(10)), hash((10,)))
125
126 # Creating this class should generate an exception, because
127 # __hash__ exists and is not None, which it would be if it had
128 # been auto-generated do due __eq__ being defined.
129 with self.assertRaisesRegex(TypeError,
130 'Cannot overwrite attribute __hash__'):
131 @dataclass(unsafe_hash=True)
132 class C:
133 x: int
134 def __eq__(self):
135 pass
136 def __hash__(self):
137 pass
138
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500139
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500140 def test_overwrite_fields_in_derived_class(self):
141 # Note that x from C1 replaces x in Base, but the order remains
142 # the same as defined in Base.
143 @dataclass
144 class Base:
145 x: Any = 15.0
146 y: int = 0
147
148 @dataclass
149 class C1(Base):
150 z: int = 10
151 x: int = 15
152
153 o = Base()
154 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
155
156 o = C1()
157 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
158
159 o = C1(x=5)
160 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
161
162 def test_field_named_self(self):
163 @dataclass
164 class C:
165 self: str
166 c=C('foo')
167 self.assertEqual(c.self, 'foo')
168
169 # Make sure the first parameter is not named 'self'.
170 sig = inspect.signature(C.__init__)
171 first = next(iter(sig.parameters))
172 self.assertNotEqual('self', first)
173
174 # But we do use 'self' if no field named self.
175 @dataclass
176 class C:
177 selfx: str
178
179 # Make sure the first parameter is named 'self'.
180 sig = inspect.signature(C.__init__)
181 first = next(iter(sig.parameters))
182 self.assertEqual('self', first)
183
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500184 def test_0_field_compare(self):
185 # Ensure that order=False is the default.
186 @dataclass
187 class C0:
188 pass
189
190 @dataclass(order=False)
191 class C1:
192 pass
193
194 for cls in [C0, C1]:
195 with self.subTest(cls=cls):
196 self.assertEqual(cls(), cls())
197 for idx, fn in enumerate([lambda a, b: a < b,
198 lambda a, b: a <= b,
199 lambda a, b: a > b,
200 lambda a, b: a >= b]):
201 with self.subTest(idx=idx):
202 with self.assertRaisesRegex(TypeError,
203 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
204 fn(cls(), cls())
205
206 @dataclass(order=True)
207 class C:
208 pass
209 self.assertLessEqual(C(), C())
210 self.assertGreaterEqual(C(), C())
211
212 def test_1_field_compare(self):
213 # Ensure that order=False is the default.
214 @dataclass
215 class C0:
216 x: int
217
218 @dataclass(order=False)
219 class C1:
220 x: int
221
222 for cls in [C0, C1]:
223 with self.subTest(cls=cls):
224 self.assertEqual(cls(1), cls(1))
225 self.assertNotEqual(cls(0), cls(1))
226 for idx, fn in enumerate([lambda a, b: a < b,
227 lambda a, b: a <= b,
228 lambda a, b: a > b,
229 lambda a, b: a >= b]):
230 with self.subTest(idx=idx):
231 with self.assertRaisesRegex(TypeError,
232 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
233 fn(cls(0), cls(0))
234
235 @dataclass(order=True)
236 class C:
237 x: int
238 self.assertLess(C(0), C(1))
239 self.assertLessEqual(C(0), C(1))
240 self.assertLessEqual(C(1), C(1))
241 self.assertGreater(C(1), C(0))
242 self.assertGreaterEqual(C(1), C(0))
243 self.assertGreaterEqual(C(1), C(1))
244
245 def test_simple_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 x: int
250 y: int
251
252 @dataclass(order=False)
253 class C1:
254 x: int
255 y: int
256
257 for cls in [C0, C1]:
258 with self.subTest(cls=cls):
259 self.assertEqual(cls(0, 0), cls(0, 0))
260 self.assertEqual(cls(1, 2), cls(1, 2))
261 self.assertNotEqual(cls(1, 0), cls(0, 0))
262 self.assertNotEqual(cls(1, 0), cls(1, 1))
263 for idx, fn in enumerate([lambda a, b: a < b,
264 lambda a, b: a <= b,
265 lambda a, b: a > b,
266 lambda a, b: a >= b]):
267 with self.subTest(idx=idx):
268 with self.assertRaisesRegex(TypeError,
269 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
270 fn(cls(0, 0), cls(0, 0))
271
272 @dataclass(order=True)
273 class C:
274 x: int
275 y: int
276
277 for idx, fn in enumerate([lambda a, b: a == b,
278 lambda a, b: a <= b,
279 lambda a, b: a >= b]):
280 with self.subTest(idx=idx):
281 self.assertTrue(fn(C(0, 0), C(0, 0)))
282
283 for idx, fn in enumerate([lambda a, b: a < b,
284 lambda a, b: a <= b,
285 lambda a, b: a != b]):
286 with self.subTest(idx=idx):
287 self.assertTrue(fn(C(0, 0), C(0, 1)))
288 self.assertTrue(fn(C(0, 1), C(1, 0)))
289 self.assertTrue(fn(C(1, 0), C(1, 1)))
290
291 for idx, fn in enumerate([lambda a, b: a > b,
292 lambda a, b: a >= b,
293 lambda a, b: a != b]):
294 with self.subTest(idx=idx):
295 self.assertTrue(fn(C(0, 1), C(0, 0)))
296 self.assertTrue(fn(C(1, 0), C(0, 1)))
297 self.assertTrue(fn(C(1, 1), C(1, 0)))
298
299 def test_compare_subclasses(self):
300 # Comparisons fail for subclasses, even if no fields
301 # are added.
302 @dataclass
303 class B:
304 i: int
305
306 @dataclass
307 class C(B):
308 pass
309
310 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
311 (lambda a, b: a != b, True)]):
312 with self.subTest(idx=idx):
313 self.assertEqual(fn(B(0), C(0)), expected)
314
315 for idx, fn in enumerate([lambda a, b: a < b,
316 lambda a, b: a <= b,
317 lambda a, b: a > b,
318 lambda a, b: a >= b]):
319 with self.subTest(idx=idx):
320 with self.assertRaisesRegex(TypeError,
321 "not supported between instances of 'B' and 'C'"):
322 fn(B(0), C(0))
323
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500324 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500325 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500326 for (eq, order, result ) in [
327 (False, False, 'neither'),
328 (False, True, 'exception'),
329 (True, False, 'eq_only'),
330 (True, True, 'both'),
331 ]:
332 with self.subTest(eq=eq, order=order):
333 if result == 'exception':
334 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
335 @dataclass(eq=eq, order=order)
336 class C:
337 pass
338 else:
339 @dataclass(eq=eq, order=order)
340 class C:
341 pass
342
343 if result == 'neither':
344 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500345 self.assertNotIn('__lt__', C.__dict__)
346 self.assertNotIn('__le__', C.__dict__)
347 self.assertNotIn('__gt__', C.__dict__)
348 self.assertNotIn('__ge__', C.__dict__)
349 elif result == 'both':
350 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500351 self.assertIn('__lt__', C.__dict__)
352 self.assertIn('__le__', C.__dict__)
353 self.assertIn('__gt__', C.__dict__)
354 self.assertIn('__ge__', C.__dict__)
355 elif result == 'eq_only':
356 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500357 self.assertNotIn('__lt__', C.__dict__)
358 self.assertNotIn('__le__', C.__dict__)
359 self.assertNotIn('__gt__', C.__dict__)
360 self.assertNotIn('__ge__', C.__dict__)
361 else:
362 assert False, f'unknown result {result!r}'
363
364 def test_field_no_default(self):
365 @dataclass
366 class C:
367 x: int = field()
368
369 self.assertEqual(C(5).x, 5)
370
371 with self.assertRaisesRegex(TypeError,
372 r"__init__\(\) missing 1 required "
373 "positional argument: 'x'"):
374 C()
375
376 def test_field_default(self):
377 default = object()
378 @dataclass
379 class C:
380 x: object = field(default=default)
381
382 self.assertIs(C.x, default)
383 c = C(10)
384 self.assertEqual(c.x, 10)
385
386 # If we delete the instance attribute, we should then see the
387 # class attribute.
388 del c.x
389 self.assertIs(c.x, default)
390
391 self.assertIs(C().x, default)
392
393 def test_not_in_repr(self):
394 @dataclass
395 class C:
396 x: int = field(repr=False)
397 with self.assertRaises(TypeError):
398 C()
399 c = C(10)
400 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
401
402 @dataclass
403 class C:
404 x: int = field(repr=False)
405 y: int
406 c = C(10, 20)
407 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
408
409 def test_not_in_compare(self):
410 @dataclass
411 class C:
412 x: int = 0
413 y: int = field(compare=False, default=4)
414
415 self.assertEqual(C(), C(0, 20))
416 self.assertEqual(C(1, 10), C(1, 20))
417 self.assertNotEqual(C(3), C(4, 10))
418 self.assertNotEqual(C(3, 10), C(4, 10))
419
420 def test_hash_field_rules(self):
421 # Test all 6 cases of:
422 # hash=True/False/None
423 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800424 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500425 (True, False, 'field' ),
426 (True, True, 'field' ),
427 (False, False, 'absent'),
428 (False, True, 'absent'),
429 (None, False, 'absent'),
430 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800431 ]:
432 with self.subTest(hash=hash_, compare=compare):
433 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500434 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800435 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500436
437 if result == 'field':
438 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800439 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500440 elif result == 'absent':
441 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800442 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500443 else:
444 assert False, f'unknown result {result!r}'
445
446 def test_init_false_no_default(self):
447 # If init=False and no default value, then the field won't be
448 # present in the instance.
449 @dataclass
450 class C:
451 x: int = field(init=False)
452
453 self.assertNotIn('x', C().__dict__)
454
455 @dataclass
456 class C:
457 x: int
458 y: int = 0
459 z: int = field(init=False)
460 t: int = 10
461
462 self.assertNotIn('z', C(0).__dict__)
463 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
464
465 def test_class_marker(self):
466 @dataclass
467 class C:
468 x: int
469 y: str = field(init=False, default=None)
470 z: str = field(repr=False)
471
472 the_fields = fields(C)
473 # the_fields is a tuple of 3 items, each value
474 # is in __annotations__.
475 self.assertIsInstance(the_fields, tuple)
476 for f in the_fields:
477 self.assertIs(type(f), Field)
478 self.assertIn(f.name, C.__annotations__)
479
480 self.assertEqual(len(the_fields), 3)
481
482 self.assertEqual(the_fields[0].name, 'x')
483 self.assertEqual(the_fields[0].type, int)
484 self.assertFalse(hasattr(C, 'x'))
485 self.assertTrue (the_fields[0].init)
486 self.assertTrue (the_fields[0].repr)
487 self.assertEqual(the_fields[1].name, 'y')
488 self.assertEqual(the_fields[1].type, str)
489 self.assertIsNone(getattr(C, 'y'))
490 self.assertFalse(the_fields[1].init)
491 self.assertTrue (the_fields[1].repr)
492 self.assertEqual(the_fields[2].name, 'z')
493 self.assertEqual(the_fields[2].type, str)
494 self.assertFalse(hasattr(C, 'z'))
495 self.assertTrue (the_fields[2].init)
496 self.assertFalse(the_fields[2].repr)
497
498 def test_field_order(self):
499 @dataclass
500 class B:
501 a: str = 'B:a'
502 b: str = 'B:b'
503 c: str = 'B:c'
504
505 @dataclass
506 class C(B):
507 b: str = 'C:b'
508
509 self.assertEqual([(f.name, f.default) for f in fields(C)],
510 [('a', 'B:a'),
511 ('b', 'C:b'),
512 ('c', 'B:c')])
513
514 @dataclass
515 class D(B):
516 c: str = 'D:c'
517
518 self.assertEqual([(f.name, f.default) for f in fields(D)],
519 [('a', 'B:a'),
520 ('b', 'B:b'),
521 ('c', 'D:c')])
522
523 @dataclass
524 class E(D):
525 a: str = 'E:a'
526 d: str = 'E:d'
527
528 self.assertEqual([(f.name, f.default) for f in fields(E)],
529 [('a', 'E:a'),
530 ('b', 'B:b'),
531 ('c', 'D:c'),
532 ('d', 'E:d')])
533
534 def test_class_attrs(self):
535 # We only have a class attribute if a default value is
536 # specified, either directly or via a field with a default.
537 default = object()
538 @dataclass
539 class C:
540 x: int
541 y: int = field(repr=False)
542 z: object = default
543 t: int = field(default=100)
544
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertFalse(hasattr(C, 'y'))
547 self.assertIs (C.z, default)
548 self.assertEqual(C.t, 100)
549
550 def test_disallowed_mutable_defaults(self):
551 # For the known types, don't allow mutable default values.
552 for typ, empty, non_empty in [(list, [], [1]),
553 (dict, {}, {0:1}),
554 (set, set(), set([1])),
555 ]:
556 with self.subTest(typ=typ):
557 # Can't use a zero-length value.
558 with self.assertRaisesRegex(ValueError,
559 f'mutable default {typ} for field '
560 'x is not allowed'):
561 @dataclass
562 class Point:
563 x: typ = empty
564
565
566 # Nor a non-zero-length value
567 with self.assertRaisesRegex(ValueError,
568 f'mutable default {typ} for field '
569 'y is not allowed'):
570 @dataclass
571 class Point:
572 y: typ = non_empty
573
574 # Check subtypes also fail.
575 class Subclass(typ): pass
576
577 with self.assertRaisesRegex(ValueError,
578 f"mutable default .*Subclass'>"
579 ' for field z is not allowed'
580 ):
581 @dataclass
582 class Point:
583 z: typ = Subclass()
584
585 # Because this is a ClassVar, it can be mutable.
586 @dataclass
587 class C:
588 z: ClassVar[typ] = typ()
589
590 # Because this is a ClassVar, it can be mutable.
591 @dataclass
592 class C:
593 x: ClassVar[typ] = Subclass()
594
595
596 def test_deliberately_mutable_defaults(self):
597 # If a mutable default isn't in the known list of
598 # (list, dict, set), then it's okay.
599 class Mutable:
600 def __init__(self):
601 self.l = []
602
603 @dataclass
604 class C:
605 x: Mutable
606
607 # These 2 instances will share this value of x.
608 lst = Mutable()
609 o1 = C(lst)
610 o2 = C(lst)
611 self.assertEqual(o1, o2)
612 o1.x.l.extend([1, 2])
613 self.assertEqual(o1, o2)
614 self.assertEqual(o1.x.l, [1, 2])
615 self.assertIs(o1.x, o2.x)
616
617 def test_no_options(self):
618 # call with dataclass()
619 @dataclass()
620 class C:
621 x: int
622
623 self.assertEqual(C(42).x, 42)
624
625 def test_not_tuple(self):
626 # Make sure we can't be compared to a tuple.
627 @dataclass
628 class Point:
629 x: int
630 y: int
631 self.assertNotEqual(Point(1, 2), (1, 2))
632
633 # And that we can't compare to another unrelated dataclass
634 @dataclass
635 class C:
636 x: int
637 y: int
638 self.assertNotEqual(Point(1, 3), C(1, 3))
639
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500640 def test_frozen(self):
641 @dataclass(frozen=True)
642 class C:
643 i: int
644
645 c = C(10)
646 self.assertEqual(c.i, 10)
647 with self.assertRaises(FrozenInstanceError):
648 c.i = 5
649 self.assertEqual(c.i, 10)
650
651 # Check that a derived class is still frozen, even if not
652 # marked so.
653 @dataclass
654 class D(C):
655 pass
656
657 d = D(20)
658 self.assertEqual(d.i, 20)
659 with self.assertRaises(FrozenInstanceError):
660 d.i = 5
661 self.assertEqual(d.i, 20)
662
663 def test_not_tuple(self):
664 # Test that some of the problems with namedtuple don't happen
665 # here.
666 @dataclass
667 class Point3D:
668 x: int
669 y: int
670 z: int
671
672 @dataclass
673 class Date:
674 year: int
675 month: int
676 day: int
677
678 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
679 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
680
681 # Make sure we can't unpack
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200682 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500683 x, y, z = Point3D(4, 5, 6)
684
Eric V. Smith7c99e932018-01-28 19:18:55 -0500685 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500686 # equal.
687 @dataclass
688 class Point3Dv1:
689 x: int = 0
690 y: int = 0
691 z: int = 0
692 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
693
694 def test_function_annotations(self):
695 # Some dummy class and instance to use as a default.
696 class F:
697 pass
698 f = F()
699
700 def validate_class(cls):
701 # First, check __annotations__, even though they're not
702 # function annotations.
703 self.assertEqual(cls.__annotations__['i'], int)
704 self.assertEqual(cls.__annotations__['j'], str)
705 self.assertEqual(cls.__annotations__['k'], F)
706 self.assertEqual(cls.__annotations__['l'], float)
707 self.assertEqual(cls.__annotations__['z'], complex)
708
709 # Verify __init__.
710
711 signature = inspect.signature(cls.__init__)
712 # Check the return type, should be None
713 self.assertIs(signature.return_annotation, None)
714
715 # Check each parameter.
716 params = iter(signature.parameters.values())
717 param = next(params)
718 # This is testing an internal name, and probably shouldn't be tested.
719 self.assertEqual(param.name, 'self')
720 param = next(params)
721 self.assertEqual(param.name, 'i')
722 self.assertIs (param.annotation, int)
723 self.assertEqual(param.default, inspect.Parameter.empty)
724 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
725 param = next(params)
726 self.assertEqual(param.name, 'j')
727 self.assertIs (param.annotation, str)
728 self.assertEqual(param.default, inspect.Parameter.empty)
729 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
730 param = next(params)
731 self.assertEqual(param.name, 'k')
732 self.assertIs (param.annotation, F)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500733 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500734 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
735 param = next(params)
736 self.assertEqual(param.name, 'l')
737 self.assertIs (param.annotation, float)
Eric V. Smith03220fd2017-12-29 13:59:58 -0500738 # Don't test for the default, since it's set to MISSING
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500739 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
740 self.assertRaises(StopIteration, next, params)
741
742
743 @dataclass
744 class C:
745 i: int
746 j: str
747 k: F = f
748 l: float=field(default=None)
749 z: complex=field(default=3+4j, init=False)
750
751 validate_class(C)
752
753 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800754 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500755 class C:
756 i: int
757 j: str
758 k: F = f
759 l: float=field(default=None)
760 z: complex=field(default=3+4j, init=False)
761
762 validate_class(C)
763
Eric V. Smith03220fd2017-12-29 13:59:58 -0500764 def test_missing_default(self):
765 # Test that MISSING works the same as a default not being
766 # specified.
767 @dataclass
768 class C:
769 x: int=field(default=MISSING)
770 with self.assertRaisesRegex(TypeError,
771 r'__init__\(\) missing 1 required '
772 'positional argument'):
773 C()
774 self.assertNotIn('x', C.__dict__)
775
776 @dataclass
777 class D:
778 x: int
779 with self.assertRaisesRegex(TypeError,
780 r'__init__\(\) missing 1 required '
781 'positional argument'):
782 D()
783 self.assertNotIn('x', D.__dict__)
784
785 def test_missing_default_factory(self):
786 # Test that MISSING works the same as a default factory not
787 # being specified (which is really the same as a default not
788 # being specified, too).
789 @dataclass
790 class C:
791 x: int=field(default_factory=MISSING)
792 with self.assertRaisesRegex(TypeError,
793 r'__init__\(\) missing 1 required '
794 'positional argument'):
795 C()
796 self.assertNotIn('x', C.__dict__)
797
798 @dataclass
799 class D:
800 x: int=field(default=MISSING, default_factory=MISSING)
801 with self.assertRaisesRegex(TypeError,
802 r'__init__\(\) missing 1 required '
803 'positional argument'):
804 D()
805 self.assertNotIn('x', D.__dict__)
806
807 def test_missing_repr(self):
808 self.assertIn('MISSING_TYPE object', repr(MISSING))
809
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500810 def test_dont_include_other_annotations(self):
811 @dataclass
812 class C:
813 i: int
814 def foo(self) -> int:
815 return 4
816 @property
817 def bar(self) -> int:
818 return 5
819 self.assertEqual(list(C.__annotations__), ['i'])
820 self.assertEqual(C(10).foo(), 4)
821 self.assertEqual(C(10).bar, 5)
822
823 def test_post_init(self):
824 # Just make sure it gets called
825 @dataclass
826 class C:
827 def __post_init__(self):
828 raise CustomError()
829 with self.assertRaises(CustomError):
830 C()
831
832 @dataclass
833 class C:
834 i: int = 10
835 def __post_init__(self):
836 if self.i == 10:
837 raise CustomError()
838 with self.assertRaises(CustomError):
839 C()
840 # post-init gets called, but doesn't raise. This is just
841 # checking that self is used correctly.
842 C(5)
843
844 # If there's not an __init__, then post-init won't get called.
845 @dataclass(init=False)
846 class C:
847 def __post_init__(self):
848 raise CustomError()
849 # Creating the class won't raise
850 C()
851
852 @dataclass
853 class C:
854 x: int = 0
855 def __post_init__(self):
856 self.x *= 2
857 self.assertEqual(C().x, 0)
858 self.assertEqual(C(2).x, 4)
859
Mike53f7a7c2017-12-14 14:04:53 +0300860 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500861 # attributes.
862 @dataclass(frozen=True)
863 class C:
864 x: int = 0
865 def __post_init__(self):
866 self.x *= 2
867 with self.assertRaises(FrozenInstanceError):
868 C()
869
870 def test_post_init_super(self):
871 # Make sure super() post-init isn't called by default.
872 class B:
873 def __post_init__(self):
874 raise CustomError()
875
876 @dataclass
877 class C(B):
878 def __post_init__(self):
879 self.x = 5
880
881 self.assertEqual(C().x, 5)
882
883 # Now call super(), and it will raise
884 @dataclass
885 class C(B):
886 def __post_init__(self):
887 super().__post_init__()
888
889 with self.assertRaises(CustomError):
890 C()
891
892 # Make sure post-init is called, even if not defined in our
893 # class.
894 @dataclass
895 class C(B):
896 pass
897
898 with self.assertRaises(CustomError):
899 C()
900
901 def test_post_init_staticmethod(self):
902 flag = False
903 @dataclass
904 class C:
905 x: int
906 y: int
907 @staticmethod
908 def __post_init__():
909 nonlocal flag
910 flag = True
911
912 self.assertFalse(flag)
913 c = C(3, 4)
914 self.assertEqual((c.x, c.y), (3, 4))
915 self.assertTrue(flag)
916
917 def test_post_init_classmethod(self):
918 @dataclass
919 class C:
920 flag = False
921 x: int
922 y: int
923 @classmethod
924 def __post_init__(cls):
925 cls.flag = True
926
927 self.assertFalse(C.flag)
928 c = C(3, 4)
929 self.assertEqual((c.x, c.y), (3, 4))
930 self.assertTrue(C.flag)
931
932 def test_class_var(self):
933 # Make sure ClassVars are ignored in __init__, __repr__, etc.
934 @dataclass
935 class C:
936 x: int
937 y: int = 10
938 z: ClassVar[int] = 1000
939 w: ClassVar[int] = 2000
940 t: ClassVar[int] = 3000
941
942 c = C(5)
943 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
944 self.assertEqual(len(fields(C)), 2) # We have 2 fields
945 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
946 self.assertEqual(c.z, 1000)
947 self.assertEqual(c.w, 2000)
948 self.assertEqual(c.t, 3000)
949 C.z += 1
950 self.assertEqual(c.z, 1001)
951 c = C(20)
952 self.assertEqual((c.x, c.y), (20, 10))
953 self.assertEqual(c.z, 1001)
954 self.assertEqual(c.w, 2000)
955 self.assertEqual(c.t, 3000)
956
957 def test_class_var_no_default(self):
958 # If a ClassVar has no default value, it should not be set on the class.
959 @dataclass
960 class C:
961 x: ClassVar[int]
962
963 self.assertNotIn('x', C.__dict__)
964
965 def test_class_var_default_factory(self):
966 # It makes no sense for a ClassVar to have a default factory. When
967 # would it be called? Call it yourself, since it's class-wide.
968 with self.assertRaisesRegex(TypeError,
969 'cannot have a default factory'):
970 @dataclass
971 class C:
972 x: ClassVar[int] = field(default_factory=int)
973
974 self.assertNotIn('x', C.__dict__)
975
976 def test_class_var_with_default(self):
977 # If a ClassVar has a default value, it should be set on the class.
978 @dataclass
979 class C:
980 x: ClassVar[int] = 10
981 self.assertEqual(C.x, 10)
982
983 @dataclass
984 class C:
985 x: ClassVar[int] = field(default=10)
986 self.assertEqual(C.x, 10)
987
988 def test_class_var_frozen(self):
989 # Make sure ClassVars work even if we're frozen.
990 @dataclass(frozen=True)
991 class C:
992 x: int
993 y: int = 10
994 z: ClassVar[int] = 1000
995 w: ClassVar[int] = 2000
996 t: ClassVar[int] = 3000
997
998 c = C(5)
999 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1000 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1001 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1002 self.assertEqual(c.z, 1000)
1003 self.assertEqual(c.w, 2000)
1004 self.assertEqual(c.t, 3000)
1005 # We can still modify the ClassVar, it's only instances that are
1006 # frozen.
1007 C.z += 1
1008 self.assertEqual(c.z, 1001)
1009 c = C(20)
1010 self.assertEqual((c.x, c.y), (20, 10))
1011 self.assertEqual(c.z, 1001)
1012 self.assertEqual(c.w, 2000)
1013 self.assertEqual(c.t, 3000)
1014
1015 def test_init_var_no_default(self):
1016 # If an InitVar has no default value, it should not be set on the class.
1017 @dataclass
1018 class C:
1019 x: InitVar[int]
1020
1021 self.assertNotIn('x', C.__dict__)
1022
1023 def test_init_var_default_factory(self):
1024 # It makes no sense for an InitVar to have a default factory. When
1025 # would it be called? Call it yourself, since it's class-wide.
1026 with self.assertRaisesRegex(TypeError,
1027 'cannot have a default factory'):
1028 @dataclass
1029 class C:
1030 x: InitVar[int] = field(default_factory=int)
1031
1032 self.assertNotIn('x', C.__dict__)
1033
1034 def test_init_var_with_default(self):
1035 # If an InitVar has a default value, it should be set on the class.
1036 @dataclass
1037 class C:
1038 x: InitVar[int] = 10
1039 self.assertEqual(C.x, 10)
1040
1041 @dataclass
1042 class C:
1043 x: InitVar[int] = field(default=10)
1044 self.assertEqual(C.x, 10)
1045
1046 def test_init_var(self):
1047 @dataclass
1048 class C:
1049 x: int = None
1050 init_param: InitVar[int] = None
1051
1052 def __post_init__(self, init_param):
1053 if self.x is None:
1054 self.x = init_param*2
1055
1056 c = C(init_param=10)
1057 self.assertEqual(c.x, 20)
1058
1059 def test_init_var_inheritance(self):
1060 # Note that this deliberately tests that a dataclass need not
1061 # have a __post_init__ function if it has an InitVar field.
1062 # It could just be used in a derived class, as shown here.
1063 @dataclass
1064 class Base:
1065 x: int
1066 init_base: InitVar[int]
1067
1068 # We can instantiate by passing the InitVar, even though
1069 # it's not used.
1070 b = Base(0, 10)
1071 self.assertEqual(vars(b), {'x': 0})
1072
1073 @dataclass
1074 class C(Base):
1075 y: int
1076 init_derived: InitVar[int]
1077
1078 def __post_init__(self, init_base, init_derived):
1079 self.x = self.x + init_base
1080 self.y = self.y + init_derived
1081
1082 c = C(10, 11, 50, 51)
1083 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1084
1085 def test_default_factory(self):
1086 # Test a factory that returns a new list.
1087 @dataclass
1088 class C:
1089 x: int
1090 y: list = field(default_factory=list)
1091
1092 c0 = C(3)
1093 c1 = C(3)
1094 self.assertEqual(c0.x, 3)
1095 self.assertEqual(c0.y, [])
1096 self.assertEqual(c0, c1)
1097 self.assertIsNot(c0.y, c1.y)
1098 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1099
1100 # Test a factory that returns a shared list.
1101 l = []
1102 @dataclass
1103 class C:
1104 x: int
1105 y: list = field(default_factory=lambda: l)
1106
1107 c0 = C(3)
1108 c1 = C(3)
1109 self.assertEqual(c0.x, 3)
1110 self.assertEqual(c0.y, [])
1111 self.assertEqual(c0, c1)
1112 self.assertIs(c0.y, c1.y)
1113 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1114
1115 # Test various other field flags.
1116 # repr
1117 @dataclass
1118 class C:
1119 x: list = field(default_factory=list, repr=False)
1120 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1121 self.assertEqual(C().x, [])
1122
1123 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001124 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001125 class C:
1126 x: list = field(default_factory=list, hash=False)
1127 self.assertEqual(astuple(C()), ([],))
1128 self.assertEqual(hash(C()), hash(()))
1129
1130 # init (see also test_default_factory_with_no_init)
1131 @dataclass
1132 class C:
1133 x: list = field(default_factory=list, init=False)
1134 self.assertEqual(astuple(C()), ([],))
1135
1136 # compare
1137 @dataclass
1138 class C:
1139 x: list = field(default_factory=list, compare=False)
1140 self.assertEqual(C(), C([1]))
1141
1142 def test_default_factory_with_no_init(self):
1143 # We need a factory with a side effect.
1144 factory = Mock()
1145
1146 @dataclass
1147 class C:
1148 x: list = field(default_factory=factory, init=False)
1149
1150 # Make sure the default factory is called for each new instance.
1151 C().x
1152 self.assertEqual(factory.call_count, 1)
1153 C().x
1154 self.assertEqual(factory.call_count, 2)
1155
1156 def test_default_factory_not_called_if_value_given(self):
1157 # We need a factory that we can test if it's been called.
1158 factory = Mock()
1159
1160 @dataclass
1161 class C:
1162 x: int = field(default_factory=factory)
1163
1164 # Make sure that if a field has a default factory function,
1165 # it's not called if a value is specified.
1166 C().x
1167 self.assertEqual(factory.call_count, 1)
1168 self.assertEqual(C(10).x, 10)
1169 self.assertEqual(factory.call_count, 1)
1170 C().x
1171 self.assertEqual(factory.call_count, 2)
1172
1173 def x_test_classvar_default_factory(self):
1174 # XXX: it's an error for a ClassVar to have a factory function
1175 @dataclass
1176 class C:
1177 x: ClassVar[int] = field(default_factory=int)
1178
1179 self.assertIs(C().x, int)
1180
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001181 def test_is_dataclass(self):
1182 class NotDataClass:
1183 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001184
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001185 self.assertFalse(is_dataclass(0))
1186 self.assertFalse(is_dataclass(int))
1187 self.assertFalse(is_dataclass(NotDataClass))
1188 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001189
1190 @dataclass
1191 class C:
1192 x: int
1193
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001194 @dataclass
1195 class D:
1196 d: C
1197 e: int
1198
1199 c = C(10)
1200 d = D(c, 4)
1201
1202 self.assertTrue(is_dataclass(C))
1203 self.assertTrue(is_dataclass(c))
1204 self.assertFalse(is_dataclass(c.x))
1205 self.assertTrue(is_dataclass(d.d))
1206 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001207
1208 def test_helper_fields_with_class_instance(self):
1209 # Check that we can call fields() on either a class or instance,
1210 # and get back the same thing.
1211 @dataclass
1212 class C:
1213 x: int
1214 y: float
1215
1216 self.assertEqual(fields(C), fields(C(0, 0.0)))
1217
1218 def test_helper_fields_exception(self):
1219 # Check that TypeError is raised if not passed a dataclass or
1220 # instance.
1221 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1222 fields(0)
1223
1224 class C: pass
1225 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1226 fields(C)
1227 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1228 fields(C())
1229
1230 def test_helper_asdict(self):
1231 # Basic tests for asdict(), it should return a new dictionary
1232 @dataclass
1233 class C:
1234 x: int
1235 y: int
1236 c = C(1, 2)
1237
1238 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1239 self.assertEqual(asdict(c), asdict(c))
1240 self.assertIsNot(asdict(c), asdict(c))
1241 c.x = 42
1242 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1243 self.assertIs(type(asdict(c)), dict)
1244
1245 def test_helper_asdict_raises_on_classes(self):
1246 # asdict() should raise on a class object
1247 @dataclass
1248 class C:
1249 x: int
1250 y: int
1251 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1252 asdict(C)
1253 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1254 asdict(int)
1255
1256 def test_helper_asdict_copy_values(self):
1257 @dataclass
1258 class C:
1259 x: int
1260 y: List[int] = field(default_factory=list)
1261 initial = []
1262 c = C(1, initial)
1263 d = asdict(c)
1264 self.assertEqual(d['y'], initial)
1265 self.assertIsNot(d['y'], initial)
1266 c = C(1)
1267 d = asdict(c)
1268 d['y'].append(1)
1269 self.assertEqual(c.y, [])
1270
1271 def test_helper_asdict_nested(self):
1272 @dataclass
1273 class UserId:
1274 token: int
1275 group: int
1276 @dataclass
1277 class User:
1278 name: str
1279 id: UserId
1280 u = User('Joe', UserId(123, 1))
1281 d = asdict(u)
1282 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1283 self.assertIsNot(asdict(u), asdict(u))
1284 u.id.group = 2
1285 self.assertEqual(asdict(u), {'name': 'Joe',
1286 'id': {'token': 123, 'group': 2}})
1287
1288 def test_helper_asdict_builtin_containers(self):
1289 @dataclass
1290 class User:
1291 name: str
1292 id: int
1293 @dataclass
1294 class GroupList:
1295 id: int
1296 users: List[User]
1297 @dataclass
1298 class GroupTuple:
1299 id: int
1300 users: Tuple[User, ...]
1301 @dataclass
1302 class GroupDict:
1303 id: int
1304 users: Dict[str, User]
1305 a = User('Alice', 1)
1306 b = User('Bob', 2)
1307 gl = GroupList(0, [a, b])
1308 gt = GroupTuple(0, (a, b))
1309 gd = GroupDict(0, {'first': a, 'second': b})
1310 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1311 {'name': 'Bob', 'id': 2}]})
1312 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1313 {'name': 'Bob', 'id': 2})})
1314 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1315 'second': {'name': 'Bob', 'id': 2}}})
1316
1317 def test_helper_asdict_builtin_containers(self):
1318 @dataclass
1319 class Child:
1320 d: object
1321
1322 @dataclass
1323 class Parent:
1324 child: Child
1325
1326 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1327 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1328
1329 def test_helper_asdict_factory(self):
1330 @dataclass
1331 class C:
1332 x: int
1333 y: int
1334 c = C(1, 2)
1335 d = asdict(c, dict_factory=OrderedDict)
1336 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1337 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1338 c.x = 42
1339 d = asdict(c, dict_factory=OrderedDict)
1340 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1341 self.assertIs(type(d), OrderedDict)
1342
1343 def test_helper_astuple(self):
1344 # Basic tests for astuple(), it should return a new tuple
1345 @dataclass
1346 class C:
1347 x: int
1348 y: int = 0
1349 c = C(1)
1350
1351 self.assertEqual(astuple(c), (1, 0))
1352 self.assertEqual(astuple(c), astuple(c))
1353 self.assertIsNot(astuple(c), astuple(c))
1354 c.y = 42
1355 self.assertEqual(astuple(c), (1, 42))
1356 self.assertIs(type(astuple(c)), tuple)
1357
1358 def test_helper_astuple_raises_on_classes(self):
1359 # astuple() should raise on a class object
1360 @dataclass
1361 class C:
1362 x: int
1363 y: int
1364 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1365 astuple(C)
1366 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1367 astuple(int)
1368
1369 def test_helper_astuple_copy_values(self):
1370 @dataclass
1371 class C:
1372 x: int
1373 y: List[int] = field(default_factory=list)
1374 initial = []
1375 c = C(1, initial)
1376 t = astuple(c)
1377 self.assertEqual(t[1], initial)
1378 self.assertIsNot(t[1], initial)
1379 c = C(1)
1380 t = astuple(c)
1381 t[1].append(1)
1382 self.assertEqual(c.y, [])
1383
1384 def test_helper_astuple_nested(self):
1385 @dataclass
1386 class UserId:
1387 token: int
1388 group: int
1389 @dataclass
1390 class User:
1391 name: str
1392 id: UserId
1393 u = User('Joe', UserId(123, 1))
1394 t = astuple(u)
1395 self.assertEqual(t, ('Joe', (123, 1)))
1396 self.assertIsNot(astuple(u), astuple(u))
1397 u.id.group = 2
1398 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1399
1400 def test_helper_astuple_builtin_containers(self):
1401 @dataclass
1402 class User:
1403 name: str
1404 id: int
1405 @dataclass
1406 class GroupList:
1407 id: int
1408 users: List[User]
1409 @dataclass
1410 class GroupTuple:
1411 id: int
1412 users: Tuple[User, ...]
1413 @dataclass
1414 class GroupDict:
1415 id: int
1416 users: Dict[str, User]
1417 a = User('Alice', 1)
1418 b = User('Bob', 2)
1419 gl = GroupList(0, [a, b])
1420 gt = GroupTuple(0, (a, b))
1421 gd = GroupDict(0, {'first': a, 'second': b})
1422 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1423 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1424 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1425
1426 def test_helper_astuple_builtin_containers(self):
1427 @dataclass
1428 class Child:
1429 d: object
1430
1431 @dataclass
1432 class Parent:
1433 child: Child
1434
1435 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1436 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1437
1438 def test_helper_astuple_factory(self):
1439 @dataclass
1440 class C:
1441 x: int
1442 y: int
1443 NT = namedtuple('NT', 'x y')
1444 def nt(lst):
1445 return NT(*lst)
1446 c = C(1, 2)
1447 t = astuple(c, tuple_factory=nt)
1448 self.assertEqual(t, NT(1, 2))
1449 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1450 c.x = 42
1451 t = astuple(c, tuple_factory=nt)
1452 self.assertEqual(t, NT(42, 2))
1453 self.assertIs(type(t), NT)
1454
1455 def test_dynamic_class_creation(self):
1456 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1457 }
1458
1459 # Create the class.
1460 cls = type('C', (), cls_dict)
1461
1462 # Make it a dataclass.
1463 cls1 = dataclass(cls)
1464
1465 self.assertEqual(cls1, cls)
1466 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1467
1468 def test_dynamic_class_creation_using_field(self):
1469 cls_dict = {'__annotations__': OrderedDict(x=int, y=int),
1470 'y': field(default=5),
1471 }
1472
1473 # Create the class.
1474 cls = type('C', (), cls_dict)
1475
1476 # Make it a dataclass.
1477 cls1 = dataclass(cls)
1478
1479 self.assertEqual(cls1, cls)
1480 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1481
1482 def test_init_in_order(self):
1483 @dataclass
1484 class C:
1485 a: int
1486 b: int = field()
1487 c: list = field(default_factory=list, init=False)
1488 d: list = field(default_factory=list)
1489 e: int = field(default=4, init=False)
1490 f: int = 4
1491
1492 calls = []
1493 def setattr(self, name, value):
1494 calls.append((name, value))
1495
1496 C.__setattr__ = setattr
1497 c = C(0, 1)
1498 self.assertEqual(('a', 0), calls[0])
1499 self.assertEqual(('b', 1), calls[1])
1500 self.assertEqual(('c', []), calls[2])
1501 self.assertEqual(('d', []), calls[3])
1502 self.assertNotIn(('e', 4), calls)
1503 self.assertEqual(('f', 4), calls[4])
1504
1505 def test_items_in_dicts(self):
1506 @dataclass
1507 class C:
1508 a: int
1509 b: list = field(default_factory=list, init=False)
1510 c: list = field(default_factory=list)
1511 d: int = field(default=4, init=False)
1512 e: int = 0
1513
1514 c = C(0)
1515 # Class dict
1516 self.assertNotIn('a', C.__dict__)
1517 self.assertNotIn('b', C.__dict__)
1518 self.assertNotIn('c', C.__dict__)
1519 self.assertIn('d', C.__dict__)
1520 self.assertEqual(C.d, 4)
1521 self.assertIn('e', C.__dict__)
1522 self.assertEqual(C.e, 0)
1523 # Instance dict
1524 self.assertIn('a', c.__dict__)
1525 self.assertEqual(c.a, 0)
1526 self.assertIn('b', c.__dict__)
1527 self.assertEqual(c.b, [])
1528 self.assertIn('c', c.__dict__)
1529 self.assertEqual(c.c, [])
1530 self.assertNotIn('d', c.__dict__)
1531 self.assertIn('e', c.__dict__)
1532 self.assertEqual(c.e, 0)
1533
1534 def test_alternate_classmethod_constructor(self):
1535 # Since __post_init__ can't take params, use a classmethod
1536 # alternate constructor. This is mostly an example to show how
1537 # to use this technique.
1538 @dataclass
1539 class C:
1540 x: int
1541 @classmethod
1542 def from_file(cls, filename):
1543 # In a real example, create a new instance
1544 # and populate 'x' from contents of a file.
1545 value_in_file = 20
1546 return cls(value_in_file)
1547
1548 self.assertEqual(C.from_file('filename').x, 20)
1549
1550 def test_field_metadata_default(self):
1551 # Make sure the default metadata is read-only and of
1552 # zero length.
1553 @dataclass
1554 class C:
1555 i: int
1556
1557 self.assertFalse(fields(C)[0].metadata)
1558 self.assertEqual(len(fields(C)[0].metadata), 0)
1559 with self.assertRaisesRegex(TypeError,
1560 'does not support item assignment'):
1561 fields(C)[0].metadata['test'] = 3
1562
1563 def test_field_metadata_mapping(self):
1564 # Make sure only a mapping can be passed as metadata
1565 # zero length.
1566 with self.assertRaises(TypeError):
1567 @dataclass
1568 class C:
1569 i: int = field(metadata=0)
1570
1571 # Make sure an empty dict works
1572 @dataclass
1573 class C:
1574 i: int = field(metadata={})
1575 self.assertFalse(fields(C)[0].metadata)
1576 self.assertEqual(len(fields(C)[0].metadata), 0)
1577 with self.assertRaisesRegex(TypeError,
1578 'does not support item assignment'):
1579 fields(C)[0].metadata['test'] = 3
1580
1581 # Make sure a non-empty dict works.
1582 @dataclass
1583 class C:
1584 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1585 self.assertEqual(len(fields(C)[0].metadata), 3)
1586 self.assertEqual(fields(C)[0].metadata['test'], 10)
1587 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1588 self.assertEqual(fields(C)[0].metadata[3], 'three')
1589 with self.assertRaises(KeyError):
1590 # Non-existent key.
1591 fields(C)[0].metadata['baz']
1592 with self.assertRaisesRegex(TypeError,
1593 'does not support item assignment'):
1594 fields(C)[0].metadata['test'] = 3
1595
1596 def test_field_metadata_custom_mapping(self):
1597 # Try a custom mapping.
1598 class SimpleNameSpace:
1599 def __init__(self, **kw):
1600 self.__dict__.update(kw)
1601
1602 def __getitem__(self, item):
1603 if item == 'xyzzy':
1604 return 'plugh'
1605 return getattr(self, item)
1606
1607 def __len__(self):
1608 return self.__dict__.__len__()
1609
1610 @dataclass
1611 class C:
1612 i: int = field(metadata=SimpleNameSpace(a=10))
1613
1614 self.assertEqual(len(fields(C)[0].metadata), 1)
1615 self.assertEqual(fields(C)[0].metadata['a'], 10)
1616 with self.assertRaises(AttributeError):
1617 fields(C)[0].metadata['b']
1618 # Make sure we're still talking to our custom mapping.
1619 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1620
1621 def test_generic_dataclasses(self):
1622 T = TypeVar('T')
1623
1624 @dataclass
1625 class LabeledBox(Generic[T]):
1626 content: T
1627 label: str = '<unknown>'
1628
1629 box = LabeledBox(42)
1630 self.assertEqual(box.content, 42)
1631 self.assertEqual(box.label, '<unknown>')
1632
1633 # subscripting the resulting class should work, etc.
1634 Alias = List[LabeledBox[int]]
1635
1636 def test_generic_extending(self):
1637 S = TypeVar('S')
1638 T = TypeVar('T')
1639
1640 @dataclass
1641 class Base(Generic[T, S]):
1642 x: T
1643 y: S
1644
1645 @dataclass
1646 class DataDerived(Base[int, T]):
1647 new_field: str
1648 Alias = DataDerived[str]
1649 c = Alias(0, 'test1', 'test2')
1650 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1651
1652 class NonDataDerived(Base[int, T]):
1653 def new_method(self):
1654 return self.y
1655 Alias = NonDataDerived[float]
1656 c = Alias(10, 1.0)
1657 self.assertEqual(c.new_method(), 1.0)
1658
1659 def test_helper_replace(self):
1660 @dataclass(frozen=True)
1661 class C:
1662 x: int
1663 y: int
1664
1665 c = C(1, 2)
1666 c1 = replace(c, x=3)
1667 self.assertEqual(c1.x, 3)
1668 self.assertEqual(c1.y, 2)
1669
1670 def test_helper_replace_frozen(self):
1671 @dataclass(frozen=True)
1672 class C:
1673 x: int
1674 y: int
1675 z: int = field(init=False, default=10)
1676 t: int = field(init=False, default=100)
1677
1678 c = C(1, 2)
1679 c1 = replace(c, x=3)
1680 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1681 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1682
1683
1684 with self.assertRaisesRegex(ValueError, 'init=False'):
1685 replace(c, x=3, z=20, t=50)
1686 with self.assertRaisesRegex(ValueError, 'init=False'):
1687 replace(c, z=20)
1688 replace(c, x=3, z=20, t=50)
1689
1690 # Make sure the result is still frozen.
1691 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1692 c1.x = 3
1693
1694 # Make sure we can't replace an attribute that doesn't exist,
1695 # if we're also replacing one that does exist. Test this
1696 # here, because setting attributes on frozen instances is
1697 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001698 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001699 "keyword argument 'a'"):
1700 c1 = replace(c, x=20, a=5)
1701
1702 def test_helper_replace_invalid_field_name(self):
1703 @dataclass(frozen=True)
1704 class C:
1705 x: int
1706 y: int
1707
1708 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001709 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001710 "keyword argument 'z'"):
1711 c1 = replace(c, z=3)
1712
1713 def test_helper_replace_invalid_object(self):
1714 @dataclass(frozen=True)
1715 class C:
1716 x: int
1717 y: int
1718
1719 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1720 replace(C, x=3)
1721
1722 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1723 replace(0, x=3)
1724
1725 def test_helper_replace_no_init(self):
1726 @dataclass
1727 class C:
1728 x: int
1729 y: int = field(init=False, default=10)
1730
1731 c = C(1)
1732 c.y = 20
1733
1734 # Make sure y gets the default value.
1735 c1 = replace(c, x=5)
1736 self.assertEqual((c1.x, c1.y), (5, 10))
1737
1738 # Trying to replace y is an error.
1739 with self.assertRaisesRegex(ValueError, 'init=False'):
1740 replace(c, x=2, y=30)
1741 with self.assertRaisesRegex(ValueError, 'init=False'):
1742 replace(c, y=30)
1743
1744 def test_dataclassses_pickleable(self):
1745 global P, Q, R
1746 @dataclass
1747 class P:
1748 x: int
1749 y: int = 0
1750 @dataclass
1751 class Q:
1752 x: int
1753 y: int = field(default=0, init=False)
1754 @dataclass
1755 class R:
1756 x: int
1757 y: List[int] = field(default_factory=list)
1758 q = Q(1)
1759 q.y = 2
1760 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1761 for sample in samples:
1762 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1763 with self.subTest(sample=sample, proto=proto):
1764 new_sample = pickle.loads(pickle.dumps(sample, proto))
1765 self.assertEqual(sample.x, new_sample.x)
1766 self.assertEqual(sample.y, new_sample.y)
1767 self.assertIsNot(sample, new_sample)
1768 new_sample.x = 42
1769 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1770 self.assertEqual(new_sample.x, another_new_sample.x)
1771 self.assertEqual(sample.y, another_new_sample.y)
1772
1773 def test_helper_make_dataclass(self):
1774 C = make_dataclass('C',
1775 [('x', int),
1776 ('y', int, field(default=5))],
1777 namespace={'add_one': lambda self: self.x + 1})
1778 c = C(10)
1779 self.assertEqual((c.x, c.y), (10, 5))
1780 self.assertEqual(c.add_one(), 11)
1781
1782
1783 def test_helper_make_dataclass_no_mutate_namespace(self):
1784 # Make sure a provided namespace isn't mutated.
1785 ns = {}
1786 C = make_dataclass('C',
1787 [('x', int),
1788 ('y', int, field(default=5))],
1789 namespace=ns)
1790 self.assertEqual(ns, {})
1791
1792 def test_helper_make_dataclass_base(self):
1793 class Base1:
1794 pass
1795 class Base2:
1796 pass
1797 C = make_dataclass('C',
1798 [('x', int)],
1799 bases=(Base1, Base2))
1800 c = C(2)
1801 self.assertIsInstance(c, C)
1802 self.assertIsInstance(c, Base1)
1803 self.assertIsInstance(c, Base2)
1804
1805 def test_helper_make_dataclass_base_dataclass(self):
1806 @dataclass
1807 class Base1:
1808 x: int
1809 class Base2:
1810 pass
1811 C = make_dataclass('C',
1812 [('y', int)],
1813 bases=(Base1, Base2))
1814 with self.assertRaisesRegex(TypeError, 'required positional'):
1815 c = C(2)
1816 c = C(1, 2)
1817 self.assertIsInstance(c, C)
1818 self.assertIsInstance(c, Base1)
1819 self.assertIsInstance(c, Base2)
1820
1821 self.assertEqual((c.x, c.y), (1, 2))
1822
1823 def test_helper_make_dataclass_init_var(self):
1824 def post_init(self, y):
1825 self.x *= y
1826
1827 C = make_dataclass('C',
1828 [('x', int),
1829 ('y', InitVar[int]),
1830 ],
1831 namespace={'__post_init__': post_init},
1832 )
1833 c = C(2, 3)
1834 self.assertEqual(vars(c), {'x': 6})
1835 self.assertEqual(len(fields(c)), 1)
1836
1837 def test_helper_make_dataclass_class_var(self):
1838 C = make_dataclass('C',
1839 [('x', int),
1840 ('y', ClassVar[int], 10),
1841 ('z', ClassVar[int], field(default=20)),
1842 ])
1843 c = C(1)
1844 self.assertEqual(vars(c), {'x': 1})
1845 self.assertEqual(len(fields(c)), 1)
1846 self.assertEqual(C.y, 10)
1847 self.assertEqual(C.z, 20)
1848
Eric V. Smithd80b4432018-01-06 17:09:58 -05001849 def test_helper_make_dataclass_other_params(self):
1850 C = make_dataclass('C',
1851 [('x', int),
1852 ('y', ClassVar[int], 10),
1853 ('z', ClassVar[int], field(default=20)),
1854 ],
1855 init=False)
1856 # Make sure we have a repr, but no init.
1857 self.assertNotIn('__init__', vars(C))
1858 self.assertIn('__repr__', vars(C))
1859
1860 # Make sure random other params don't work.
1861 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1862 C = make_dataclass('C',
1863 [],
1864 xxinit=False)
1865
Eric V. Smithed7d4292018-01-06 16:14:03 -05001866 def test_helper_make_dataclass_no_types(self):
1867 C = make_dataclass('Point', ['x', 'y', 'z'])
1868 c = C(1, 2, 3)
1869 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1870 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1871 'y': 'typing.Any',
1872 'z': 'typing.Any'})
1873
1874 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1875 c = C(1, 2, 3)
1876 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1877 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1878 'y': int,
1879 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001880
Eric V. Smithea8fc522018-01-27 19:07:40 -05001881
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001882class TestDocString(unittest.TestCase):
1883 def assertDocStrEqual(self, a, b):
1884 # Because 3.6 and 3.7 differ in how inspect.signature work
1885 # (see bpo #32108), for the time being just compare them with
1886 # whitespace stripped.
1887 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1888
1889 def test_existing_docstring_not_overridden(self):
1890 @dataclass
1891 class C:
1892 """Lorem ipsum"""
1893 x: int
1894
1895 self.assertEqual(C.__doc__, "Lorem ipsum")
1896
1897 def test_docstring_no_fields(self):
1898 @dataclass
1899 class C:
1900 pass
1901
1902 self.assertDocStrEqual(C.__doc__, "C()")
1903
1904 def test_docstring_one_field(self):
1905 @dataclass
1906 class C:
1907 x: int
1908
1909 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1910
1911 def test_docstring_two_fields(self):
1912 @dataclass
1913 class C:
1914 x: int
1915 y: int
1916
1917 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1918
1919 def test_docstring_three_fields(self):
1920 @dataclass
1921 class C:
1922 x: int
1923 y: int
1924 z: str
1925
1926 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1927
1928 def test_docstring_one_field_with_default(self):
1929 @dataclass
1930 class C:
1931 x: int = 3
1932
1933 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1934
1935 def test_docstring_one_field_with_default_none(self):
1936 @dataclass
1937 class C:
1938 x: Union[int, type(None)] = None
1939
1940 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1941
1942 def test_docstring_list_field(self):
1943 @dataclass
1944 class C:
1945 x: List[int]
1946
1947 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1948
1949 def test_docstring_list_field_with_default_factory(self):
1950 @dataclass
1951 class C:
1952 x: List[int] = field(default_factory=list)
1953
1954 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1955
1956 def test_docstring_deque_field(self):
1957 @dataclass
1958 class C:
1959 x: deque
1960
1961 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1962
1963 def test_docstring_deque_field_with_default_factory(self):
1964 @dataclass
1965 class C:
1966 x: deque = field(default_factory=deque)
1967
1968 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1969
1970
Eric V. Smithea8fc522018-01-27 19:07:40 -05001971class TestInit(unittest.TestCase):
1972 def test_base_has_init(self):
1973 class B:
1974 def __init__(self):
1975 self.z = 100
1976 pass
1977
1978 # Make sure that declaring this class doesn't raise an error.
1979 # The issue is that we can't override __init__ in our class,
1980 # but it should be okay to add __init__ to us if our base has
1981 # an __init__.
1982 @dataclass
1983 class C(B):
1984 x: int = 0
1985 c = C(10)
1986 self.assertEqual(c.x, 10)
1987 self.assertNotIn('z', vars(c))
1988
1989 # Make sure that if we don't add an init, the base __init__
1990 # gets called.
1991 @dataclass(init=False)
1992 class C(B):
1993 x: int = 10
1994 c = C()
1995 self.assertEqual(c.x, 10)
1996 self.assertEqual(c.z, 100)
1997
1998 def test_no_init(self):
1999 dataclass(init=False)
2000 class C:
2001 i: int = 0
2002 self.assertEqual(C().i, 0)
2003
2004 dataclass(init=False)
2005 class C:
2006 i: int = 2
2007 def __init__(self):
2008 self.i = 3
2009 self.assertEqual(C().i, 3)
2010
2011 def test_overwriting_init(self):
2012 # If the class has __init__, use it no matter the value of
2013 # init=.
2014
2015 @dataclass
2016 class C:
2017 x: int
2018 def __init__(self, x):
2019 self.x = 2 * x
2020 self.assertEqual(C(3).x, 6)
2021
2022 @dataclass(init=True)
2023 class C:
2024 x: int
2025 def __init__(self, x):
2026 self.x = 2 * x
2027 self.assertEqual(C(4).x, 8)
2028
2029 @dataclass(init=False)
2030 class C:
2031 x: int
2032 def __init__(self, x):
2033 self.x = 2 * x
2034 self.assertEqual(C(5).x, 10)
2035
2036
2037class TestRepr(unittest.TestCase):
2038 def test_repr(self):
2039 @dataclass
2040 class B:
2041 x: int
2042
2043 @dataclass
2044 class C(B):
2045 y: int = 10
2046
2047 o = C(4)
2048 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2049
2050 @dataclass
2051 class D(C):
2052 x: int = 20
2053 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2054
2055 @dataclass
2056 class C:
2057 @dataclass
2058 class D:
2059 i: int
2060 @dataclass
2061 class E:
2062 pass
2063 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2064 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2065
2066 def test_no_repr(self):
2067 # Test a class with no __repr__ and repr=False.
2068 @dataclass(repr=False)
2069 class C:
2070 x: int
2071 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2072 repr(C(3)))
2073
2074 # Test a class with a __repr__ and repr=False.
2075 @dataclass(repr=False)
2076 class C:
2077 x: int
2078 def __repr__(self):
2079 return 'C-class'
2080 self.assertEqual(repr(C(3)), 'C-class')
2081
2082 def test_overwriting_repr(self):
2083 # If the class has __repr__, use it no matter the value of
2084 # repr=.
2085
2086 @dataclass
2087 class C:
2088 x: int
2089 def __repr__(self):
2090 return 'x'
2091 self.assertEqual(repr(C(0)), 'x')
2092
2093 @dataclass(repr=True)
2094 class C:
2095 x: int
2096 def __repr__(self):
2097 return 'x'
2098 self.assertEqual(repr(C(0)), 'x')
2099
2100 @dataclass(repr=False)
2101 class C:
2102 x: int
2103 def __repr__(self):
2104 return 'x'
2105 self.assertEqual(repr(C(0)), 'x')
2106
2107
2108class TestFrozen(unittest.TestCase):
2109 def test_overwriting_frozen(self):
2110 # frozen uses __setattr__ and __delattr__
2111 with self.assertRaisesRegex(TypeError,
2112 'Cannot overwrite attribute __setattr__'):
2113 @dataclass(frozen=True)
2114 class C:
2115 x: int
2116 def __setattr__(self):
2117 pass
2118
2119 with self.assertRaisesRegex(TypeError,
2120 'Cannot overwrite attribute __delattr__'):
2121 @dataclass(frozen=True)
2122 class C:
2123 x: int
2124 def __delattr__(self):
2125 pass
2126
2127 @dataclass(frozen=False)
2128 class C:
2129 x: int
2130 def __setattr__(self, name, value):
2131 self.__dict__['x'] = value * 2
2132 self.assertEqual(C(10).x, 20)
2133
2134
2135class TestEq(unittest.TestCase):
2136 def test_no_eq(self):
2137 # Test a class with no __eq__ and eq=False.
2138 @dataclass(eq=False)
2139 class C:
2140 x: int
2141 self.assertNotEqual(C(0), C(0))
2142 c = C(3)
2143 self.assertEqual(c, c)
2144
2145 # Test a class with an __eq__ and eq=False.
2146 @dataclass(eq=False)
2147 class C:
2148 x: int
2149 def __eq__(self, other):
2150 return other == 10
2151 self.assertEqual(C(3), 10)
2152
2153 def test_overwriting_eq(self):
2154 # If the class has __eq__, use it no matter the value of
2155 # eq=.
2156
2157 @dataclass
2158 class C:
2159 x: int
2160 def __eq__(self, other):
2161 return other == 3
2162 self.assertEqual(C(1), 3)
2163 self.assertNotEqual(C(1), 1)
2164
2165 @dataclass(eq=True)
2166 class C:
2167 x: int
2168 def __eq__(self, other):
2169 return other == 4
2170 self.assertEqual(C(1), 4)
2171 self.assertNotEqual(C(1), 1)
2172
2173 @dataclass(eq=False)
2174 class C:
2175 x: int
2176 def __eq__(self, other):
2177 return other == 5
2178 self.assertEqual(C(1), 5)
2179 self.assertNotEqual(C(1), 1)
2180
2181
2182class TestOrdering(unittest.TestCase):
2183 def test_functools_total_ordering(self):
2184 # Test that functools.total_ordering works with this class.
2185 @total_ordering
2186 @dataclass
2187 class C:
2188 x: int
2189 def __lt__(self, other):
2190 # Perform the test "backward", just to make
2191 # sure this is being called.
2192 return self.x >= other
2193
2194 self.assertLess(C(0), -1)
2195 self.assertLessEqual(C(0), -1)
2196 self.assertGreater(C(0), 1)
2197 self.assertGreaterEqual(C(0), 1)
2198
2199 def test_no_order(self):
2200 # Test that no ordering functions are added by default.
2201 @dataclass(order=False)
2202 class C:
2203 x: int
2204 # Make sure no order methods are added.
2205 self.assertNotIn('__le__', C.__dict__)
2206 self.assertNotIn('__lt__', C.__dict__)
2207 self.assertNotIn('__ge__', C.__dict__)
2208 self.assertNotIn('__gt__', C.__dict__)
2209
2210 # Test that __lt__ is still called
2211 @dataclass(order=False)
2212 class C:
2213 x: int
2214 def __lt__(self, other):
2215 return False
2216 # Make sure other methods aren't added.
2217 self.assertNotIn('__le__', C.__dict__)
2218 self.assertNotIn('__ge__', C.__dict__)
2219 self.assertNotIn('__gt__', C.__dict__)
2220
2221 def test_overwriting_order(self):
2222 with self.assertRaisesRegex(TypeError,
2223 'Cannot overwrite attribute __lt__'
2224 '.*using functools.total_ordering'):
2225 @dataclass(order=True)
2226 class C:
2227 x: int
2228 def __lt__(self):
2229 pass
2230
2231 with self.assertRaisesRegex(TypeError,
2232 'Cannot overwrite attribute __le__'
2233 '.*using functools.total_ordering'):
2234 @dataclass(order=True)
2235 class C:
2236 x: int
2237 def __le__(self):
2238 pass
2239
2240 with self.assertRaisesRegex(TypeError,
2241 'Cannot overwrite attribute __gt__'
2242 '.*using functools.total_ordering'):
2243 @dataclass(order=True)
2244 class C:
2245 x: int
2246 def __gt__(self):
2247 pass
2248
2249 with self.assertRaisesRegex(TypeError,
2250 'Cannot overwrite attribute __ge__'
2251 '.*using functools.total_ordering'):
2252 @dataclass(order=True)
2253 class C:
2254 x: int
2255 def __ge__(self):
2256 pass
2257
2258class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002259 def test_unsafe_hash(self):
2260 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002261 class C:
2262 x: int
2263 y: str
2264 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2265
Eric V. Smithea8fc522018-01-27 19:07:40 -05002266 def test_hash_rules(self):
2267 def non_bool(value):
2268 # Map to something else that's True, but not a bool.
2269 if value is None:
2270 return None
2271 if value:
2272 return (3,)
2273 return 0
2274
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002275 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2276 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2277 frozen=frozen):
2278 if result != 'exception':
2279 if with_hash:
2280 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2281 class C:
2282 def __hash__(self):
2283 return 0
2284 else:
2285 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2286 class C:
2287 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002288
2289 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002290 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002291 # __hash__ contains the function we generated.
2292 self.assertIn('__hash__', C.__dict__)
2293 self.assertIsNotNone(C.__dict__['__hash__'])
2294
Eric V. Smithea8fc522018-01-27 19:07:40 -05002295 elif result == '':
2296 # __hash__ is not present in our class.
2297 if not with_hash:
2298 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002299
Eric V. Smithea8fc522018-01-27 19:07:40 -05002300 elif result == 'none':
2301 # __hash__ is set to None.
2302 self.assertIn('__hash__', C.__dict__)
2303 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002304
2305 elif result == 'exception':
2306 # Creating the class should cause an exception.
2307 # This only happens with with_hash==True.
2308 assert(with_hash)
2309 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2310 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2311 class C:
2312 def __hash__(self):
2313 return 0
2314
Eric V. Smithea8fc522018-01-27 19:07:40 -05002315 else:
2316 assert False, f'unknown result {result!r}'
2317
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002318 # There are 8 cases of:
2319 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002320 # eq=True/False
2321 # frozen=True/False
2322 # And for each of these, a different result if
2323 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002324 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2325 (False, False, False, '', ''),
2326 (False, False, True, '', ''),
2327 (False, True, False, 'none', ''),
2328 (False, True, True, 'fn', ''),
2329 (True, False, False, 'fn', 'exception'),
2330 (True, False, True, 'fn', 'exception'),
2331 (True, True, False, 'fn', 'exception'),
2332 (True, True, True, 'fn', 'exception'),
2333 ], 1):
2334 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2335 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002336
2337 # Test non-bool truth values, too. This is just to
2338 # make sure the data-driven table in the decorator
2339 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002340 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2341 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002342
2343
2344 def test_eq_only(self):
2345 # If a class defines __eq__, __hash__ is automatically added
2346 # and set to None. This is normal Python behavior, not
2347 # related to dataclasses. Make sure we don't interfere with
2348 # that (see bpo=32546).
2349
2350 @dataclass
2351 class C:
2352 i: int
2353 def __eq__(self, other):
2354 return self.i == other.i
2355 self.assertEqual(C(1), C(1))
2356 self.assertNotEqual(C(1), C(4))
2357
2358 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002359 # unsafe_hash=True.
2360 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002361 class C:
2362 i: int
2363 def __eq__(self, other):
2364 return self.i == other.i
2365 self.assertEqual(C(1), C(1.0))
2366 self.assertEqual(hash(C(1)), hash(C(1.0)))
2367
2368 # And check that the classes __eq__ is being used, despite
2369 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002370 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002371 class C:
2372 i: int
2373 def __eq__(self, other):
2374 return self.i == 3 and self.i == other.i
2375 self.assertEqual(C(3), C(3))
2376 self.assertNotEqual(C(1), C(1))
2377 self.assertEqual(hash(C(1)), hash(C(1.0)))
2378
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002379 def test_0_field_hash(self):
2380 @dataclass(frozen=True)
2381 class C:
2382 pass
2383 self.assertEqual(hash(C()), hash(()))
2384
2385 @dataclass(unsafe_hash=True)
2386 class C:
2387 pass
2388 self.assertEqual(hash(C()), hash(()))
2389
2390 def test_1_field_hash(self):
2391 @dataclass(frozen=True)
2392 class C:
2393 x: int
2394 self.assertEqual(hash(C(4)), hash((4,)))
2395 self.assertEqual(hash(C(42)), hash((42,)))
2396
2397 @dataclass(unsafe_hash=True)
2398 class C:
2399 x: int
2400 self.assertEqual(hash(C(4)), hash((4,)))
2401 self.assertEqual(hash(C(42)), hash((42,)))
2402
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002403 def test_hash_no_args(self):
2404 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002405 # make sure that if the @dataclass parameter name is changed
2406 # or the non-default hashing behavior changes, the default
2407 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002408
2409 class Base:
2410 def __hash__(self):
2411 return 301
2412
2413 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)1a579062018-02-25 19:09:05 -08002414 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002415 for frozen, eq, base, expected in [
2416 (None, None, object, 'unhashable'),
2417 (None, None, Base, 'unhashable'),
2418 (None, False, object, 'object'),
2419 (None, False, Base, 'base'),
2420 (None, True, object, 'unhashable'),
2421 (None, True, Base, 'unhashable'),
2422 (False, None, object, 'unhashable'),
2423 (False, None, Base, 'unhashable'),
2424 (False, False, object, 'object'),
2425 (False, False, Base, 'base'),
2426 (False, True, object, 'unhashable'),
2427 (False, True, Base, 'unhashable'),
2428 (True, None, object, 'tuple'),
2429 (True, None, Base, 'tuple'),
2430 (True, False, object, 'object'),
2431 (True, False, Base, 'base'),
2432 (True, True, object, 'tuple'),
2433 (True, True, Base, 'tuple'),
2434 ]:
2435
2436 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2437 # First, create the class.
2438 if frozen is None and eq is None:
2439 @dataclass
2440 class C(base):
2441 i: int
2442 elif frozen is None:
2443 @dataclass(eq=eq)
2444 class C(base):
2445 i: int
2446 elif eq is None:
2447 @dataclass(frozen=frozen)
2448 class C(base):
2449 i: int
2450 else:
2451 @dataclass(frozen=frozen, eq=eq)
2452 class C(base):
2453 i: int
2454
2455 # Now, make sure it hashes as expected.
2456 if expected == 'unhashable':
2457 c = C(10)
2458 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2459 hash(c)
2460
2461 elif expected == 'base':
2462 self.assertEqual(hash(C(10)), 301)
2463
2464 elif expected == 'object':
2465 # I'm not sure what test to use here. object's
2466 # hash isn't based on id(), so calling hash()
2467 # won't tell us much. So, just check the function
2468 # used is object's.
2469 self.assertIs(C.__hash__, object.__hash__)
2470
2471 elif expected == 'tuple':
2472 self.assertEqual(hash(C(42)), hash((42,)))
2473
2474 else:
2475 assert False, f'unknown value for expected={expected!r}'
2476
Eric V. Smithea8fc522018-01-27 19:07:40 -05002477
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002478if __name__ == '__main__':
2479 unittest.main()