blob: 2745eaf6893b7499f4dc9e84764b86f2c66761eb [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
11from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar
12from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
15# Just any custom exception we can catch.
16class CustomError(Exception): pass
17
18class TestCase(unittest.TestCase):
19 def test_no_fields(self):
20 @dataclass
21 class C:
22 pass
23
24 o = C()
25 self.assertEqual(len(fields(C)), 0)
26
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -070027 def test_no_fields_but_member_variable(self):
28 @dataclass
29 class C:
30 i = 0
31
32 o = C()
33 self.assertEqual(len(fields(C)), 0)
34
Eric V. Smithf0db54a2017-12-04 16:58:55 -050035 def test_one_field_no_default(self):
36 @dataclass
37 class C:
38 x: int
39
40 o = C(42)
41 self.assertEqual(o.x, 42)
42
43 def test_named_init_params(self):
44 @dataclass
45 class C:
46 x: int
47
48 o = C(x=32)
49 self.assertEqual(o.x, 32)
50
51 def test_two_fields_one_default(self):
52 @dataclass
53 class C:
54 x: int
55 y: int = 0
56
57 o = C(3)
58 self.assertEqual((o.x, o.y), (3, 0))
59
60 # Non-defaults following defaults.
61 with self.assertRaisesRegex(TypeError,
62 "non-default argument 'y' follows "
63 "default argument"):
64 @dataclass
65 class C:
66 x: int = 0
67 y: int
68
69 # A derived class adds a non-default field after a default one.
70 with self.assertRaisesRegex(TypeError,
71 "non-default argument 'y' follows "
72 "default argument"):
73 @dataclass
74 class B:
75 x: int = 0
76
77 @dataclass
78 class C(B):
79 y: int
80
81 # Override a base class field and add a default to
82 # a field which didn't use to have a default.
83 with self.assertRaisesRegex(TypeError,
84 "non-default argument 'y' follows "
85 "default argument"):
86 @dataclass
87 class B:
88 x: int
89 y: int
90
91 @dataclass
92 class C(B):
93 x: int = 0
94
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080095 def test_overwrite_hash(self):
96 # Test that declaring this class isn't an error. It should
97 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -050098 @dataclass(frozen=True)
99 class C:
100 x: int
101 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800102 return 301
103 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500104
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 # Test that declaring this class isn't an error. It should
106 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500107 @dataclass(frozen=True)
108 class C:
109 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800110 def __eq__(self, other):
111 return False
112 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500113
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800114 # But this one should generate an exception, because with
115 # unsafe_hash=True, it's an error to have a __hash__ defined.
116 with self.assertRaisesRegex(TypeError,
117 'Cannot overwrite attribute __hash__'):
118 @dataclass(unsafe_hash=True)
119 class C:
120 def __hash__(self):
121 pass
122
123 # Creating this class should not generate an exception,
124 # because even though __hash__ exists before @dataclass is
125 # called, (due to __eq__ being defined), since it's None
126 # that's okay.
127 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500128 class C:
129 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800130 def __eq__(self):
131 pass
132 # The generated hash function works as we'd expect.
133 self.assertEqual(hash(C(10)), hash((10,)))
134
135 # Creating this class should generate an exception, because
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700136 # __hash__ exists and is not None, which it would be if it
137 # had been auto-generated due to __eq__ being defined.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800138 with self.assertRaisesRegex(TypeError,
139 'Cannot overwrite attribute __hash__'):
140 @dataclass(unsafe_hash=True)
141 class C:
142 x: int
143 def __eq__(self):
144 pass
145 def __hash__(self):
146 pass
147
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500148 def test_overwrite_fields_in_derived_class(self):
149 # Note that x from C1 replaces x in Base, but the order remains
150 # the same as defined in Base.
151 @dataclass
152 class Base:
153 x: Any = 15.0
154 y: int = 0
155
156 @dataclass
157 class C1(Base):
158 z: int = 10
159 x: int = 15
160
161 o = Base()
162 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
163
164 o = C1()
165 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
166
167 o = C1(x=5)
168 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
169
170 def test_field_named_self(self):
171 @dataclass
172 class C:
173 self: str
174 c=C('foo')
175 self.assertEqual(c.self, 'foo')
176
177 # Make sure the first parameter is not named 'self'.
178 sig = inspect.signature(C.__init__)
179 first = next(iter(sig.parameters))
180 self.assertNotEqual('self', first)
181
182 # But we do use 'self' if no field named self.
183 @dataclass
184 class C:
185 selfx: str
186
187 # Make sure the first parameter is named 'self'.
188 sig = inspect.signature(C.__init__)
189 first = next(iter(sig.parameters))
190 self.assertEqual('self', first)
191
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500192 def test_0_field_compare(self):
193 # Ensure that order=False is the default.
194 @dataclass
195 class C0:
196 pass
197
198 @dataclass(order=False)
199 class C1:
200 pass
201
202 for cls in [C0, C1]:
203 with self.subTest(cls=cls):
204 self.assertEqual(cls(), cls())
205 for idx, fn in enumerate([lambda a, b: a < b,
206 lambda a, b: a <= b,
207 lambda a, b: a > b,
208 lambda a, b: a >= b]):
209 with self.subTest(idx=idx):
210 with self.assertRaisesRegex(TypeError,
211 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
212 fn(cls(), cls())
213
214 @dataclass(order=True)
215 class C:
216 pass
217 self.assertLessEqual(C(), C())
218 self.assertGreaterEqual(C(), C())
219
220 def test_1_field_compare(self):
221 # Ensure that order=False is the default.
222 @dataclass
223 class C0:
224 x: int
225
226 @dataclass(order=False)
227 class C1:
228 x: int
229
230 for cls in [C0, C1]:
231 with self.subTest(cls=cls):
232 self.assertEqual(cls(1), cls(1))
233 self.assertNotEqual(cls(0), cls(1))
234 for idx, fn in enumerate([lambda a, b: a < b,
235 lambda a, b: a <= b,
236 lambda a, b: a > b,
237 lambda a, b: a >= b]):
238 with self.subTest(idx=idx):
239 with self.assertRaisesRegex(TypeError,
240 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
241 fn(cls(0), cls(0))
242
243 @dataclass(order=True)
244 class C:
245 x: int
246 self.assertLess(C(0), C(1))
247 self.assertLessEqual(C(0), C(1))
248 self.assertLessEqual(C(1), C(1))
249 self.assertGreater(C(1), C(0))
250 self.assertGreaterEqual(C(1), C(0))
251 self.assertGreaterEqual(C(1), C(1))
252
253 def test_simple_compare(self):
254 # Ensure that order=False is the default.
255 @dataclass
256 class C0:
257 x: int
258 y: int
259
260 @dataclass(order=False)
261 class C1:
262 x: int
263 y: int
264
265 for cls in [C0, C1]:
266 with self.subTest(cls=cls):
267 self.assertEqual(cls(0, 0), cls(0, 0))
268 self.assertEqual(cls(1, 2), cls(1, 2))
269 self.assertNotEqual(cls(1, 0), cls(0, 0))
270 self.assertNotEqual(cls(1, 0), cls(1, 1))
271 for idx, fn in enumerate([lambda a, b: a < b,
272 lambda a, b: a <= b,
273 lambda a, b: a > b,
274 lambda a, b: a >= b]):
275 with self.subTest(idx=idx):
276 with self.assertRaisesRegex(TypeError,
277 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
278 fn(cls(0, 0), cls(0, 0))
279
280 @dataclass(order=True)
281 class C:
282 x: int
283 y: int
284
285 for idx, fn in enumerate([lambda a, b: a == b,
286 lambda a, b: a <= b,
287 lambda a, b: a >= b]):
288 with self.subTest(idx=idx):
289 self.assertTrue(fn(C(0, 0), C(0, 0)))
290
291 for idx, fn in enumerate([lambda a, b: a < b,
292 lambda a, b: a <= b,
293 lambda a, b: a != b]):
294 with self.subTest(idx=idx):
295 self.assertTrue(fn(C(0, 0), C(0, 1)))
296 self.assertTrue(fn(C(0, 1), C(1, 0)))
297 self.assertTrue(fn(C(1, 0), C(1, 1)))
298
299 for idx, fn in enumerate([lambda a, b: a > b,
300 lambda a, b: a >= b,
301 lambda a, b: a != b]):
302 with self.subTest(idx=idx):
303 self.assertTrue(fn(C(0, 1), C(0, 0)))
304 self.assertTrue(fn(C(1, 0), C(0, 1)))
305 self.assertTrue(fn(C(1, 1), C(1, 0)))
306
307 def test_compare_subclasses(self):
308 # Comparisons fail for subclasses, even if no fields
309 # are added.
310 @dataclass
311 class B:
312 i: int
313
314 @dataclass
315 class C(B):
316 pass
317
318 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
319 (lambda a, b: a != b, True)]):
320 with self.subTest(idx=idx):
321 self.assertEqual(fn(B(0), C(0)), expected)
322
323 for idx, fn in enumerate([lambda a, b: a < b,
324 lambda a, b: a <= b,
325 lambda a, b: a > b,
326 lambda a, b: a >= b]):
327 with self.subTest(idx=idx):
328 with self.assertRaisesRegex(TypeError,
329 "not supported between instances of 'B' and 'C'"):
330 fn(B(0), C(0))
331
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500332 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500333 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500334 for (eq, order, result ) in [
335 (False, False, 'neither'),
336 (False, True, 'exception'),
337 (True, False, 'eq_only'),
338 (True, True, 'both'),
339 ]:
340 with self.subTest(eq=eq, order=order):
341 if result == 'exception':
342 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
343 @dataclass(eq=eq, order=order)
344 class C:
345 pass
346 else:
347 @dataclass(eq=eq, order=order)
348 class C:
349 pass
350
351 if result == 'neither':
352 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500353 self.assertNotIn('__lt__', C.__dict__)
354 self.assertNotIn('__le__', C.__dict__)
355 self.assertNotIn('__gt__', C.__dict__)
356 self.assertNotIn('__ge__', C.__dict__)
357 elif result == 'both':
358 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500359 self.assertIn('__lt__', C.__dict__)
360 self.assertIn('__le__', C.__dict__)
361 self.assertIn('__gt__', C.__dict__)
362 self.assertIn('__ge__', C.__dict__)
363 elif result == 'eq_only':
364 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500365 self.assertNotIn('__lt__', C.__dict__)
366 self.assertNotIn('__le__', C.__dict__)
367 self.assertNotIn('__gt__', C.__dict__)
368 self.assertNotIn('__ge__', C.__dict__)
369 else:
370 assert False, f'unknown result {result!r}'
371
372 def test_field_no_default(self):
373 @dataclass
374 class C:
375 x: int = field()
376
377 self.assertEqual(C(5).x, 5)
378
379 with self.assertRaisesRegex(TypeError,
380 r"__init__\(\) missing 1 required "
381 "positional argument: 'x'"):
382 C()
383
384 def test_field_default(self):
385 default = object()
386 @dataclass
387 class C:
388 x: object = field(default=default)
389
390 self.assertIs(C.x, default)
391 c = C(10)
392 self.assertEqual(c.x, 10)
393
394 # If we delete the instance attribute, we should then see the
395 # class attribute.
396 del c.x
397 self.assertIs(c.x, default)
398
399 self.assertIs(C().x, default)
400
401 def test_not_in_repr(self):
402 @dataclass
403 class C:
404 x: int = field(repr=False)
405 with self.assertRaises(TypeError):
406 C()
407 c = C(10)
408 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
409
410 @dataclass
411 class C:
412 x: int = field(repr=False)
413 y: int
414 c = C(10, 20)
415 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
416
417 def test_not_in_compare(self):
418 @dataclass
419 class C:
420 x: int = 0
421 y: int = field(compare=False, default=4)
422
423 self.assertEqual(C(), C(0, 20))
424 self.assertEqual(C(1, 10), C(1, 20))
425 self.assertNotEqual(C(3), C(4, 10))
426 self.assertNotEqual(C(3, 10), C(4, 10))
427
428 def test_hash_field_rules(self):
429 # Test all 6 cases of:
430 # hash=True/False/None
431 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800432 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500433 (True, False, 'field' ),
434 (True, True, 'field' ),
435 (False, False, 'absent'),
436 (False, True, 'absent'),
437 (None, False, 'absent'),
438 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800439 ]:
440 with self.subTest(hash=hash_, compare=compare):
441 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500442 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800443 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500444
445 if result == 'field':
446 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800447 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500448 elif result == 'absent':
449 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800450 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500451 else:
452 assert False, f'unknown result {result!r}'
453
454 def test_init_false_no_default(self):
455 # If init=False and no default value, then the field won't be
456 # present in the instance.
457 @dataclass
458 class C:
459 x: int = field(init=False)
460
461 self.assertNotIn('x', C().__dict__)
462
463 @dataclass
464 class C:
465 x: int
466 y: int = 0
467 z: int = field(init=False)
468 t: int = 10
469
470 self.assertNotIn('z', C(0).__dict__)
471 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
472
473 def test_class_marker(self):
474 @dataclass
475 class C:
476 x: int
477 y: str = field(init=False, default=None)
478 z: str = field(repr=False)
479
480 the_fields = fields(C)
481 # the_fields is a tuple of 3 items, each value
482 # is in __annotations__.
483 self.assertIsInstance(the_fields, tuple)
484 for f in the_fields:
485 self.assertIs(type(f), Field)
486 self.assertIn(f.name, C.__annotations__)
487
488 self.assertEqual(len(the_fields), 3)
489
490 self.assertEqual(the_fields[0].name, 'x')
491 self.assertEqual(the_fields[0].type, int)
492 self.assertFalse(hasattr(C, 'x'))
493 self.assertTrue (the_fields[0].init)
494 self.assertTrue (the_fields[0].repr)
495 self.assertEqual(the_fields[1].name, 'y')
496 self.assertEqual(the_fields[1].type, str)
497 self.assertIsNone(getattr(C, 'y'))
498 self.assertFalse(the_fields[1].init)
499 self.assertTrue (the_fields[1].repr)
500 self.assertEqual(the_fields[2].name, 'z')
501 self.assertEqual(the_fields[2].type, str)
502 self.assertFalse(hasattr(C, 'z'))
503 self.assertTrue (the_fields[2].init)
504 self.assertFalse(the_fields[2].repr)
505
506 def test_field_order(self):
507 @dataclass
508 class B:
509 a: str = 'B:a'
510 b: str = 'B:b'
511 c: str = 'B:c'
512
513 @dataclass
514 class C(B):
515 b: str = 'C:b'
516
517 self.assertEqual([(f.name, f.default) for f in fields(C)],
518 [('a', 'B:a'),
519 ('b', 'C:b'),
520 ('c', 'B:c')])
521
522 @dataclass
523 class D(B):
524 c: str = 'D:c'
525
526 self.assertEqual([(f.name, f.default) for f in fields(D)],
527 [('a', 'B:a'),
528 ('b', 'B:b'),
529 ('c', 'D:c')])
530
531 @dataclass
532 class E(D):
533 a: str = 'E:a'
534 d: str = 'E:d'
535
536 self.assertEqual([(f.name, f.default) for f in fields(E)],
537 [('a', 'E:a'),
538 ('b', 'B:b'),
539 ('c', 'D:c'),
540 ('d', 'E:d')])
541
542 def test_class_attrs(self):
543 # We only have a class attribute if a default value is
544 # specified, either directly or via a field with a default.
545 default = object()
546 @dataclass
547 class C:
548 x: int
549 y: int = field(repr=False)
550 z: object = default
551 t: int = field(default=100)
552
553 self.assertFalse(hasattr(C, 'x'))
554 self.assertFalse(hasattr(C, 'y'))
555 self.assertIs (C.z, default)
556 self.assertEqual(C.t, 100)
557
558 def test_disallowed_mutable_defaults(self):
559 # For the known types, don't allow mutable default values.
560 for typ, empty, non_empty in [(list, [], [1]),
561 (dict, {}, {0:1}),
562 (set, set(), set([1])),
563 ]:
564 with self.subTest(typ=typ):
565 # Can't use a zero-length value.
566 with self.assertRaisesRegex(ValueError,
567 f'mutable default {typ} for field '
568 'x is not allowed'):
569 @dataclass
570 class Point:
571 x: typ = empty
572
573
574 # Nor a non-zero-length value
575 with self.assertRaisesRegex(ValueError,
576 f'mutable default {typ} for field '
577 'y is not allowed'):
578 @dataclass
579 class Point:
580 y: typ = non_empty
581
582 # Check subtypes also fail.
583 class Subclass(typ): pass
584
585 with self.assertRaisesRegex(ValueError,
586 f"mutable default .*Subclass'>"
587 ' for field z is not allowed'
588 ):
589 @dataclass
590 class Point:
591 z: typ = Subclass()
592
593 # Because this is a ClassVar, it can be mutable.
594 @dataclass
595 class C:
596 z: ClassVar[typ] = typ()
597
598 # Because this is a ClassVar, it can be mutable.
599 @dataclass
600 class C:
601 x: ClassVar[typ] = Subclass()
602
603
604 def test_deliberately_mutable_defaults(self):
605 # If a mutable default isn't in the known list of
606 # (list, dict, set), then it's okay.
607 class Mutable:
608 def __init__(self):
609 self.l = []
610
611 @dataclass
612 class C:
613 x: Mutable
614
615 # These 2 instances will share this value of x.
616 lst = Mutable()
617 o1 = C(lst)
618 o2 = C(lst)
619 self.assertEqual(o1, o2)
620 o1.x.l.extend([1, 2])
621 self.assertEqual(o1, o2)
622 self.assertEqual(o1.x.l, [1, 2])
623 self.assertIs(o1.x, o2.x)
624
625 def test_no_options(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700626 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500627 @dataclass()
628 class C:
629 x: int
630
631 self.assertEqual(C(42).x, 42)
632
633 def test_not_tuple(self):
634 # Make sure we can't be compared to a tuple.
635 @dataclass
636 class Point:
637 x: int
638 y: int
639 self.assertNotEqual(Point(1, 2), (1, 2))
640
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700641 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500642 @dataclass
643 class C:
644 x: int
645 y: int
646 self.assertNotEqual(Point(1, 3), C(1, 3))
647
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500648 def test_not_tuple(self):
649 # Test that some of the problems with namedtuple don't happen
650 # here.
651 @dataclass
652 class Point3D:
653 x: int
654 y: int
655 z: int
656
657 @dataclass
658 class Date:
659 year: int
660 month: int
661 day: int
662
663 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
664 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
665
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700666 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200667 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500668 x, y, z = Point3D(4, 5, 6)
669
Eric V. Smith7c99e932018-01-28 19:18:55 -0500670 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500671 # equal.
672 @dataclass
673 class Point3Dv1:
674 x: int = 0
675 y: int = 0
676 z: int = 0
677 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
678
679 def test_function_annotations(self):
680 # Some dummy class and instance to use as a default.
681 class F:
682 pass
683 f = F()
684
685 def validate_class(cls):
686 # First, check __annotations__, even though they're not
687 # function annotations.
688 self.assertEqual(cls.__annotations__['i'], int)
689 self.assertEqual(cls.__annotations__['j'], str)
690 self.assertEqual(cls.__annotations__['k'], F)
691 self.assertEqual(cls.__annotations__['l'], float)
692 self.assertEqual(cls.__annotations__['z'], complex)
693
694 # Verify __init__.
695
696 signature = inspect.signature(cls.__init__)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700697 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500698 self.assertIs(signature.return_annotation, None)
699
700 # Check each parameter.
701 params = iter(signature.parameters.values())
702 param = next(params)
703 # This is testing an internal name, and probably shouldn't be tested.
704 self.assertEqual(param.name, 'self')
705 param = next(params)
706 self.assertEqual(param.name, 'i')
707 self.assertIs (param.annotation, int)
708 self.assertEqual(param.default, inspect.Parameter.empty)
709 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
710 param = next(params)
711 self.assertEqual(param.name, 'j')
712 self.assertIs (param.annotation, str)
713 self.assertEqual(param.default, inspect.Parameter.empty)
714 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
715 param = next(params)
716 self.assertEqual(param.name, 'k')
717 self.assertIs (param.annotation, F)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700718 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500719 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
720 param = next(params)
721 self.assertEqual(param.name, 'l')
722 self.assertIs (param.annotation, float)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700723 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500724 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
725 self.assertRaises(StopIteration, next, params)
726
727
728 @dataclass
729 class C:
730 i: int
731 j: str
732 k: F = f
733 l: float=field(default=None)
734 z: complex=field(default=3+4j, init=False)
735
736 validate_class(C)
737
738 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800739 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500740 class C:
741 i: int
742 j: str
743 k: F = f
744 l: float=field(default=None)
745 z: complex=field(default=3+4j, init=False)
746
747 validate_class(C)
748
Eric V. Smith03220fd2017-12-29 13:59:58 -0500749 def test_missing_default(self):
750 # Test that MISSING works the same as a default not being
751 # specified.
752 @dataclass
753 class C:
754 x: int=field(default=MISSING)
755 with self.assertRaisesRegex(TypeError,
756 r'__init__\(\) missing 1 required '
757 'positional argument'):
758 C()
759 self.assertNotIn('x', C.__dict__)
760
761 @dataclass
762 class D:
763 x: int
764 with self.assertRaisesRegex(TypeError,
765 r'__init__\(\) missing 1 required '
766 'positional argument'):
767 D()
768 self.assertNotIn('x', D.__dict__)
769
770 def test_missing_default_factory(self):
771 # Test that MISSING works the same as a default factory not
772 # being specified (which is really the same as a default not
773 # being specified, too).
774 @dataclass
775 class C:
776 x: int=field(default_factory=MISSING)
777 with self.assertRaisesRegex(TypeError,
778 r'__init__\(\) missing 1 required '
779 'positional argument'):
780 C()
781 self.assertNotIn('x', C.__dict__)
782
783 @dataclass
784 class D:
785 x: int=field(default=MISSING, default_factory=MISSING)
786 with self.assertRaisesRegex(TypeError,
787 r'__init__\(\) missing 1 required '
788 'positional argument'):
789 D()
790 self.assertNotIn('x', D.__dict__)
791
792 def test_missing_repr(self):
793 self.assertIn('MISSING_TYPE object', repr(MISSING))
794
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500795 def test_dont_include_other_annotations(self):
796 @dataclass
797 class C:
798 i: int
799 def foo(self) -> int:
800 return 4
801 @property
802 def bar(self) -> int:
803 return 5
804 self.assertEqual(list(C.__annotations__), ['i'])
805 self.assertEqual(C(10).foo(), 4)
806 self.assertEqual(C(10).bar, 5)
Miss Islington (bot)5666a552018-03-25 06:27:50 -0700807 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500808
809 def test_post_init(self):
810 # Just make sure it gets called
811 @dataclass
812 class C:
813 def __post_init__(self):
814 raise CustomError()
815 with self.assertRaises(CustomError):
816 C()
817
818 @dataclass
819 class C:
820 i: int = 10
821 def __post_init__(self):
822 if self.i == 10:
823 raise CustomError()
824 with self.assertRaises(CustomError):
825 C()
826 # post-init gets called, but doesn't raise. This is just
827 # checking that self is used correctly.
828 C(5)
829
830 # If there's not an __init__, then post-init won't get called.
831 @dataclass(init=False)
832 class C:
833 def __post_init__(self):
834 raise CustomError()
835 # Creating the class won't raise
836 C()
837
838 @dataclass
839 class C:
840 x: int = 0
841 def __post_init__(self):
842 self.x *= 2
843 self.assertEqual(C().x, 0)
844 self.assertEqual(C(2).x, 4)
845
Mike53f7a7c2017-12-14 14:04:53 +0300846 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 # attributes.
848 @dataclass(frozen=True)
849 class C:
850 x: int = 0
851 def __post_init__(self):
852 self.x *= 2
853 with self.assertRaises(FrozenInstanceError):
854 C()
855
856 def test_post_init_super(self):
857 # Make sure super() post-init isn't called by default.
858 class B:
859 def __post_init__(self):
860 raise CustomError()
861
862 @dataclass
863 class C(B):
864 def __post_init__(self):
865 self.x = 5
866
867 self.assertEqual(C().x, 5)
868
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700869 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500870 @dataclass
871 class C(B):
872 def __post_init__(self):
873 super().__post_init__()
874
875 with self.assertRaises(CustomError):
876 C()
877
878 # Make sure post-init is called, even if not defined in our
879 # class.
880 @dataclass
881 class C(B):
882 pass
883
884 with self.assertRaises(CustomError):
885 C()
886
887 def test_post_init_staticmethod(self):
888 flag = False
889 @dataclass
890 class C:
891 x: int
892 y: int
893 @staticmethod
894 def __post_init__():
895 nonlocal flag
896 flag = True
897
898 self.assertFalse(flag)
899 c = C(3, 4)
900 self.assertEqual((c.x, c.y), (3, 4))
901 self.assertTrue(flag)
902
903 def test_post_init_classmethod(self):
904 @dataclass
905 class C:
906 flag = False
907 x: int
908 y: int
909 @classmethod
910 def __post_init__(cls):
911 cls.flag = True
912
913 self.assertFalse(C.flag)
914 c = C(3, 4)
915 self.assertEqual((c.x, c.y), (3, 4))
916 self.assertTrue(C.flag)
917
918 def test_class_var(self):
919 # Make sure ClassVars are ignored in __init__, __repr__, etc.
920 @dataclass
921 class C:
922 x: int
923 y: int = 10
924 z: ClassVar[int] = 1000
925 w: ClassVar[int] = 2000
926 t: ClassVar[int] = 3000
927
928 c = C(5)
929 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700930 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
931 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500932 self.assertEqual(c.z, 1000)
933 self.assertEqual(c.w, 2000)
934 self.assertEqual(c.t, 3000)
935 C.z += 1
936 self.assertEqual(c.z, 1001)
937 c = C(20)
938 self.assertEqual((c.x, c.y), (20, 10))
939 self.assertEqual(c.z, 1001)
940 self.assertEqual(c.w, 2000)
941 self.assertEqual(c.t, 3000)
942
943 def test_class_var_no_default(self):
944 # If a ClassVar has no default value, it should not be set on the class.
945 @dataclass
946 class C:
947 x: ClassVar[int]
948
949 self.assertNotIn('x', C.__dict__)
950
951 def test_class_var_default_factory(self):
952 # It makes no sense for a ClassVar to have a default factory. When
953 # would it be called? Call it yourself, since it's class-wide.
954 with self.assertRaisesRegex(TypeError,
955 'cannot have a default factory'):
956 @dataclass
957 class C:
958 x: ClassVar[int] = field(default_factory=int)
959
960 self.assertNotIn('x', C.__dict__)
961
962 def test_class_var_with_default(self):
963 # If a ClassVar has a default value, it should be set on the class.
964 @dataclass
965 class C:
966 x: ClassVar[int] = 10
967 self.assertEqual(C.x, 10)
968
969 @dataclass
970 class C:
971 x: ClassVar[int] = field(default=10)
972 self.assertEqual(C.x, 10)
973
974 def test_class_var_frozen(self):
975 # Make sure ClassVars work even if we're frozen.
976 @dataclass(frozen=True)
977 class C:
978 x: int
979 y: int = 10
980 z: ClassVar[int] = 1000
981 w: ClassVar[int] = 2000
982 t: ClassVar[int] = 3000
983
984 c = C(5)
985 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
986 self.assertEqual(len(fields(C)), 2) # We have 2 fields
987 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
988 self.assertEqual(c.z, 1000)
989 self.assertEqual(c.w, 2000)
990 self.assertEqual(c.t, 3000)
991 # We can still modify the ClassVar, it's only instances that are
992 # frozen.
993 C.z += 1
994 self.assertEqual(c.z, 1001)
995 c = C(20)
996 self.assertEqual((c.x, c.y), (20, 10))
997 self.assertEqual(c.z, 1001)
998 self.assertEqual(c.w, 2000)
999 self.assertEqual(c.t, 3000)
1000
1001 def test_init_var_no_default(self):
1002 # If an InitVar has no default value, it should not be set on the class.
1003 @dataclass
1004 class C:
1005 x: InitVar[int]
1006
1007 self.assertNotIn('x', C.__dict__)
1008
1009 def test_init_var_default_factory(self):
1010 # It makes no sense for an InitVar to have a default factory. When
1011 # would it be called? Call it yourself, since it's class-wide.
1012 with self.assertRaisesRegex(TypeError,
1013 'cannot have a default factory'):
1014 @dataclass
1015 class C:
1016 x: InitVar[int] = field(default_factory=int)
1017
1018 self.assertNotIn('x', C.__dict__)
1019
1020 def test_init_var_with_default(self):
1021 # If an InitVar has a default value, it should be set on the class.
1022 @dataclass
1023 class C:
1024 x: InitVar[int] = 10
1025 self.assertEqual(C.x, 10)
1026
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = field(default=10)
1030 self.assertEqual(C.x, 10)
1031
1032 def test_init_var(self):
1033 @dataclass
1034 class C:
1035 x: int = None
1036 init_param: InitVar[int] = None
1037
1038 def __post_init__(self, init_param):
1039 if self.x is None:
1040 self.x = init_param*2
1041
1042 c = C(init_param=10)
1043 self.assertEqual(c.x, 20)
1044
1045 def test_init_var_inheritance(self):
1046 # Note that this deliberately tests that a dataclass need not
1047 # have a __post_init__ function if it has an InitVar field.
1048 # It could just be used in a derived class, as shown here.
1049 @dataclass
1050 class Base:
1051 x: int
1052 init_base: InitVar[int]
1053
1054 # We can instantiate by passing the InitVar, even though
1055 # it's not used.
1056 b = Base(0, 10)
1057 self.assertEqual(vars(b), {'x': 0})
1058
1059 @dataclass
1060 class C(Base):
1061 y: int
1062 init_derived: InitVar[int]
1063
1064 def __post_init__(self, init_base, init_derived):
1065 self.x = self.x + init_base
1066 self.y = self.y + init_derived
1067
1068 c = C(10, 11, 50, 51)
1069 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1070
1071 def test_default_factory(self):
1072 # Test a factory that returns a new list.
1073 @dataclass
1074 class C:
1075 x: int
1076 y: list = field(default_factory=list)
1077
1078 c0 = C(3)
1079 c1 = C(3)
1080 self.assertEqual(c0.x, 3)
1081 self.assertEqual(c0.y, [])
1082 self.assertEqual(c0, c1)
1083 self.assertIsNot(c0.y, c1.y)
1084 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1085
1086 # Test a factory that returns a shared list.
1087 l = []
1088 @dataclass
1089 class C:
1090 x: int
1091 y: list = field(default_factory=lambda: l)
1092
1093 c0 = C(3)
1094 c1 = C(3)
1095 self.assertEqual(c0.x, 3)
1096 self.assertEqual(c0.y, [])
1097 self.assertEqual(c0, c1)
1098 self.assertIs(c0.y, c1.y)
1099 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1100
1101 # Test various other field flags.
1102 # repr
1103 @dataclass
1104 class C:
1105 x: list = field(default_factory=list, repr=False)
1106 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1107 self.assertEqual(C().x, [])
1108
1109 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001110 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001111 class C:
1112 x: list = field(default_factory=list, hash=False)
1113 self.assertEqual(astuple(C()), ([],))
1114 self.assertEqual(hash(C()), hash(()))
1115
1116 # init (see also test_default_factory_with_no_init)
1117 @dataclass
1118 class C:
1119 x: list = field(default_factory=list, init=False)
1120 self.assertEqual(astuple(C()), ([],))
1121
1122 # compare
1123 @dataclass
1124 class C:
1125 x: list = field(default_factory=list, compare=False)
1126 self.assertEqual(C(), C([1]))
1127
1128 def test_default_factory_with_no_init(self):
1129 # We need a factory with a side effect.
1130 factory = Mock()
1131
1132 @dataclass
1133 class C:
1134 x: list = field(default_factory=factory, init=False)
1135
1136 # Make sure the default factory is called for each new instance.
1137 C().x
1138 self.assertEqual(factory.call_count, 1)
1139 C().x
1140 self.assertEqual(factory.call_count, 2)
1141
1142 def test_default_factory_not_called_if_value_given(self):
1143 # We need a factory that we can test if it's been called.
1144 factory = Mock()
1145
1146 @dataclass
1147 class C:
1148 x: int = field(default_factory=factory)
1149
1150 # Make sure that if a field has a default factory function,
1151 # it's not called if a value is specified.
1152 C().x
1153 self.assertEqual(factory.call_count, 1)
1154 self.assertEqual(C(10).x, 10)
1155 self.assertEqual(factory.call_count, 1)
1156 C().x
1157 self.assertEqual(factory.call_count, 2)
1158
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001159 def test_default_factory_derived(self):
1160 # See bpo-32896.
1161 @dataclass
1162 class Foo:
1163 x: dict = field(default_factory=dict)
1164
1165 @dataclass
1166 class Bar(Foo):
1167 y: int = 1
1168
1169 self.assertEqual(Foo().x, {})
1170 self.assertEqual(Bar().x, {})
1171 self.assertEqual(Bar().y, 1)
1172
1173 @dataclass
1174 class Baz(Foo):
1175 pass
1176 self.assertEqual(Baz().x, {})
1177
1178 def test_intermediate_non_dataclass(self):
1179 # Test that an intermediate class that defines
1180 # annotations does not define fields.
1181
1182 @dataclass
1183 class A:
1184 x: int
1185
1186 class B(A):
1187 y: int
1188
1189 @dataclass
1190 class C(B):
1191 z: int
1192
1193 c = C(1, 3)
1194 self.assertEqual((c.x, c.z), (1, 3))
1195
1196 # .y was not initialized.
1197 with self.assertRaisesRegex(AttributeError,
1198 'object has no attribute'):
1199 c.y
1200
1201 # And if we again derive a non-dataclass, no fields are added.
1202 class D(C):
1203 t: int
1204 d = D(4, 5)
1205 self.assertEqual((d.x, d.z), (4, 5))
1206
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001207 def test_classvar_default_factory(self):
1208 # It's an error for a ClassVar to have a factory function.
1209 with self.assertRaisesRegex(TypeError,
1210 'cannot have a default factory'):
1211 @dataclass
1212 class C:
1213 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001214
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001215 def test_is_dataclass(self):
1216 class NotDataClass:
1217 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001218
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001219 self.assertFalse(is_dataclass(0))
1220 self.assertFalse(is_dataclass(int))
1221 self.assertFalse(is_dataclass(NotDataClass))
1222 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001223
1224 @dataclass
1225 class C:
1226 x: int
1227
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001228 @dataclass
1229 class D:
1230 d: C
1231 e: int
1232
1233 c = C(10)
1234 d = D(c, 4)
1235
1236 self.assertTrue(is_dataclass(C))
1237 self.assertTrue(is_dataclass(c))
1238 self.assertFalse(is_dataclass(c.x))
1239 self.assertTrue(is_dataclass(d.d))
1240 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001241
1242 def test_helper_fields_with_class_instance(self):
1243 # Check that we can call fields() on either a class or instance,
1244 # and get back the same thing.
1245 @dataclass
1246 class C:
1247 x: int
1248 y: float
1249
1250 self.assertEqual(fields(C), fields(C(0, 0.0)))
1251
1252 def test_helper_fields_exception(self):
1253 # Check that TypeError is raised if not passed a dataclass or
1254 # instance.
1255 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1256 fields(0)
1257
1258 class C: pass
1259 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1260 fields(C)
1261 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1262 fields(C())
1263
1264 def test_helper_asdict(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001265 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001266 @dataclass
1267 class C:
1268 x: int
1269 y: int
1270 c = C(1, 2)
1271
1272 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1273 self.assertEqual(asdict(c), asdict(c))
1274 self.assertIsNot(asdict(c), asdict(c))
1275 c.x = 42
1276 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1277 self.assertIs(type(asdict(c)), dict)
1278
1279 def test_helper_asdict_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001280 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001281 @dataclass
1282 class C:
1283 x: int
1284 y: int
1285 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1286 asdict(C)
1287 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1288 asdict(int)
1289
1290 def test_helper_asdict_copy_values(self):
1291 @dataclass
1292 class C:
1293 x: int
1294 y: List[int] = field(default_factory=list)
1295 initial = []
1296 c = C(1, initial)
1297 d = asdict(c)
1298 self.assertEqual(d['y'], initial)
1299 self.assertIsNot(d['y'], initial)
1300 c = C(1)
1301 d = asdict(c)
1302 d['y'].append(1)
1303 self.assertEqual(c.y, [])
1304
1305 def test_helper_asdict_nested(self):
1306 @dataclass
1307 class UserId:
1308 token: int
1309 group: int
1310 @dataclass
1311 class User:
1312 name: str
1313 id: UserId
1314 u = User('Joe', UserId(123, 1))
1315 d = asdict(u)
1316 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1317 self.assertIsNot(asdict(u), asdict(u))
1318 u.id.group = 2
1319 self.assertEqual(asdict(u), {'name': 'Joe',
1320 'id': {'token': 123, 'group': 2}})
1321
1322 def test_helper_asdict_builtin_containers(self):
1323 @dataclass
1324 class User:
1325 name: str
1326 id: int
1327 @dataclass
1328 class GroupList:
1329 id: int
1330 users: List[User]
1331 @dataclass
1332 class GroupTuple:
1333 id: int
1334 users: Tuple[User, ...]
1335 @dataclass
1336 class GroupDict:
1337 id: int
1338 users: Dict[str, User]
1339 a = User('Alice', 1)
1340 b = User('Bob', 2)
1341 gl = GroupList(0, [a, b])
1342 gt = GroupTuple(0, (a, b))
1343 gd = GroupDict(0, {'first': a, 'second': b})
1344 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1345 {'name': 'Bob', 'id': 2}]})
1346 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1347 {'name': 'Bob', 'id': 2})})
1348 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1349 'second': {'name': 'Bob', 'id': 2}}})
1350
1351 def test_helper_asdict_builtin_containers(self):
1352 @dataclass
1353 class Child:
1354 d: object
1355
1356 @dataclass
1357 class Parent:
1358 child: Child
1359
1360 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1361 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1362
1363 def test_helper_asdict_factory(self):
1364 @dataclass
1365 class C:
1366 x: int
1367 y: int
1368 c = C(1, 2)
1369 d = asdict(c, dict_factory=OrderedDict)
1370 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1371 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1372 c.x = 42
1373 d = asdict(c, dict_factory=OrderedDict)
1374 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1375 self.assertIs(type(d), OrderedDict)
1376
1377 def test_helper_astuple(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001378 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001379 @dataclass
1380 class C:
1381 x: int
1382 y: int = 0
1383 c = C(1)
1384
1385 self.assertEqual(astuple(c), (1, 0))
1386 self.assertEqual(astuple(c), astuple(c))
1387 self.assertIsNot(astuple(c), astuple(c))
1388 c.y = 42
1389 self.assertEqual(astuple(c), (1, 42))
1390 self.assertIs(type(astuple(c)), tuple)
1391
1392 def test_helper_astuple_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001393 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001394 @dataclass
1395 class C:
1396 x: int
1397 y: int
1398 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1399 astuple(C)
1400 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1401 astuple(int)
1402
1403 def test_helper_astuple_copy_values(self):
1404 @dataclass
1405 class C:
1406 x: int
1407 y: List[int] = field(default_factory=list)
1408 initial = []
1409 c = C(1, initial)
1410 t = astuple(c)
1411 self.assertEqual(t[1], initial)
1412 self.assertIsNot(t[1], initial)
1413 c = C(1)
1414 t = astuple(c)
1415 t[1].append(1)
1416 self.assertEqual(c.y, [])
1417
1418 def test_helper_astuple_nested(self):
1419 @dataclass
1420 class UserId:
1421 token: int
1422 group: int
1423 @dataclass
1424 class User:
1425 name: str
1426 id: UserId
1427 u = User('Joe', UserId(123, 1))
1428 t = astuple(u)
1429 self.assertEqual(t, ('Joe', (123, 1)))
1430 self.assertIsNot(astuple(u), astuple(u))
1431 u.id.group = 2
1432 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1433
1434 def test_helper_astuple_builtin_containers(self):
1435 @dataclass
1436 class User:
1437 name: str
1438 id: int
1439 @dataclass
1440 class GroupList:
1441 id: int
1442 users: List[User]
1443 @dataclass
1444 class GroupTuple:
1445 id: int
1446 users: Tuple[User, ...]
1447 @dataclass
1448 class GroupDict:
1449 id: int
1450 users: Dict[str, User]
1451 a = User('Alice', 1)
1452 b = User('Bob', 2)
1453 gl = GroupList(0, [a, b])
1454 gt = GroupTuple(0, (a, b))
1455 gd = GroupDict(0, {'first': a, 'second': b})
1456 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1457 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1458 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1459
1460 def test_helper_astuple_builtin_containers(self):
1461 @dataclass
1462 class Child:
1463 d: object
1464
1465 @dataclass
1466 class Parent:
1467 child: Child
1468
1469 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1470 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1471
1472 def test_helper_astuple_factory(self):
1473 @dataclass
1474 class C:
1475 x: int
1476 y: int
1477 NT = namedtuple('NT', 'x y')
1478 def nt(lst):
1479 return NT(*lst)
1480 c = C(1, 2)
1481 t = astuple(c, tuple_factory=nt)
1482 self.assertEqual(t, NT(1, 2))
1483 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1484 c.x = 42
1485 t = astuple(c, tuple_factory=nt)
1486 self.assertEqual(t, NT(42, 2))
1487 self.assertIs(type(t), NT)
1488
1489 def test_dynamic_class_creation(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001490 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001491 }
1492
1493 # Create the class.
1494 cls = type('C', (), cls_dict)
1495
1496 # Make it a dataclass.
1497 cls1 = dataclass(cls)
1498
1499 self.assertEqual(cls1, cls)
1500 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1501
1502 def test_dynamic_class_creation_using_field(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001503 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001504 'y': field(default=5),
1505 }
1506
1507 # Create the class.
1508 cls = type('C', (), cls_dict)
1509
1510 # Make it a dataclass.
1511 cls1 = dataclass(cls)
1512
1513 self.assertEqual(cls1, cls)
1514 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1515
1516 def test_init_in_order(self):
1517 @dataclass
1518 class C:
1519 a: int
1520 b: int = field()
1521 c: list = field(default_factory=list, init=False)
1522 d: list = field(default_factory=list)
1523 e: int = field(default=4, init=False)
1524 f: int = 4
1525
1526 calls = []
1527 def setattr(self, name, value):
1528 calls.append((name, value))
1529
1530 C.__setattr__ = setattr
1531 c = C(0, 1)
1532 self.assertEqual(('a', 0), calls[0])
1533 self.assertEqual(('b', 1), calls[1])
1534 self.assertEqual(('c', []), calls[2])
1535 self.assertEqual(('d', []), calls[3])
1536 self.assertNotIn(('e', 4), calls)
1537 self.assertEqual(('f', 4), calls[4])
1538
1539 def test_items_in_dicts(self):
1540 @dataclass
1541 class C:
1542 a: int
1543 b: list = field(default_factory=list, init=False)
1544 c: list = field(default_factory=list)
1545 d: int = field(default=4, init=False)
1546 e: int = 0
1547
1548 c = C(0)
1549 # Class dict
1550 self.assertNotIn('a', C.__dict__)
1551 self.assertNotIn('b', C.__dict__)
1552 self.assertNotIn('c', C.__dict__)
1553 self.assertIn('d', C.__dict__)
1554 self.assertEqual(C.d, 4)
1555 self.assertIn('e', C.__dict__)
1556 self.assertEqual(C.e, 0)
1557 # Instance dict
1558 self.assertIn('a', c.__dict__)
1559 self.assertEqual(c.a, 0)
1560 self.assertIn('b', c.__dict__)
1561 self.assertEqual(c.b, [])
1562 self.assertIn('c', c.__dict__)
1563 self.assertEqual(c.c, [])
1564 self.assertNotIn('d', c.__dict__)
1565 self.assertIn('e', c.__dict__)
1566 self.assertEqual(c.e, 0)
1567
1568 def test_alternate_classmethod_constructor(self):
1569 # Since __post_init__ can't take params, use a classmethod
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001570 # alternate constructor. This is mostly an example to show
1571 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001572 @dataclass
1573 class C:
1574 x: int
1575 @classmethod
1576 def from_file(cls, filename):
1577 # In a real example, create a new instance
1578 # and populate 'x' from contents of a file.
1579 value_in_file = 20
1580 return cls(value_in_file)
1581
1582 self.assertEqual(C.from_file('filename').x, 20)
1583
1584 def test_field_metadata_default(self):
1585 # Make sure the default metadata is read-only and of
1586 # zero length.
1587 @dataclass
1588 class C:
1589 i: int
1590
1591 self.assertFalse(fields(C)[0].metadata)
1592 self.assertEqual(len(fields(C)[0].metadata), 0)
1593 with self.assertRaisesRegex(TypeError,
1594 'does not support item assignment'):
1595 fields(C)[0].metadata['test'] = 3
1596
1597 def test_field_metadata_mapping(self):
1598 # Make sure only a mapping can be passed as metadata
1599 # zero length.
1600 with self.assertRaises(TypeError):
1601 @dataclass
1602 class C:
1603 i: int = field(metadata=0)
1604
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001605 # Make sure an empty dict works.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001606 @dataclass
1607 class C:
1608 i: int = field(metadata={})
1609 self.assertFalse(fields(C)[0].metadata)
1610 self.assertEqual(len(fields(C)[0].metadata), 0)
1611 with self.assertRaisesRegex(TypeError,
1612 'does not support item assignment'):
1613 fields(C)[0].metadata['test'] = 3
1614
1615 # Make sure a non-empty dict works.
1616 @dataclass
1617 class C:
1618 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1619 self.assertEqual(len(fields(C)[0].metadata), 3)
1620 self.assertEqual(fields(C)[0].metadata['test'], 10)
1621 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1622 self.assertEqual(fields(C)[0].metadata[3], 'three')
1623 with self.assertRaises(KeyError):
1624 # Non-existent key.
1625 fields(C)[0].metadata['baz']
1626 with self.assertRaisesRegex(TypeError,
1627 'does not support item assignment'):
1628 fields(C)[0].metadata['test'] = 3
1629
1630 def test_field_metadata_custom_mapping(self):
1631 # Try a custom mapping.
1632 class SimpleNameSpace:
1633 def __init__(self, **kw):
1634 self.__dict__.update(kw)
1635
1636 def __getitem__(self, item):
1637 if item == 'xyzzy':
1638 return 'plugh'
1639 return getattr(self, item)
1640
1641 def __len__(self):
1642 return self.__dict__.__len__()
1643
1644 @dataclass
1645 class C:
1646 i: int = field(metadata=SimpleNameSpace(a=10))
1647
1648 self.assertEqual(len(fields(C)[0].metadata), 1)
1649 self.assertEqual(fields(C)[0].metadata['a'], 10)
1650 with self.assertRaises(AttributeError):
1651 fields(C)[0].metadata['b']
1652 # Make sure we're still talking to our custom mapping.
1653 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1654
1655 def test_generic_dataclasses(self):
1656 T = TypeVar('T')
1657
1658 @dataclass
1659 class LabeledBox(Generic[T]):
1660 content: T
1661 label: str = '<unknown>'
1662
1663 box = LabeledBox(42)
1664 self.assertEqual(box.content, 42)
1665 self.assertEqual(box.label, '<unknown>')
1666
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001667 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001668 Alias = List[LabeledBox[int]]
1669
1670 def test_generic_extending(self):
1671 S = TypeVar('S')
1672 T = TypeVar('T')
1673
1674 @dataclass
1675 class Base(Generic[T, S]):
1676 x: T
1677 y: S
1678
1679 @dataclass
1680 class DataDerived(Base[int, T]):
1681 new_field: str
1682 Alias = DataDerived[str]
1683 c = Alias(0, 'test1', 'test2')
1684 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1685
1686 class NonDataDerived(Base[int, T]):
1687 def new_method(self):
1688 return self.y
1689 Alias = NonDataDerived[float]
1690 c = Alias(10, 1.0)
1691 self.assertEqual(c.new_method(), 1.0)
1692
1693 def test_helper_replace(self):
1694 @dataclass(frozen=True)
1695 class C:
1696 x: int
1697 y: int
1698
1699 c = C(1, 2)
1700 c1 = replace(c, x=3)
1701 self.assertEqual(c1.x, 3)
1702 self.assertEqual(c1.y, 2)
1703
1704 def test_helper_replace_frozen(self):
1705 @dataclass(frozen=True)
1706 class C:
1707 x: int
1708 y: int
1709 z: int = field(init=False, default=10)
1710 t: int = field(init=False, default=100)
1711
1712 c = C(1, 2)
1713 c1 = replace(c, x=3)
1714 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1715 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1716
1717
1718 with self.assertRaisesRegex(ValueError, 'init=False'):
1719 replace(c, x=3, z=20, t=50)
1720 with self.assertRaisesRegex(ValueError, 'init=False'):
1721 replace(c, z=20)
1722 replace(c, x=3, z=20, t=50)
1723
1724 # Make sure the result is still frozen.
1725 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1726 c1.x = 3
1727
1728 # Make sure we can't replace an attribute that doesn't exist,
1729 # if we're also replacing one that does exist. Test this
1730 # here, because setting attributes on frozen instances is
1731 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001732 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001733 "keyword argument 'a'"):
1734 c1 = replace(c, x=20, a=5)
1735
1736 def test_helper_replace_invalid_field_name(self):
1737 @dataclass(frozen=True)
1738 class C:
1739 x: int
1740 y: int
1741
1742 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001743 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001744 "keyword argument 'z'"):
1745 c1 = replace(c, z=3)
1746
1747 def test_helper_replace_invalid_object(self):
1748 @dataclass(frozen=True)
1749 class C:
1750 x: int
1751 y: int
1752
1753 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1754 replace(C, x=3)
1755
1756 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1757 replace(0, x=3)
1758
1759 def test_helper_replace_no_init(self):
1760 @dataclass
1761 class C:
1762 x: int
1763 y: int = field(init=False, default=10)
1764
1765 c = C(1)
1766 c.y = 20
1767
1768 # Make sure y gets the default value.
1769 c1 = replace(c, x=5)
1770 self.assertEqual((c1.x, c1.y), (5, 10))
1771
1772 # Trying to replace y is an error.
1773 with self.assertRaisesRegex(ValueError, 'init=False'):
1774 replace(c, x=2, y=30)
1775 with self.assertRaisesRegex(ValueError, 'init=False'):
1776 replace(c, y=30)
1777
1778 def test_dataclassses_pickleable(self):
1779 global P, Q, R
1780 @dataclass
1781 class P:
1782 x: int
1783 y: int = 0
1784 @dataclass
1785 class Q:
1786 x: int
1787 y: int = field(default=0, init=False)
1788 @dataclass
1789 class R:
1790 x: int
1791 y: List[int] = field(default_factory=list)
1792 q = Q(1)
1793 q.y = 2
1794 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1795 for sample in samples:
1796 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1797 with self.subTest(sample=sample, proto=proto):
1798 new_sample = pickle.loads(pickle.dumps(sample, proto))
1799 self.assertEqual(sample.x, new_sample.x)
1800 self.assertEqual(sample.y, new_sample.y)
1801 self.assertIsNot(sample, new_sample)
1802 new_sample.x = 42
1803 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1804 self.assertEqual(new_sample.x, another_new_sample.x)
1805 self.assertEqual(sample.y, another_new_sample.y)
1806
1807 def test_helper_make_dataclass(self):
1808 C = make_dataclass('C',
1809 [('x', int),
1810 ('y', int, field(default=5))],
1811 namespace={'add_one': lambda self: self.x + 1})
1812 c = C(10)
1813 self.assertEqual((c.x, c.y), (10, 5))
1814 self.assertEqual(c.add_one(), 11)
1815
1816
1817 def test_helper_make_dataclass_no_mutate_namespace(self):
1818 # Make sure a provided namespace isn't mutated.
1819 ns = {}
1820 C = make_dataclass('C',
1821 [('x', int),
1822 ('y', int, field(default=5))],
1823 namespace=ns)
1824 self.assertEqual(ns, {})
1825
1826 def test_helper_make_dataclass_base(self):
1827 class Base1:
1828 pass
1829 class Base2:
1830 pass
1831 C = make_dataclass('C',
1832 [('x', int)],
1833 bases=(Base1, Base2))
1834 c = C(2)
1835 self.assertIsInstance(c, C)
1836 self.assertIsInstance(c, Base1)
1837 self.assertIsInstance(c, Base2)
1838
1839 def test_helper_make_dataclass_base_dataclass(self):
1840 @dataclass
1841 class Base1:
1842 x: int
1843 class Base2:
1844 pass
1845 C = make_dataclass('C',
1846 [('y', int)],
1847 bases=(Base1, Base2))
1848 with self.assertRaisesRegex(TypeError, 'required positional'):
1849 c = C(2)
1850 c = C(1, 2)
1851 self.assertIsInstance(c, C)
1852 self.assertIsInstance(c, Base1)
1853 self.assertIsInstance(c, Base2)
1854
1855 self.assertEqual((c.x, c.y), (1, 2))
1856
1857 def test_helper_make_dataclass_init_var(self):
1858 def post_init(self, y):
1859 self.x *= y
1860
1861 C = make_dataclass('C',
1862 [('x', int),
1863 ('y', InitVar[int]),
1864 ],
1865 namespace={'__post_init__': post_init},
1866 )
1867 c = C(2, 3)
1868 self.assertEqual(vars(c), {'x': 6})
1869 self.assertEqual(len(fields(c)), 1)
1870
1871 def test_helper_make_dataclass_class_var(self):
1872 C = make_dataclass('C',
1873 [('x', int),
1874 ('y', ClassVar[int], 10),
1875 ('z', ClassVar[int], field(default=20)),
1876 ])
1877 c = C(1)
1878 self.assertEqual(vars(c), {'x': 1})
1879 self.assertEqual(len(fields(c)), 1)
1880 self.assertEqual(C.y, 10)
1881 self.assertEqual(C.z, 20)
1882
Eric V. Smithd80b4432018-01-06 17:09:58 -05001883 def test_helper_make_dataclass_other_params(self):
1884 C = make_dataclass('C',
1885 [('x', int),
1886 ('y', ClassVar[int], 10),
1887 ('z', ClassVar[int], field(default=20)),
1888 ],
1889 init=False)
1890 # Make sure we have a repr, but no init.
1891 self.assertNotIn('__init__', vars(C))
1892 self.assertIn('__repr__', vars(C))
1893
1894 # Make sure random other params don't work.
1895 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
1896 C = make_dataclass('C',
1897 [],
1898 xxinit=False)
1899
Eric V. Smithed7d4292018-01-06 16:14:03 -05001900 def test_helper_make_dataclass_no_types(self):
1901 C = make_dataclass('Point', ['x', 'y', 'z'])
1902 c = C(1, 2, 3)
1903 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1904 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1905 'y': 'typing.Any',
1906 'z': 'typing.Any'})
1907
1908 C = make_dataclass('Point', ['x', ('y', int), 'z'])
1909 c = C(1, 2, 3)
1910 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
1911 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
1912 'y': int,
1913 'z': 'typing.Any'})
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001914
Eric V. Smithea8fc522018-01-27 19:07:40 -05001915
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001916class TestFieldNoAnnotation(unittest.TestCase):
1917 def test_field_without_annotation(self):
1918 with self.assertRaisesRegex(TypeError,
1919 "'f' is a field but has no type annotation"):
1920 @dataclass
1921 class C:
1922 f = field()
1923
1924 def test_field_without_annotation_but_annotation_in_base(self):
1925 @dataclass
1926 class B:
1927 f: int
1928
1929 with self.assertRaisesRegex(TypeError,
1930 "'f' is a field but has no type annotation"):
1931 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001932 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001933 @dataclass
1934 class C(B):
1935 f = field()
1936
1937 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1938 # Same test, but with the base class not a dataclass.
1939 class B:
1940 f: int
1941
1942 with self.assertRaisesRegex(TypeError,
1943 "'f' is a field but has no type annotation"):
1944 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001945 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001946 @dataclass
1947 class C(B):
1948 f = field()
1949
1950
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001951class TestDocString(unittest.TestCase):
1952 def assertDocStrEqual(self, a, b):
1953 # Because 3.6 and 3.7 differ in how inspect.signature work
1954 # (see bpo #32108), for the time being just compare them with
1955 # whitespace stripped.
1956 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1957
1958 def test_existing_docstring_not_overridden(self):
1959 @dataclass
1960 class C:
1961 """Lorem ipsum"""
1962 x: int
1963
1964 self.assertEqual(C.__doc__, "Lorem ipsum")
1965
1966 def test_docstring_no_fields(self):
1967 @dataclass
1968 class C:
1969 pass
1970
1971 self.assertDocStrEqual(C.__doc__, "C()")
1972
1973 def test_docstring_one_field(self):
1974 @dataclass
1975 class C:
1976 x: int
1977
1978 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1979
1980 def test_docstring_two_fields(self):
1981 @dataclass
1982 class C:
1983 x: int
1984 y: int
1985
1986 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1987
1988 def test_docstring_three_fields(self):
1989 @dataclass
1990 class C:
1991 x: int
1992 y: int
1993 z: str
1994
1995 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1996
1997 def test_docstring_one_field_with_default(self):
1998 @dataclass
1999 class C:
2000 x: int = 3
2001
2002 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
2003
2004 def test_docstring_one_field_with_default_none(self):
2005 @dataclass
2006 class C:
2007 x: Union[int, type(None)] = None
2008
2009 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
2010
2011 def test_docstring_list_field(self):
2012 @dataclass
2013 class C:
2014 x: List[int]
2015
2016 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
2017
2018 def test_docstring_list_field_with_default_factory(self):
2019 @dataclass
2020 class C:
2021 x: List[int] = field(default_factory=list)
2022
2023 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
2024
2025 def test_docstring_deque_field(self):
2026 @dataclass
2027 class C:
2028 x: deque
2029
2030 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
2031
2032 def test_docstring_deque_field_with_default_factory(self):
2033 @dataclass
2034 class C:
2035 x: deque = field(default_factory=deque)
2036
2037 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
2038
2039
Eric V. Smithea8fc522018-01-27 19:07:40 -05002040class TestInit(unittest.TestCase):
2041 def test_base_has_init(self):
2042 class B:
2043 def __init__(self):
2044 self.z = 100
2045 pass
2046
2047 # Make sure that declaring this class doesn't raise an error.
2048 # The issue is that we can't override __init__ in our class,
2049 # but it should be okay to add __init__ to us if our base has
2050 # an __init__.
2051 @dataclass
2052 class C(B):
2053 x: int = 0
2054 c = C(10)
2055 self.assertEqual(c.x, 10)
2056 self.assertNotIn('z', vars(c))
2057
2058 # Make sure that if we don't add an init, the base __init__
2059 # gets called.
2060 @dataclass(init=False)
2061 class C(B):
2062 x: int = 10
2063 c = C()
2064 self.assertEqual(c.x, 10)
2065 self.assertEqual(c.z, 100)
2066
2067 def test_no_init(self):
2068 dataclass(init=False)
2069 class C:
2070 i: int = 0
2071 self.assertEqual(C().i, 0)
2072
2073 dataclass(init=False)
2074 class C:
2075 i: int = 2
2076 def __init__(self):
2077 self.i = 3
2078 self.assertEqual(C().i, 3)
2079
2080 def test_overwriting_init(self):
2081 # If the class has __init__, use it no matter the value of
2082 # init=.
2083
2084 @dataclass
2085 class C:
2086 x: int
2087 def __init__(self, x):
2088 self.x = 2 * x
2089 self.assertEqual(C(3).x, 6)
2090
2091 @dataclass(init=True)
2092 class C:
2093 x: int
2094 def __init__(self, x):
2095 self.x = 2 * x
2096 self.assertEqual(C(4).x, 8)
2097
2098 @dataclass(init=False)
2099 class C:
2100 x: int
2101 def __init__(self, x):
2102 self.x = 2 * x
2103 self.assertEqual(C(5).x, 10)
2104
2105
2106class TestRepr(unittest.TestCase):
2107 def test_repr(self):
2108 @dataclass
2109 class B:
2110 x: int
2111
2112 @dataclass
2113 class C(B):
2114 y: int = 10
2115
2116 o = C(4)
2117 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2118
2119 @dataclass
2120 class D(C):
2121 x: int = 20
2122 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2123
2124 @dataclass
2125 class C:
2126 @dataclass
2127 class D:
2128 i: int
2129 @dataclass
2130 class E:
2131 pass
2132 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2133 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2134
2135 def test_no_repr(self):
2136 # Test a class with no __repr__ and repr=False.
2137 @dataclass(repr=False)
2138 class C:
2139 x: int
2140 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2141 repr(C(3)))
2142
2143 # Test a class with a __repr__ and repr=False.
2144 @dataclass(repr=False)
2145 class C:
2146 x: int
2147 def __repr__(self):
2148 return 'C-class'
2149 self.assertEqual(repr(C(3)), 'C-class')
2150
2151 def test_overwriting_repr(self):
2152 # If the class has __repr__, use it no matter the value of
2153 # repr=.
2154
2155 @dataclass
2156 class C:
2157 x: int
2158 def __repr__(self):
2159 return 'x'
2160 self.assertEqual(repr(C(0)), 'x')
2161
2162 @dataclass(repr=True)
2163 class C:
2164 x: int
2165 def __repr__(self):
2166 return 'x'
2167 self.assertEqual(repr(C(0)), 'x')
2168
2169 @dataclass(repr=False)
2170 class C:
2171 x: int
2172 def __repr__(self):
2173 return 'x'
2174 self.assertEqual(repr(C(0)), 'x')
2175
2176
2177class TestFrozen(unittest.TestCase):
2178 def test_overwriting_frozen(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002179 # frozen uses __setattr__ and __delattr__.
Eric V. Smithea8fc522018-01-27 19:07:40 -05002180 with self.assertRaisesRegex(TypeError,
2181 'Cannot overwrite attribute __setattr__'):
2182 @dataclass(frozen=True)
2183 class C:
2184 x: int
2185 def __setattr__(self):
2186 pass
2187
2188 with self.assertRaisesRegex(TypeError,
2189 'Cannot overwrite attribute __delattr__'):
2190 @dataclass(frozen=True)
2191 class C:
2192 x: int
2193 def __delattr__(self):
2194 pass
2195
2196 @dataclass(frozen=False)
2197 class C:
2198 x: int
2199 def __setattr__(self, name, value):
2200 self.__dict__['x'] = value * 2
2201 self.assertEqual(C(10).x, 20)
2202
2203
2204class TestEq(unittest.TestCase):
2205 def test_no_eq(self):
2206 # Test a class with no __eq__ and eq=False.
2207 @dataclass(eq=False)
2208 class C:
2209 x: int
2210 self.assertNotEqual(C(0), C(0))
2211 c = C(3)
2212 self.assertEqual(c, c)
2213
2214 # Test a class with an __eq__ and eq=False.
2215 @dataclass(eq=False)
2216 class C:
2217 x: int
2218 def __eq__(self, other):
2219 return other == 10
2220 self.assertEqual(C(3), 10)
2221
2222 def test_overwriting_eq(self):
2223 # If the class has __eq__, use it no matter the value of
2224 # eq=.
2225
2226 @dataclass
2227 class C:
2228 x: int
2229 def __eq__(self, other):
2230 return other == 3
2231 self.assertEqual(C(1), 3)
2232 self.assertNotEqual(C(1), 1)
2233
2234 @dataclass(eq=True)
2235 class C:
2236 x: int
2237 def __eq__(self, other):
2238 return other == 4
2239 self.assertEqual(C(1), 4)
2240 self.assertNotEqual(C(1), 1)
2241
2242 @dataclass(eq=False)
2243 class C:
2244 x: int
2245 def __eq__(self, other):
2246 return other == 5
2247 self.assertEqual(C(1), 5)
2248 self.assertNotEqual(C(1), 1)
2249
2250
2251class TestOrdering(unittest.TestCase):
2252 def test_functools_total_ordering(self):
2253 # Test that functools.total_ordering works with this class.
2254 @total_ordering
2255 @dataclass
2256 class C:
2257 x: int
2258 def __lt__(self, other):
2259 # Perform the test "backward", just to make
2260 # sure this is being called.
2261 return self.x >= other
2262
2263 self.assertLess(C(0), -1)
2264 self.assertLessEqual(C(0), -1)
2265 self.assertGreater(C(0), 1)
2266 self.assertGreaterEqual(C(0), 1)
2267
2268 def test_no_order(self):
2269 # Test that no ordering functions are added by default.
2270 @dataclass(order=False)
2271 class C:
2272 x: int
2273 # Make sure no order methods are added.
2274 self.assertNotIn('__le__', C.__dict__)
2275 self.assertNotIn('__lt__', C.__dict__)
2276 self.assertNotIn('__ge__', C.__dict__)
2277 self.assertNotIn('__gt__', C.__dict__)
2278
2279 # Test that __lt__ is still called
2280 @dataclass(order=False)
2281 class C:
2282 x: int
2283 def __lt__(self, other):
2284 return False
2285 # Make sure other methods aren't added.
2286 self.assertNotIn('__le__', C.__dict__)
2287 self.assertNotIn('__ge__', C.__dict__)
2288 self.assertNotIn('__gt__', C.__dict__)
2289
2290 def test_overwriting_order(self):
2291 with self.assertRaisesRegex(TypeError,
2292 'Cannot overwrite attribute __lt__'
2293 '.*using functools.total_ordering'):
2294 @dataclass(order=True)
2295 class C:
2296 x: int
2297 def __lt__(self):
2298 pass
2299
2300 with self.assertRaisesRegex(TypeError,
2301 'Cannot overwrite attribute __le__'
2302 '.*using functools.total_ordering'):
2303 @dataclass(order=True)
2304 class C:
2305 x: int
2306 def __le__(self):
2307 pass
2308
2309 with self.assertRaisesRegex(TypeError,
2310 'Cannot overwrite attribute __gt__'
2311 '.*using functools.total_ordering'):
2312 @dataclass(order=True)
2313 class C:
2314 x: int
2315 def __gt__(self):
2316 pass
2317
2318 with self.assertRaisesRegex(TypeError,
2319 'Cannot overwrite attribute __ge__'
2320 '.*using functools.total_ordering'):
2321 @dataclass(order=True)
2322 class C:
2323 x: int
2324 def __ge__(self):
2325 pass
2326
2327class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002328 def test_unsafe_hash(self):
2329 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002330 class C:
2331 x: int
2332 y: str
2333 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2334
Eric V. Smithea8fc522018-01-27 19:07:40 -05002335 def test_hash_rules(self):
2336 def non_bool(value):
2337 # Map to something else that's True, but not a bool.
2338 if value is None:
2339 return None
2340 if value:
2341 return (3,)
2342 return 0
2343
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002344 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2345 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2346 frozen=frozen):
2347 if result != 'exception':
2348 if with_hash:
2349 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2350 class C:
2351 def __hash__(self):
2352 return 0
2353 else:
2354 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2355 class C:
2356 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002357
2358 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002359 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002360 # __hash__ contains the function we generated.
2361 self.assertIn('__hash__', C.__dict__)
2362 self.assertIsNotNone(C.__dict__['__hash__'])
2363
Eric V. Smithea8fc522018-01-27 19:07:40 -05002364 elif result == '':
2365 # __hash__ is not present in our class.
2366 if not with_hash:
2367 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002368
Eric V. Smithea8fc522018-01-27 19:07:40 -05002369 elif result == 'none':
2370 # __hash__ is set to None.
2371 self.assertIn('__hash__', C.__dict__)
2372 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002373
2374 elif result == 'exception':
2375 # Creating the class should cause an exception.
2376 # This only happens with with_hash==True.
2377 assert(with_hash)
2378 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2379 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2380 class C:
2381 def __hash__(self):
2382 return 0
2383
Eric V. Smithea8fc522018-01-27 19:07:40 -05002384 else:
2385 assert False, f'unknown result {result!r}'
2386
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002387 # There are 8 cases of:
2388 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002389 # eq=True/False
2390 # frozen=True/False
2391 # And for each of these, a different result if
2392 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002393 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2394 (False, False, False, '', ''),
2395 (False, False, True, '', ''),
2396 (False, True, False, 'none', ''),
2397 (False, True, True, 'fn', ''),
2398 (True, False, False, 'fn', 'exception'),
2399 (True, False, True, 'fn', 'exception'),
2400 (True, True, False, 'fn', 'exception'),
2401 (True, True, True, 'fn', 'exception'),
2402 ], 1):
2403 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2404 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002405
2406 # Test non-bool truth values, too. This is just to
2407 # make sure the data-driven table in the decorator
2408 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002409 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2410 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002411
2412
2413 def test_eq_only(self):
2414 # If a class defines __eq__, __hash__ is automatically added
2415 # and set to None. This is normal Python behavior, not
2416 # related to dataclasses. Make sure we don't interfere with
2417 # that (see bpo=32546).
2418
2419 @dataclass
2420 class C:
2421 i: int
2422 def __eq__(self, other):
2423 return self.i == other.i
2424 self.assertEqual(C(1), C(1))
2425 self.assertNotEqual(C(1), C(4))
2426
2427 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002428 # unsafe_hash=True.
2429 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002430 class C:
2431 i: int
2432 def __eq__(self, other):
2433 return self.i == other.i
2434 self.assertEqual(C(1), C(1.0))
2435 self.assertEqual(hash(C(1)), hash(C(1.0)))
2436
2437 # And check that the classes __eq__ is being used, despite
2438 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002439 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002440 class C:
2441 i: int
2442 def __eq__(self, other):
2443 return self.i == 3 and self.i == other.i
2444 self.assertEqual(C(3), C(3))
2445 self.assertNotEqual(C(1), C(1))
2446 self.assertEqual(hash(C(1)), hash(C(1.0)))
2447
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002448 def test_0_field_hash(self):
2449 @dataclass(frozen=True)
2450 class C:
2451 pass
2452 self.assertEqual(hash(C()), hash(()))
2453
2454 @dataclass(unsafe_hash=True)
2455 class C:
2456 pass
2457 self.assertEqual(hash(C()), hash(()))
2458
2459 def test_1_field_hash(self):
2460 @dataclass(frozen=True)
2461 class C:
2462 x: int
2463 self.assertEqual(hash(C(4)), hash((4,)))
2464 self.assertEqual(hash(C(42)), hash((42,)))
2465
2466 @dataclass(unsafe_hash=True)
2467 class C:
2468 x: int
2469 self.assertEqual(hash(C(4)), hash((4,)))
2470 self.assertEqual(hash(C(42)), hash((42,)))
2471
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002472 def test_hash_no_args(self):
2473 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002474 # make sure that if the @dataclass parameter name is changed
2475 # or the non-default hashing behavior changes, the default
2476 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002477
2478 class Base:
2479 def __hash__(self):
2480 return 301
2481
2482 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002483 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002484 for frozen, eq, base, expected in [
2485 (None, None, object, 'unhashable'),
2486 (None, None, Base, 'unhashable'),
2487 (None, False, object, 'object'),
2488 (None, False, Base, 'base'),
2489 (None, True, object, 'unhashable'),
2490 (None, True, Base, 'unhashable'),
2491 (False, None, object, 'unhashable'),
2492 (False, None, Base, 'unhashable'),
2493 (False, False, object, 'object'),
2494 (False, False, Base, 'base'),
2495 (False, True, object, 'unhashable'),
2496 (False, True, Base, 'unhashable'),
2497 (True, None, object, 'tuple'),
2498 (True, None, Base, 'tuple'),
2499 (True, False, object, 'object'),
2500 (True, False, Base, 'base'),
2501 (True, True, object, 'tuple'),
2502 (True, True, Base, 'tuple'),
2503 ]:
2504
2505 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2506 # First, create the class.
2507 if frozen is None and eq is None:
2508 @dataclass
2509 class C(base):
2510 i: int
2511 elif frozen is None:
2512 @dataclass(eq=eq)
2513 class C(base):
2514 i: int
2515 elif eq is None:
2516 @dataclass(frozen=frozen)
2517 class C(base):
2518 i: int
2519 else:
2520 @dataclass(frozen=frozen, eq=eq)
2521 class C(base):
2522 i: int
2523
2524 # Now, make sure it hashes as expected.
2525 if expected == 'unhashable':
2526 c = C(10)
2527 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2528 hash(c)
2529
2530 elif expected == 'base':
2531 self.assertEqual(hash(C(10)), 301)
2532
2533 elif expected == 'object':
2534 # I'm not sure what test to use here. object's
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002535 # hash isn't based on id(), so calling hash()
2536 # won't tell us much. So, just check the
2537 # function used is object's.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002538 self.assertIs(C.__hash__, object.__hash__)
2539
2540 elif expected == 'tuple':
2541 self.assertEqual(hash(C(42)), hash((42,)))
2542
2543 else:
2544 assert False, f'unknown value for expected={expected!r}'
2545
Eric V. Smithea8fc522018-01-27 19:07:40 -05002546
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002547class TestFrozen(unittest.TestCase):
2548 def test_frozen(self):
2549 @dataclass(frozen=True)
2550 class C:
2551 i: int
2552
2553 c = C(10)
2554 self.assertEqual(c.i, 10)
2555 with self.assertRaises(FrozenInstanceError):
2556 c.i = 5
2557 self.assertEqual(c.i, 10)
2558
2559 def test_inherit(self):
2560 @dataclass(frozen=True)
2561 class C:
2562 i: int
2563
2564 @dataclass(frozen=True)
2565 class D(C):
2566 j: int
2567
2568 d = D(0, 10)
2569 with self.assertRaises(FrozenInstanceError):
2570 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002571 with self.assertRaises(FrozenInstanceError):
2572 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002573 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002574 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002575
Miss Islington (bot)45648312018-03-18 18:03:36 -07002576 # Test both ways: with an intermediate normal (non-dataclass)
2577 # class and without an intermediate class.
2578 def test_inherit_nonfrozen_from_frozen(self):
2579 for intermediate_class in [True, False]:
2580 with self.subTest(intermediate_class=intermediate_class):
2581 @dataclass(frozen=True)
2582 class C:
2583 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002584
Miss Islington (bot)45648312018-03-18 18:03:36 -07002585 if intermediate_class:
2586 class I(C): pass
2587 else:
2588 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002589
Miss Islington (bot)45648312018-03-18 18:03:36 -07002590 with self.assertRaisesRegex(TypeError,
2591 'cannot inherit non-frozen dataclass from a frozen one'):
2592 @dataclass
2593 class D(I):
2594 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002595
Miss Islington (bot)45648312018-03-18 18:03:36 -07002596 def test_inherit_frozen_from_nonfrozen(self):
2597 for intermediate_class in [True, False]:
2598 with self.subTest(intermediate_class=intermediate_class):
2599 @dataclass
2600 class C:
2601 i: int
2602
2603 if intermediate_class:
2604 class I(C): pass
2605 else:
2606 I = C
2607
2608 with self.assertRaisesRegex(TypeError,
2609 'cannot inherit frozen dataclass from a non-frozen one'):
2610 @dataclass(frozen=True)
2611 class D(I):
2612 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002613
2614 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002615 for intermediate_class in [True, False]:
2616 with self.subTest(intermediate_class=intermediate_class):
2617 class C:
2618 pass
2619
2620 if intermediate_class:
2621 class I(C): pass
2622 else:
2623 I = C
2624
2625 @dataclass(frozen=True)
2626 class D(I):
2627 i: int
2628
2629 d = D(10)
2630 with self.assertRaises(FrozenInstanceError):
2631 d.i = 5
2632
2633 def test_non_frozen_normal_derived(self):
2634 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002635
2636 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002637 class D:
2638 x: int
2639 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002640
Miss Islington (bot)45648312018-03-18 18:03:36 -07002641 class S(D):
2642 pass
2643
2644 s = S(3)
2645 self.assertEqual(s.x, 3)
2646 self.assertEqual(s.y, 10)
2647 s.cached = True
2648
2649 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002650 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002651 s.x = 5
2652 with self.assertRaises(FrozenInstanceError):
2653 s.y = 5
2654 self.assertEqual(s.x, 3)
2655 self.assertEqual(s.y, 10)
2656 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002657
2658
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002659class TestSlots(unittest.TestCase):
2660 def test_simple(self):
2661 @dataclass
2662 class C:
2663 __slots__ = ('x',)
2664 x: Any
2665
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002666 # There was a bug where a variable in a slot was assumed to
2667 # also have a default value (of type
2668 # types.MemberDescriptorType).
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002669 with self.assertRaisesRegex(TypeError,
Miss Islington (bot)5729b9c2018-03-24 20:23:00 -07002670 r"__init__\(\) missing 1 required positional argument: 'x'"):
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002671 C()
2672
2673 # We can create an instance, and assign to x.
2674 c = C(10)
2675 self.assertEqual(c.x, 10)
2676 c.x = 5
2677 self.assertEqual(c.x, 5)
2678
2679 # We can't assign to anything else.
2680 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2681 c.y = 5
2682
2683 def test_derived_added_field(self):
2684 # See bpo-33100.
2685 @dataclass
2686 class Base:
2687 __slots__ = ('x',)
2688 x: Any
2689
2690 @dataclass
2691 class Derived(Base):
2692 x: int
2693 y: int
2694
2695 d = Derived(1, 2)
2696 self.assertEqual((d.x, d.y), (1, 2))
2697
2698 # We can add a new field to the derived instance.
2699 d.z = 10
2700
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002701class TestDescriptors(unittest.TestCase):
2702 def test_set_name(self):
2703 # See bpo-33141.
2704
2705 # Create a descriptor.
2706 class D:
2707 def __set_name__(self, owner, name):
2708 self.name = name
2709 def __get__(self, instance, owner):
2710 if instance is not None:
2711 return 1
2712 return self
2713
2714 # This is the case of just normal descriptor behavior, no
2715 # dataclass code is involved in initializing the descriptor.
2716 @dataclass
2717 class C:
2718 c: int=D()
2719 self.assertEqual(C.c.name, 'c')
2720
2721 # Now test with a default value and init=False, which is the
2722 # only time this is really meaningful. If not using
2723 # init=False, then the descriptor will be overwritten, anyway.
2724 @dataclass
2725 class C:
2726 c: int=field(default=D(), init=False)
2727 self.assertEqual(C.c.name, 'c')
2728 self.assertEqual(C().c, 1)
2729
2730 def test_non_descriptor(self):
2731 # PEP 487 says __set_name__ should work on non-descriptors.
2732 # Create a descriptor.
2733
2734 class D:
2735 def __set_name__(self, owner, name):
2736 self.name = name
2737
2738 @dataclass
2739 class C:
2740 c: int=field(default=D(), init=False)
2741 self.assertEqual(C.c.name, 'c')
2742
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002743
Eric V. Smithf0db54a2017-12-04 16:58:55 -05002744if __name__ == '__main__':
2745 unittest.main()