blob: 2c890a2cbe9206d17bc5da1868b60615e8cda9ed [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
Miss Islington (bot)d063ad82018-04-01 04:33:13 -070011from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050012from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
15# Just any custom exception we can catch.
16class CustomError(Exception): pass
17
18class TestCase(unittest.TestCase):
19 def test_no_fields(self):
20 @dataclass
21 class C:
22 pass
23
24 o = C()
25 self.assertEqual(len(fields(C)), 0)
26
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -070027 def test_no_fields_but_member_variable(self):
28 @dataclass
29 class C:
30 i = 0
31
32 o = C()
33 self.assertEqual(len(fields(C)), 0)
34
Eric V. Smithf0db54a2017-12-04 16:58:55 -050035 def test_one_field_no_default(self):
36 @dataclass
37 class C:
38 x: int
39
40 o = C(42)
41 self.assertEqual(o.x, 42)
42
43 def test_named_init_params(self):
44 @dataclass
45 class C:
46 x: int
47
48 o = C(x=32)
49 self.assertEqual(o.x, 32)
50
51 def test_two_fields_one_default(self):
52 @dataclass
53 class C:
54 x: int
55 y: int = 0
56
57 o = C(3)
58 self.assertEqual((o.x, o.y), (3, 0))
59
60 # Non-defaults following defaults.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class C:
66 x: int = 0
67 y: int
68
69 # A derived class adds a non-default field after a default one.
70 with self.assertRaisesRegex(TypeError,
71 "non-default argument 'y' follows "
72 "default argument"):
73 @dataclass
74 class B:
75 x: int = 0
76
77 @dataclass
78 class C(B):
79 y: int
80
81 # Override a base class field and add a default to
82 # a field which didn't use to have a default.
83 with self.assertRaisesRegex(TypeError,
84 "non-default argument 'y' follows "
85 "default argument"):
86 @dataclass
87 class B:
88 x: int
89 y: int
90
91 @dataclass
92 class C(B):
93 x: int = 0
94
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080095 def test_overwrite_hash(self):
96 # Test that declaring this class isn't an error. It should
97 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
101 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800102 return 301
103 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # Test that declaring this class isn't an error. It should
106 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500107 @dataclass(frozen=True)
108 class C:
109 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800110 def __eq__(self, other):
111 return False
112 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500113
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800114 # But this one should generate an exception, because with
115 # unsafe_hash=True, it's an error to have a __hash__ defined.
116 with self.assertRaisesRegex(TypeError,
117 'Cannot overwrite attribute __hash__'):
118 @dataclass(unsafe_hash=True)
119 class C:
120 def __hash__(self):
121 pass
122
123 # Creating this class should not generate an exception,
124 # because even though __hash__ exists before @dataclass is
125 # called, (due to __eq__ being defined), since it's None
126 # that's okay.
127 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500128 class C:
129 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800130 def __eq__(self):
131 pass
132 # The generated hash function works as we'd expect.
133 self.assertEqual(hash(C(10)), hash((10,)))
134
135 # Creating this class should generate an exception, because
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700136 # __hash__ exists and is not None, which it would be if it
137 # had been auto-generated due to __eq__ being defined.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800138 with self.assertRaisesRegex(TypeError,
139 'Cannot overwrite attribute __hash__'):
140 @dataclass(unsafe_hash=True)
141 class C:
142 x: int
143 def __eq__(self):
144 pass
145 def __hash__(self):
146 pass
147
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500148 def test_overwrite_fields_in_derived_class(self):
149 # Note that x from C1 replaces x in Base, but the order remains
150 # the same as defined in Base.
151 @dataclass
152 class Base:
153 x: Any = 15.0
154 y: int = 0
155
156 @dataclass
157 class C1(Base):
158 z: int = 10
159 x: int = 15
160
161 o = Base()
162 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
163
164 o = C1()
165 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
166
167 o = C1(x=5)
168 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
169
170 def test_field_named_self(self):
171 @dataclass
172 class C:
173 self: str
174 c=C('foo')
175 self.assertEqual(c.self, 'foo')
176
177 # Make sure the first parameter is not named 'self'.
178 sig = inspect.signature(C.__init__)
179 first = next(iter(sig.parameters))
180 self.assertNotEqual('self', first)
181
182 # But we do use 'self' if no field named self.
183 @dataclass
184 class C:
185 selfx: str
186
187 # Make sure the first parameter is named 'self'.
188 sig = inspect.signature(C.__init__)
189 first = next(iter(sig.parameters))
190 self.assertEqual('self', first)
191
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500192 def test_0_field_compare(self):
193 # Ensure that order=False is the default.
194 @dataclass
195 class C0:
196 pass
197
198 @dataclass(order=False)
199 class C1:
200 pass
201
202 for cls in [C0, C1]:
203 with self.subTest(cls=cls):
204 self.assertEqual(cls(), cls())
205 for idx, fn in enumerate([lambda a, b: a < b,
206 lambda a, b: a <= b,
207 lambda a, b: a > b,
208 lambda a, b: a >= b]):
209 with self.subTest(idx=idx):
210 with self.assertRaisesRegex(TypeError,
211 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
212 fn(cls(), cls())
213
214 @dataclass(order=True)
215 class C:
216 pass
217 self.assertLessEqual(C(), C())
218 self.assertGreaterEqual(C(), C())
219
220 def test_1_field_compare(self):
221 # Ensure that order=False is the default.
222 @dataclass
223 class C0:
224 x: int
225
226 @dataclass(order=False)
227 class C1:
228 x: int
229
230 for cls in [C0, C1]:
231 with self.subTest(cls=cls):
232 self.assertEqual(cls(1), cls(1))
233 self.assertNotEqual(cls(0), cls(1))
234 for idx, fn in enumerate([lambda a, b: a < b,
235 lambda a, b: a <= b,
236 lambda a, b: a > b,
237 lambda a, b: a >= b]):
238 with self.subTest(idx=idx):
239 with self.assertRaisesRegex(TypeError,
240 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
241 fn(cls(0), cls(0))
242
243 @dataclass(order=True)
244 class C:
245 x: int
246 self.assertLess(C(0), C(1))
247 self.assertLessEqual(C(0), C(1))
248 self.assertLessEqual(C(1), C(1))
249 self.assertGreater(C(1), C(0))
250 self.assertGreaterEqual(C(1), C(0))
251 self.assertGreaterEqual(C(1), C(1))
252
253 def test_simple_compare(self):
254 # Ensure that order=False is the default.
255 @dataclass
256 class C0:
257 x: int
258 y: int
259
260 @dataclass(order=False)
261 class C1:
262 x: int
263 y: int
264
265 for cls in [C0, C1]:
266 with self.subTest(cls=cls):
267 self.assertEqual(cls(0, 0), cls(0, 0))
268 self.assertEqual(cls(1, 2), cls(1, 2))
269 self.assertNotEqual(cls(1, 0), cls(0, 0))
270 self.assertNotEqual(cls(1, 0), cls(1, 1))
271 for idx, fn in enumerate([lambda a, b: a < b,
272 lambda a, b: a <= b,
273 lambda a, b: a > b,
274 lambda a, b: a >= b]):
275 with self.subTest(idx=idx):
276 with self.assertRaisesRegex(TypeError,
277 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
278 fn(cls(0, 0), cls(0, 0))
279
280 @dataclass(order=True)
281 class C:
282 x: int
283 y: int
284
285 for idx, fn in enumerate([lambda a, b: a == b,
286 lambda a, b: a <= b,
287 lambda a, b: a >= b]):
288 with self.subTest(idx=idx):
289 self.assertTrue(fn(C(0, 0), C(0, 0)))
290
291 for idx, fn in enumerate([lambda a, b: a < b,
292 lambda a, b: a <= b,
293 lambda a, b: a != b]):
294 with self.subTest(idx=idx):
295 self.assertTrue(fn(C(0, 0), C(0, 1)))
296 self.assertTrue(fn(C(0, 1), C(1, 0)))
297 self.assertTrue(fn(C(1, 0), C(1, 1)))
298
299 for idx, fn in enumerate([lambda a, b: a > b,
300 lambda a, b: a >= b,
301 lambda a, b: a != b]):
302 with self.subTest(idx=idx):
303 self.assertTrue(fn(C(0, 1), C(0, 0)))
304 self.assertTrue(fn(C(1, 0), C(0, 1)))
305 self.assertTrue(fn(C(1, 1), C(1, 0)))
306
307 def test_compare_subclasses(self):
308 # Comparisons fail for subclasses, even if no fields
309 # are added.
310 @dataclass
311 class B:
312 i: int
313
314 @dataclass
315 class C(B):
316 pass
317
318 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
319 (lambda a, b: a != b, True)]):
320 with self.subTest(idx=idx):
321 self.assertEqual(fn(B(0), C(0)), expected)
322
323 for idx, fn in enumerate([lambda a, b: a < b,
324 lambda a, b: a <= b,
325 lambda a, b: a > b,
326 lambda a, b: a >= b]):
327 with self.subTest(idx=idx):
328 with self.assertRaisesRegex(TypeError,
329 "not supported between instances of 'B' and 'C'"):
330 fn(B(0), C(0))
331
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500332 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500333 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500334 for (eq, order, result ) in [
335 (False, False, 'neither'),
336 (False, True, 'exception'),
337 (True, False, 'eq_only'),
338 (True, True, 'both'),
339 ]:
340 with self.subTest(eq=eq, order=order):
341 if result == 'exception':
342 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
343 @dataclass(eq=eq, order=order)
344 class C:
345 pass
346 else:
347 @dataclass(eq=eq, order=order)
348 class C:
349 pass
350
351 if result == 'neither':
352 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500353 self.assertNotIn('__lt__', C.__dict__)
354 self.assertNotIn('__le__', C.__dict__)
355 self.assertNotIn('__gt__', C.__dict__)
356 self.assertNotIn('__ge__', C.__dict__)
357 elif result == 'both':
358 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500359 self.assertIn('__lt__', C.__dict__)
360 self.assertIn('__le__', C.__dict__)
361 self.assertIn('__gt__', C.__dict__)
362 self.assertIn('__ge__', C.__dict__)
363 elif result == 'eq_only':
364 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500365 self.assertNotIn('__lt__', C.__dict__)
366 self.assertNotIn('__le__', C.__dict__)
367 self.assertNotIn('__gt__', C.__dict__)
368 self.assertNotIn('__ge__', C.__dict__)
369 else:
370 assert False, f'unknown result {result!r}'
371
372 def test_field_no_default(self):
373 @dataclass
374 class C:
375 x: int = field()
376
377 self.assertEqual(C(5).x, 5)
378
379 with self.assertRaisesRegex(TypeError,
380 r"__init__\(\) missing 1 required "
381 "positional argument: 'x'"):
382 C()
383
384 def test_field_default(self):
385 default = object()
386 @dataclass
387 class C:
388 x: object = field(default=default)
389
390 self.assertIs(C.x, default)
391 c = C(10)
392 self.assertEqual(c.x, 10)
393
394 # If we delete the instance attribute, we should then see the
395 # class attribute.
396 del c.x
397 self.assertIs(c.x, default)
398
399 self.assertIs(C().x, default)
400
401 def test_not_in_repr(self):
402 @dataclass
403 class C:
404 x: int = field(repr=False)
405 with self.assertRaises(TypeError):
406 C()
407 c = C(10)
408 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
409
410 @dataclass
411 class C:
412 x: int = field(repr=False)
413 y: int
414 c = C(10, 20)
415 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
416
417 def test_not_in_compare(self):
418 @dataclass
419 class C:
420 x: int = 0
421 y: int = field(compare=False, default=4)
422
423 self.assertEqual(C(), C(0, 20))
424 self.assertEqual(C(1, 10), C(1, 20))
425 self.assertNotEqual(C(3), C(4, 10))
426 self.assertNotEqual(C(3, 10), C(4, 10))
427
428 def test_hash_field_rules(self):
429 # Test all 6 cases of:
430 # hash=True/False/None
431 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800432 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500433 (True, False, 'field' ),
434 (True, True, 'field' ),
435 (False, False, 'absent'),
436 (False, True, 'absent'),
437 (None, False, 'absent'),
438 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800439 ]:
440 with self.subTest(hash=hash_, compare=compare):
441 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500442 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800443 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500444
445 if result == 'field':
446 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800447 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500448 elif result == 'absent':
449 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800450 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500451 else:
452 assert False, f'unknown result {result!r}'
453
454 def test_init_false_no_default(self):
455 # If init=False and no default value, then the field won't be
456 # present in the instance.
457 @dataclass
458 class C:
459 x: int = field(init=False)
460
461 self.assertNotIn('x', C().__dict__)
462
463 @dataclass
464 class C:
465 x: int
466 y: int = 0
467 z: int = field(init=False)
468 t: int = 10
469
470 self.assertNotIn('z', C(0).__dict__)
471 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
472
473 def test_class_marker(self):
474 @dataclass
475 class C:
476 x: int
477 y: str = field(init=False, default=None)
478 z: str = field(repr=False)
479
480 the_fields = fields(C)
481 # the_fields is a tuple of 3 items, each value
482 # is in __annotations__.
483 self.assertIsInstance(the_fields, tuple)
484 for f in the_fields:
485 self.assertIs(type(f), Field)
486 self.assertIn(f.name, C.__annotations__)
487
488 self.assertEqual(len(the_fields), 3)
489
490 self.assertEqual(the_fields[0].name, 'x')
491 self.assertEqual(the_fields[0].type, int)
492 self.assertFalse(hasattr(C, 'x'))
493 self.assertTrue (the_fields[0].init)
494 self.assertTrue (the_fields[0].repr)
495 self.assertEqual(the_fields[1].name, 'y')
496 self.assertEqual(the_fields[1].type, str)
497 self.assertIsNone(getattr(C, 'y'))
498 self.assertFalse(the_fields[1].init)
499 self.assertTrue (the_fields[1].repr)
500 self.assertEqual(the_fields[2].name, 'z')
501 self.assertEqual(the_fields[2].type, str)
502 self.assertFalse(hasattr(C, 'z'))
503 self.assertTrue (the_fields[2].init)
504 self.assertFalse(the_fields[2].repr)
505
506 def test_field_order(self):
507 @dataclass
508 class B:
509 a: str = 'B:a'
510 b: str = 'B:b'
511 c: str = 'B:c'
512
513 @dataclass
514 class C(B):
515 b: str = 'C:b'
516
517 self.assertEqual([(f.name, f.default) for f in fields(C)],
518 [('a', 'B:a'),
519 ('b', 'C:b'),
520 ('c', 'B:c')])
521
522 @dataclass
523 class D(B):
524 c: str = 'D:c'
525
526 self.assertEqual([(f.name, f.default) for f in fields(D)],
527 [('a', 'B:a'),
528 ('b', 'B:b'),
529 ('c', 'D:c')])
530
531 @dataclass
532 class E(D):
533 a: str = 'E:a'
534 d: str = 'E:d'
535
536 self.assertEqual([(f.name, f.default) for f in fields(E)],
537 [('a', 'E:a'),
538 ('b', 'B:b'),
539 ('c', 'D:c'),
540 ('d', 'E:d')])
541
542 def test_class_attrs(self):
543 # We only have a class attribute if a default value is
544 # specified, either directly or via a field with a default.
545 default = object()
546 @dataclass
547 class C:
548 x: int
549 y: int = field(repr=False)
550 z: object = default
551 t: int = field(default=100)
552
553 self.assertFalse(hasattr(C, 'x'))
554 self.assertFalse(hasattr(C, 'y'))
555 self.assertIs (C.z, default)
556 self.assertEqual(C.t, 100)
557
558 def test_disallowed_mutable_defaults(self):
559 # For the known types, don't allow mutable default values.
560 for typ, empty, non_empty in [(list, [], [1]),
561 (dict, {}, {0:1}),
562 (set, set(), set([1])),
563 ]:
564 with self.subTest(typ=typ):
565 # Can't use a zero-length value.
566 with self.assertRaisesRegex(ValueError,
567 f'mutable default {typ} for field '
568 'x is not allowed'):
569 @dataclass
570 class Point:
571 x: typ = empty
572
573
574 # Nor a non-zero-length value
575 with self.assertRaisesRegex(ValueError,
576 f'mutable default {typ} for field '
577 'y is not allowed'):
578 @dataclass
579 class Point:
580 y: typ = non_empty
581
582 # Check subtypes also fail.
583 class Subclass(typ): pass
584
585 with self.assertRaisesRegex(ValueError,
586 f"mutable default .*Subclass'>"
587 ' for field z is not allowed'
588 ):
589 @dataclass
590 class Point:
591 z: typ = Subclass()
592
593 # Because this is a ClassVar, it can be mutable.
594 @dataclass
595 class C:
596 z: ClassVar[typ] = typ()
597
598 # Because this is a ClassVar, it can be mutable.
599 @dataclass
600 class C:
601 x: ClassVar[typ] = Subclass()
602
603
604 def test_deliberately_mutable_defaults(self):
605 # If a mutable default isn't in the known list of
606 # (list, dict, set), then it's okay.
607 class Mutable:
608 def __init__(self):
609 self.l = []
610
611 @dataclass
612 class C:
613 x: Mutable
614
615 # These 2 instances will share this value of x.
616 lst = Mutable()
617 o1 = C(lst)
618 o2 = C(lst)
619 self.assertEqual(o1, o2)
620 o1.x.l.extend([1, 2])
621 self.assertEqual(o1, o2)
622 self.assertEqual(o1.x.l, [1, 2])
623 self.assertIs(o1.x, o2.x)
624
625 def test_no_options(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700626 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500627 @dataclass()
628 class C:
629 x: int
630
631 self.assertEqual(C(42).x, 42)
632
633 def test_not_tuple(self):
634 # Make sure we can't be compared to a tuple.
635 @dataclass
636 class Point:
637 x: int
638 y: int
639 self.assertNotEqual(Point(1, 2), (1, 2))
640
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700641 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500642 @dataclass
643 class C:
644 x: int
645 y: int
646 self.assertNotEqual(Point(1, 3), C(1, 3))
647
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500648 def test_not_tuple(self):
649 # Test that some of the problems with namedtuple don't happen
650 # here.
651 @dataclass
652 class Point3D:
653 x: int
654 y: int
655 z: int
656
657 @dataclass
658 class Date:
659 year: int
660 month: int
661 day: int
662
663 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
664 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
665
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700666 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200667 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500668 x, y, z = Point3D(4, 5, 6)
669
Eric V. Smith7c99e932018-01-28 19:18:55 -0500670 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500671 # equal.
672 @dataclass
673 class Point3Dv1:
674 x: int = 0
675 y: int = 0
676 z: int = 0
677 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
678
679 def test_function_annotations(self):
680 # Some dummy class and instance to use as a default.
681 class F:
682 pass
683 f = F()
684
685 def validate_class(cls):
686 # First, check __annotations__, even though they're not
687 # function annotations.
688 self.assertEqual(cls.__annotations__['i'], int)
689 self.assertEqual(cls.__annotations__['j'], str)
690 self.assertEqual(cls.__annotations__['k'], F)
691 self.assertEqual(cls.__annotations__['l'], float)
692 self.assertEqual(cls.__annotations__['z'], complex)
693
694 # Verify __init__.
695
696 signature = inspect.signature(cls.__init__)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700697 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500698 self.assertIs(signature.return_annotation, None)
699
700 # Check each parameter.
701 params = iter(signature.parameters.values())
702 param = next(params)
703 # This is testing an internal name, and probably shouldn't be tested.
704 self.assertEqual(param.name, 'self')
705 param = next(params)
706 self.assertEqual(param.name, 'i')
707 self.assertIs (param.annotation, int)
708 self.assertEqual(param.default, inspect.Parameter.empty)
709 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
710 param = next(params)
711 self.assertEqual(param.name, 'j')
712 self.assertIs (param.annotation, str)
713 self.assertEqual(param.default, inspect.Parameter.empty)
714 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
715 param = next(params)
716 self.assertEqual(param.name, 'k')
717 self.assertIs (param.annotation, F)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700718 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500719 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
720 param = next(params)
721 self.assertEqual(param.name, 'l')
722 self.assertIs (param.annotation, float)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700723 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500724 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
725 self.assertRaises(StopIteration, next, params)
726
727
728 @dataclass
729 class C:
730 i: int
731 j: str
732 k: F = f
733 l: float=field(default=None)
734 z: complex=field(default=3+4j, init=False)
735
736 validate_class(C)
737
738 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800739 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500740 class C:
741 i: int
742 j: str
743 k: F = f
744 l: float=field(default=None)
745 z: complex=field(default=3+4j, init=False)
746
747 validate_class(C)
748
Eric V. Smith03220fd2017-12-29 13:59:58 -0500749 def test_missing_default(self):
750 # Test that MISSING works the same as a default not being
751 # specified.
752 @dataclass
753 class C:
754 x: int=field(default=MISSING)
755 with self.assertRaisesRegex(TypeError,
756 r'__init__\(\) missing 1 required '
757 'positional argument'):
758 C()
759 self.assertNotIn('x', C.__dict__)
760
761 @dataclass
762 class D:
763 x: int
764 with self.assertRaisesRegex(TypeError,
765 r'__init__\(\) missing 1 required '
766 'positional argument'):
767 D()
768 self.assertNotIn('x', D.__dict__)
769
770 def test_missing_default_factory(self):
771 # Test that MISSING works the same as a default factory not
772 # being specified (which is really the same as a default not
773 # being specified, too).
774 @dataclass
775 class C:
776 x: int=field(default_factory=MISSING)
777 with self.assertRaisesRegex(TypeError,
778 r'__init__\(\) missing 1 required '
779 'positional argument'):
780 C()
781 self.assertNotIn('x', C.__dict__)
782
783 @dataclass
784 class D:
785 x: int=field(default=MISSING, default_factory=MISSING)
786 with self.assertRaisesRegex(TypeError,
787 r'__init__\(\) missing 1 required '
788 'positional argument'):
789 D()
790 self.assertNotIn('x', D.__dict__)
791
792 def test_missing_repr(self):
793 self.assertIn('MISSING_TYPE object', repr(MISSING))
794
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500795 def test_dont_include_other_annotations(self):
796 @dataclass
797 class C:
798 i: int
799 def foo(self) -> int:
800 return 4
801 @property
802 def bar(self) -> int:
803 return 5
804 self.assertEqual(list(C.__annotations__), ['i'])
805 self.assertEqual(C(10).foo(), 4)
806 self.assertEqual(C(10).bar, 5)
Miss Islington (bot)5666a552018-03-25 06:27:50 -0700807 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500808
809 def test_post_init(self):
810 # Just make sure it gets called
811 @dataclass
812 class C:
813 def __post_init__(self):
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817
818 @dataclass
819 class C:
820 i: int = 10
821 def __post_init__(self):
822 if self.i == 10:
823 raise CustomError()
824 with self.assertRaises(CustomError):
825 C()
826 # post-init gets called, but doesn't raise. This is just
827 # checking that self is used correctly.
828 C(5)
829
830 # If there's not an __init__, then post-init won't get called.
831 @dataclass(init=False)
832 class C:
833 def __post_init__(self):
834 raise CustomError()
835 # Creating the class won't raise
836 C()
837
838 @dataclass
839 class C:
840 x: int = 0
841 def __post_init__(self):
842 self.x *= 2
843 self.assertEqual(C().x, 0)
844 self.assertEqual(C(2).x, 4)
845
Mike53f7a7c2017-12-14 14:04:53 +0300846 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 # attributes.
848 @dataclass(frozen=True)
849 class C:
850 x: int = 0
851 def __post_init__(self):
852 self.x *= 2
853 with self.assertRaises(FrozenInstanceError):
854 C()
855
856 def test_post_init_super(self):
857 # Make sure super() post-init isn't called by default.
858 class B:
859 def __post_init__(self):
860 raise CustomError()
861
862 @dataclass
863 class C(B):
864 def __post_init__(self):
865 self.x = 5
866
867 self.assertEqual(C().x, 5)
868
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700869 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500870 @dataclass
871 class C(B):
872 def __post_init__(self):
873 super().__post_init__()
874
875 with self.assertRaises(CustomError):
876 C()
877
878 # Make sure post-init is called, even if not defined in our
879 # class.
880 @dataclass
881 class C(B):
882 pass
883
884 with self.assertRaises(CustomError):
885 C()
886
887 def test_post_init_staticmethod(self):
888 flag = False
889 @dataclass
890 class C:
891 x: int
892 y: int
893 @staticmethod
894 def __post_init__():
895 nonlocal flag
896 flag = True
897
898 self.assertFalse(flag)
899 c = C(3, 4)
900 self.assertEqual((c.x, c.y), (3, 4))
901 self.assertTrue(flag)
902
903 def test_post_init_classmethod(self):
904 @dataclass
905 class C:
906 flag = False
907 x: int
908 y: int
909 @classmethod
910 def __post_init__(cls):
911 cls.flag = True
912
913 self.assertFalse(C.flag)
914 c = C(3, 4)
915 self.assertEqual((c.x, c.y), (3, 4))
916 self.assertTrue(C.flag)
917
918 def test_class_var(self):
919 # Make sure ClassVars are ignored in __init__, __repr__, etc.
920 @dataclass
921 class C:
922 x: int
923 y: int = 10
924 z: ClassVar[int] = 1000
925 w: ClassVar[int] = 2000
926 t: ClassVar[int] = 3000
927
928 c = C(5)
929 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700930 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
931 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500932 self.assertEqual(c.z, 1000)
933 self.assertEqual(c.w, 2000)
934 self.assertEqual(c.t, 3000)
935 C.z += 1
936 self.assertEqual(c.z, 1001)
937 c = C(20)
938 self.assertEqual((c.x, c.y), (20, 10))
939 self.assertEqual(c.z, 1001)
940 self.assertEqual(c.w, 2000)
941 self.assertEqual(c.t, 3000)
942
943 def test_class_var_no_default(self):
944 # If a ClassVar has no default value, it should not be set on the class.
945 @dataclass
946 class C:
947 x: ClassVar[int]
948
949 self.assertNotIn('x', C.__dict__)
950
951 def test_class_var_default_factory(self):
952 # It makes no sense for a ClassVar to have a default factory. When
953 # would it be called? Call it yourself, since it's class-wide.
954 with self.assertRaisesRegex(TypeError,
955 'cannot have a default factory'):
956 @dataclass
957 class C:
958 x: ClassVar[int] = field(default_factory=int)
959
960 self.assertNotIn('x', C.__dict__)
961
962 def test_class_var_with_default(self):
963 # If a ClassVar has a default value, it should be set on the class.
964 @dataclass
965 class C:
966 x: ClassVar[int] = 10
967 self.assertEqual(C.x, 10)
968
969 @dataclass
970 class C:
971 x: ClassVar[int] = field(default=10)
972 self.assertEqual(C.x, 10)
973
974 def test_class_var_frozen(self):
975 # Make sure ClassVars work even if we're frozen.
976 @dataclass(frozen=True)
977 class C:
978 x: int
979 y: int = 10
980 z: ClassVar[int] = 1000
981 w: ClassVar[int] = 2000
982 t: ClassVar[int] = 3000
983
984 c = C(5)
985 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
986 self.assertEqual(len(fields(C)), 2) # We have 2 fields
987 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
988 self.assertEqual(c.z, 1000)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991 # We can still modify the ClassVar, it's only instances that are
992 # frozen.
993 C.z += 1
994 self.assertEqual(c.z, 1001)
995 c = C(20)
996 self.assertEqual((c.x, c.y), (20, 10))
997 self.assertEqual(c.z, 1001)
998 self.assertEqual(c.w, 2000)
999 self.assertEqual(c.t, 3000)
1000
1001 def test_init_var_no_default(self):
1002 # If an InitVar has no default value, it should not be set on the class.
1003 @dataclass
1004 class C:
1005 x: InitVar[int]
1006
1007 self.assertNotIn('x', C.__dict__)
1008
1009 def test_init_var_default_factory(self):
1010 # It makes no sense for an InitVar to have a default factory. When
1011 # would it be called? Call it yourself, since it's class-wide.
1012 with self.assertRaisesRegex(TypeError,
1013 'cannot have a default factory'):
1014 @dataclass
1015 class C:
1016 x: InitVar[int] = field(default_factory=int)
1017
1018 self.assertNotIn('x', C.__dict__)
1019
1020 def test_init_var_with_default(self):
1021 # If an InitVar has a default value, it should be set on the class.
1022 @dataclass
1023 class C:
1024 x: InitVar[int] = 10
1025 self.assertEqual(C.x, 10)
1026
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = field(default=10)
1030 self.assertEqual(C.x, 10)
1031
1032 def test_init_var(self):
1033 @dataclass
1034 class C:
1035 x: int = None
1036 init_param: InitVar[int] = None
1037
1038 def __post_init__(self, init_param):
1039 if self.x is None:
1040 self.x = init_param*2
1041
1042 c = C(init_param=10)
1043 self.assertEqual(c.x, 20)
1044
1045 def test_init_var_inheritance(self):
1046 # Note that this deliberately tests that a dataclass need not
1047 # have a __post_init__ function if it has an InitVar field.
1048 # It could just be used in a derived class, as shown here.
1049 @dataclass
1050 class Base:
1051 x: int
1052 init_base: InitVar[int]
1053
1054 # We can instantiate by passing the InitVar, even though
1055 # it's not used.
1056 b = Base(0, 10)
1057 self.assertEqual(vars(b), {'x': 0})
1058
1059 @dataclass
1060 class C(Base):
1061 y: int
1062 init_derived: InitVar[int]
1063
1064 def __post_init__(self, init_base, init_derived):
1065 self.x = self.x + init_base
1066 self.y = self.y + init_derived
1067
1068 c = C(10, 11, 50, 51)
1069 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1070
1071 def test_default_factory(self):
1072 # Test a factory that returns a new list.
1073 @dataclass
1074 class C:
1075 x: int
1076 y: list = field(default_factory=list)
1077
1078 c0 = C(3)
1079 c1 = C(3)
1080 self.assertEqual(c0.x, 3)
1081 self.assertEqual(c0.y, [])
1082 self.assertEqual(c0, c1)
1083 self.assertIsNot(c0.y, c1.y)
1084 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1085
1086 # Test a factory that returns a shared list.
1087 l = []
1088 @dataclass
1089 class C:
1090 x: int
1091 y: list = field(default_factory=lambda: l)
1092
1093 c0 = C(3)
1094 c1 = C(3)
1095 self.assertEqual(c0.x, 3)
1096 self.assertEqual(c0.y, [])
1097 self.assertEqual(c0, c1)
1098 self.assertIs(c0.y, c1.y)
1099 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1100
1101 # Test various other field flags.
1102 # repr
1103 @dataclass
1104 class C:
1105 x: list = field(default_factory=list, repr=False)
1106 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1107 self.assertEqual(C().x, [])
1108
1109 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001110 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001111 class C:
1112 x: list = field(default_factory=list, hash=False)
1113 self.assertEqual(astuple(C()), ([],))
1114 self.assertEqual(hash(C()), hash(()))
1115
1116 # init (see also test_default_factory_with_no_init)
1117 @dataclass
1118 class C:
1119 x: list = field(default_factory=list, init=False)
1120 self.assertEqual(astuple(C()), ([],))
1121
1122 # compare
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=list, compare=False)
1126 self.assertEqual(C(), C([1]))
1127
1128 def test_default_factory_with_no_init(self):
1129 # We need a factory with a side effect.
1130 factory = Mock()
1131
1132 @dataclass
1133 class C:
1134 x: list = field(default_factory=factory, init=False)
1135
1136 # Make sure the default factory is called for each new instance.
1137 C().x
1138 self.assertEqual(factory.call_count, 1)
1139 C().x
1140 self.assertEqual(factory.call_count, 2)
1141
1142 def test_default_factory_not_called_if_value_given(self):
1143 # We need a factory that we can test if it's been called.
1144 factory = Mock()
1145
1146 @dataclass
1147 class C:
1148 x: int = field(default_factory=factory)
1149
1150 # Make sure that if a field has a default factory function,
1151 # it's not called if a value is specified.
1152 C().x
1153 self.assertEqual(factory.call_count, 1)
1154 self.assertEqual(C(10).x, 10)
1155 self.assertEqual(factory.call_count, 1)
1156 C().x
1157 self.assertEqual(factory.call_count, 2)
1158
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001159 def test_default_factory_derived(self):
1160 # See bpo-32896.
1161 @dataclass
1162 class Foo:
1163 x: dict = field(default_factory=dict)
1164
1165 @dataclass
1166 class Bar(Foo):
1167 y: int = 1
1168
1169 self.assertEqual(Foo().x, {})
1170 self.assertEqual(Bar().x, {})
1171 self.assertEqual(Bar().y, 1)
1172
1173 @dataclass
1174 class Baz(Foo):
1175 pass
1176 self.assertEqual(Baz().x, {})
1177
1178 def test_intermediate_non_dataclass(self):
1179 # Test that an intermediate class that defines
1180 # annotations does not define fields.
1181
1182 @dataclass
1183 class A:
1184 x: int
1185
1186 class B(A):
1187 y: int
1188
1189 @dataclass
1190 class C(B):
1191 z: int
1192
1193 c = C(1, 3)
1194 self.assertEqual((c.x, c.z), (1, 3))
1195
1196 # .y was not initialized.
1197 with self.assertRaisesRegex(AttributeError,
1198 'object has no attribute'):
1199 c.y
1200
1201 # And if we again derive a non-dataclass, no fields are added.
1202 class D(C):
1203 t: int
1204 d = D(4, 5)
1205 self.assertEqual((d.x, d.z), (4, 5))
1206
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001207 def test_classvar_default_factory(self):
1208 # It's an error for a ClassVar to have a factory function.
1209 with self.assertRaisesRegex(TypeError,
1210 'cannot have a default factory'):
1211 @dataclass
1212 class C:
1213 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001214
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001215 def test_is_dataclass(self):
1216 class NotDataClass:
1217 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001218
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001219 self.assertFalse(is_dataclass(0))
1220 self.assertFalse(is_dataclass(int))
1221 self.assertFalse(is_dataclass(NotDataClass))
1222 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001223
1224 @dataclass
1225 class C:
1226 x: int
1227
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001228 @dataclass
1229 class D:
1230 d: C
1231 e: int
1232
1233 c = C(10)
1234 d = D(c, 4)
1235
1236 self.assertTrue(is_dataclass(C))
1237 self.assertTrue(is_dataclass(c))
1238 self.assertFalse(is_dataclass(c.x))
1239 self.assertTrue(is_dataclass(d.d))
1240 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001241
1242 def test_helper_fields_with_class_instance(self):
1243 # Check that we can call fields() on either a class or instance,
1244 # and get back the same thing.
1245 @dataclass
1246 class C:
1247 x: int
1248 y: float
1249
1250 self.assertEqual(fields(C), fields(C(0, 0.0)))
1251
1252 def test_helper_fields_exception(self):
1253 # Check that TypeError is raised if not passed a dataclass or
1254 # instance.
1255 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1256 fields(0)
1257
1258 class C: pass
1259 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1260 fields(C)
1261 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1262 fields(C())
1263
1264 def test_helper_asdict(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001265 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001266 @dataclass
1267 class C:
1268 x: int
1269 y: int
1270 c = C(1, 2)
1271
1272 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1273 self.assertEqual(asdict(c), asdict(c))
1274 self.assertIsNot(asdict(c), asdict(c))
1275 c.x = 42
1276 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1277 self.assertIs(type(asdict(c)), dict)
1278
1279 def test_helper_asdict_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001280 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001281 @dataclass
1282 class C:
1283 x: int
1284 y: int
1285 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1286 asdict(C)
1287 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1288 asdict(int)
1289
1290 def test_helper_asdict_copy_values(self):
1291 @dataclass
1292 class C:
1293 x: int
1294 y: List[int] = field(default_factory=list)
1295 initial = []
1296 c = C(1, initial)
1297 d = asdict(c)
1298 self.assertEqual(d['y'], initial)
1299 self.assertIsNot(d['y'], initial)
1300 c = C(1)
1301 d = asdict(c)
1302 d['y'].append(1)
1303 self.assertEqual(c.y, [])
1304
1305 def test_helper_asdict_nested(self):
1306 @dataclass
1307 class UserId:
1308 token: int
1309 group: int
1310 @dataclass
1311 class User:
1312 name: str
1313 id: UserId
1314 u = User('Joe', UserId(123, 1))
1315 d = asdict(u)
1316 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1317 self.assertIsNot(asdict(u), asdict(u))
1318 u.id.group = 2
1319 self.assertEqual(asdict(u), {'name': 'Joe',
1320 'id': {'token': 123, 'group': 2}})
1321
1322 def test_helper_asdict_builtin_containers(self):
1323 @dataclass
1324 class User:
1325 name: str
1326 id: int
1327 @dataclass
1328 class GroupList:
1329 id: int
1330 users: List[User]
1331 @dataclass
1332 class GroupTuple:
1333 id: int
1334 users: Tuple[User, ...]
1335 @dataclass
1336 class GroupDict:
1337 id: int
1338 users: Dict[str, User]
1339 a = User('Alice', 1)
1340 b = User('Bob', 2)
1341 gl = GroupList(0, [a, b])
1342 gt = GroupTuple(0, (a, b))
1343 gd = GroupDict(0, {'first': a, 'second': b})
1344 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1345 {'name': 'Bob', 'id': 2}]})
1346 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1347 {'name': 'Bob', 'id': 2})})
1348 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1349 'second': {'name': 'Bob', 'id': 2}}})
1350
1351 def test_helper_asdict_builtin_containers(self):
1352 @dataclass
1353 class Child:
1354 d: object
1355
1356 @dataclass
1357 class Parent:
1358 child: Child
1359
1360 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1361 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1362
1363 def test_helper_asdict_factory(self):
1364 @dataclass
1365 class C:
1366 x: int
1367 y: int
1368 c = C(1, 2)
1369 d = asdict(c, dict_factory=OrderedDict)
1370 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1371 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1372 c.x = 42
1373 d = asdict(c, dict_factory=OrderedDict)
1374 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1375 self.assertIs(type(d), OrderedDict)
1376
1377 def test_helper_astuple(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001378 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001379 @dataclass
1380 class C:
1381 x: int
1382 y: int = 0
1383 c = C(1)
1384
1385 self.assertEqual(astuple(c), (1, 0))
1386 self.assertEqual(astuple(c), astuple(c))
1387 self.assertIsNot(astuple(c), astuple(c))
1388 c.y = 42
1389 self.assertEqual(astuple(c), (1, 42))
1390 self.assertIs(type(astuple(c)), tuple)
1391
1392 def test_helper_astuple_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001393 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001394 @dataclass
1395 class C:
1396 x: int
1397 y: int
1398 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1399 astuple(C)
1400 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1401 astuple(int)
1402
1403 def test_helper_astuple_copy_values(self):
1404 @dataclass
1405 class C:
1406 x: int
1407 y: List[int] = field(default_factory=list)
1408 initial = []
1409 c = C(1, initial)
1410 t = astuple(c)
1411 self.assertEqual(t[1], initial)
1412 self.assertIsNot(t[1], initial)
1413 c = C(1)
1414 t = astuple(c)
1415 t[1].append(1)
1416 self.assertEqual(c.y, [])
1417
1418 def test_helper_astuple_nested(self):
1419 @dataclass
1420 class UserId:
1421 token: int
1422 group: int
1423 @dataclass
1424 class User:
1425 name: str
1426 id: UserId
1427 u = User('Joe', UserId(123, 1))
1428 t = astuple(u)
1429 self.assertEqual(t, ('Joe', (123, 1)))
1430 self.assertIsNot(astuple(u), astuple(u))
1431 u.id.group = 2
1432 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1433
1434 def test_helper_astuple_builtin_containers(self):
1435 @dataclass
1436 class User:
1437 name: str
1438 id: int
1439 @dataclass
1440 class GroupList:
1441 id: int
1442 users: List[User]
1443 @dataclass
1444 class GroupTuple:
1445 id: int
1446 users: Tuple[User, ...]
1447 @dataclass
1448 class GroupDict:
1449 id: int
1450 users: Dict[str, User]
1451 a = User('Alice', 1)
1452 b = User('Bob', 2)
1453 gl = GroupList(0, [a, b])
1454 gt = GroupTuple(0, (a, b))
1455 gd = GroupDict(0, {'first': a, 'second': b})
1456 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1457 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1458 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1459
1460 def test_helper_astuple_builtin_containers(self):
1461 @dataclass
1462 class Child:
1463 d: object
1464
1465 @dataclass
1466 class Parent:
1467 child: Child
1468
1469 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1470 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1471
1472 def test_helper_astuple_factory(self):
1473 @dataclass
1474 class C:
1475 x: int
1476 y: int
1477 NT = namedtuple('NT', 'x y')
1478 def nt(lst):
1479 return NT(*lst)
1480 c = C(1, 2)
1481 t = astuple(c, tuple_factory=nt)
1482 self.assertEqual(t, NT(1, 2))
1483 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1484 c.x = 42
1485 t = astuple(c, tuple_factory=nt)
1486 self.assertEqual(t, NT(42, 2))
1487 self.assertIs(type(t), NT)
1488
1489 def test_dynamic_class_creation(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001490 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001491 }
1492
1493 # Create the class.
1494 cls = type('C', (), cls_dict)
1495
1496 # Make it a dataclass.
1497 cls1 = dataclass(cls)
1498
1499 self.assertEqual(cls1, cls)
1500 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1501
1502 def test_dynamic_class_creation_using_field(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001503 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001504 'y': field(default=5),
1505 }
1506
1507 # Create the class.
1508 cls = type('C', (), cls_dict)
1509
1510 # Make it a dataclass.
1511 cls1 = dataclass(cls)
1512
1513 self.assertEqual(cls1, cls)
1514 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1515
1516 def test_init_in_order(self):
1517 @dataclass
1518 class C:
1519 a: int
1520 b: int = field()
1521 c: list = field(default_factory=list, init=False)
1522 d: list = field(default_factory=list)
1523 e: int = field(default=4, init=False)
1524 f: int = 4
1525
1526 calls = []
1527 def setattr(self, name, value):
1528 calls.append((name, value))
1529
1530 C.__setattr__ = setattr
1531 c = C(0, 1)
1532 self.assertEqual(('a', 0), calls[0])
1533 self.assertEqual(('b', 1), calls[1])
1534 self.assertEqual(('c', []), calls[2])
1535 self.assertEqual(('d', []), calls[3])
1536 self.assertNotIn(('e', 4), calls)
1537 self.assertEqual(('f', 4), calls[4])
1538
1539 def test_items_in_dicts(self):
1540 @dataclass
1541 class C:
1542 a: int
1543 b: list = field(default_factory=list, init=False)
1544 c: list = field(default_factory=list)
1545 d: int = field(default=4, init=False)
1546 e: int = 0
1547
1548 c = C(0)
1549 # Class dict
1550 self.assertNotIn('a', C.__dict__)
1551 self.assertNotIn('b', C.__dict__)
1552 self.assertNotIn('c', C.__dict__)
1553 self.assertIn('d', C.__dict__)
1554 self.assertEqual(C.d, 4)
1555 self.assertIn('e', C.__dict__)
1556 self.assertEqual(C.e, 0)
1557 # Instance dict
1558 self.assertIn('a', c.__dict__)
1559 self.assertEqual(c.a, 0)
1560 self.assertIn('b', c.__dict__)
1561 self.assertEqual(c.b, [])
1562 self.assertIn('c', c.__dict__)
1563 self.assertEqual(c.c, [])
1564 self.assertNotIn('d', c.__dict__)
1565 self.assertIn('e', c.__dict__)
1566 self.assertEqual(c.e, 0)
1567
1568 def test_alternate_classmethod_constructor(self):
1569 # Since __post_init__ can't take params, use a classmethod
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001570 # alternate constructor. This is mostly an example to show
1571 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001572 @dataclass
1573 class C:
1574 x: int
1575 @classmethod
1576 def from_file(cls, filename):
1577 # In a real example, create a new instance
1578 # and populate 'x' from contents of a file.
1579 value_in_file = 20
1580 return cls(value_in_file)
1581
1582 self.assertEqual(C.from_file('filename').x, 20)
1583
1584 def test_field_metadata_default(self):
1585 # Make sure the default metadata is read-only and of
1586 # zero length.
1587 @dataclass
1588 class C:
1589 i: int
1590
1591 self.assertFalse(fields(C)[0].metadata)
1592 self.assertEqual(len(fields(C)[0].metadata), 0)
1593 with self.assertRaisesRegex(TypeError,
1594 'does not support item assignment'):
1595 fields(C)[0].metadata['test'] = 3
1596
1597 def test_field_metadata_mapping(self):
1598 # Make sure only a mapping can be passed as metadata
1599 # zero length.
1600 with self.assertRaises(TypeError):
1601 @dataclass
1602 class C:
1603 i: int = field(metadata=0)
1604
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001605 # Make sure an empty dict works.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001606 @dataclass
1607 class C:
1608 i: int = field(metadata={})
1609 self.assertFalse(fields(C)[0].metadata)
1610 self.assertEqual(len(fields(C)[0].metadata), 0)
1611 with self.assertRaisesRegex(TypeError,
1612 'does not support item assignment'):
1613 fields(C)[0].metadata['test'] = 3
1614
1615 # Make sure a non-empty dict works.
1616 @dataclass
1617 class C:
1618 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1619 self.assertEqual(len(fields(C)[0].metadata), 3)
1620 self.assertEqual(fields(C)[0].metadata['test'], 10)
1621 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1622 self.assertEqual(fields(C)[0].metadata[3], 'three')
1623 with self.assertRaises(KeyError):
1624 # Non-existent key.
1625 fields(C)[0].metadata['baz']
1626 with self.assertRaisesRegex(TypeError,
1627 'does not support item assignment'):
1628 fields(C)[0].metadata['test'] = 3
1629
1630 def test_field_metadata_custom_mapping(self):
1631 # Try a custom mapping.
1632 class SimpleNameSpace:
1633 def __init__(self, **kw):
1634 self.__dict__.update(kw)
1635
1636 def __getitem__(self, item):
1637 if item == 'xyzzy':
1638 return 'plugh'
1639 return getattr(self, item)
1640
1641 def __len__(self):
1642 return self.__dict__.__len__()
1643
1644 @dataclass
1645 class C:
1646 i: int = field(metadata=SimpleNameSpace(a=10))
1647
1648 self.assertEqual(len(fields(C)[0].metadata), 1)
1649 self.assertEqual(fields(C)[0].metadata['a'], 10)
1650 with self.assertRaises(AttributeError):
1651 fields(C)[0].metadata['b']
1652 # Make sure we're still talking to our custom mapping.
1653 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1654
1655 def test_generic_dataclasses(self):
1656 T = TypeVar('T')
1657
1658 @dataclass
1659 class LabeledBox(Generic[T]):
1660 content: T
1661 label: str = '<unknown>'
1662
1663 box = LabeledBox(42)
1664 self.assertEqual(box.content, 42)
1665 self.assertEqual(box.label, '<unknown>')
1666
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001667 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001668 Alias = List[LabeledBox[int]]
1669
1670 def test_generic_extending(self):
1671 S = TypeVar('S')
1672 T = TypeVar('T')
1673
1674 @dataclass
1675 class Base(Generic[T, S]):
1676 x: T
1677 y: S
1678
1679 @dataclass
1680 class DataDerived(Base[int, T]):
1681 new_field: str
1682 Alias = DataDerived[str]
1683 c = Alias(0, 'test1', 'test2')
1684 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1685
1686 class NonDataDerived(Base[int, T]):
1687 def new_method(self):
1688 return self.y
1689 Alias = NonDataDerived[float]
1690 c = Alias(10, 1.0)
1691 self.assertEqual(c.new_method(), 1.0)
1692
Miss Islington (bot)d063ad82018-04-01 04:33:13 -07001693 def test_generic_dynamic(self):
1694 T = TypeVar('T')
1695
1696 @dataclass
1697 class Parent(Generic[T]):
1698 x: T
1699 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1700 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1701 self.assertIs(Child[int](1, 2).z, None)
1702 self.assertEqual(Child[int](1, 2, 3).z, 3)
1703 self.assertEqual(Child[int](1, 2, 3).other, 42)
1704 # Check that type aliases work correctly.
1705 Alias = Child[T]
1706 self.assertEqual(Alias[int](1, 2).x, 1)
1707 # Check MRO resolution.
1708 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1709
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001710 def test_helper_replace(self):
1711 @dataclass(frozen=True)
1712 class C:
1713 x: int
1714 y: int
1715
1716 c = C(1, 2)
1717 c1 = replace(c, x=3)
1718 self.assertEqual(c1.x, 3)
1719 self.assertEqual(c1.y, 2)
1720
1721 def test_helper_replace_frozen(self):
1722 @dataclass(frozen=True)
1723 class C:
1724 x: int
1725 y: int
1726 z: int = field(init=False, default=10)
1727 t: int = field(init=False, default=100)
1728
1729 c = C(1, 2)
1730 c1 = replace(c, x=3)
1731 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1732 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1733
1734
1735 with self.assertRaisesRegex(ValueError, 'init=False'):
1736 replace(c, x=3, z=20, t=50)
1737 with self.assertRaisesRegex(ValueError, 'init=False'):
1738 replace(c, z=20)
1739 replace(c, x=3, z=20, t=50)
1740
1741 # Make sure the result is still frozen.
1742 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1743 c1.x = 3
1744
1745 # Make sure we can't replace an attribute that doesn't exist,
1746 # if we're also replacing one that does exist. Test this
1747 # here, because setting attributes on frozen instances is
1748 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001749 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001750 "keyword argument 'a'"):
1751 c1 = replace(c, x=20, a=5)
1752
1753 def test_helper_replace_invalid_field_name(self):
1754 @dataclass(frozen=True)
1755 class C:
1756 x: int
1757 y: int
1758
1759 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001760 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001761 "keyword argument 'z'"):
1762 c1 = replace(c, z=3)
1763
1764 def test_helper_replace_invalid_object(self):
1765 @dataclass(frozen=True)
1766 class C:
1767 x: int
1768 y: int
1769
1770 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1771 replace(C, x=3)
1772
1773 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1774 replace(0, x=3)
1775
1776 def test_helper_replace_no_init(self):
1777 @dataclass
1778 class C:
1779 x: int
1780 y: int = field(init=False, default=10)
1781
1782 c = C(1)
1783 c.y = 20
1784
1785 # Make sure y gets the default value.
1786 c1 = replace(c, x=5)
1787 self.assertEqual((c1.x, c1.y), (5, 10))
1788
1789 # Trying to replace y is an error.
1790 with self.assertRaisesRegex(ValueError, 'init=False'):
1791 replace(c, x=2, y=30)
1792 with self.assertRaisesRegex(ValueError, 'init=False'):
1793 replace(c, y=30)
1794
1795 def test_dataclassses_pickleable(self):
1796 global P, Q, R
1797 @dataclass
1798 class P:
1799 x: int
1800 y: int = 0
1801 @dataclass
1802 class Q:
1803 x: int
1804 y: int = field(default=0, init=False)
1805 @dataclass
1806 class R:
1807 x: int
1808 y: List[int] = field(default_factory=list)
1809 q = Q(1)
1810 q.y = 2
1811 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1812 for sample in samples:
1813 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1814 with self.subTest(sample=sample, proto=proto):
1815 new_sample = pickle.loads(pickle.dumps(sample, proto))
1816 self.assertEqual(sample.x, new_sample.x)
1817 self.assertEqual(sample.y, new_sample.y)
1818 self.assertIsNot(sample, new_sample)
1819 new_sample.x = 42
1820 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1821 self.assertEqual(new_sample.x, another_new_sample.x)
1822 self.assertEqual(sample.y, another_new_sample.y)
1823
1824 def test_helper_make_dataclass(self):
1825 C = make_dataclass('C',
1826 [('x', int),
1827 ('y', int, field(default=5))],
1828 namespace={'add_one': lambda self: self.x + 1})
1829 c = C(10)
1830 self.assertEqual((c.x, c.y), (10, 5))
1831 self.assertEqual(c.add_one(), 11)
1832
1833
1834 def test_helper_make_dataclass_no_mutate_namespace(self):
1835 # Make sure a provided namespace isn't mutated.
1836 ns = {}
1837 C = make_dataclass('C',
1838 [('x', int),
1839 ('y', int, field(default=5))],
1840 namespace=ns)
1841 self.assertEqual(ns, {})
1842
1843 def test_helper_make_dataclass_base(self):
1844 class Base1:
1845 pass
1846 class Base2:
1847 pass
1848 C = make_dataclass('C',
1849 [('x', int)],
1850 bases=(Base1, Base2))
1851 c = C(2)
1852 self.assertIsInstance(c, C)
1853 self.assertIsInstance(c, Base1)
1854 self.assertIsInstance(c, Base2)
1855
1856 def test_helper_make_dataclass_base_dataclass(self):
1857 @dataclass
1858 class Base1:
1859 x: int
1860 class Base2:
1861 pass
1862 C = make_dataclass('C',
1863 [('y', int)],
1864 bases=(Base1, Base2))
1865 with self.assertRaisesRegex(TypeError, 'required positional'):
1866 c = C(2)
1867 c = C(1, 2)
1868 self.assertIsInstance(c, C)
1869 self.assertIsInstance(c, Base1)
1870 self.assertIsInstance(c, Base2)
1871
1872 self.assertEqual((c.x, c.y), (1, 2))
1873
1874 def test_helper_make_dataclass_init_var(self):
1875 def post_init(self, y):
1876 self.x *= y
1877
1878 C = make_dataclass('C',
1879 [('x', int),
1880 ('y', InitVar[int]),
1881 ],
1882 namespace={'__post_init__': post_init},
1883 )
1884 c = C(2, 3)
1885 self.assertEqual(vars(c), {'x': 6})
1886 self.assertEqual(len(fields(c)), 1)
1887
1888 def test_helper_make_dataclass_class_var(self):
1889 C = make_dataclass('C',
1890 [('x', int),
1891 ('y', ClassVar[int], 10),
1892 ('z', ClassVar[int], field(default=20)),
1893 ])
1894 c = C(1)
1895 self.assertEqual(vars(c), {'x': 1})
1896 self.assertEqual(len(fields(c)), 1)
1897 self.assertEqual(C.y, 10)
1898 self.assertEqual(C.z, 20)
1899
Eric V. Smithd80b4432018-01-06 17:09:58 -05001900 def test_helper_make_dataclass_other_params(self):
1901 C = make_dataclass('C',
1902 [('x', int),
1903 ('y', ClassVar[int], 10),
1904 ('z', ClassVar[int], field(default=20)),
1905 ],
1906 init=False)
1907 # Make sure we have a repr, but no init.
1908 self.assertNotIn('__init__', vars(C))
1909 self.assertIn('__repr__', vars(C))
1910
1911 # Make sure random other params don't work.
1912 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1913 C = make_dataclass('C',
1914 [],
1915 xxinit=False)
1916
Eric V. Smithed7d4292018-01-06 16:14:03 -05001917 def test_helper_make_dataclass_no_types(self):
1918 C = make_dataclass('Point', ['x', 'y', 'z'])
1919 c = C(1, 2, 3)
1920 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1921 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1922 'y': 'typing.Any',
1923 'z': 'typing.Any'})
1924
1925 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1926 c = C(1, 2, 3)
1927 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1928 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1929 'y': int,
1930 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001931
Eric V. Smithea8fc522018-01-27 19:07:40 -05001932
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001933class TestFieldNoAnnotation(unittest.TestCase):
1934 def test_field_without_annotation(self):
1935 with self.assertRaisesRegex(TypeError,
1936 "'f' is a field but has no type annotation"):
1937 @dataclass
1938 class C:
1939 f = field()
1940
1941 def test_field_without_annotation_but_annotation_in_base(self):
1942 @dataclass
1943 class B:
1944 f: int
1945
1946 with self.assertRaisesRegex(TypeError,
1947 "'f' is a field but has no type annotation"):
1948 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001949 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001950 @dataclass
1951 class C(B):
1952 f = field()
1953
1954 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1955 # Same test, but with the base class not a dataclass.
1956 class B:
1957 f: int
1958
1959 with self.assertRaisesRegex(TypeError,
1960 "'f' is a field but has no type annotation"):
1961 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001962 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001963 @dataclass
1964 class C(B):
1965 f = field()
1966
1967
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001968class TestDocString(unittest.TestCase):
1969 def assertDocStrEqual(self, a, b):
1970 # Because 3.6 and 3.7 differ in how inspect.signature work
1971 # (see bpo #32108), for the time being just compare them with
1972 # whitespace stripped.
1973 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1974
1975 def test_existing_docstring_not_overridden(self):
1976 @dataclass
1977 class C:
1978 """Lorem ipsum"""
1979 x: int
1980
1981 self.assertEqual(C.__doc__, "Lorem ipsum")
1982
1983 def test_docstring_no_fields(self):
1984 @dataclass
1985 class C:
1986 pass
1987
1988 self.assertDocStrEqual(C.__doc__, "C()")
1989
1990 def test_docstring_one_field(self):
1991 @dataclass
1992 class C:
1993 x: int
1994
1995 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1996
1997 def test_docstring_two_fields(self):
1998 @dataclass
1999 class C:
2000 x: int
2001 y: int
2002
2003 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
2004
2005 def test_docstring_three_fields(self):
2006 @dataclass
2007 class C:
2008 x: int
2009 y: int
2010 z: str
2011
2012 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
2013
2014 def test_docstring_one_field_with_default(self):
2015 @dataclass
2016 class C:
2017 x: int = 3
2018
2019 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2020
2021 def test_docstring_one_field_with_default_none(self):
2022 @dataclass
2023 class C:
2024 x: Union[int, type(None)] = None
2025
2026 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2027
2028 def test_docstring_list_field(self):
2029 @dataclass
2030 class C:
2031 x: List[int]
2032
2033 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2034
2035 def test_docstring_list_field_with_default_factory(self):
2036 @dataclass
2037 class C:
2038 x: List[int] = field(default_factory=list)
2039
2040 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2041
2042 def test_docstring_deque_field(self):
2043 @dataclass
2044 class C:
2045 x: deque
2046
2047 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2048
2049 def test_docstring_deque_field_with_default_factory(self):
2050 @dataclass
2051 class C:
2052 x: deque = field(default_factory=deque)
2053
2054 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2055
2056
Eric V. Smithea8fc522018-01-27 19:07:40 -05002057class TestInit(unittest.TestCase):
2058 def test_base_has_init(self):
2059 class B:
2060 def __init__(self):
2061 self.z = 100
2062 pass
2063
2064 # Make sure that declaring this class doesn't raise an error.
2065 # The issue is that we can't override __init__ in our class,
2066 # but it should be okay to add __init__ to us if our base has
2067 # an __init__.
2068 @dataclass
2069 class C(B):
2070 x: int = 0
2071 c = C(10)
2072 self.assertEqual(c.x, 10)
2073 self.assertNotIn('z', vars(c))
2074
2075 # Make sure that if we don't add an init, the base __init__
2076 # gets called.
2077 @dataclass(init=False)
2078 class C(B):
2079 x: int = 10
2080 c = C()
2081 self.assertEqual(c.x, 10)
2082 self.assertEqual(c.z, 100)
2083
2084 def test_no_init(self):
2085 dataclass(init=False)
2086 class C:
2087 i: int = 0
2088 self.assertEqual(C().i, 0)
2089
2090 dataclass(init=False)
2091 class C:
2092 i: int = 2
2093 def __init__(self):
2094 self.i = 3
2095 self.assertEqual(C().i, 3)
2096
2097 def test_overwriting_init(self):
2098 # If the class has __init__, use it no matter the value of
2099 # init=.
2100
2101 @dataclass
2102 class C:
2103 x: int
2104 def __init__(self, x):
2105 self.x = 2 * x
2106 self.assertEqual(C(3).x, 6)
2107
2108 @dataclass(init=True)
2109 class C:
2110 x: int
2111 def __init__(self, x):
2112 self.x = 2 * x
2113 self.assertEqual(C(4).x, 8)
2114
2115 @dataclass(init=False)
2116 class C:
2117 x: int
2118 def __init__(self, x):
2119 self.x = 2 * x
2120 self.assertEqual(C(5).x, 10)
2121
2122
2123class TestRepr(unittest.TestCase):
2124 def test_repr(self):
2125 @dataclass
2126 class B:
2127 x: int
2128
2129 @dataclass
2130 class C(B):
2131 y: int = 10
2132
2133 o = C(4)
2134 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2135
2136 @dataclass
2137 class D(C):
2138 x: int = 20
2139 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2140
2141 @dataclass
2142 class C:
2143 @dataclass
2144 class D:
2145 i: int
2146 @dataclass
2147 class E:
2148 pass
2149 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2150 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2151
2152 def test_no_repr(self):
2153 # Test a class with no __repr__ and repr=False.
2154 @dataclass(repr=False)
2155 class C:
2156 x: int
2157 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2158 repr(C(3)))
2159
2160 # Test a class with a __repr__ and repr=False.
2161 @dataclass(repr=False)
2162 class C:
2163 x: int
2164 def __repr__(self):
2165 return 'C-class'
2166 self.assertEqual(repr(C(3)), 'C-class')
2167
2168 def test_overwriting_repr(self):
2169 # If the class has __repr__, use it no matter the value of
2170 # repr=.
2171
2172 @dataclass
2173 class C:
2174 x: int
2175 def __repr__(self):
2176 return 'x'
2177 self.assertEqual(repr(C(0)), 'x')
2178
2179 @dataclass(repr=True)
2180 class C:
2181 x: int
2182 def __repr__(self):
2183 return 'x'
2184 self.assertEqual(repr(C(0)), 'x')
2185
2186 @dataclass(repr=False)
2187 class C:
2188 x: int
2189 def __repr__(self):
2190 return 'x'
2191 self.assertEqual(repr(C(0)), 'x')
2192
2193
Eric V. Smithea8fc522018-01-27 19:07:40 -05002194class TestEq(unittest.TestCase):
2195 def test_no_eq(self):
2196 # Test a class with no __eq__ and eq=False.
2197 @dataclass(eq=False)
2198 class C:
2199 x: int
2200 self.assertNotEqual(C(0), C(0))
2201 c = C(3)
2202 self.assertEqual(c, c)
2203
2204 # Test a class with an __eq__ and eq=False.
2205 @dataclass(eq=False)
2206 class C:
2207 x: int
2208 def __eq__(self, other):
2209 return other == 10
2210 self.assertEqual(C(3), 10)
2211
2212 def test_overwriting_eq(self):
2213 # If the class has __eq__, use it no matter the value of
2214 # eq=.
2215
2216 @dataclass
2217 class C:
2218 x: int
2219 def __eq__(self, other):
2220 return other == 3
2221 self.assertEqual(C(1), 3)
2222 self.assertNotEqual(C(1), 1)
2223
2224 @dataclass(eq=True)
2225 class C:
2226 x: int
2227 def __eq__(self, other):
2228 return other == 4
2229 self.assertEqual(C(1), 4)
2230 self.assertNotEqual(C(1), 1)
2231
2232 @dataclass(eq=False)
2233 class C:
2234 x: int
2235 def __eq__(self, other):
2236 return other == 5
2237 self.assertEqual(C(1), 5)
2238 self.assertNotEqual(C(1), 1)
2239
2240
2241class TestOrdering(unittest.TestCase):
2242 def test_functools_total_ordering(self):
2243 # Test that functools.total_ordering works with this class.
2244 @total_ordering
2245 @dataclass
2246 class C:
2247 x: int
2248 def __lt__(self, other):
2249 # Perform the test "backward", just to make
2250 # sure this is being called.
2251 return self.x >= other
2252
2253 self.assertLess(C(0), -1)
2254 self.assertLessEqual(C(0), -1)
2255 self.assertGreater(C(0), 1)
2256 self.assertGreaterEqual(C(0), 1)
2257
2258 def test_no_order(self):
2259 # Test that no ordering functions are added by default.
2260 @dataclass(order=False)
2261 class C:
2262 x: int
2263 # Make sure no order methods are added.
2264 self.assertNotIn('__le__', C.__dict__)
2265 self.assertNotIn('__lt__', C.__dict__)
2266 self.assertNotIn('__ge__', C.__dict__)
2267 self.assertNotIn('__gt__', C.__dict__)
2268
2269 # Test that __lt__ is still called
2270 @dataclass(order=False)
2271 class C:
2272 x: int
2273 def __lt__(self, other):
2274 return False
2275 # Make sure other methods aren't added.
2276 self.assertNotIn('__le__', C.__dict__)
2277 self.assertNotIn('__ge__', C.__dict__)
2278 self.assertNotIn('__gt__', C.__dict__)
2279
2280 def test_overwriting_order(self):
2281 with self.assertRaisesRegex(TypeError,
2282 'Cannot overwrite attribute __lt__'
2283 '.*using functools.total_ordering'):
2284 @dataclass(order=True)
2285 class C:
2286 x: int
2287 def __lt__(self):
2288 pass
2289
2290 with self.assertRaisesRegex(TypeError,
2291 'Cannot overwrite attribute __le__'
2292 '.*using functools.total_ordering'):
2293 @dataclass(order=True)
2294 class C:
2295 x: int
2296 def __le__(self):
2297 pass
2298
2299 with self.assertRaisesRegex(TypeError,
2300 'Cannot overwrite attribute __gt__'
2301 '.*using functools.total_ordering'):
2302 @dataclass(order=True)
2303 class C:
2304 x: int
2305 def __gt__(self):
2306 pass
2307
2308 with self.assertRaisesRegex(TypeError,
2309 'Cannot overwrite attribute __ge__'
2310 '.*using functools.total_ordering'):
2311 @dataclass(order=True)
2312 class C:
2313 x: int
2314 def __ge__(self):
2315 pass
2316
2317class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002318 def test_unsafe_hash(self):
2319 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002320 class C:
2321 x: int
2322 y: str
2323 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2324
Eric V. Smithea8fc522018-01-27 19:07:40 -05002325 def test_hash_rules(self):
2326 def non_bool(value):
2327 # Map to something else that's True, but not a bool.
2328 if value is None:
2329 return None
2330 if value:
2331 return (3,)
2332 return 0
2333
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002334 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2335 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2336 frozen=frozen):
2337 if result != 'exception':
2338 if with_hash:
2339 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2340 class C:
2341 def __hash__(self):
2342 return 0
2343 else:
2344 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2345 class C:
2346 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002347
2348 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002349 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002350 # __hash__ contains the function we generated.
2351 self.assertIn('__hash__', C.__dict__)
2352 self.assertIsNotNone(C.__dict__['__hash__'])
2353
Eric V. Smithea8fc522018-01-27 19:07:40 -05002354 elif result == '':
2355 # __hash__ is not present in our class.
2356 if not with_hash:
2357 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002358
Eric V. Smithea8fc522018-01-27 19:07:40 -05002359 elif result == 'none':
2360 # __hash__ is set to None.
2361 self.assertIn('__hash__', C.__dict__)
2362 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002363
2364 elif result == 'exception':
2365 # Creating the class should cause an exception.
2366 # This only happens with with_hash==True.
2367 assert(with_hash)
2368 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2369 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2370 class C:
2371 def __hash__(self):
2372 return 0
2373
Eric V. Smithea8fc522018-01-27 19:07:40 -05002374 else:
2375 assert False, f'unknown result {result!r}'
2376
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002377 # There are 8 cases of:
2378 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002379 # eq=True/False
2380 # frozen=True/False
2381 # And for each of these, a different result if
2382 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002383 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2384 (False, False, False, '', ''),
2385 (False, False, True, '', ''),
2386 (False, True, False, 'none', ''),
2387 (False, True, True, 'fn', ''),
2388 (True, False, False, 'fn', 'exception'),
2389 (True, False, True, 'fn', 'exception'),
2390 (True, True, False, 'fn', 'exception'),
2391 (True, True, True, 'fn', 'exception'),
2392 ], 1):
2393 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2394 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002395
2396 # Test non-bool truth values, too. This is just to
2397 # make sure the data-driven table in the decorator
2398 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002399 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2400 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002401
2402
2403 def test_eq_only(self):
2404 # If a class defines __eq__, __hash__ is automatically added
2405 # and set to None. This is normal Python behavior, not
2406 # related to dataclasses. Make sure we don't interfere with
2407 # that (see bpo=32546).
2408
2409 @dataclass
2410 class C:
2411 i: int
2412 def __eq__(self, other):
2413 return self.i == other.i
2414 self.assertEqual(C(1), C(1))
2415 self.assertNotEqual(C(1), C(4))
2416
2417 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002418 # unsafe_hash=True.
2419 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002420 class C:
2421 i: int
2422 def __eq__(self, other):
2423 return self.i == other.i
2424 self.assertEqual(C(1), C(1.0))
2425 self.assertEqual(hash(C(1)), hash(C(1.0)))
2426
2427 # And check that the classes __eq__ is being used, despite
2428 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002429 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002430 class C:
2431 i: int
2432 def __eq__(self, other):
2433 return self.i == 3 and self.i == other.i
2434 self.assertEqual(C(3), C(3))
2435 self.assertNotEqual(C(1), C(1))
2436 self.assertEqual(hash(C(1)), hash(C(1.0)))
2437
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002438 def test_0_field_hash(self):
2439 @dataclass(frozen=True)
2440 class C:
2441 pass
2442 self.assertEqual(hash(C()), hash(()))
2443
2444 @dataclass(unsafe_hash=True)
2445 class C:
2446 pass
2447 self.assertEqual(hash(C()), hash(()))
2448
2449 def test_1_field_hash(self):
2450 @dataclass(frozen=True)
2451 class C:
2452 x: int
2453 self.assertEqual(hash(C(4)), hash((4,)))
2454 self.assertEqual(hash(C(42)), hash((42,)))
2455
2456 @dataclass(unsafe_hash=True)
2457 class C:
2458 x: int
2459 self.assertEqual(hash(C(4)), hash((4,)))
2460 self.assertEqual(hash(C(42)), hash((42,)))
2461
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002462 def test_hash_no_args(self):
2463 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002464 # make sure that if the @dataclass parameter name is changed
2465 # or the non-default hashing behavior changes, the default
2466 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002467
2468 class Base:
2469 def __hash__(self):
2470 return 301
2471
2472 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002473 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002474 for frozen, eq, base, expected in [
2475 (None, None, object, 'unhashable'),
2476 (None, None, Base, 'unhashable'),
2477 (None, False, object, 'object'),
2478 (None, False, Base, 'base'),
2479 (None, True, object, 'unhashable'),
2480 (None, True, Base, 'unhashable'),
2481 (False, None, object, 'unhashable'),
2482 (False, None, Base, 'unhashable'),
2483 (False, False, object, 'object'),
2484 (False, False, Base, 'base'),
2485 (False, True, object, 'unhashable'),
2486 (False, True, Base, 'unhashable'),
2487 (True, None, object, 'tuple'),
2488 (True, None, Base, 'tuple'),
2489 (True, False, object, 'object'),
2490 (True, False, Base, 'base'),
2491 (True, True, object, 'tuple'),
2492 (True, True, Base, 'tuple'),
2493 ]:
2494
2495 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2496 # First, create the class.
2497 if frozen is None and eq is None:
2498 @dataclass
2499 class C(base):
2500 i: int
2501 elif frozen is None:
2502 @dataclass(eq=eq)
2503 class C(base):
2504 i: int
2505 elif eq is None:
2506 @dataclass(frozen=frozen)
2507 class C(base):
2508 i: int
2509 else:
2510 @dataclass(frozen=frozen, eq=eq)
2511 class C(base):
2512 i: int
2513
2514 # Now, make sure it hashes as expected.
2515 if expected == 'unhashable':
2516 c = C(10)
2517 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2518 hash(c)
2519
2520 elif expected == 'base':
2521 self.assertEqual(hash(C(10)), 301)
2522
2523 elif expected == 'object':
2524 # I'm not sure what test to use here. object's
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002525 # hash isn't based on id(), so calling hash()
2526 # won't tell us much. So, just check the
2527 # function used is object's.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002528 self.assertIs(C.__hash__, object.__hash__)
2529
2530 elif expected == 'tuple':
2531 self.assertEqual(hash(C(42)), hash((42,)))
2532
2533 else:
2534 assert False, f'unknown value for expected={expected!r}'
2535
Eric V. Smithea8fc522018-01-27 19:07:40 -05002536
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002537class TestFrozen(unittest.TestCase):
2538 def test_frozen(self):
2539 @dataclass(frozen=True)
2540 class C:
2541 i: int
2542
2543 c = C(10)
2544 self.assertEqual(c.i, 10)
2545 with self.assertRaises(FrozenInstanceError):
2546 c.i = 5
2547 self.assertEqual(c.i, 10)
2548
2549 def test_inherit(self):
2550 @dataclass(frozen=True)
2551 class C:
2552 i: int
2553
2554 @dataclass(frozen=True)
2555 class D(C):
2556 j: int
2557
2558 d = D(0, 10)
2559 with self.assertRaises(FrozenInstanceError):
2560 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002561 with self.assertRaises(FrozenInstanceError):
2562 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002563 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002564 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002565
Miss Islington (bot)45648312018-03-18 18:03:36 -07002566 # Test both ways: with an intermediate normal (non-dataclass)
2567 # class and without an intermediate class.
2568 def test_inherit_nonfrozen_from_frozen(self):
2569 for intermediate_class in [True, False]:
2570 with self.subTest(intermediate_class=intermediate_class):
2571 @dataclass(frozen=True)
2572 class C:
2573 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002574
Miss Islington (bot)45648312018-03-18 18:03:36 -07002575 if intermediate_class:
2576 class I(C): pass
2577 else:
2578 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002579
Miss Islington (bot)45648312018-03-18 18:03:36 -07002580 with self.assertRaisesRegex(TypeError,
2581 'cannot inherit non-frozen dataclass from a frozen one'):
2582 @dataclass
2583 class D(I):
2584 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002585
Miss Islington (bot)45648312018-03-18 18:03:36 -07002586 def test_inherit_frozen_from_nonfrozen(self):
2587 for intermediate_class in [True, False]:
2588 with self.subTest(intermediate_class=intermediate_class):
2589 @dataclass
2590 class C:
2591 i: int
2592
2593 if intermediate_class:
2594 class I(C): pass
2595 else:
2596 I = C
2597
2598 with self.assertRaisesRegex(TypeError,
2599 'cannot inherit frozen dataclass from a non-frozen one'):
2600 @dataclass(frozen=True)
2601 class D(I):
2602 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002603
2604 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002605 for intermediate_class in [True, False]:
2606 with self.subTest(intermediate_class=intermediate_class):
2607 class C:
2608 pass
2609
2610 if intermediate_class:
2611 class I(C): pass
2612 else:
2613 I = C
2614
2615 @dataclass(frozen=True)
2616 class D(I):
2617 i: int
2618
2619 d = D(10)
2620 with self.assertRaises(FrozenInstanceError):
2621 d.i = 5
2622
2623 def test_non_frozen_normal_derived(self):
2624 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002625
2626 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002627 class D:
2628 x: int
2629 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002630
Miss Islington (bot)45648312018-03-18 18:03:36 -07002631 class S(D):
2632 pass
2633
2634 s = S(3)
2635 self.assertEqual(s.x, 3)
2636 self.assertEqual(s.y, 10)
2637 s.cached = True
2638
2639 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002640 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002641 s.x = 5
2642 with self.assertRaises(FrozenInstanceError):
2643 s.y = 5
2644 self.assertEqual(s.x, 3)
2645 self.assertEqual(s.y, 10)
2646 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002647
Miss Islington (bot)83f564f2018-04-05 04:12:31 -07002648 def test_overwriting_frozen(self):
2649 # frozen uses __setattr__ and __delattr__.
2650 with self.assertRaisesRegex(TypeError,
2651 'Cannot overwrite attribute __setattr__'):
2652 @dataclass(frozen=True)
2653 class C:
2654 x: int
2655 def __setattr__(self):
2656 pass
2657
2658 with self.assertRaisesRegex(TypeError,
2659 'Cannot overwrite attribute __delattr__'):
2660 @dataclass(frozen=True)
2661 class C:
2662 x: int
2663 def __delattr__(self):
2664 pass
2665
2666 @dataclass(frozen=False)
2667 class C:
2668 x: int
2669 def __setattr__(self, name, value):
2670 self.__dict__['x'] = value * 2
2671 self.assertEqual(C(10).x, 20)
2672
2673 def test_frozen_hash(self):
2674 @dataclass(frozen=True)
2675 class C:
2676 x: Any
2677
2678 # If x is immutable, we can compute the hash. No exception is
2679 # raised.
2680 hash(C(3))
2681
2682 # If x is mutable, computing the hash is an error.
2683 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2684 hash(C({}))
2685
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002686
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002687class TestSlots(unittest.TestCase):
2688 def test_simple(self):
2689 @dataclass
2690 class C:
2691 __slots__ = ('x',)
2692 x: Any
2693
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002694 # There was a bug where a variable in a slot was assumed to
2695 # also have a default value (of type
2696 # types.MemberDescriptorType).
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002697 with self.assertRaisesRegex(TypeError,
Miss Islington (bot)5729b9c2018-03-24 20:23:00 -07002698 r"__init__\(\) missing 1 required positional argument: 'x'"):
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002699 C()
2700
2701 # We can create an instance, and assign to x.
2702 c = C(10)
2703 self.assertEqual(c.x, 10)
2704 c.x = 5
2705 self.assertEqual(c.x, 5)
2706
2707 # We can't assign to anything else.
2708 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2709 c.y = 5
2710
2711 def test_derived_added_field(self):
2712 # See bpo-33100.
2713 @dataclass
2714 class Base:
2715 __slots__ = ('x',)
2716 x: Any
2717
2718 @dataclass
2719 class Derived(Base):
2720 x: int
2721 y: int
2722
2723 d = Derived(1, 2)
2724 self.assertEqual((d.x, d.y), (1, 2))
2725
2726 # We can add a new field to the derived instance.
2727 d.z = 10
2728
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002729class TestDescriptors(unittest.TestCase):
2730 def test_set_name(self):
2731 # See bpo-33141.
2732
2733 # Create a descriptor.
2734 class D:
2735 def __set_name__(self, owner, name):
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002736 self.name = name + 'x'
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002737 def __get__(self, instance, owner):
2738 if instance is not None:
2739 return 1
2740 return self
2741
2742 # This is the case of just normal descriptor behavior, no
2743 # dataclass code is involved in initializing the descriptor.
2744 @dataclass
2745 class C:
2746 c: int=D()
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002747 self.assertEqual(C.c.name, 'cx')
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002748
2749 # Now test with a default value and init=False, which is the
2750 # only time this is really meaningful. If not using
2751 # init=False, then the descriptor will be overwritten, anyway.
2752 @dataclass
2753 class C:
2754 c: int=field(default=D(), init=False)
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002755 self.assertEqual(C.c.name, 'cx')
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002756 self.assertEqual(C().c, 1)
2757
2758 def test_non_descriptor(self):
2759 # PEP 487 says __set_name__ should work on non-descriptors.
2760 # Create a descriptor.
2761
2762 class D:
2763 def __set_name__(self, owner, name):
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002764 self.name = name + 'x'
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002765
2766 @dataclass
2767 class C:
2768 c: int=field(default=D(), init=False)
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002769 self.assertEqual(C.c.name, 'cx')
2770
2771 def test_lookup_on_instance(self):
2772 # See bpo-33175.
2773 class D:
2774 pass
2775
2776 d = D()
2777 # Create an attribute on the instance, not type.
2778 d.__set_name__ = Mock()
2779
2780 # Make sure d.__set_name__ is not called.
2781 @dataclass
2782 class C:
2783 i: int=field(default=d, init=False)
2784
2785 self.assertEqual(d.__set_name__.call_count, 0)
2786
2787 def test_lookup_on_class(self):
2788 # See bpo-33175.
2789 class D:
2790 pass
2791 D.__set_name__ = Mock()
2792
2793 # Make sure D.__set_name__ is called.
2794 @dataclass
2795 class C:
2796 i: int=field(default=D(), init=False)
2797
2798 self.assertEqual(D.__set_name__.call_count, 1)
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002799
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002800
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002801if __name__ == '__main__':
2802 unittest.main()