blob: 7c39b79142b29498195d6c2fed4b88f76f2de853 [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
9import unittest
10from unittest.mock import Mock
Miss Islington (bot)d063ad82018-04-01 04:33:13 -070011from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050012from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050013from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050014
Miss Islington (bot)c73268a2018-05-15 21:22:13 -070015import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
16import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
17
Eric V. Smithf0db54a2017-12-04 16:58:55 -050018# Just any custom exception we can catch.
19class CustomError(Exception): pass
20
21class TestCase(unittest.TestCase):
22 def test_no_fields(self):
23 @dataclass
24 class C:
25 pass
26
27 o = C()
28 self.assertEqual(len(fields(C)), 0)
29
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -070030 def test_no_fields_but_member_variable(self):
31 @dataclass
32 class C:
33 i = 0
34
35 o = C()
36 self.assertEqual(len(fields(C)), 0)
37
Eric V. Smithf0db54a2017-12-04 16:58:55 -050038 def test_one_field_no_default(self):
39 @dataclass
40 class C:
41 x: int
42
43 o = C(42)
44 self.assertEqual(o.x, 42)
45
46 def test_named_init_params(self):
47 @dataclass
48 class C:
49 x: int
50
51 o = C(x=32)
52 self.assertEqual(o.x, 32)
53
54 def test_two_fields_one_default(self):
55 @dataclass
56 class C:
57 x: int
58 y: int = 0
59
60 o = C(3)
61 self.assertEqual((o.x, o.y), (3, 0))
62
63 # Non-defaults following defaults.
64 with self.assertRaisesRegex(TypeError,
65 "non-default argument 'y' follows "
66 "default argument"):
67 @dataclass
68 class C:
69 x: int = 0
70 y: int
71
72 # A derived class adds a non-default field after a default one.
73 with self.assertRaisesRegex(TypeError,
74 "non-default argument 'y' follows "
75 "default argument"):
76 @dataclass
77 class B:
78 x: int = 0
79
80 @dataclass
81 class C(B):
82 y: int
83
84 # Override a base class field and add a default to
85 # a field which didn't use to have a default.
86 with self.assertRaisesRegex(TypeError,
87 "non-default argument 'y' follows "
88 "default argument"):
89 @dataclass
90 class B:
91 x: int
92 y: int
93
94 @dataclass
95 class C(B):
96 x: int = 0
97
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080098 def test_overwrite_hash(self):
99 # Test that declaring this class isn't an error. It should
100 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500101 @dataclass(frozen=True)
102 class C:
103 x: int
104 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800105 return 301
106 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500107
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800108 # Test that declaring this class isn't an error. It should
109 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500110 @dataclass(frozen=True)
111 class C:
112 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800113 def __eq__(self, other):
114 return False
115 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500116
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800117 # But this one should generate an exception, because with
118 # unsafe_hash=True, it's an error to have a __hash__ defined.
119 with self.assertRaisesRegex(TypeError,
120 'Cannot overwrite attribute __hash__'):
121 @dataclass(unsafe_hash=True)
122 class C:
123 def __hash__(self):
124 pass
125
126 # Creating this class should not generate an exception,
127 # because even though __hash__ exists before @dataclass is
128 # called, (due to __eq__ being defined), since it's None
129 # that's okay.
130 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500131 class C:
132 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800133 def __eq__(self):
134 pass
135 # The generated hash function works as we'd expect.
136 self.assertEqual(hash(C(10)), hash((10,)))
137
138 # Creating this class should generate an exception, because
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700139 # __hash__ exists and is not None, which it would be if it
140 # had been auto-generated due to __eq__ being defined.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800141 with self.assertRaisesRegex(TypeError,
142 'Cannot overwrite attribute __hash__'):
143 @dataclass(unsafe_hash=True)
144 class C:
145 x: int
146 def __eq__(self):
147 pass
148 def __hash__(self):
149 pass
150
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500151 def test_overwrite_fields_in_derived_class(self):
152 # Note that x from C1 replaces x in Base, but the order remains
153 # the same as defined in Base.
154 @dataclass
155 class Base:
156 x: Any = 15.0
157 y: int = 0
158
159 @dataclass
160 class C1(Base):
161 z: int = 10
162 x: int = 15
163
164 o = Base()
165 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
166
167 o = C1()
168 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
169
170 o = C1(x=5)
171 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
172
173 def test_field_named_self(self):
174 @dataclass
175 class C:
176 self: str
177 c=C('foo')
178 self.assertEqual(c.self, 'foo')
179
180 # Make sure the first parameter is not named 'self'.
181 sig = inspect.signature(C.__init__)
182 first = next(iter(sig.parameters))
183 self.assertNotEqual('self', first)
184
185 # But we do use 'self' if no field named self.
186 @dataclass
187 class C:
188 selfx: str
189
190 # Make sure the first parameter is named 'self'.
191 sig = inspect.signature(C.__init__)
192 first = next(iter(sig.parameters))
193 self.assertEqual('self', first)
194
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500195 def test_0_field_compare(self):
196 # Ensure that order=False is the default.
197 @dataclass
198 class C0:
199 pass
200
201 @dataclass(order=False)
202 class C1:
203 pass
204
205 for cls in [C0, C1]:
206 with self.subTest(cls=cls):
207 self.assertEqual(cls(), cls())
208 for idx, fn in enumerate([lambda a, b: a < b,
209 lambda a, b: a <= b,
210 lambda a, b: a > b,
211 lambda a, b: a >= b]):
212 with self.subTest(idx=idx):
213 with self.assertRaisesRegex(TypeError,
214 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
215 fn(cls(), cls())
216
217 @dataclass(order=True)
218 class C:
219 pass
220 self.assertLessEqual(C(), C())
221 self.assertGreaterEqual(C(), C())
222
223 def test_1_field_compare(self):
224 # Ensure that order=False is the default.
225 @dataclass
226 class C0:
227 x: int
228
229 @dataclass(order=False)
230 class C1:
231 x: int
232
233 for cls in [C0, C1]:
234 with self.subTest(cls=cls):
235 self.assertEqual(cls(1), cls(1))
236 self.assertNotEqual(cls(0), cls(1))
237 for idx, fn in enumerate([lambda a, b: a < b,
238 lambda a, b: a <= b,
239 lambda a, b: a > b,
240 lambda a, b: a >= b]):
241 with self.subTest(idx=idx):
242 with self.assertRaisesRegex(TypeError,
243 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
244 fn(cls(0), cls(0))
245
246 @dataclass(order=True)
247 class C:
248 x: int
249 self.assertLess(C(0), C(1))
250 self.assertLessEqual(C(0), C(1))
251 self.assertLessEqual(C(1), C(1))
252 self.assertGreater(C(1), C(0))
253 self.assertGreaterEqual(C(1), C(0))
254 self.assertGreaterEqual(C(1), C(1))
255
256 def test_simple_compare(self):
257 # Ensure that order=False is the default.
258 @dataclass
259 class C0:
260 x: int
261 y: int
262
263 @dataclass(order=False)
264 class C1:
265 x: int
266 y: int
267
268 for cls in [C0, C1]:
269 with self.subTest(cls=cls):
270 self.assertEqual(cls(0, 0), cls(0, 0))
271 self.assertEqual(cls(1, 2), cls(1, 2))
272 self.assertNotEqual(cls(1, 0), cls(0, 0))
273 self.assertNotEqual(cls(1, 0), cls(1, 1))
274 for idx, fn in enumerate([lambda a, b: a < b,
275 lambda a, b: a <= b,
276 lambda a, b: a > b,
277 lambda a, b: a >= b]):
278 with self.subTest(idx=idx):
279 with self.assertRaisesRegex(TypeError,
280 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
281 fn(cls(0, 0), cls(0, 0))
282
283 @dataclass(order=True)
284 class C:
285 x: int
286 y: int
287
288 for idx, fn in enumerate([lambda a, b: a == b,
289 lambda a, b: a <= b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 self.assertTrue(fn(C(0, 0), C(0, 0)))
293
294 for idx, fn in enumerate([lambda a, b: a < b,
295 lambda a, b: a <= b,
296 lambda a, b: a != b]):
297 with self.subTest(idx=idx):
298 self.assertTrue(fn(C(0, 0), C(0, 1)))
299 self.assertTrue(fn(C(0, 1), C(1, 0)))
300 self.assertTrue(fn(C(1, 0), C(1, 1)))
301
302 for idx, fn in enumerate([lambda a, b: a > b,
303 lambda a, b: a >= b,
304 lambda a, b: a != b]):
305 with self.subTest(idx=idx):
306 self.assertTrue(fn(C(0, 1), C(0, 0)))
307 self.assertTrue(fn(C(1, 0), C(0, 1)))
308 self.assertTrue(fn(C(1, 1), C(1, 0)))
309
310 def test_compare_subclasses(self):
311 # Comparisons fail for subclasses, even if no fields
312 # are added.
313 @dataclass
314 class B:
315 i: int
316
317 @dataclass
318 class C(B):
319 pass
320
321 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
322 (lambda a, b: a != b, True)]):
323 with self.subTest(idx=idx):
324 self.assertEqual(fn(B(0), C(0)), expected)
325
326 for idx, fn in enumerate([lambda a, b: a < b,
327 lambda a, b: a <= b,
328 lambda a, b: a > b,
329 lambda a, b: a >= b]):
330 with self.subTest(idx=idx):
331 with self.assertRaisesRegex(TypeError,
332 "not supported between instances of 'B' and 'C'"):
333 fn(B(0), C(0))
334
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500335 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500336 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500337 for (eq, order, result ) in [
338 (False, False, 'neither'),
339 (False, True, 'exception'),
340 (True, False, 'eq_only'),
341 (True, True, 'both'),
342 ]:
343 with self.subTest(eq=eq, order=order):
344 if result == 'exception':
345 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
346 @dataclass(eq=eq, order=order)
347 class C:
348 pass
349 else:
350 @dataclass(eq=eq, order=order)
351 class C:
352 pass
353
354 if result == 'neither':
355 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500356 self.assertNotIn('__lt__', C.__dict__)
357 self.assertNotIn('__le__', C.__dict__)
358 self.assertNotIn('__gt__', C.__dict__)
359 self.assertNotIn('__ge__', C.__dict__)
360 elif result == 'both':
361 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500362 self.assertIn('__lt__', C.__dict__)
363 self.assertIn('__le__', C.__dict__)
364 self.assertIn('__gt__', C.__dict__)
365 self.assertIn('__ge__', C.__dict__)
366 elif result == 'eq_only':
367 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500368 self.assertNotIn('__lt__', C.__dict__)
369 self.assertNotIn('__le__', C.__dict__)
370 self.assertNotIn('__gt__', C.__dict__)
371 self.assertNotIn('__ge__', C.__dict__)
372 else:
373 assert False, f'unknown result {result!r}'
374
375 def test_field_no_default(self):
376 @dataclass
377 class C:
378 x: int = field()
379
380 self.assertEqual(C(5).x, 5)
381
382 with self.assertRaisesRegex(TypeError,
383 r"__init__\(\) missing 1 required "
384 "positional argument: 'x'"):
385 C()
386
387 def test_field_default(self):
388 default = object()
389 @dataclass
390 class C:
391 x: object = field(default=default)
392
393 self.assertIs(C.x, default)
394 c = C(10)
395 self.assertEqual(c.x, 10)
396
397 # If we delete the instance attribute, we should then see the
398 # class attribute.
399 del c.x
400 self.assertIs(c.x, default)
401
402 self.assertIs(C().x, default)
403
404 def test_not_in_repr(self):
405 @dataclass
406 class C:
407 x: int = field(repr=False)
408 with self.assertRaises(TypeError):
409 C()
410 c = C(10)
411 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
412
413 @dataclass
414 class C:
415 x: int = field(repr=False)
416 y: int
417 c = C(10, 20)
418 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
419
420 def test_not_in_compare(self):
421 @dataclass
422 class C:
423 x: int = 0
424 y: int = field(compare=False, default=4)
425
426 self.assertEqual(C(), C(0, 20))
427 self.assertEqual(C(1, 10), C(1, 20))
428 self.assertNotEqual(C(3), C(4, 10))
429 self.assertNotEqual(C(3, 10), C(4, 10))
430
431 def test_hash_field_rules(self):
432 # Test all 6 cases of:
433 # hash=True/False/None
434 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800435 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500436 (True, False, 'field' ),
437 (True, True, 'field' ),
438 (False, False, 'absent'),
439 (False, True, 'absent'),
440 (None, False, 'absent'),
441 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800442 ]:
443 with self.subTest(hash=hash_, compare=compare):
444 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500445 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800446 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500447
448 if result == 'field':
449 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800450 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500451 elif result == 'absent':
452 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800453 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500454 else:
455 assert False, f'unknown result {result!r}'
456
457 def test_init_false_no_default(self):
458 # If init=False and no default value, then the field won't be
459 # present in the instance.
460 @dataclass
461 class C:
462 x: int = field(init=False)
463
464 self.assertNotIn('x', C().__dict__)
465
466 @dataclass
467 class C:
468 x: int
469 y: int = 0
470 z: int = field(init=False)
471 t: int = 10
472
473 self.assertNotIn('z', C(0).__dict__)
474 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
475
476 def test_class_marker(self):
477 @dataclass
478 class C:
479 x: int
480 y: str = field(init=False, default=None)
481 z: str = field(repr=False)
482
483 the_fields = fields(C)
484 # the_fields is a tuple of 3 items, each value
485 # is in __annotations__.
486 self.assertIsInstance(the_fields, tuple)
487 for f in the_fields:
488 self.assertIs(type(f), Field)
489 self.assertIn(f.name, C.__annotations__)
490
491 self.assertEqual(len(the_fields), 3)
492
493 self.assertEqual(the_fields[0].name, 'x')
494 self.assertEqual(the_fields[0].type, int)
495 self.assertFalse(hasattr(C, 'x'))
496 self.assertTrue (the_fields[0].init)
497 self.assertTrue (the_fields[0].repr)
498 self.assertEqual(the_fields[1].name, 'y')
499 self.assertEqual(the_fields[1].type, str)
500 self.assertIsNone(getattr(C, 'y'))
501 self.assertFalse(the_fields[1].init)
502 self.assertTrue (the_fields[1].repr)
503 self.assertEqual(the_fields[2].name, 'z')
504 self.assertEqual(the_fields[2].type, str)
505 self.assertFalse(hasattr(C, 'z'))
506 self.assertTrue (the_fields[2].init)
507 self.assertFalse(the_fields[2].repr)
508
509 def test_field_order(self):
510 @dataclass
511 class B:
512 a: str = 'B:a'
513 b: str = 'B:b'
514 c: str = 'B:c'
515
516 @dataclass
517 class C(B):
518 b: str = 'C:b'
519
520 self.assertEqual([(f.name, f.default) for f in fields(C)],
521 [('a', 'B:a'),
522 ('b', 'C:b'),
523 ('c', 'B:c')])
524
525 @dataclass
526 class D(B):
527 c: str = 'D:c'
528
529 self.assertEqual([(f.name, f.default) for f in fields(D)],
530 [('a', 'B:a'),
531 ('b', 'B:b'),
532 ('c', 'D:c')])
533
534 @dataclass
535 class E(D):
536 a: str = 'E:a'
537 d: str = 'E:d'
538
539 self.assertEqual([(f.name, f.default) for f in fields(E)],
540 [('a', 'E:a'),
541 ('b', 'B:b'),
542 ('c', 'D:c'),
543 ('d', 'E:d')])
544
545 def test_class_attrs(self):
546 # We only have a class attribute if a default value is
547 # specified, either directly or via a field with a default.
548 default = object()
549 @dataclass
550 class C:
551 x: int
552 y: int = field(repr=False)
553 z: object = default
554 t: int = field(default=100)
555
556 self.assertFalse(hasattr(C, 'x'))
557 self.assertFalse(hasattr(C, 'y'))
558 self.assertIs (C.z, default)
559 self.assertEqual(C.t, 100)
560
561 def test_disallowed_mutable_defaults(self):
562 # For the known types, don't allow mutable default values.
563 for typ, empty, non_empty in [(list, [], [1]),
564 (dict, {}, {0:1}),
565 (set, set(), set([1])),
566 ]:
567 with self.subTest(typ=typ):
568 # Can't use a zero-length value.
569 with self.assertRaisesRegex(ValueError,
570 f'mutable default {typ} for field '
571 'x is not allowed'):
572 @dataclass
573 class Point:
574 x: typ = empty
575
576
577 # Nor a non-zero-length value
578 with self.assertRaisesRegex(ValueError,
579 f'mutable default {typ} for field '
580 'y is not allowed'):
581 @dataclass
582 class Point:
583 y: typ = non_empty
584
585 # Check subtypes also fail.
586 class Subclass(typ): pass
587
588 with self.assertRaisesRegex(ValueError,
589 f"mutable default .*Subclass'>"
590 ' for field z is not allowed'
591 ):
592 @dataclass
593 class Point:
594 z: typ = Subclass()
595
596 # Because this is a ClassVar, it can be mutable.
597 @dataclass
598 class C:
599 z: ClassVar[typ] = typ()
600
601 # Because this is a ClassVar, it can be mutable.
602 @dataclass
603 class C:
604 x: ClassVar[typ] = Subclass()
605
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500606 def test_deliberately_mutable_defaults(self):
607 # If a mutable default isn't in the known list of
608 # (list, dict, set), then it's okay.
609 class Mutable:
610 def __init__(self):
611 self.l = []
612
613 @dataclass
614 class C:
615 x: Mutable
616
617 # These 2 instances will share this value of x.
618 lst = Mutable()
619 o1 = C(lst)
620 o2 = C(lst)
621 self.assertEqual(o1, o2)
622 o1.x.l.extend([1, 2])
623 self.assertEqual(o1, o2)
624 self.assertEqual(o1.x.l, [1, 2])
625 self.assertIs(o1.x, o2.x)
626
627 def test_no_options(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700628 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500629 @dataclass()
630 class C:
631 x: int
632
633 self.assertEqual(C(42).x, 42)
634
635 def test_not_tuple(self):
636 # Make sure we can't be compared to a tuple.
637 @dataclass
638 class Point:
639 x: int
640 y: int
641 self.assertNotEqual(Point(1, 2), (1, 2))
642
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700643 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500644 @dataclass
645 class C:
646 x: int
647 y: int
648 self.assertNotEqual(Point(1, 3), C(1, 3))
649
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500650 def test_not_tuple(self):
651 # Test that some of the problems with namedtuple don't happen
652 # here.
653 @dataclass
654 class Point3D:
655 x: int
656 y: int
657 z: int
658
659 @dataclass
660 class Date:
661 year: int
662 month: int
663 day: int
664
665 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
666 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
667
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700668 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200669 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500670 x, y, z = Point3D(4, 5, 6)
671
Eric V. Smith7c99e932018-01-28 19:18:55 -0500672 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500673 # equal.
674 @dataclass
675 class Point3Dv1:
676 x: int = 0
677 y: int = 0
678 z: int = 0
679 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
680
681 def test_function_annotations(self):
682 # Some dummy class and instance to use as a default.
683 class F:
684 pass
685 f = F()
686
687 def validate_class(cls):
688 # First, check __annotations__, even though they're not
689 # function annotations.
690 self.assertEqual(cls.__annotations__['i'], int)
691 self.assertEqual(cls.__annotations__['j'], str)
692 self.assertEqual(cls.__annotations__['k'], F)
693 self.assertEqual(cls.__annotations__['l'], float)
694 self.assertEqual(cls.__annotations__['z'], complex)
695
696 # Verify __init__.
697
698 signature = inspect.signature(cls.__init__)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700699 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500700 self.assertIs(signature.return_annotation, None)
701
702 # Check each parameter.
703 params = iter(signature.parameters.values())
704 param = next(params)
705 # This is testing an internal name, and probably shouldn't be tested.
706 self.assertEqual(param.name, 'self')
707 param = next(params)
708 self.assertEqual(param.name, 'i')
709 self.assertIs (param.annotation, int)
710 self.assertEqual(param.default, inspect.Parameter.empty)
711 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
712 param = next(params)
713 self.assertEqual(param.name, 'j')
714 self.assertIs (param.annotation, str)
715 self.assertEqual(param.default, inspect.Parameter.empty)
716 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
717 param = next(params)
718 self.assertEqual(param.name, 'k')
719 self.assertIs (param.annotation, F)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700720 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500721 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
722 param = next(params)
723 self.assertEqual(param.name, 'l')
724 self.assertIs (param.annotation, float)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700725 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500726 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
727 self.assertRaises(StopIteration, next, params)
728
729
730 @dataclass
731 class C:
732 i: int
733 j: str
734 k: F = f
735 l: float=field(default=None)
736 z: complex=field(default=3+4j, init=False)
737
738 validate_class(C)
739
740 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800741 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500742 class C:
743 i: int
744 j: str
745 k: F = f
746 l: float=field(default=None)
747 z: complex=field(default=3+4j, init=False)
748
749 validate_class(C)
750
Eric V. Smith03220fd2017-12-29 13:59:58 -0500751 def test_missing_default(self):
752 # Test that MISSING works the same as a default not being
753 # specified.
754 @dataclass
755 class C:
756 x: int=field(default=MISSING)
757 with self.assertRaisesRegex(TypeError,
758 r'__init__\(\) missing 1 required '
759 'positional argument'):
760 C()
761 self.assertNotIn('x', C.__dict__)
762
763 @dataclass
764 class D:
765 x: int
766 with self.assertRaisesRegex(TypeError,
767 r'__init__\(\) missing 1 required '
768 'positional argument'):
769 D()
770 self.assertNotIn('x', D.__dict__)
771
772 def test_missing_default_factory(self):
773 # Test that MISSING works the same as a default factory not
774 # being specified (which is really the same as a default not
775 # being specified, too).
776 @dataclass
777 class C:
778 x: int=field(default_factory=MISSING)
779 with self.assertRaisesRegex(TypeError,
780 r'__init__\(\) missing 1 required '
781 'positional argument'):
782 C()
783 self.assertNotIn('x', C.__dict__)
784
785 @dataclass
786 class D:
787 x: int=field(default=MISSING, default_factory=MISSING)
788 with self.assertRaisesRegex(TypeError,
789 r'__init__\(\) missing 1 required '
790 'positional argument'):
791 D()
792 self.assertNotIn('x', D.__dict__)
793
794 def test_missing_repr(self):
795 self.assertIn('MISSING_TYPE object', repr(MISSING))
796
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500797 def test_dont_include_other_annotations(self):
798 @dataclass
799 class C:
800 i: int
801 def foo(self) -> int:
802 return 4
803 @property
804 def bar(self) -> int:
805 return 5
806 self.assertEqual(list(C.__annotations__), ['i'])
807 self.assertEqual(C(10).foo(), 4)
808 self.assertEqual(C(10).bar, 5)
Miss Islington (bot)5666a552018-03-25 06:27:50 -0700809 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500810
811 def test_post_init(self):
812 # Just make sure it gets called
813 @dataclass
814 class C:
815 def __post_init__(self):
816 raise CustomError()
817 with self.assertRaises(CustomError):
818 C()
819
820 @dataclass
821 class C:
822 i: int = 10
823 def __post_init__(self):
824 if self.i == 10:
825 raise CustomError()
826 with self.assertRaises(CustomError):
827 C()
828 # post-init gets called, but doesn't raise. This is just
829 # checking that self is used correctly.
830 C(5)
831
832 # If there's not an __init__, then post-init won't get called.
833 @dataclass(init=False)
834 class C:
835 def __post_init__(self):
836 raise CustomError()
837 # Creating the class won't raise
838 C()
839
840 @dataclass
841 class C:
842 x: int = 0
843 def __post_init__(self):
844 self.x *= 2
845 self.assertEqual(C().x, 0)
846 self.assertEqual(C(2).x, 4)
847
Mike53f7a7c2017-12-14 14:04:53 +0300848 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500849 # attributes.
850 @dataclass(frozen=True)
851 class C:
852 x: int = 0
853 def __post_init__(self):
854 self.x *= 2
855 with self.assertRaises(FrozenInstanceError):
856 C()
857
858 def test_post_init_super(self):
859 # Make sure super() post-init isn't called by default.
860 class B:
861 def __post_init__(self):
862 raise CustomError()
863
864 @dataclass
865 class C(B):
866 def __post_init__(self):
867 self.x = 5
868
869 self.assertEqual(C().x, 5)
870
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700871 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500872 @dataclass
873 class C(B):
874 def __post_init__(self):
875 super().__post_init__()
876
877 with self.assertRaises(CustomError):
878 C()
879
880 # Make sure post-init is called, even if not defined in our
881 # class.
882 @dataclass
883 class C(B):
884 pass
885
886 with self.assertRaises(CustomError):
887 C()
888
889 def test_post_init_staticmethod(self):
890 flag = False
891 @dataclass
892 class C:
893 x: int
894 y: int
895 @staticmethod
896 def __post_init__():
897 nonlocal flag
898 flag = True
899
900 self.assertFalse(flag)
901 c = C(3, 4)
902 self.assertEqual((c.x, c.y), (3, 4))
903 self.assertTrue(flag)
904
905 def test_post_init_classmethod(self):
906 @dataclass
907 class C:
908 flag = False
909 x: int
910 y: int
911 @classmethod
912 def __post_init__(cls):
913 cls.flag = True
914
915 self.assertFalse(C.flag)
916 c = C(3, 4)
917 self.assertEqual((c.x, c.y), (3, 4))
918 self.assertTrue(C.flag)
919
920 def test_class_var(self):
921 # Make sure ClassVars are ignored in __init__, __repr__, etc.
922 @dataclass
923 class C:
924 x: int
925 y: int = 10
926 z: ClassVar[int] = 1000
927 w: ClassVar[int] = 2000
928 t: ClassVar[int] = 3000
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700929 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500930
931 c = C(5)
932 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700933 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700934 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500935 self.assertEqual(c.z, 1000)
936 self.assertEqual(c.w, 2000)
937 self.assertEqual(c.t, 3000)
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700938 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500939 C.z += 1
940 self.assertEqual(c.z, 1001)
941 c = C(20)
942 self.assertEqual((c.x, c.y), (20, 10))
943 self.assertEqual(c.z, 1001)
944 self.assertEqual(c.w, 2000)
945 self.assertEqual(c.t, 3000)
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700946 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500947
948 def test_class_var_no_default(self):
949 # If a ClassVar has no default value, it should not be set on the class.
950 @dataclass
951 class C:
952 x: ClassVar[int]
953
954 self.assertNotIn('x', C.__dict__)
955
956 def test_class_var_default_factory(self):
957 # It makes no sense for a ClassVar to have a default factory. When
958 # would it be called? Call it yourself, since it's class-wide.
959 with self.assertRaisesRegex(TypeError,
960 'cannot have a default factory'):
961 @dataclass
962 class C:
963 x: ClassVar[int] = field(default_factory=int)
964
965 self.assertNotIn('x', C.__dict__)
966
967 def test_class_var_with_default(self):
968 # If a ClassVar has a default value, it should be set on the class.
969 @dataclass
970 class C:
971 x: ClassVar[int] = 10
972 self.assertEqual(C.x, 10)
973
974 @dataclass
975 class C:
976 x: ClassVar[int] = field(default=10)
977 self.assertEqual(C.x, 10)
978
979 def test_class_var_frozen(self):
980 # Make sure ClassVars work even if we're frozen.
981 @dataclass(frozen=True)
982 class C:
983 x: int
984 y: int = 10
985 z: ClassVar[int] = 1000
986 w: ClassVar[int] = 2000
987 t: ClassVar[int] = 3000
988
989 c = C(5)
990 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
991 self.assertEqual(len(fields(C)), 2) # We have 2 fields
992 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
993 self.assertEqual(c.z, 1000)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
996 # We can still modify the ClassVar, it's only instances that are
997 # frozen.
998 C.z += 1
999 self.assertEqual(c.z, 1001)
1000 c = C(20)
1001 self.assertEqual((c.x, c.y), (20, 10))
1002 self.assertEqual(c.z, 1001)
1003 self.assertEqual(c.w, 2000)
1004 self.assertEqual(c.t, 3000)
1005
1006 def test_init_var_no_default(self):
1007 # If an InitVar has no default value, it should not be set on the class.
1008 @dataclass
1009 class C:
1010 x: InitVar[int]
1011
1012 self.assertNotIn('x', C.__dict__)
1013
1014 def test_init_var_default_factory(self):
1015 # It makes no sense for an InitVar to have a default factory. When
1016 # would it be called? Call it yourself, since it's class-wide.
1017 with self.assertRaisesRegex(TypeError,
1018 'cannot have a default factory'):
1019 @dataclass
1020 class C:
1021 x: InitVar[int] = field(default_factory=int)
1022
1023 self.assertNotIn('x', C.__dict__)
1024
1025 def test_init_var_with_default(self):
1026 # If an InitVar has a default value, it should be set on the class.
1027 @dataclass
1028 class C:
1029 x: InitVar[int] = 10
1030 self.assertEqual(C.x, 10)
1031
1032 @dataclass
1033 class C:
1034 x: InitVar[int] = field(default=10)
1035 self.assertEqual(C.x, 10)
1036
1037 def test_init_var(self):
1038 @dataclass
1039 class C:
1040 x: int = None
1041 init_param: InitVar[int] = None
1042
1043 def __post_init__(self, init_param):
1044 if self.x is None:
1045 self.x = init_param*2
1046
1047 c = C(init_param=10)
1048 self.assertEqual(c.x, 20)
1049
1050 def test_init_var_inheritance(self):
1051 # Note that this deliberately tests that a dataclass need not
1052 # have a __post_init__ function if it has an InitVar field.
1053 # It could just be used in a derived class, as shown here.
1054 @dataclass
1055 class Base:
1056 x: int
1057 init_base: InitVar[int]
1058
1059 # We can instantiate by passing the InitVar, even though
1060 # it's not used.
1061 b = Base(0, 10)
1062 self.assertEqual(vars(b), {'x': 0})
1063
1064 @dataclass
1065 class C(Base):
1066 y: int
1067 init_derived: InitVar[int]
1068
1069 def __post_init__(self, init_base, init_derived):
1070 self.x = self.x + init_base
1071 self.y = self.y + init_derived
1072
1073 c = C(10, 11, 50, 51)
1074 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1075
1076 def test_default_factory(self):
1077 # Test a factory that returns a new list.
1078 @dataclass
1079 class C:
1080 x: int
1081 y: list = field(default_factory=list)
1082
1083 c0 = C(3)
1084 c1 = C(3)
1085 self.assertEqual(c0.x, 3)
1086 self.assertEqual(c0.y, [])
1087 self.assertEqual(c0, c1)
1088 self.assertIsNot(c0.y, c1.y)
1089 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1090
1091 # Test a factory that returns a shared list.
1092 l = []
1093 @dataclass
1094 class C:
1095 x: int
1096 y: list = field(default_factory=lambda: l)
1097
1098 c0 = C(3)
1099 c1 = C(3)
1100 self.assertEqual(c0.x, 3)
1101 self.assertEqual(c0.y, [])
1102 self.assertEqual(c0, c1)
1103 self.assertIs(c0.y, c1.y)
1104 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1105
1106 # Test various other field flags.
1107 # repr
1108 @dataclass
1109 class C:
1110 x: list = field(default_factory=list, repr=False)
1111 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1112 self.assertEqual(C().x, [])
1113
1114 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001115 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001116 class C:
1117 x: list = field(default_factory=list, hash=False)
1118 self.assertEqual(astuple(C()), ([],))
1119 self.assertEqual(hash(C()), hash(()))
1120
1121 # init (see also test_default_factory_with_no_init)
1122 @dataclass
1123 class C:
1124 x: list = field(default_factory=list, init=False)
1125 self.assertEqual(astuple(C()), ([],))
1126
1127 # compare
1128 @dataclass
1129 class C:
1130 x: list = field(default_factory=list, compare=False)
1131 self.assertEqual(C(), C([1]))
1132
1133 def test_default_factory_with_no_init(self):
1134 # We need a factory with a side effect.
1135 factory = Mock()
1136
1137 @dataclass
1138 class C:
1139 x: list = field(default_factory=factory, init=False)
1140
1141 # Make sure the default factory is called for each new instance.
1142 C().x
1143 self.assertEqual(factory.call_count, 1)
1144 C().x
1145 self.assertEqual(factory.call_count, 2)
1146
1147 def test_default_factory_not_called_if_value_given(self):
1148 # We need a factory that we can test if it's been called.
1149 factory = Mock()
1150
1151 @dataclass
1152 class C:
1153 x: int = field(default_factory=factory)
1154
1155 # Make sure that if a field has a default factory function,
1156 # it's not called if a value is specified.
1157 C().x
1158 self.assertEqual(factory.call_count, 1)
1159 self.assertEqual(C(10).x, 10)
1160 self.assertEqual(factory.call_count, 1)
1161 C().x
1162 self.assertEqual(factory.call_count, 2)
1163
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001164 def test_default_factory_derived(self):
1165 # See bpo-32896.
1166 @dataclass
1167 class Foo:
1168 x: dict = field(default_factory=dict)
1169
1170 @dataclass
1171 class Bar(Foo):
1172 y: int = 1
1173
1174 self.assertEqual(Foo().x, {})
1175 self.assertEqual(Bar().x, {})
1176 self.assertEqual(Bar().y, 1)
1177
1178 @dataclass
1179 class Baz(Foo):
1180 pass
1181 self.assertEqual(Baz().x, {})
1182
1183 def test_intermediate_non_dataclass(self):
1184 # Test that an intermediate class that defines
1185 # annotations does not define fields.
1186
1187 @dataclass
1188 class A:
1189 x: int
1190
1191 class B(A):
1192 y: int
1193
1194 @dataclass
1195 class C(B):
1196 z: int
1197
1198 c = C(1, 3)
1199 self.assertEqual((c.x, c.z), (1, 3))
1200
1201 # .y was not initialized.
1202 with self.assertRaisesRegex(AttributeError,
1203 'object has no attribute'):
1204 c.y
1205
1206 # And if we again derive a non-dataclass, no fields are added.
1207 class D(C):
1208 t: int
1209 d = D(4, 5)
1210 self.assertEqual((d.x, d.z), (4, 5))
1211
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001212 def test_classvar_default_factory(self):
1213 # It's an error for a ClassVar to have a factory function.
1214 with self.assertRaisesRegex(TypeError,
1215 'cannot have a default factory'):
1216 @dataclass
1217 class C:
1218 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001219
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001220 def test_is_dataclass(self):
1221 class NotDataClass:
1222 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001223
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001224 self.assertFalse(is_dataclass(0))
1225 self.assertFalse(is_dataclass(int))
1226 self.assertFalse(is_dataclass(NotDataClass))
1227 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001228
1229 @dataclass
1230 class C:
1231 x: int
1232
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001233 @dataclass
1234 class D:
1235 d: C
1236 e: int
1237
1238 c = C(10)
1239 d = D(c, 4)
1240
1241 self.assertTrue(is_dataclass(C))
1242 self.assertTrue(is_dataclass(c))
1243 self.assertFalse(is_dataclass(c.x))
1244 self.assertTrue(is_dataclass(d.d))
1245 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001246
1247 def test_helper_fields_with_class_instance(self):
1248 # Check that we can call fields() on either a class or instance,
1249 # and get back the same thing.
1250 @dataclass
1251 class C:
1252 x: int
1253 y: float
1254
1255 self.assertEqual(fields(C), fields(C(0, 0.0)))
1256
1257 def test_helper_fields_exception(self):
1258 # Check that TypeError is raised if not passed a dataclass or
1259 # instance.
1260 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1261 fields(0)
1262
1263 class C: pass
1264 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1265 fields(C)
1266 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1267 fields(C())
1268
1269 def test_helper_asdict(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001270 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001271 @dataclass
1272 class C:
1273 x: int
1274 y: int
1275 c = C(1, 2)
1276
1277 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1278 self.assertEqual(asdict(c), asdict(c))
1279 self.assertIsNot(asdict(c), asdict(c))
1280 c.x = 42
1281 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1282 self.assertIs(type(asdict(c)), dict)
1283
1284 def test_helper_asdict_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001285 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001286 @dataclass
1287 class C:
1288 x: int
1289 y: int
1290 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1291 asdict(C)
1292 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1293 asdict(int)
1294
1295 def test_helper_asdict_copy_values(self):
1296 @dataclass
1297 class C:
1298 x: int
1299 y: List[int] = field(default_factory=list)
1300 initial = []
1301 c = C(1, initial)
1302 d = asdict(c)
1303 self.assertEqual(d['y'], initial)
1304 self.assertIsNot(d['y'], initial)
1305 c = C(1)
1306 d = asdict(c)
1307 d['y'].append(1)
1308 self.assertEqual(c.y, [])
1309
1310 def test_helper_asdict_nested(self):
1311 @dataclass
1312 class UserId:
1313 token: int
1314 group: int
1315 @dataclass
1316 class User:
1317 name: str
1318 id: UserId
1319 u = User('Joe', UserId(123, 1))
1320 d = asdict(u)
1321 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1322 self.assertIsNot(asdict(u), asdict(u))
1323 u.id.group = 2
1324 self.assertEqual(asdict(u), {'name': 'Joe',
1325 'id': {'token': 123, 'group': 2}})
1326
1327 def test_helper_asdict_builtin_containers(self):
1328 @dataclass
1329 class User:
1330 name: str
1331 id: int
1332 @dataclass
1333 class GroupList:
1334 id: int
1335 users: List[User]
1336 @dataclass
1337 class GroupTuple:
1338 id: int
1339 users: Tuple[User, ...]
1340 @dataclass
1341 class GroupDict:
1342 id: int
1343 users: Dict[str, User]
1344 a = User('Alice', 1)
1345 b = User('Bob', 2)
1346 gl = GroupList(0, [a, b])
1347 gt = GroupTuple(0, (a, b))
1348 gd = GroupDict(0, {'first': a, 'second': b})
1349 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1350 {'name': 'Bob', 'id': 2}]})
1351 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1352 {'name': 'Bob', 'id': 2})})
1353 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1354 'second': {'name': 'Bob', 'id': 2}}})
1355
1356 def test_helper_asdict_builtin_containers(self):
1357 @dataclass
1358 class Child:
1359 d: object
1360
1361 @dataclass
1362 class Parent:
1363 child: Child
1364
1365 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1366 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1367
1368 def test_helper_asdict_factory(self):
1369 @dataclass
1370 class C:
1371 x: int
1372 y: int
1373 c = C(1, 2)
1374 d = asdict(c, dict_factory=OrderedDict)
1375 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1376 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1377 c.x = 42
1378 d = asdict(c, dict_factory=OrderedDict)
1379 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1380 self.assertIs(type(d), OrderedDict)
1381
1382 def test_helper_astuple(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001383 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001384 @dataclass
1385 class C:
1386 x: int
1387 y: int = 0
1388 c = C(1)
1389
1390 self.assertEqual(astuple(c), (1, 0))
1391 self.assertEqual(astuple(c), astuple(c))
1392 self.assertIsNot(astuple(c), astuple(c))
1393 c.y = 42
1394 self.assertEqual(astuple(c), (1, 42))
1395 self.assertIs(type(astuple(c)), tuple)
1396
1397 def test_helper_astuple_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001398 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001399 @dataclass
1400 class C:
1401 x: int
1402 y: int
1403 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1404 astuple(C)
1405 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1406 astuple(int)
1407
1408 def test_helper_astuple_copy_values(self):
1409 @dataclass
1410 class C:
1411 x: int
1412 y: List[int] = field(default_factory=list)
1413 initial = []
1414 c = C(1, initial)
1415 t = astuple(c)
1416 self.assertEqual(t[1], initial)
1417 self.assertIsNot(t[1], initial)
1418 c = C(1)
1419 t = astuple(c)
1420 t[1].append(1)
1421 self.assertEqual(c.y, [])
1422
1423 def test_helper_astuple_nested(self):
1424 @dataclass
1425 class UserId:
1426 token: int
1427 group: int
1428 @dataclass
1429 class User:
1430 name: str
1431 id: UserId
1432 u = User('Joe', UserId(123, 1))
1433 t = astuple(u)
1434 self.assertEqual(t, ('Joe', (123, 1)))
1435 self.assertIsNot(astuple(u), astuple(u))
1436 u.id.group = 2
1437 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1438
1439 def test_helper_astuple_builtin_containers(self):
1440 @dataclass
1441 class User:
1442 name: str
1443 id: int
1444 @dataclass
1445 class GroupList:
1446 id: int
1447 users: List[User]
1448 @dataclass
1449 class GroupTuple:
1450 id: int
1451 users: Tuple[User, ...]
1452 @dataclass
1453 class GroupDict:
1454 id: int
1455 users: Dict[str, User]
1456 a = User('Alice', 1)
1457 b = User('Bob', 2)
1458 gl = GroupList(0, [a, b])
1459 gt = GroupTuple(0, (a, b))
1460 gd = GroupDict(0, {'first': a, 'second': b})
1461 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1462 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1463 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1464
1465 def test_helper_astuple_builtin_containers(self):
1466 @dataclass
1467 class Child:
1468 d: object
1469
1470 @dataclass
1471 class Parent:
1472 child: Child
1473
1474 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1475 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1476
1477 def test_helper_astuple_factory(self):
1478 @dataclass
1479 class C:
1480 x: int
1481 y: int
1482 NT = namedtuple('NT', 'x y')
1483 def nt(lst):
1484 return NT(*lst)
1485 c = C(1, 2)
1486 t = astuple(c, tuple_factory=nt)
1487 self.assertEqual(t, NT(1, 2))
1488 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1489 c.x = 42
1490 t = astuple(c, tuple_factory=nt)
1491 self.assertEqual(t, NT(42, 2))
1492 self.assertIs(type(t), NT)
1493
1494 def test_dynamic_class_creation(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001495 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001496 }
1497
1498 # Create the class.
1499 cls = type('C', (), cls_dict)
1500
1501 # Make it a dataclass.
1502 cls1 = dataclass(cls)
1503
1504 self.assertEqual(cls1, cls)
1505 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1506
1507 def test_dynamic_class_creation_using_field(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001508 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001509 'y': field(default=5),
1510 }
1511
1512 # Create the class.
1513 cls = type('C', (), cls_dict)
1514
1515 # Make it a dataclass.
1516 cls1 = dataclass(cls)
1517
1518 self.assertEqual(cls1, cls)
1519 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1520
1521 def test_init_in_order(self):
1522 @dataclass
1523 class C:
1524 a: int
1525 b: int = field()
1526 c: list = field(default_factory=list, init=False)
1527 d: list = field(default_factory=list)
1528 e: int = field(default=4, init=False)
1529 f: int = 4
1530
1531 calls = []
1532 def setattr(self, name, value):
1533 calls.append((name, value))
1534
1535 C.__setattr__ = setattr
1536 c = C(0, 1)
1537 self.assertEqual(('a', 0), calls[0])
1538 self.assertEqual(('b', 1), calls[1])
1539 self.assertEqual(('c', []), calls[2])
1540 self.assertEqual(('d', []), calls[3])
1541 self.assertNotIn(('e', 4), calls)
1542 self.assertEqual(('f', 4), calls[4])
1543
1544 def test_items_in_dicts(self):
1545 @dataclass
1546 class C:
1547 a: int
1548 b: list = field(default_factory=list, init=False)
1549 c: list = field(default_factory=list)
1550 d: int = field(default=4, init=False)
1551 e: int = 0
1552
1553 c = C(0)
1554 # Class dict
1555 self.assertNotIn('a', C.__dict__)
1556 self.assertNotIn('b', C.__dict__)
1557 self.assertNotIn('c', C.__dict__)
1558 self.assertIn('d', C.__dict__)
1559 self.assertEqual(C.d, 4)
1560 self.assertIn('e', C.__dict__)
1561 self.assertEqual(C.e, 0)
1562 # Instance dict
1563 self.assertIn('a', c.__dict__)
1564 self.assertEqual(c.a, 0)
1565 self.assertIn('b', c.__dict__)
1566 self.assertEqual(c.b, [])
1567 self.assertIn('c', c.__dict__)
1568 self.assertEqual(c.c, [])
1569 self.assertNotIn('d', c.__dict__)
1570 self.assertIn('e', c.__dict__)
1571 self.assertEqual(c.e, 0)
1572
1573 def test_alternate_classmethod_constructor(self):
1574 # Since __post_init__ can't take params, use a classmethod
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001575 # alternate constructor. This is mostly an example to show
1576 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001577 @dataclass
1578 class C:
1579 x: int
1580 @classmethod
1581 def from_file(cls, filename):
1582 # In a real example, create a new instance
1583 # and populate 'x' from contents of a file.
1584 value_in_file = 20
1585 return cls(value_in_file)
1586
1587 self.assertEqual(C.from_file('filename').x, 20)
1588
1589 def test_field_metadata_default(self):
1590 # Make sure the default metadata is read-only and of
1591 # zero length.
1592 @dataclass
1593 class C:
1594 i: int
1595
1596 self.assertFalse(fields(C)[0].metadata)
1597 self.assertEqual(len(fields(C)[0].metadata), 0)
1598 with self.assertRaisesRegex(TypeError,
1599 'does not support item assignment'):
1600 fields(C)[0].metadata['test'] = 3
1601
1602 def test_field_metadata_mapping(self):
1603 # Make sure only a mapping can be passed as metadata
1604 # zero length.
1605 with self.assertRaises(TypeError):
1606 @dataclass
1607 class C:
1608 i: int = field(metadata=0)
1609
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001610 # Make sure an empty dict works.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001611 @dataclass
1612 class C:
1613 i: int = field(metadata={})
1614 self.assertFalse(fields(C)[0].metadata)
1615 self.assertEqual(len(fields(C)[0].metadata), 0)
1616 with self.assertRaisesRegex(TypeError,
1617 'does not support item assignment'):
1618 fields(C)[0].metadata['test'] = 3
1619
1620 # Make sure a non-empty dict works.
1621 @dataclass
1622 class C:
1623 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1624 self.assertEqual(len(fields(C)[0].metadata), 3)
1625 self.assertEqual(fields(C)[0].metadata['test'], 10)
1626 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1627 self.assertEqual(fields(C)[0].metadata[3], 'three')
1628 with self.assertRaises(KeyError):
1629 # Non-existent key.
1630 fields(C)[0].metadata['baz']
1631 with self.assertRaisesRegex(TypeError,
1632 'does not support item assignment'):
1633 fields(C)[0].metadata['test'] = 3
1634
1635 def test_field_metadata_custom_mapping(self):
1636 # Try a custom mapping.
1637 class SimpleNameSpace:
1638 def __init__(self, **kw):
1639 self.__dict__.update(kw)
1640
1641 def __getitem__(self, item):
1642 if item == 'xyzzy':
1643 return 'plugh'
1644 return getattr(self, item)
1645
1646 def __len__(self):
1647 return self.__dict__.__len__()
1648
1649 @dataclass
1650 class C:
1651 i: int = field(metadata=SimpleNameSpace(a=10))
1652
1653 self.assertEqual(len(fields(C)[0].metadata), 1)
1654 self.assertEqual(fields(C)[0].metadata['a'], 10)
1655 with self.assertRaises(AttributeError):
1656 fields(C)[0].metadata['b']
1657 # Make sure we're still talking to our custom mapping.
1658 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1659
1660 def test_generic_dataclasses(self):
1661 T = TypeVar('T')
1662
1663 @dataclass
1664 class LabeledBox(Generic[T]):
1665 content: T
1666 label: str = '<unknown>'
1667
1668 box = LabeledBox(42)
1669 self.assertEqual(box.content, 42)
1670 self.assertEqual(box.label, '<unknown>')
1671
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001672 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001673 Alias = List[LabeledBox[int]]
1674
1675 def test_generic_extending(self):
1676 S = TypeVar('S')
1677 T = TypeVar('T')
1678
1679 @dataclass
1680 class Base(Generic[T, S]):
1681 x: T
1682 y: S
1683
1684 @dataclass
1685 class DataDerived(Base[int, T]):
1686 new_field: str
1687 Alias = DataDerived[str]
1688 c = Alias(0, 'test1', 'test2')
1689 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1690
1691 class NonDataDerived(Base[int, T]):
1692 def new_method(self):
1693 return self.y
1694 Alias = NonDataDerived[float]
1695 c = Alias(10, 1.0)
1696 self.assertEqual(c.new_method(), 1.0)
1697
Miss Islington (bot)d063ad82018-04-01 04:33:13 -07001698 def test_generic_dynamic(self):
1699 T = TypeVar('T')
1700
1701 @dataclass
1702 class Parent(Generic[T]):
1703 x: T
1704 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1705 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1706 self.assertIs(Child[int](1, 2).z, None)
1707 self.assertEqual(Child[int](1, 2, 3).z, 3)
1708 self.assertEqual(Child[int](1, 2, 3).other, 42)
1709 # Check that type aliases work correctly.
1710 Alias = Child[T]
1711 self.assertEqual(Alias[int](1, 2).x, 1)
1712 # Check MRO resolution.
1713 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1714
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001715 def test_helper_replace(self):
1716 @dataclass(frozen=True)
1717 class C:
1718 x: int
1719 y: int
1720
1721 c = C(1, 2)
1722 c1 = replace(c, x=3)
1723 self.assertEqual(c1.x, 3)
1724 self.assertEqual(c1.y, 2)
1725
1726 def test_helper_replace_frozen(self):
1727 @dataclass(frozen=True)
1728 class C:
1729 x: int
1730 y: int
1731 z: int = field(init=False, default=10)
1732 t: int = field(init=False, default=100)
1733
1734 c = C(1, 2)
1735 c1 = replace(c, x=3)
1736 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
1737 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
1738
1739
1740 with self.assertRaisesRegex(ValueError, 'init=False'):
1741 replace(c, x=3, z=20, t=50)
1742 with self.assertRaisesRegex(ValueError, 'init=False'):
1743 replace(c, z=20)
1744 replace(c, x=3, z=20, t=50)
1745
1746 # Make sure the result is still frozen.
1747 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
1748 c1.x = 3
1749
1750 # Make sure we can't replace an attribute that doesn't exist,
1751 # if we're also replacing one that does exist. Test this
1752 # here, because setting attributes on frozen instances is
1753 # handled slightly differently from non-frozen ones.
Eric V. Smith24e77f92017-12-06 14:00:34 -05001754 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001755 "keyword argument 'a'"):
1756 c1 = replace(c, x=20, a=5)
1757
1758 def test_helper_replace_invalid_field_name(self):
1759 @dataclass(frozen=True)
1760 class C:
1761 x: int
1762 y: int
1763
1764 c = C(1, 2)
Eric V. Smith24e77f92017-12-06 14:00:34 -05001765 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001766 "keyword argument 'z'"):
1767 c1 = replace(c, z=3)
1768
1769 def test_helper_replace_invalid_object(self):
1770 @dataclass(frozen=True)
1771 class C:
1772 x: int
1773 y: int
1774
1775 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1776 replace(C, x=3)
1777
1778 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1779 replace(0, x=3)
1780
1781 def test_helper_replace_no_init(self):
1782 @dataclass
1783 class C:
1784 x: int
1785 y: int = field(init=False, default=10)
1786
1787 c = C(1)
1788 c.y = 20
1789
1790 # Make sure y gets the default value.
1791 c1 = replace(c, x=5)
1792 self.assertEqual((c1.x, c1.y), (5, 10))
1793
1794 # Trying to replace y is an error.
1795 with self.assertRaisesRegex(ValueError, 'init=False'):
1796 replace(c, x=2, y=30)
1797 with self.assertRaisesRegex(ValueError, 'init=False'):
1798 replace(c, y=30)
1799
1800 def test_dataclassses_pickleable(self):
1801 global P, Q, R
1802 @dataclass
1803 class P:
1804 x: int
1805 y: int = 0
1806 @dataclass
1807 class Q:
1808 x: int
1809 y: int = field(default=0, init=False)
1810 @dataclass
1811 class R:
1812 x: int
1813 y: List[int] = field(default_factory=list)
1814 q = Q(1)
1815 q.y = 2
1816 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1817 for sample in samples:
1818 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1819 with self.subTest(sample=sample, proto=proto):
1820 new_sample = pickle.loads(pickle.dumps(sample, proto))
1821 self.assertEqual(sample.x, new_sample.x)
1822 self.assertEqual(sample.y, new_sample.y)
1823 self.assertIsNot(sample, new_sample)
1824 new_sample.x = 42
1825 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1826 self.assertEqual(new_sample.x, another_new_sample.x)
1827 self.assertEqual(sample.y, another_new_sample.y)
1828
Eric V. Smithea8fc522018-01-27 19:07:40 -05001829
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001830class TestFieldNoAnnotation(unittest.TestCase):
1831 def test_field_without_annotation(self):
1832 with self.assertRaisesRegex(TypeError,
1833 "'f' is a field but has no type annotation"):
1834 @dataclass
1835 class C:
1836 f = field()
1837
1838 def test_field_without_annotation_but_annotation_in_base(self):
1839 @dataclass
1840 class B:
1841 f: int
1842
1843 with self.assertRaisesRegex(TypeError,
1844 "'f' is a field but has no type annotation"):
1845 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001846 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001847 @dataclass
1848 class C(B):
1849 f = field()
1850
1851 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1852 # Same test, but with the base class not a dataclass.
1853 class B:
1854 f: int
1855
1856 with self.assertRaisesRegex(TypeError,
1857 "'f' is a field but has no type annotation"):
1858 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001859 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001860 @dataclass
1861 class C(B):
1862 f = field()
1863
1864
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001865class TestDocString(unittest.TestCase):
1866 def assertDocStrEqual(self, a, b):
1867 # Because 3.6 and 3.7 differ in how inspect.signature work
1868 # (see bpo #32108), for the time being just compare them with
1869 # whitespace stripped.
1870 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1871
1872 def test_existing_docstring_not_overridden(self):
1873 @dataclass
1874 class C:
1875 """Lorem ipsum"""
1876 x: int
1877
1878 self.assertEqual(C.__doc__, "Lorem ipsum")
1879
1880 def test_docstring_no_fields(self):
1881 @dataclass
1882 class C:
1883 pass
1884
1885 self.assertDocStrEqual(C.__doc__, "C()")
1886
1887 def test_docstring_one_field(self):
1888 @dataclass
1889 class C:
1890 x: int
1891
1892 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1893
1894 def test_docstring_two_fields(self):
1895 @dataclass
1896 class C:
1897 x: int
1898 y: int
1899
1900 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1901
1902 def test_docstring_three_fields(self):
1903 @dataclass
1904 class C:
1905 x: int
1906 y: int
1907 z: str
1908
1909 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1910
1911 def test_docstring_one_field_with_default(self):
1912 @dataclass
1913 class C:
1914 x: int = 3
1915
1916 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1917
1918 def test_docstring_one_field_with_default_none(self):
1919 @dataclass
1920 class C:
1921 x: Union[int, type(None)] = None
1922
1923 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1924
1925 def test_docstring_list_field(self):
1926 @dataclass
1927 class C:
1928 x: List[int]
1929
1930 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1931
1932 def test_docstring_list_field_with_default_factory(self):
1933 @dataclass
1934 class C:
1935 x: List[int] = field(default_factory=list)
1936
1937 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1938
1939 def test_docstring_deque_field(self):
1940 @dataclass
1941 class C:
1942 x: deque
1943
1944 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1945
1946 def test_docstring_deque_field_with_default_factory(self):
1947 @dataclass
1948 class C:
1949 x: deque = field(default_factory=deque)
1950
1951 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1952
1953
Eric V. Smithea8fc522018-01-27 19:07:40 -05001954class TestInit(unittest.TestCase):
1955 def test_base_has_init(self):
1956 class B:
1957 def __init__(self):
1958 self.z = 100
1959 pass
1960
1961 # Make sure that declaring this class doesn't raise an error.
1962 # The issue is that we can't override __init__ in our class,
1963 # but it should be okay to add __init__ to us if our base has
1964 # an __init__.
1965 @dataclass
1966 class C(B):
1967 x: int = 0
1968 c = C(10)
1969 self.assertEqual(c.x, 10)
1970 self.assertNotIn('z', vars(c))
1971
1972 # Make sure that if we don't add an init, the base __init__
1973 # gets called.
1974 @dataclass(init=False)
1975 class C(B):
1976 x: int = 10
1977 c = C()
1978 self.assertEqual(c.x, 10)
1979 self.assertEqual(c.z, 100)
1980
1981 def test_no_init(self):
1982 dataclass(init=False)
1983 class C:
1984 i: int = 0
1985 self.assertEqual(C().i, 0)
1986
1987 dataclass(init=False)
1988 class C:
1989 i: int = 2
1990 def __init__(self):
1991 self.i = 3
1992 self.assertEqual(C().i, 3)
1993
1994 def test_overwriting_init(self):
1995 # If the class has __init__, use it no matter the value of
1996 # init=.
1997
1998 @dataclass
1999 class C:
2000 x: int
2001 def __init__(self, x):
2002 self.x = 2 * x
2003 self.assertEqual(C(3).x, 6)
2004
2005 @dataclass(init=True)
2006 class C:
2007 x: int
2008 def __init__(self, x):
2009 self.x = 2 * x
2010 self.assertEqual(C(4).x, 8)
2011
2012 @dataclass(init=False)
2013 class C:
2014 x: int
2015 def __init__(self, x):
2016 self.x = 2 * x
2017 self.assertEqual(C(5).x, 10)
2018
2019
2020class TestRepr(unittest.TestCase):
2021 def test_repr(self):
2022 @dataclass
2023 class B:
2024 x: int
2025
2026 @dataclass
2027 class C(B):
2028 y: int = 10
2029
2030 o = C(4)
2031 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
2032
2033 @dataclass
2034 class D(C):
2035 x: int = 20
2036 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2037
2038 @dataclass
2039 class C:
2040 @dataclass
2041 class D:
2042 i: int
2043 @dataclass
2044 class E:
2045 pass
2046 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2047 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2048
2049 def test_no_repr(self):
2050 # Test a class with no __repr__ and repr=False.
2051 @dataclass(repr=False)
2052 class C:
2053 x: int
2054 self.assertIn('test_dataclasses.TestRepr.test_no_repr.<locals>.C object at',
2055 repr(C(3)))
2056
2057 # Test a class with a __repr__ and repr=False.
2058 @dataclass(repr=False)
2059 class C:
2060 x: int
2061 def __repr__(self):
2062 return 'C-class'
2063 self.assertEqual(repr(C(3)), 'C-class')
2064
2065 def test_overwriting_repr(self):
2066 # If the class has __repr__, use it no matter the value of
2067 # repr=.
2068
2069 @dataclass
2070 class C:
2071 x: int
2072 def __repr__(self):
2073 return 'x'
2074 self.assertEqual(repr(C(0)), 'x')
2075
2076 @dataclass(repr=True)
2077 class C:
2078 x: int
2079 def __repr__(self):
2080 return 'x'
2081 self.assertEqual(repr(C(0)), 'x')
2082
2083 @dataclass(repr=False)
2084 class C:
2085 x: int
2086 def __repr__(self):
2087 return 'x'
2088 self.assertEqual(repr(C(0)), 'x')
2089
2090
Eric V. Smithea8fc522018-01-27 19:07:40 -05002091class TestEq(unittest.TestCase):
2092 def test_no_eq(self):
2093 # Test a class with no __eq__ and eq=False.
2094 @dataclass(eq=False)
2095 class C:
2096 x: int
2097 self.assertNotEqual(C(0), C(0))
2098 c = C(3)
2099 self.assertEqual(c, c)
2100
2101 # Test a class with an __eq__ and eq=False.
2102 @dataclass(eq=False)
2103 class C:
2104 x: int
2105 def __eq__(self, other):
2106 return other == 10
2107 self.assertEqual(C(3), 10)
2108
2109 def test_overwriting_eq(self):
2110 # If the class has __eq__, use it no matter the value of
2111 # eq=.
2112
2113 @dataclass
2114 class C:
2115 x: int
2116 def __eq__(self, other):
2117 return other == 3
2118 self.assertEqual(C(1), 3)
2119 self.assertNotEqual(C(1), 1)
2120
2121 @dataclass(eq=True)
2122 class C:
2123 x: int
2124 def __eq__(self, other):
2125 return other == 4
2126 self.assertEqual(C(1), 4)
2127 self.assertNotEqual(C(1), 1)
2128
2129 @dataclass(eq=False)
2130 class C:
2131 x: int
2132 def __eq__(self, other):
2133 return other == 5
2134 self.assertEqual(C(1), 5)
2135 self.assertNotEqual(C(1), 1)
2136
2137
2138class TestOrdering(unittest.TestCase):
2139 def test_functools_total_ordering(self):
2140 # Test that functools.total_ordering works with this class.
2141 @total_ordering
2142 @dataclass
2143 class C:
2144 x: int
2145 def __lt__(self, other):
2146 # Perform the test "backward", just to make
2147 # sure this is being called.
2148 return self.x >= other
2149
2150 self.assertLess(C(0), -1)
2151 self.assertLessEqual(C(0), -1)
2152 self.assertGreater(C(0), 1)
2153 self.assertGreaterEqual(C(0), 1)
2154
2155 def test_no_order(self):
2156 # Test that no ordering functions are added by default.
2157 @dataclass(order=False)
2158 class C:
2159 x: int
2160 # Make sure no order methods are added.
2161 self.assertNotIn('__le__', C.__dict__)
2162 self.assertNotIn('__lt__', C.__dict__)
2163 self.assertNotIn('__ge__', C.__dict__)
2164 self.assertNotIn('__gt__', C.__dict__)
2165
2166 # Test that __lt__ is still called
2167 @dataclass(order=False)
2168 class C:
2169 x: int
2170 def __lt__(self, other):
2171 return False
2172 # Make sure other methods aren't added.
2173 self.assertNotIn('__le__', C.__dict__)
2174 self.assertNotIn('__ge__', C.__dict__)
2175 self.assertNotIn('__gt__', C.__dict__)
2176
2177 def test_overwriting_order(self):
2178 with self.assertRaisesRegex(TypeError,
2179 'Cannot overwrite attribute __lt__'
2180 '.*using functools.total_ordering'):
2181 @dataclass(order=True)
2182 class C:
2183 x: int
2184 def __lt__(self):
2185 pass
2186
2187 with self.assertRaisesRegex(TypeError,
2188 'Cannot overwrite attribute __le__'
2189 '.*using functools.total_ordering'):
2190 @dataclass(order=True)
2191 class C:
2192 x: int
2193 def __le__(self):
2194 pass
2195
2196 with self.assertRaisesRegex(TypeError,
2197 'Cannot overwrite attribute __gt__'
2198 '.*using functools.total_ordering'):
2199 @dataclass(order=True)
2200 class C:
2201 x: int
2202 def __gt__(self):
2203 pass
2204
2205 with self.assertRaisesRegex(TypeError,
2206 'Cannot overwrite attribute __ge__'
2207 '.*using functools.total_ordering'):
2208 @dataclass(order=True)
2209 class C:
2210 x: int
2211 def __ge__(self):
2212 pass
2213
2214class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002215 def test_unsafe_hash(self):
2216 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002217 class C:
2218 x: int
2219 y: str
2220 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2221
Eric V. Smithea8fc522018-01-27 19:07:40 -05002222 def test_hash_rules(self):
2223 def non_bool(value):
2224 # Map to something else that's True, but not a bool.
2225 if value is None:
2226 return None
2227 if value:
2228 return (3,)
2229 return 0
2230
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002231 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2232 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2233 frozen=frozen):
2234 if result != 'exception':
2235 if with_hash:
2236 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2237 class C:
2238 def __hash__(self):
2239 return 0
2240 else:
2241 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2242 class C:
2243 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002244
2245 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002246 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002247 # __hash__ contains the function we generated.
2248 self.assertIn('__hash__', C.__dict__)
2249 self.assertIsNotNone(C.__dict__['__hash__'])
2250
Eric V. Smithea8fc522018-01-27 19:07:40 -05002251 elif result == '':
2252 # __hash__ is not present in our class.
2253 if not with_hash:
2254 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002255
Eric V. Smithea8fc522018-01-27 19:07:40 -05002256 elif result == 'none':
2257 # __hash__ is set to None.
2258 self.assertIn('__hash__', C.__dict__)
2259 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002260
2261 elif result == 'exception':
2262 # Creating the class should cause an exception.
2263 # This only happens with with_hash==True.
2264 assert(with_hash)
2265 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2266 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2267 class C:
2268 def __hash__(self):
2269 return 0
2270
Eric V. Smithea8fc522018-01-27 19:07:40 -05002271 else:
2272 assert False, f'unknown result {result!r}'
2273
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002274 # There are 8 cases of:
2275 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002276 # eq=True/False
2277 # frozen=True/False
2278 # And for each of these, a different result if
2279 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002280 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2281 (False, False, False, '', ''),
2282 (False, False, True, '', ''),
2283 (False, True, False, 'none', ''),
2284 (False, True, True, 'fn', ''),
2285 (True, False, False, 'fn', 'exception'),
2286 (True, False, True, 'fn', 'exception'),
2287 (True, True, False, 'fn', 'exception'),
2288 (True, True, True, 'fn', 'exception'),
2289 ], 1):
2290 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2291 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002292
2293 # Test non-bool truth values, too. This is just to
2294 # make sure the data-driven table in the decorator
2295 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002296 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2297 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002298
2299
2300 def test_eq_only(self):
2301 # If a class defines __eq__, __hash__ is automatically added
2302 # and set to None. This is normal Python behavior, not
2303 # related to dataclasses. Make sure we don't interfere with
2304 # that (see bpo=32546).
2305
2306 @dataclass
2307 class C:
2308 i: int
2309 def __eq__(self, other):
2310 return self.i == other.i
2311 self.assertEqual(C(1), C(1))
2312 self.assertNotEqual(C(1), C(4))
2313
2314 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002315 # unsafe_hash=True.
2316 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002317 class C:
2318 i: int
2319 def __eq__(self, other):
2320 return self.i == other.i
2321 self.assertEqual(C(1), C(1.0))
2322 self.assertEqual(hash(C(1)), hash(C(1.0)))
2323
2324 # And check that the classes __eq__ is being used, despite
2325 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002326 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002327 class C:
2328 i: int
2329 def __eq__(self, other):
2330 return self.i == 3 and self.i == other.i
2331 self.assertEqual(C(3), C(3))
2332 self.assertNotEqual(C(1), C(1))
2333 self.assertEqual(hash(C(1)), hash(C(1.0)))
2334
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002335 def test_0_field_hash(self):
2336 @dataclass(frozen=True)
2337 class C:
2338 pass
2339 self.assertEqual(hash(C()), hash(()))
2340
2341 @dataclass(unsafe_hash=True)
2342 class C:
2343 pass
2344 self.assertEqual(hash(C()), hash(()))
2345
2346 def test_1_field_hash(self):
2347 @dataclass(frozen=True)
2348 class C:
2349 x: int
2350 self.assertEqual(hash(C(4)), hash((4,)))
2351 self.assertEqual(hash(C(42)), hash((42,)))
2352
2353 @dataclass(unsafe_hash=True)
2354 class C:
2355 x: int
2356 self.assertEqual(hash(C(4)), hash((4,)))
2357 self.assertEqual(hash(C(42)), hash((42,)))
2358
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002359 def test_hash_no_args(self):
2360 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002361 # make sure that if the @dataclass parameter name is changed
2362 # or the non-default hashing behavior changes, the default
2363 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002364
2365 class Base:
2366 def __hash__(self):
2367 return 301
2368
2369 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002370 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002371 for frozen, eq, base, expected in [
2372 (None, None, object, 'unhashable'),
2373 (None, None, Base, 'unhashable'),
2374 (None, False, object, 'object'),
2375 (None, False, Base, 'base'),
2376 (None, True, object, 'unhashable'),
2377 (None, True, Base, 'unhashable'),
2378 (False, None, object, 'unhashable'),
2379 (False, None, Base, 'unhashable'),
2380 (False, False, object, 'object'),
2381 (False, False, Base, 'base'),
2382 (False, True, object, 'unhashable'),
2383 (False, True, Base, 'unhashable'),
2384 (True, None, object, 'tuple'),
2385 (True, None, Base, 'tuple'),
2386 (True, False, object, 'object'),
2387 (True, False, Base, 'base'),
2388 (True, True, object, 'tuple'),
2389 (True, True, Base, 'tuple'),
2390 ]:
2391
2392 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2393 # First, create the class.
2394 if frozen is None and eq is None:
2395 @dataclass
2396 class C(base):
2397 i: int
2398 elif frozen is None:
2399 @dataclass(eq=eq)
2400 class C(base):
2401 i: int
2402 elif eq is None:
2403 @dataclass(frozen=frozen)
2404 class C(base):
2405 i: int
2406 else:
2407 @dataclass(frozen=frozen, eq=eq)
2408 class C(base):
2409 i: int
2410
2411 # Now, make sure it hashes as expected.
2412 if expected == 'unhashable':
2413 c = C(10)
2414 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2415 hash(c)
2416
2417 elif expected == 'base':
2418 self.assertEqual(hash(C(10)), 301)
2419
2420 elif expected == 'object':
2421 # I'm not sure what test to use here. object's
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002422 # hash isn't based on id(), so calling hash()
2423 # won't tell us much. So, just check the
2424 # function used is object's.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002425 self.assertIs(C.__hash__, object.__hash__)
2426
2427 elif expected == 'tuple':
2428 self.assertEqual(hash(C(42)), hash((42,)))
2429
2430 else:
2431 assert False, f'unknown value for expected={expected!r}'
2432
Eric V. Smithea8fc522018-01-27 19:07:40 -05002433
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002434class TestFrozen(unittest.TestCase):
2435 def test_frozen(self):
2436 @dataclass(frozen=True)
2437 class C:
2438 i: int
2439
2440 c = C(10)
2441 self.assertEqual(c.i, 10)
2442 with self.assertRaises(FrozenInstanceError):
2443 c.i = 5
2444 self.assertEqual(c.i, 10)
2445
2446 def test_inherit(self):
2447 @dataclass(frozen=True)
2448 class C:
2449 i: int
2450
2451 @dataclass(frozen=True)
2452 class D(C):
2453 j: int
2454
2455 d = D(0, 10)
2456 with self.assertRaises(FrozenInstanceError):
2457 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002458 with self.assertRaises(FrozenInstanceError):
2459 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002460 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002461 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002462
Miss Islington (bot)45648312018-03-18 18:03:36 -07002463 # Test both ways: with an intermediate normal (non-dataclass)
2464 # class and without an intermediate class.
2465 def test_inherit_nonfrozen_from_frozen(self):
2466 for intermediate_class in [True, False]:
2467 with self.subTest(intermediate_class=intermediate_class):
2468 @dataclass(frozen=True)
2469 class C:
2470 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002471
Miss Islington (bot)45648312018-03-18 18:03:36 -07002472 if intermediate_class:
2473 class I(C): pass
2474 else:
2475 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002476
Miss Islington (bot)45648312018-03-18 18:03:36 -07002477 with self.assertRaisesRegex(TypeError,
2478 'cannot inherit non-frozen dataclass from a frozen one'):
2479 @dataclass
2480 class D(I):
2481 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002482
Miss Islington (bot)45648312018-03-18 18:03:36 -07002483 def test_inherit_frozen_from_nonfrozen(self):
2484 for intermediate_class in [True, False]:
2485 with self.subTest(intermediate_class=intermediate_class):
2486 @dataclass
2487 class C:
2488 i: int
2489
2490 if intermediate_class:
2491 class I(C): pass
2492 else:
2493 I = C
2494
2495 with self.assertRaisesRegex(TypeError,
2496 'cannot inherit frozen dataclass from a non-frozen one'):
2497 @dataclass(frozen=True)
2498 class D(I):
2499 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002500
2501 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002502 for intermediate_class in [True, False]:
2503 with self.subTest(intermediate_class=intermediate_class):
2504 class C:
2505 pass
2506
2507 if intermediate_class:
2508 class I(C): pass
2509 else:
2510 I = C
2511
2512 @dataclass(frozen=True)
2513 class D(I):
2514 i: int
2515
2516 d = D(10)
2517 with self.assertRaises(FrozenInstanceError):
2518 d.i = 5
2519
2520 def test_non_frozen_normal_derived(self):
2521 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002522
2523 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002524 class D:
2525 x: int
2526 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002527
Miss Islington (bot)45648312018-03-18 18:03:36 -07002528 class S(D):
2529 pass
2530
2531 s = S(3)
2532 self.assertEqual(s.x, 3)
2533 self.assertEqual(s.y, 10)
2534 s.cached = True
2535
2536 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002537 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002538 s.x = 5
2539 with self.assertRaises(FrozenInstanceError):
2540 s.y = 5
2541 self.assertEqual(s.x, 3)
2542 self.assertEqual(s.y, 10)
2543 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002544
Miss Islington (bot)83f564f2018-04-05 04:12:31 -07002545 def test_overwriting_frozen(self):
2546 # frozen uses __setattr__ and __delattr__.
2547 with self.assertRaisesRegex(TypeError,
2548 'Cannot overwrite attribute __setattr__'):
2549 @dataclass(frozen=True)
2550 class C:
2551 x: int
2552 def __setattr__(self):
2553 pass
2554
2555 with self.assertRaisesRegex(TypeError,
2556 'Cannot overwrite attribute __delattr__'):
2557 @dataclass(frozen=True)
2558 class C:
2559 x: int
2560 def __delattr__(self):
2561 pass
2562
2563 @dataclass(frozen=False)
2564 class C:
2565 x: int
2566 def __setattr__(self, name, value):
2567 self.__dict__['x'] = value * 2
2568 self.assertEqual(C(10).x, 20)
2569
2570 def test_frozen_hash(self):
2571 @dataclass(frozen=True)
2572 class C:
2573 x: Any
2574
2575 # If x is immutable, we can compute the hash. No exception is
2576 # raised.
2577 hash(C(3))
2578
2579 # If x is mutable, computing the hash is an error.
2580 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2581 hash(C({}))
2582
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002583
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002584class TestSlots(unittest.TestCase):
2585 def test_simple(self):
2586 @dataclass
2587 class C:
2588 __slots__ = ('x',)
2589 x: Any
2590
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002591 # There was a bug where a variable in a slot was assumed to
2592 # also have a default value (of type
2593 # types.MemberDescriptorType).
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002594 with self.assertRaisesRegex(TypeError,
Miss Islington (bot)5729b9c2018-03-24 20:23:00 -07002595 r"__init__\(\) missing 1 required positional argument: 'x'"):
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002596 C()
2597
2598 # We can create an instance, and assign to x.
2599 c = C(10)
2600 self.assertEqual(c.x, 10)
2601 c.x = 5
2602 self.assertEqual(c.x, 5)
2603
2604 # We can't assign to anything else.
2605 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2606 c.y = 5
2607
2608 def test_derived_added_field(self):
2609 # See bpo-33100.
2610 @dataclass
2611 class Base:
2612 __slots__ = ('x',)
2613 x: Any
2614
2615 @dataclass
2616 class Derived(Base):
2617 x: int
2618 y: int
2619
2620 d = Derived(1, 2)
2621 self.assertEqual((d.x, d.y), (1, 2))
2622
2623 # We can add a new field to the derived instance.
2624 d.z = 10
2625
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002626class TestDescriptors(unittest.TestCase):
2627 def test_set_name(self):
2628 # See bpo-33141.
2629
2630 # Create a descriptor.
2631 class D:
2632 def __set_name__(self, owner, name):
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002633 self.name = name + 'x'
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002634 def __get__(self, instance, owner):
2635 if instance is not None:
2636 return 1
2637 return self
2638
2639 # This is the case of just normal descriptor behavior, no
2640 # dataclass code is involved in initializing the descriptor.
2641 @dataclass
2642 class C:
2643 c: int=D()
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002644 self.assertEqual(C.c.name, 'cx')
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002645
2646 # Now test with a default value and init=False, which is the
2647 # only time this is really meaningful. If not using
2648 # init=False, then the descriptor will be overwritten, anyway.
2649 @dataclass
2650 class C:
2651 c: int=field(default=D(), init=False)
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002652 self.assertEqual(C.c.name, 'cx')
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002653 self.assertEqual(C().c, 1)
2654
2655 def test_non_descriptor(self):
2656 # PEP 487 says __set_name__ should work on non-descriptors.
2657 # Create a descriptor.
2658
2659 class D:
2660 def __set_name__(self, owner, name):
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002661 self.name = name + 'x'
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002662
2663 @dataclass
2664 class C:
2665 c: int=field(default=D(), init=False)
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002666 self.assertEqual(C.c.name, 'cx')
2667
2668 def test_lookup_on_instance(self):
2669 # See bpo-33175.
2670 class D:
2671 pass
2672
2673 d = D()
2674 # Create an attribute on the instance, not type.
2675 d.__set_name__ = Mock()
2676
2677 # Make sure d.__set_name__ is not called.
2678 @dataclass
2679 class C:
2680 i: int=field(default=d, init=False)
2681
2682 self.assertEqual(d.__set_name__.call_count, 0)
2683
2684 def test_lookup_on_class(self):
2685 # See bpo-33175.
2686 class D:
2687 pass
2688 D.__set_name__ = Mock()
2689
2690 # Make sure D.__set_name__ is called.
2691 @dataclass
2692 class C:
2693 i: int=field(default=D(), init=False)
2694
2695 self.assertEqual(D.__set_name__.call_count, 1)
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002696
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002697
Miss Islington (bot)c73268a2018-05-15 21:22:13 -07002698class TestStringAnnotations(unittest.TestCase):
2699 def test_classvar(self):
2700 # Some expressions recognized as ClassVar really aren't. But
2701 # if you're using string annotations, it's not an exact
2702 # science.
2703 # These tests assume that both "import typing" and "from
2704 # typing import *" have been run in this file.
2705 for typestr in ('ClassVar[int]',
2706 'ClassVar [int]'
2707 ' ClassVar [int]',
2708 'ClassVar',
2709 ' ClassVar ',
2710 'typing.ClassVar[int]',
2711 'typing.ClassVar[str]',
2712 ' typing.ClassVar[str]',
2713 'typing .ClassVar[str]',
2714 'typing. ClassVar[str]',
2715 'typing.ClassVar [str]',
2716 'typing.ClassVar [ str]',
2717
2718 # Not syntactically valid, but these will
2719 # be treated as ClassVars.
2720 'typing.ClassVar.[int]',
2721 'typing.ClassVar+',
2722 ):
2723 with self.subTest(typestr=typestr):
2724 @dataclass
2725 class C:
2726 x: typestr
2727
2728 # x is a ClassVar, so C() takes no args.
2729 C()
2730
2731 # And it won't appear in the class's dict because it doesn't
2732 # have a default.
2733 self.assertNotIn('x', C.__dict__)
2734
2735 def test_isnt_classvar(self):
2736 for typestr in ('CV',
2737 't.ClassVar',
2738 't.ClassVar[int]',
2739 'typing..ClassVar[int]',
2740 'Classvar',
2741 'Classvar[int]',
2742 'typing.ClassVarx[int]',
2743 'typong.ClassVar[int]',
2744 'dataclasses.ClassVar[int]',
2745 'typingxClassVar[str]',
2746 ):
2747 with self.subTest(typestr=typestr):
2748 @dataclass
2749 class C:
2750 x: typestr
2751
2752 # x is not a ClassVar, so C() takes one arg.
2753 self.assertEqual(C(10).x, 10)
2754
2755 def test_initvar(self):
2756 # These tests assume that both "import dataclasses" and "from
2757 # dataclasses import *" have been run in this file.
2758 for typestr in ('InitVar[int]',
2759 'InitVar [int]'
2760 ' InitVar [int]',
2761 'InitVar',
2762 ' InitVar ',
2763 'dataclasses.InitVar[int]',
2764 'dataclasses.InitVar[str]',
2765 ' dataclasses.InitVar[str]',
2766 'dataclasses .InitVar[str]',
2767 'dataclasses. InitVar[str]',
2768 'dataclasses.InitVar [str]',
2769 'dataclasses.InitVar [ str]',
2770
2771 # Not syntactically valid, but these will
2772 # be treated as InitVars.
2773 'dataclasses.InitVar.[int]',
2774 'dataclasses.InitVar+',
2775 ):
2776 with self.subTest(typestr=typestr):
2777 @dataclass
2778 class C:
2779 x: typestr
2780
2781 # x is an InitVar, so doesn't create a member.
2782 with self.assertRaisesRegex(AttributeError,
2783 "object has no attribute 'x'"):
2784 C(1).x
2785
2786 def test_isnt_initvar(self):
2787 for typestr in ('IV',
2788 'dc.InitVar',
2789 'xdataclasses.xInitVar',
2790 'typing.xInitVar[int]',
2791 ):
2792 with self.subTest(typestr=typestr):
2793 @dataclass
2794 class C:
2795 x: typestr
2796
2797 # x is not an InitVar, so there will be a member x.
2798 self.assertEqual(C(10).x, 10)
2799
2800 def test_classvar_module_level_import(self):
2801 from . import dataclass_module_1
2802 from . import dataclass_module_1_str
2803 from . import dataclass_module_2
2804 from . import dataclass_module_2_str
2805
2806 for m in (dataclass_module_1, dataclass_module_1_str,
2807 dataclass_module_2, dataclass_module_2_str,
2808 ):
2809 with self.subTest(m=m):
2810 # There's a difference in how the ClassVars are
2811 # interpreted when using string annotations or
2812 # not. See the imported modules for details.
2813 if m.USING_STRINGS:
2814 c = m.CV(10)
2815 else:
2816 c = m.CV()
2817 self.assertEqual(c.cv0, 20)
2818
2819
2820 # There's a difference in how the InitVars are
2821 # interpreted when using string annotations or
2822 # not. See the imported modules for details.
2823 c = m.IV(0, 1, 2, 3, 4)
2824
2825 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2826 with self.subTest(field_name=field_name):
2827 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2828 # Since field_name is an InitVar, it's
2829 # not an instance field.
2830 getattr(c, field_name)
2831
2832 if m.USING_STRINGS:
2833 # iv4 is interpreted as a normal field.
2834 self.assertIn('not_iv4', c.__dict__)
2835 self.assertEqual(c.not_iv4, 4)
2836 else:
2837 # iv4 is interpreted as an InitVar, so it
2838 # won't exist on the instance.
2839 self.assertNotIn('not_iv4', c.__dict__)
2840
2841
Miss Islington (bot)6409e752018-05-16 09:28:22 -07002842class TestMakeDataclass(unittest.TestCase):
2843 def test_simple(self):
2844 C = make_dataclass('C',
2845 [('x', int),
2846 ('y', int, field(default=5))],
2847 namespace={'add_one': lambda self: self.x + 1})
2848 c = C(10)
2849 self.assertEqual((c.x, c.y), (10, 5))
2850 self.assertEqual(c.add_one(), 11)
2851
2852
2853 def test_no_mutate_namespace(self):
2854 # Make sure a provided namespace isn't mutated.
2855 ns = {}
2856 C = make_dataclass('C',
2857 [('x', int),
2858 ('y', int, field(default=5))],
2859 namespace=ns)
2860 self.assertEqual(ns, {})
2861
2862 def test_base(self):
2863 class Base1:
2864 pass
2865 class Base2:
2866 pass
2867 C = make_dataclass('C',
2868 [('x', int)],
2869 bases=(Base1, Base2))
2870 c = C(2)
2871 self.assertIsInstance(c, C)
2872 self.assertIsInstance(c, Base1)
2873 self.assertIsInstance(c, Base2)
2874
2875 def test_base_dataclass(self):
2876 @dataclass
2877 class Base1:
2878 x: int
2879 class Base2:
2880 pass
2881 C = make_dataclass('C',
2882 [('y', int)],
2883 bases=(Base1, Base2))
2884 with self.assertRaisesRegex(TypeError, 'required positional'):
2885 c = C(2)
2886 c = C(1, 2)
2887 self.assertIsInstance(c, C)
2888 self.assertIsInstance(c, Base1)
2889 self.assertIsInstance(c, Base2)
2890
2891 self.assertEqual((c.x, c.y), (1, 2))
2892
2893 def test_init_var(self):
2894 def post_init(self, y):
2895 self.x *= y
2896
2897 C = make_dataclass('C',
2898 [('x', int),
2899 ('y', InitVar[int]),
2900 ],
2901 namespace={'__post_init__': post_init},
2902 )
2903 c = C(2, 3)
2904 self.assertEqual(vars(c), {'x': 6})
2905 self.assertEqual(len(fields(c)), 1)
2906
2907 def test_class_var(self):
2908 C = make_dataclass('C',
2909 [('x', int),
2910 ('y', ClassVar[int], 10),
2911 ('z', ClassVar[int], field(default=20)),
2912 ])
2913 c = C(1)
2914 self.assertEqual(vars(c), {'x': 1})
2915 self.assertEqual(len(fields(c)), 1)
2916 self.assertEqual(C.y, 10)
2917 self.assertEqual(C.z, 20)
2918
2919 def test_other_params(self):
2920 C = make_dataclass('C',
2921 [('x', int),
2922 ('y', ClassVar[int], 10),
2923 ('z', ClassVar[int], field(default=20)),
2924 ],
2925 init=False)
2926 # Make sure we have a repr, but no init.
2927 self.assertNotIn('__init__', vars(C))
2928 self.assertIn('__repr__', vars(C))
2929
2930 # Make sure random other params don't work.
2931 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2932 C = make_dataclass('C',
2933 [],
2934 xxinit=False)
2935
2936 def test_no_types(self):
2937 C = make_dataclass('Point', ['x', 'y', 'z'])
2938 c = C(1, 2, 3)
2939 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2940 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2941 'y': 'typing.Any',
2942 'z': 'typing.Any'})
2943
2944 C = make_dataclass('Point', ['x', ('y', int), 'z'])
2945 c = C(1, 2, 3)
2946 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2947 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2948 'y': int,
2949 'z': 'typing.Any'})
2950
2951 def test_invalid_type_specification(self):
2952 for bad_field in [(),
2953 (1, 2, 3, 4),
2954 ]:
2955 with self.subTest(bad_field=bad_field):
2956 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
2957 make_dataclass('C', ['a', bad_field])
2958
2959 # And test for things with no len().
2960 for bad_field in [float,
2961 lambda x:x,
2962 ]:
2963 with self.subTest(bad_field=bad_field):
2964 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
2965 make_dataclass('C', ['a', bad_field])
2966
2967 def test_duplicate_field_names(self):
2968 for field in ['a', 'ab']:
2969 with self.subTest(field=field):
2970 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
2971 make_dataclass('C', [field, 'a', field])
2972
2973 def test_keyword_field_names(self):
2974 for field in ['for', 'async', 'await', 'as']:
2975 with self.subTest(field=field):
2976 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2977 make_dataclass('C', ['a', field])
2978 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2979 make_dataclass('C', [field])
2980 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2981 make_dataclass('C', [field, 'a'])
2982
2983 def test_non_identifier_field_names(self):
2984 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
2985 with self.subTest(field=field):
2986 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2987 make_dataclass('C', ['a', field])
2988 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2989 make_dataclass('C', [field])
2990 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2991 make_dataclass('C', [field, 'a'])
2992
2993 def test_underscore_field_names(self):
2994 # Unlike namedtuple, it's okay if dataclass field names have
2995 # an underscore.
2996 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
2997
2998 def test_funny_class_names_names(self):
2999 # No reason to prevent weird class names, since
3000 # types.new_class allows them.
3001 for classname in ['()', 'x,y', '*', '2@3', '']:
3002 with self.subTest(classname=classname):
3003 C = make_dataclass(classname, ['a', 'b'])
3004 self.assertEqual(C.__name__, classname)
3005
3006
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003007if __name__ == '__main__':
3008 unittest.main()