blob: 4c93513956a2d5c1ad0732e2ec8c4422a7a03035 [file] [log] [blame]
Miss Islington (bot)4ddc99d2018-03-21 14:44:23 -07001# Deliberately use "from dataclasses import *". Every name in __all__
2# is tested, so they all must be present. This is a way to catch
3# missing ones.
4
5from dataclasses import *
Eric V. Smithf0db54a2017-12-04 16:58:55 -05006
7import pickle
8import inspect
Miss Islington (bot)32e58fc2018-08-12 20:32:44 -07009import builtins
Eric V. Smithf0db54a2017-12-04 16:58:55 -050010import unittest
11from unittest.mock import Mock
Miss Islington (bot)d063ad82018-04-01 04:33:13 -070012from typing import ClassVar, Any, List, Union, Tuple, Dict, Generic, TypeVar, Optional
Eric V. Smithf0db54a2017-12-04 16:58:55 -050013from collections import deque, OrderedDict, namedtuple
Eric V. Smithea8fc522018-01-27 19:07:40 -050014from functools import total_ordering
Eric V. Smithf0db54a2017-12-04 16:58:55 -050015
Miss Islington (bot)c73268a2018-05-15 21:22:13 -070016import typing # Needed for the string "typing.ClassVar[int]" to work as an annotation.
17import dataclasses # Needed for the string "dataclasses.InitVar[int]" to work as an annotation.
18
Eric V. Smithf0db54a2017-12-04 16:58:55 -050019# Just any custom exception we can catch.
20class CustomError(Exception): pass
21
22class TestCase(unittest.TestCase):
23 def test_no_fields(self):
24 @dataclass
25 class C:
26 pass
27
28 o = C()
29 self.assertEqual(len(fields(C)), 0)
30
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -070031 def test_no_fields_but_member_variable(self):
32 @dataclass
33 class C:
34 i = 0
35
36 o = C()
37 self.assertEqual(len(fields(C)), 0)
38
Eric V. Smithf0db54a2017-12-04 16:58:55 -050039 def test_one_field_no_default(self):
40 @dataclass
41 class C:
42 x: int
43
44 o = C(42)
45 self.assertEqual(o.x, 42)
46
47 def test_named_init_params(self):
48 @dataclass
49 class C:
50 x: int
51
52 o = C(x=32)
53 self.assertEqual(o.x, 32)
54
55 def test_two_fields_one_default(self):
56 @dataclass
57 class C:
58 x: int
59 y: int = 0
60
61 o = C(3)
62 self.assertEqual((o.x, o.y), (3, 0))
63
64 # Non-defaults following defaults.
65 with self.assertRaisesRegex(TypeError,
66 "non-default argument 'y' follows "
67 "default argument"):
68 @dataclass
69 class C:
70 x: int = 0
71 y: int
72
73 # A derived class adds a non-default field after a default one.
74 with self.assertRaisesRegex(TypeError,
75 "non-default argument 'y' follows "
76 "default argument"):
77 @dataclass
78 class B:
79 x: int = 0
80
81 @dataclass
82 class C(B):
83 y: int
84
85 # Override a base class field and add a default to
86 # a field which didn't use to have a default.
87 with self.assertRaisesRegex(TypeError,
88 "non-default argument 'y' follows "
89 "default argument"):
90 @dataclass
91 class B:
92 x: int
93 y: int
94
95 @dataclass
96 class C(B):
97 x: int = 0
98
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -080099 def test_overwrite_hash(self):
100 # Test that declaring this class isn't an error. It should
101 # use the user-provided __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500102 @dataclass(frozen=True)
103 class C:
104 x: int
105 def __hash__(self):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800106 return 301
107 self.assertEqual(hash(C(100)), 301)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500108
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800109 # Test that declaring this class isn't an error. It should
110 # use the generated __hash__.
Eric V. Smithea8fc522018-01-27 19:07:40 -0500111 @dataclass(frozen=True)
112 class C:
113 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800114 def __eq__(self, other):
115 return False
116 self.assertEqual(hash(C(100)), hash((100,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500117
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800118 # But this one should generate an exception, because with
119 # unsafe_hash=True, it's an error to have a __hash__ defined.
120 with self.assertRaisesRegex(TypeError,
121 'Cannot overwrite attribute __hash__'):
122 @dataclass(unsafe_hash=True)
123 class C:
124 def __hash__(self):
125 pass
126
127 # Creating this class should not generate an exception,
128 # because even though __hash__ exists before @dataclass is
129 # called, (due to __eq__ being defined), since it's None
130 # that's okay.
131 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500132 class C:
133 x: int
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800134 def __eq__(self):
135 pass
136 # The generated hash function works as we'd expect.
137 self.assertEqual(hash(C(10)), hash((10,)))
138
139 # Creating this class should generate an exception, because
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700140 # __hash__ exists and is not None, which it would be if it
141 # had been auto-generated due to __eq__ being defined.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800142 with self.assertRaisesRegex(TypeError,
143 'Cannot overwrite attribute __hash__'):
144 @dataclass(unsafe_hash=True)
145 class C:
146 x: int
147 def __eq__(self):
148 pass
149 def __hash__(self):
150 pass
151
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500152 def test_overwrite_fields_in_derived_class(self):
153 # Note that x from C1 replaces x in Base, but the order remains
154 # the same as defined in Base.
155 @dataclass
156 class Base:
157 x: Any = 15.0
158 y: int = 0
159
160 @dataclass
161 class C1(Base):
162 z: int = 10
163 x: int = 15
164
165 o = Base()
166 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.Base(x=15.0, y=0)')
167
168 o = C1()
169 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=15, y=0, z=10)')
170
171 o = C1(x=5)
172 self.assertEqual(repr(o), 'TestCase.test_overwrite_fields_in_derived_class.<locals>.C1(x=5, y=0, z=10)')
173
174 def test_field_named_self(self):
175 @dataclass
176 class C:
177 self: str
178 c=C('foo')
179 self.assertEqual(c.self, 'foo')
180
181 # Make sure the first parameter is not named 'self'.
182 sig = inspect.signature(C.__init__)
183 first = next(iter(sig.parameters))
184 self.assertNotEqual('self', first)
185
186 # But we do use 'self' if no field named self.
187 @dataclass
188 class C:
189 selfx: str
190
191 # Make sure the first parameter is named 'self'.
192 sig = inspect.signature(C.__init__)
193 first = next(iter(sig.parameters))
194 self.assertEqual('self', first)
195
Miss Islington (bot)32e58fc2018-08-12 20:32:44 -0700196 def test_field_named_object(self):
197 @dataclass
198 class C:
199 object: str
200 c = C('foo')
201 self.assertEqual(c.object, 'foo')
202
203 def test_field_named_object_frozen(self):
204 @dataclass(frozen=True)
205 class C:
206 object: str
207 c = C('foo')
208 self.assertEqual(c.object, 'foo')
209
210 def test_field_named_like_builtin(self):
211 # Attribute names can shadow built-in names
212 # since code generation is used.
213 # Ensure that this is not happening.
214 exclusions = {'None', 'True', 'False'}
215 builtins_names = sorted(
216 b for b in builtins.__dict__.keys()
217 if not b.startswith('__') and b not in exclusions
218 )
219 attributes = [(name, str) for name in builtins_names]
220 C = make_dataclass('C', attributes)
221
222 c = C(*[name for name in builtins_names])
223
224 for name in builtins_names:
225 self.assertEqual(getattr(c, name), name)
226
227 def test_field_named_like_builtin_frozen(self):
228 # Attribute names can shadow built-in names
229 # since code generation is used.
230 # Ensure that this is not happening
231 # for frozen data classes.
232 exclusions = {'None', 'True', 'False'}
233 builtins_names = sorted(
234 b for b in builtins.__dict__.keys()
235 if not b.startswith('__') and b not in exclusions
236 )
237 attributes = [(name, str) for name in builtins_names]
238 C = make_dataclass('C', attributes, frozen=True)
239
240 c = C(*[name for name in builtins_names])
241
242 for name in builtins_names:
243 self.assertEqual(getattr(c, name), name)
244
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500245 def test_0_field_compare(self):
246 # Ensure that order=False is the default.
247 @dataclass
248 class C0:
249 pass
250
251 @dataclass(order=False)
252 class C1:
253 pass
254
255 for cls in [C0, C1]:
256 with self.subTest(cls=cls):
257 self.assertEqual(cls(), cls())
258 for idx, fn in enumerate([lambda a, b: a < b,
259 lambda a, b: a <= b,
260 lambda a, b: a > b,
261 lambda a, b: a >= b]):
262 with self.subTest(idx=idx):
263 with self.assertRaisesRegex(TypeError,
264 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
265 fn(cls(), cls())
266
267 @dataclass(order=True)
268 class C:
269 pass
270 self.assertLessEqual(C(), C())
271 self.assertGreaterEqual(C(), C())
272
273 def test_1_field_compare(self):
274 # Ensure that order=False is the default.
275 @dataclass
276 class C0:
277 x: int
278
279 @dataclass(order=False)
280 class C1:
281 x: int
282
283 for cls in [C0, C1]:
284 with self.subTest(cls=cls):
285 self.assertEqual(cls(1), cls(1))
286 self.assertNotEqual(cls(0), cls(1))
287 for idx, fn in enumerate([lambda a, b: a < b,
288 lambda a, b: a <= b,
289 lambda a, b: a > b,
290 lambda a, b: a >= b]):
291 with self.subTest(idx=idx):
292 with self.assertRaisesRegex(TypeError,
293 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
294 fn(cls(0), cls(0))
295
296 @dataclass(order=True)
297 class C:
298 x: int
299 self.assertLess(C(0), C(1))
300 self.assertLessEqual(C(0), C(1))
301 self.assertLessEqual(C(1), C(1))
302 self.assertGreater(C(1), C(0))
303 self.assertGreaterEqual(C(1), C(0))
304 self.assertGreaterEqual(C(1), C(1))
305
306 def test_simple_compare(self):
307 # Ensure that order=False is the default.
308 @dataclass
309 class C0:
310 x: int
311 y: int
312
313 @dataclass(order=False)
314 class C1:
315 x: int
316 y: int
317
318 for cls in [C0, C1]:
319 with self.subTest(cls=cls):
320 self.assertEqual(cls(0, 0), cls(0, 0))
321 self.assertEqual(cls(1, 2), cls(1, 2))
322 self.assertNotEqual(cls(1, 0), cls(0, 0))
323 self.assertNotEqual(cls(1, 0), cls(1, 1))
324 for idx, fn in enumerate([lambda a, b: a < b,
325 lambda a, b: a <= b,
326 lambda a, b: a > b,
327 lambda a, b: a >= b]):
328 with self.subTest(idx=idx):
329 with self.assertRaisesRegex(TypeError,
330 f"not supported between instances of '{cls.__name__}' and '{cls.__name__}'"):
331 fn(cls(0, 0), cls(0, 0))
332
333 @dataclass(order=True)
334 class C:
335 x: int
336 y: int
337
338 for idx, fn in enumerate([lambda a, b: a == b,
339 lambda a, b: a <= b,
340 lambda a, b: a >= b]):
341 with self.subTest(idx=idx):
342 self.assertTrue(fn(C(0, 0), C(0, 0)))
343
344 for idx, fn in enumerate([lambda a, b: a < b,
345 lambda a, b: a <= b,
346 lambda a, b: a != b]):
347 with self.subTest(idx=idx):
348 self.assertTrue(fn(C(0, 0), C(0, 1)))
349 self.assertTrue(fn(C(0, 1), C(1, 0)))
350 self.assertTrue(fn(C(1, 0), C(1, 1)))
351
352 for idx, fn in enumerate([lambda a, b: a > b,
353 lambda a, b: a >= b,
354 lambda a, b: a != b]):
355 with self.subTest(idx=idx):
356 self.assertTrue(fn(C(0, 1), C(0, 0)))
357 self.assertTrue(fn(C(1, 0), C(0, 1)))
358 self.assertTrue(fn(C(1, 1), C(1, 0)))
359
360 def test_compare_subclasses(self):
361 # Comparisons fail for subclasses, even if no fields
362 # are added.
363 @dataclass
364 class B:
365 i: int
366
367 @dataclass
368 class C(B):
369 pass
370
371 for idx, (fn, expected) in enumerate([(lambda a, b: a == b, False),
372 (lambda a, b: a != b, True)]):
373 with self.subTest(idx=idx):
374 self.assertEqual(fn(B(0), C(0)), expected)
375
376 for idx, fn in enumerate([lambda a, b: a < b,
377 lambda a, b: a <= b,
378 lambda a, b: a > b,
379 lambda a, b: a >= b]):
380 with self.subTest(idx=idx):
381 with self.assertRaisesRegex(TypeError,
382 "not supported between instances of 'B' and 'C'"):
383 fn(B(0), C(0))
384
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500385 def test_eq_order(self):
Eric V. Smithea8fc522018-01-27 19:07:40 -0500386 # Test combining eq and order.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500387 for (eq, order, result ) in [
388 (False, False, 'neither'),
389 (False, True, 'exception'),
390 (True, False, 'eq_only'),
391 (True, True, 'both'),
392 ]:
393 with self.subTest(eq=eq, order=order):
394 if result == 'exception':
395 with self.assertRaisesRegex(ValueError, 'eq must be true if order is true'):
396 @dataclass(eq=eq, order=order)
397 class C:
398 pass
399 else:
400 @dataclass(eq=eq, order=order)
401 class C:
402 pass
403
404 if result == 'neither':
405 self.assertNotIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500406 self.assertNotIn('__lt__', C.__dict__)
407 self.assertNotIn('__le__', C.__dict__)
408 self.assertNotIn('__gt__', C.__dict__)
409 self.assertNotIn('__ge__', C.__dict__)
410 elif result == 'both':
411 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500412 self.assertIn('__lt__', C.__dict__)
413 self.assertIn('__le__', C.__dict__)
414 self.assertIn('__gt__', C.__dict__)
415 self.assertIn('__ge__', C.__dict__)
416 elif result == 'eq_only':
417 self.assertIn('__eq__', C.__dict__)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500418 self.assertNotIn('__lt__', C.__dict__)
419 self.assertNotIn('__le__', C.__dict__)
420 self.assertNotIn('__gt__', C.__dict__)
421 self.assertNotIn('__ge__', C.__dict__)
422 else:
423 assert False, f'unknown result {result!r}'
424
425 def test_field_no_default(self):
426 @dataclass
427 class C:
428 x: int = field()
429
430 self.assertEqual(C(5).x, 5)
431
432 with self.assertRaisesRegex(TypeError,
433 r"__init__\(\) missing 1 required "
434 "positional argument: 'x'"):
435 C()
436
437 def test_field_default(self):
438 default = object()
439 @dataclass
440 class C:
441 x: object = field(default=default)
442
443 self.assertIs(C.x, default)
444 c = C(10)
445 self.assertEqual(c.x, 10)
446
447 # If we delete the instance attribute, we should then see the
448 # class attribute.
449 del c.x
450 self.assertIs(c.x, default)
451
452 self.assertIs(C().x, default)
453
454 def test_not_in_repr(self):
455 @dataclass
456 class C:
457 x: int = field(repr=False)
458 with self.assertRaises(TypeError):
459 C()
460 c = C(10)
461 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C()')
462
463 @dataclass
464 class C:
465 x: int = field(repr=False)
466 y: int
467 c = C(10, 20)
468 self.assertEqual(repr(c), 'TestCase.test_not_in_repr.<locals>.C(y=20)')
469
470 def test_not_in_compare(self):
471 @dataclass
472 class C:
473 x: int = 0
474 y: int = field(compare=False, default=4)
475
476 self.assertEqual(C(), C(0, 20))
477 self.assertEqual(C(1, 10), C(1, 20))
478 self.assertNotEqual(C(3), C(4, 10))
479 self.assertNotEqual(C(3, 10), C(4, 10))
480
481 def test_hash_field_rules(self):
482 # Test all 6 cases of:
483 # hash=True/False/None
484 # compare=True/False
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800485 for (hash_, compare, result ) in [
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500486 (True, False, 'field' ),
487 (True, True, 'field' ),
488 (False, False, 'absent'),
489 (False, True, 'absent'),
490 (None, False, 'absent'),
491 (None, True, 'field' ),
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800492 ]:
493 with self.subTest(hash=hash_, compare=compare):
494 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500495 class C:
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800496 x: int = field(compare=compare, hash=hash_, default=5)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500497
498 if result == 'field':
499 # __hash__ contains the field.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800500 self.assertEqual(hash(C(5)), hash((5,)))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500501 elif result == 'absent':
502 # The field is not present in the hash.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800503 self.assertEqual(hash(C(5)), hash(()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500504 else:
505 assert False, f'unknown result {result!r}'
506
507 def test_init_false_no_default(self):
508 # If init=False and no default value, then the field won't be
509 # present in the instance.
510 @dataclass
511 class C:
512 x: int = field(init=False)
513
514 self.assertNotIn('x', C().__dict__)
515
516 @dataclass
517 class C:
518 x: int
519 y: int = 0
520 z: int = field(init=False)
521 t: int = 10
522
523 self.assertNotIn('z', C(0).__dict__)
524 self.assertEqual(vars(C(5)), {'t': 10, 'x': 5, 'y': 0})
525
526 def test_class_marker(self):
527 @dataclass
528 class C:
529 x: int
530 y: str = field(init=False, default=None)
531 z: str = field(repr=False)
532
533 the_fields = fields(C)
534 # the_fields is a tuple of 3 items, each value
535 # is in __annotations__.
536 self.assertIsInstance(the_fields, tuple)
537 for f in the_fields:
538 self.assertIs(type(f), Field)
539 self.assertIn(f.name, C.__annotations__)
540
541 self.assertEqual(len(the_fields), 3)
542
543 self.assertEqual(the_fields[0].name, 'x')
544 self.assertEqual(the_fields[0].type, int)
545 self.assertFalse(hasattr(C, 'x'))
546 self.assertTrue (the_fields[0].init)
547 self.assertTrue (the_fields[0].repr)
548 self.assertEqual(the_fields[1].name, 'y')
549 self.assertEqual(the_fields[1].type, str)
550 self.assertIsNone(getattr(C, 'y'))
551 self.assertFalse(the_fields[1].init)
552 self.assertTrue (the_fields[1].repr)
553 self.assertEqual(the_fields[2].name, 'z')
554 self.assertEqual(the_fields[2].type, str)
555 self.assertFalse(hasattr(C, 'z'))
556 self.assertTrue (the_fields[2].init)
557 self.assertFalse(the_fields[2].repr)
558
559 def test_field_order(self):
560 @dataclass
561 class B:
562 a: str = 'B:a'
563 b: str = 'B:b'
564 c: str = 'B:c'
565
566 @dataclass
567 class C(B):
568 b: str = 'C:b'
569
570 self.assertEqual([(f.name, f.default) for f in fields(C)],
571 [('a', 'B:a'),
572 ('b', 'C:b'),
573 ('c', 'B:c')])
574
575 @dataclass
576 class D(B):
577 c: str = 'D:c'
578
579 self.assertEqual([(f.name, f.default) for f in fields(D)],
580 [('a', 'B:a'),
581 ('b', 'B:b'),
582 ('c', 'D:c')])
583
584 @dataclass
585 class E(D):
586 a: str = 'E:a'
587 d: str = 'E:d'
588
589 self.assertEqual([(f.name, f.default) for f in fields(E)],
590 [('a', 'E:a'),
591 ('b', 'B:b'),
592 ('c', 'D:c'),
593 ('d', 'E:d')])
594
595 def test_class_attrs(self):
596 # We only have a class attribute if a default value is
597 # specified, either directly or via a field with a default.
598 default = object()
599 @dataclass
600 class C:
601 x: int
602 y: int = field(repr=False)
603 z: object = default
604 t: int = field(default=100)
605
606 self.assertFalse(hasattr(C, 'x'))
607 self.assertFalse(hasattr(C, 'y'))
608 self.assertIs (C.z, default)
609 self.assertEqual(C.t, 100)
610
611 def test_disallowed_mutable_defaults(self):
612 # For the known types, don't allow mutable default values.
613 for typ, empty, non_empty in [(list, [], [1]),
614 (dict, {}, {0:1}),
615 (set, set(), set([1])),
616 ]:
617 with self.subTest(typ=typ):
618 # Can't use a zero-length value.
619 with self.assertRaisesRegex(ValueError,
620 f'mutable default {typ} for field '
621 'x is not allowed'):
622 @dataclass
623 class Point:
624 x: typ = empty
625
626
627 # Nor a non-zero-length value
628 with self.assertRaisesRegex(ValueError,
629 f'mutable default {typ} for field '
630 'y is not allowed'):
631 @dataclass
632 class Point:
633 y: typ = non_empty
634
635 # Check subtypes also fail.
636 class Subclass(typ): pass
637
638 with self.assertRaisesRegex(ValueError,
639 f"mutable default .*Subclass'>"
640 ' for field z is not allowed'
641 ):
642 @dataclass
643 class Point:
644 z: typ = Subclass()
645
646 # Because this is a ClassVar, it can be mutable.
647 @dataclass
648 class C:
649 z: ClassVar[typ] = typ()
650
651 # Because this is a ClassVar, it can be mutable.
652 @dataclass
653 class C:
654 x: ClassVar[typ] = Subclass()
655
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500656 def test_deliberately_mutable_defaults(self):
657 # If a mutable default isn't in the known list of
658 # (list, dict, set), then it's okay.
659 class Mutable:
660 def __init__(self):
661 self.l = []
662
663 @dataclass
664 class C:
665 x: Mutable
666
667 # These 2 instances will share this value of x.
668 lst = Mutable()
669 o1 = C(lst)
670 o2 = C(lst)
671 self.assertEqual(o1, o2)
672 o1.x.l.extend([1, 2])
673 self.assertEqual(o1, o2)
674 self.assertEqual(o1.x.l, [1, 2])
675 self.assertIs(o1.x, o2.x)
676
677 def test_no_options(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700678 # Call with dataclass().
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500679 @dataclass()
680 class C:
681 x: int
682
683 self.assertEqual(C(42).x, 42)
684
685 def test_not_tuple(self):
686 # Make sure we can't be compared to a tuple.
687 @dataclass
688 class Point:
689 x: int
690 y: int
691 self.assertNotEqual(Point(1, 2), (1, 2))
692
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700693 # And that we can't compare to another unrelated dataclass.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500694 @dataclass
695 class C:
696 x: int
697 y: int
698 self.assertNotEqual(Point(1, 3), C(1, 3))
699
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500700 def test_not_tuple(self):
701 # Test that some of the problems with namedtuple don't happen
702 # here.
703 @dataclass
704 class Point3D:
705 x: int
706 y: int
707 z: int
708
709 @dataclass
710 class Date:
711 year: int
712 month: int
713 day: int
714
715 self.assertNotEqual(Point3D(2017, 6, 3), Date(2017, 6, 3))
716 self.assertNotEqual(Point3D(1, 2, 3), (1, 2, 3))
717
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700718 # Make sure we can't unpack.
Serhiy Storchaka13a6c092017-12-26 12:30:41 +0200719 with self.assertRaisesRegex(TypeError, 'unpack'):
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500720 x, y, z = Point3D(4, 5, 6)
721
Eric V. Smith7c99e932018-01-28 19:18:55 -0500722 # Make sure another class with the same field names isn't
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500723 # equal.
724 @dataclass
725 class Point3Dv1:
726 x: int = 0
727 y: int = 0
728 z: int = 0
729 self.assertNotEqual(Point3D(0, 0, 0), Point3Dv1())
730
731 def test_function_annotations(self):
732 # Some dummy class and instance to use as a default.
733 class F:
734 pass
735 f = F()
736
737 def validate_class(cls):
738 # First, check __annotations__, even though they're not
739 # function annotations.
740 self.assertEqual(cls.__annotations__['i'], int)
741 self.assertEqual(cls.__annotations__['j'], str)
742 self.assertEqual(cls.__annotations__['k'], F)
743 self.assertEqual(cls.__annotations__['l'], float)
744 self.assertEqual(cls.__annotations__['z'], complex)
745
746 # Verify __init__.
747
748 signature = inspect.signature(cls.__init__)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700749 # Check the return type, should be None.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500750 self.assertIs(signature.return_annotation, None)
751
752 # Check each parameter.
753 params = iter(signature.parameters.values())
754 param = next(params)
755 # This is testing an internal name, and probably shouldn't be tested.
756 self.assertEqual(param.name, 'self')
757 param = next(params)
758 self.assertEqual(param.name, 'i')
759 self.assertIs (param.annotation, int)
760 self.assertEqual(param.default, inspect.Parameter.empty)
761 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
762 param = next(params)
763 self.assertEqual(param.name, 'j')
764 self.assertIs (param.annotation, str)
765 self.assertEqual(param.default, inspect.Parameter.empty)
766 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
767 param = next(params)
768 self.assertEqual(param.name, 'k')
769 self.assertIs (param.annotation, F)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700770 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500771 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
772 param = next(params)
773 self.assertEqual(param.name, 'l')
774 self.assertIs (param.annotation, float)
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700775 # Don't test for the default, since it's set to MISSING.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500776 self.assertEqual(param.kind, inspect.Parameter.POSITIONAL_OR_KEYWORD)
777 self.assertRaises(StopIteration, next, params)
778
779
780 @dataclass
781 class C:
782 i: int
783 j: str
784 k: F = f
785 l: float=field(default=None)
786 z: complex=field(default=3+4j, init=False)
787
788 validate_class(C)
789
790 # Now repeat with __hash__.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -0800791 @dataclass(frozen=True, unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500792 class C:
793 i: int
794 j: str
795 k: F = f
796 l: float=field(default=None)
797 z: complex=field(default=3+4j, init=False)
798
799 validate_class(C)
800
Eric V. Smith03220fd2017-12-29 13:59:58 -0500801 def test_missing_default(self):
802 # Test that MISSING works the same as a default not being
803 # specified.
804 @dataclass
805 class C:
806 x: int=field(default=MISSING)
807 with self.assertRaisesRegex(TypeError,
808 r'__init__\(\) missing 1 required '
809 'positional argument'):
810 C()
811 self.assertNotIn('x', C.__dict__)
812
813 @dataclass
814 class D:
815 x: int
816 with self.assertRaisesRegex(TypeError,
817 r'__init__\(\) missing 1 required '
818 'positional argument'):
819 D()
820 self.assertNotIn('x', D.__dict__)
821
822 def test_missing_default_factory(self):
823 # Test that MISSING works the same as a default factory not
824 # being specified (which is really the same as a default not
825 # being specified, too).
826 @dataclass
827 class C:
828 x: int=field(default_factory=MISSING)
829 with self.assertRaisesRegex(TypeError,
830 r'__init__\(\) missing 1 required '
831 'positional argument'):
832 C()
833 self.assertNotIn('x', C.__dict__)
834
835 @dataclass
836 class D:
837 x: int=field(default=MISSING, default_factory=MISSING)
838 with self.assertRaisesRegex(TypeError,
839 r'__init__\(\) missing 1 required '
840 'positional argument'):
841 D()
842 self.assertNotIn('x', D.__dict__)
843
844 def test_missing_repr(self):
845 self.assertIn('MISSING_TYPE object', repr(MISSING))
846
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500847 def test_dont_include_other_annotations(self):
848 @dataclass
849 class C:
850 i: int
851 def foo(self) -> int:
852 return 4
853 @property
854 def bar(self) -> int:
855 return 5
856 self.assertEqual(list(C.__annotations__), ['i'])
857 self.assertEqual(C(10).foo(), 4)
858 self.assertEqual(C(10).bar, 5)
Miss Islington (bot)5666a552018-03-25 06:27:50 -0700859 self.assertEqual(C(10).i, 10)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500860
861 def test_post_init(self):
862 # Just make sure it gets called
863 @dataclass
864 class C:
865 def __post_init__(self):
866 raise CustomError()
867 with self.assertRaises(CustomError):
868 C()
869
870 @dataclass
871 class C:
872 i: int = 10
873 def __post_init__(self):
874 if self.i == 10:
875 raise CustomError()
876 with self.assertRaises(CustomError):
877 C()
878 # post-init gets called, but doesn't raise. This is just
879 # checking that self is used correctly.
880 C(5)
881
882 # If there's not an __init__, then post-init won't get called.
883 @dataclass(init=False)
884 class C:
885 def __post_init__(self):
886 raise CustomError()
887 # Creating the class won't raise
888 C()
889
890 @dataclass
891 class C:
892 x: int = 0
893 def __post_init__(self):
894 self.x *= 2
895 self.assertEqual(C().x, 0)
896 self.assertEqual(C(2).x, 4)
897
Mike53f7a7c2017-12-14 14:04:53 +0300898 # Make sure that if we're frozen, post-init can't set
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500899 # attributes.
900 @dataclass(frozen=True)
901 class C:
902 x: int = 0
903 def __post_init__(self):
904 self.x *= 2
905 with self.assertRaises(FrozenInstanceError):
906 C()
907
908 def test_post_init_super(self):
909 # Make sure super() post-init isn't called by default.
910 class B:
911 def __post_init__(self):
912 raise CustomError()
913
914 @dataclass
915 class C(B):
916 def __post_init__(self):
917 self.x = 5
918
919 self.assertEqual(C().x, 5)
920
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700921 # Now call super(), and it will raise.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500922 @dataclass
923 class C(B):
924 def __post_init__(self):
925 super().__post_init__()
926
927 with self.assertRaises(CustomError):
928 C()
929
930 # Make sure post-init is called, even if not defined in our
931 # class.
932 @dataclass
933 class C(B):
934 pass
935
936 with self.assertRaises(CustomError):
937 C()
938
939 def test_post_init_staticmethod(self):
940 flag = False
941 @dataclass
942 class C:
943 x: int
944 y: int
945 @staticmethod
946 def __post_init__():
947 nonlocal flag
948 flag = True
949
950 self.assertFalse(flag)
951 c = C(3, 4)
952 self.assertEqual((c.x, c.y), (3, 4))
953 self.assertTrue(flag)
954
955 def test_post_init_classmethod(self):
956 @dataclass
957 class C:
958 flag = False
959 x: int
960 y: int
961 @classmethod
962 def __post_init__(cls):
963 cls.flag = True
964
965 self.assertFalse(C.flag)
966 c = C(3, 4)
967 self.assertEqual((c.x, c.y), (3, 4))
968 self.assertTrue(C.flag)
969
970 def test_class_var(self):
971 # Make sure ClassVars are ignored in __init__, __repr__, etc.
972 @dataclass
973 class C:
974 x: int
975 y: int = 10
976 z: ClassVar[int] = 1000
977 w: ClassVar[int] = 2000
978 t: ClassVar[int] = 3000
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700979 s: ClassVar = 4000
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500980
981 c = C(5)
982 self.assertEqual(repr(c), 'TestCase.test_class_var.<locals>.C(x=5, y=10)')
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -0700983 self.assertEqual(len(fields(C)), 2) # We have 2 fields.
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700984 self.assertEqual(len(C.__annotations__), 6) # And 4 ClassVars.
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500985 self.assertEqual(c.z, 1000)
986 self.assertEqual(c.w, 2000)
987 self.assertEqual(c.t, 3000)
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700988 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500989 C.z += 1
990 self.assertEqual(c.z, 1001)
991 c = C(20)
992 self.assertEqual((c.x, c.y), (20, 10))
993 self.assertEqual(c.z, 1001)
994 self.assertEqual(c.w, 2000)
995 self.assertEqual(c.t, 3000)
Miss Islington (bot)c73268a2018-05-15 21:22:13 -0700996 self.assertEqual(c.s, 4000)
Eric V. Smithf0db54a2017-12-04 16:58:55 -0500997
998 def test_class_var_no_default(self):
999 # If a ClassVar has no default value, it should not be set on the class.
1000 @dataclass
1001 class C:
1002 x: ClassVar[int]
1003
1004 self.assertNotIn('x', C.__dict__)
1005
1006 def test_class_var_default_factory(self):
1007 # It makes no sense for a ClassVar to have a default factory. When
1008 # would it be called? Call it yourself, since it's class-wide.
1009 with self.assertRaisesRegex(TypeError,
1010 'cannot have a default factory'):
1011 @dataclass
1012 class C:
1013 x: ClassVar[int] = field(default_factory=int)
1014
1015 self.assertNotIn('x', C.__dict__)
1016
1017 def test_class_var_with_default(self):
1018 # If a ClassVar has a default value, it should be set on the class.
1019 @dataclass
1020 class C:
1021 x: ClassVar[int] = 10
1022 self.assertEqual(C.x, 10)
1023
1024 @dataclass
1025 class C:
1026 x: ClassVar[int] = field(default=10)
1027 self.assertEqual(C.x, 10)
1028
1029 def test_class_var_frozen(self):
1030 # Make sure ClassVars work even if we're frozen.
1031 @dataclass(frozen=True)
1032 class C:
1033 x: int
1034 y: int = 10
1035 z: ClassVar[int] = 1000
1036 w: ClassVar[int] = 2000
1037 t: ClassVar[int] = 3000
1038
1039 c = C(5)
1040 self.assertEqual(repr(C(5)), 'TestCase.test_class_var_frozen.<locals>.C(x=5, y=10)')
1041 self.assertEqual(len(fields(C)), 2) # We have 2 fields
1042 self.assertEqual(len(C.__annotations__), 5) # And 3 ClassVars
1043 self.assertEqual(c.z, 1000)
1044 self.assertEqual(c.w, 2000)
1045 self.assertEqual(c.t, 3000)
1046 # We can still modify the ClassVar, it's only instances that are
1047 # frozen.
1048 C.z += 1
1049 self.assertEqual(c.z, 1001)
1050 c = C(20)
1051 self.assertEqual((c.x, c.y), (20, 10))
1052 self.assertEqual(c.z, 1001)
1053 self.assertEqual(c.w, 2000)
1054 self.assertEqual(c.t, 3000)
1055
1056 def test_init_var_no_default(self):
1057 # If an InitVar has no default value, it should not be set on the class.
1058 @dataclass
1059 class C:
1060 x: InitVar[int]
1061
1062 self.assertNotIn('x', C.__dict__)
1063
1064 def test_init_var_default_factory(self):
1065 # It makes no sense for an InitVar to have a default factory. When
1066 # would it be called? Call it yourself, since it's class-wide.
1067 with self.assertRaisesRegex(TypeError,
1068 'cannot have a default factory'):
1069 @dataclass
1070 class C:
1071 x: InitVar[int] = field(default_factory=int)
1072
1073 self.assertNotIn('x', C.__dict__)
1074
1075 def test_init_var_with_default(self):
1076 # If an InitVar has a default value, it should be set on the class.
1077 @dataclass
1078 class C:
1079 x: InitVar[int] = 10
1080 self.assertEqual(C.x, 10)
1081
1082 @dataclass
1083 class C:
1084 x: InitVar[int] = field(default=10)
1085 self.assertEqual(C.x, 10)
1086
1087 def test_init_var(self):
1088 @dataclass
1089 class C:
1090 x: int = None
1091 init_param: InitVar[int] = None
1092
1093 def __post_init__(self, init_param):
1094 if self.x is None:
1095 self.x = init_param*2
1096
1097 c = C(init_param=10)
1098 self.assertEqual(c.x, 20)
1099
1100 def test_init_var_inheritance(self):
1101 # Note that this deliberately tests that a dataclass need not
1102 # have a __post_init__ function if it has an InitVar field.
1103 # It could just be used in a derived class, as shown here.
1104 @dataclass
1105 class Base:
1106 x: int
1107 init_base: InitVar[int]
1108
1109 # We can instantiate by passing the InitVar, even though
1110 # it's not used.
1111 b = Base(0, 10)
1112 self.assertEqual(vars(b), {'x': 0})
1113
1114 @dataclass
1115 class C(Base):
1116 y: int
1117 init_derived: InitVar[int]
1118
1119 def __post_init__(self, init_base, init_derived):
1120 self.x = self.x + init_base
1121 self.y = self.y + init_derived
1122
1123 c = C(10, 11, 50, 51)
1124 self.assertEqual(vars(c), {'x': 21, 'y': 101})
1125
1126 def test_default_factory(self):
1127 # Test a factory that returns a new list.
1128 @dataclass
1129 class C:
1130 x: int
1131 y: list = field(default_factory=list)
1132
1133 c0 = C(3)
1134 c1 = C(3)
1135 self.assertEqual(c0.x, 3)
1136 self.assertEqual(c0.y, [])
1137 self.assertEqual(c0, c1)
1138 self.assertIsNot(c0.y, c1.y)
1139 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1140
1141 # Test a factory that returns a shared list.
1142 l = []
1143 @dataclass
1144 class C:
1145 x: int
1146 y: list = field(default_factory=lambda: l)
1147
1148 c0 = C(3)
1149 c1 = C(3)
1150 self.assertEqual(c0.x, 3)
1151 self.assertEqual(c0.y, [])
1152 self.assertEqual(c0, c1)
1153 self.assertIs(c0.y, c1.y)
1154 self.assertEqual(astuple(C(5, [1])), (5, [1]))
1155
1156 # Test various other field flags.
1157 # repr
1158 @dataclass
1159 class C:
1160 x: list = field(default_factory=list, repr=False)
1161 self.assertEqual(repr(C()), 'TestCase.test_default_factory.<locals>.C()')
1162 self.assertEqual(C().x, [])
1163
1164 # hash
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08001165 @dataclass(unsafe_hash=True)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001166 class C:
1167 x: list = field(default_factory=list, hash=False)
1168 self.assertEqual(astuple(C()), ([],))
1169 self.assertEqual(hash(C()), hash(()))
1170
1171 # init (see also test_default_factory_with_no_init)
1172 @dataclass
1173 class C:
1174 x: list = field(default_factory=list, init=False)
1175 self.assertEqual(astuple(C()), ([],))
1176
1177 # compare
1178 @dataclass
1179 class C:
1180 x: list = field(default_factory=list, compare=False)
1181 self.assertEqual(C(), C([1]))
1182
1183 def test_default_factory_with_no_init(self):
1184 # We need a factory with a side effect.
1185 factory = Mock()
1186
1187 @dataclass
1188 class C:
1189 x: list = field(default_factory=factory, init=False)
1190
1191 # Make sure the default factory is called for each new instance.
1192 C().x
1193 self.assertEqual(factory.call_count, 1)
1194 C().x
1195 self.assertEqual(factory.call_count, 2)
1196
1197 def test_default_factory_not_called_if_value_given(self):
1198 # We need a factory that we can test if it's been called.
1199 factory = Mock()
1200
1201 @dataclass
1202 class C:
1203 x: int = field(default_factory=factory)
1204
1205 # Make sure that if a field has a default factory function,
1206 # it's not called if a value is specified.
1207 C().x
1208 self.assertEqual(factory.call_count, 1)
1209 self.assertEqual(C(10).x, 10)
1210 self.assertEqual(factory.call_count, 1)
1211 C().x
1212 self.assertEqual(factory.call_count, 2)
1213
Miss Islington (bot)22136c92018-03-21 02:17:30 -07001214 def test_default_factory_derived(self):
1215 # See bpo-32896.
1216 @dataclass
1217 class Foo:
1218 x: dict = field(default_factory=dict)
1219
1220 @dataclass
1221 class Bar(Foo):
1222 y: int = 1
1223
1224 self.assertEqual(Foo().x, {})
1225 self.assertEqual(Bar().x, {})
1226 self.assertEqual(Bar().y, 1)
1227
1228 @dataclass
1229 class Baz(Foo):
1230 pass
1231 self.assertEqual(Baz().x, {})
1232
1233 def test_intermediate_non_dataclass(self):
1234 # Test that an intermediate class that defines
1235 # annotations does not define fields.
1236
1237 @dataclass
1238 class A:
1239 x: int
1240
1241 class B(A):
1242 y: int
1243
1244 @dataclass
1245 class C(B):
1246 z: int
1247
1248 c = C(1, 3)
1249 self.assertEqual((c.x, c.z), (1, 3))
1250
1251 # .y was not initialized.
1252 with self.assertRaisesRegex(AttributeError,
1253 'object has no attribute'):
1254 c.y
1255
1256 # And if we again derive a non-dataclass, no fields are added.
1257 class D(C):
1258 t: int
1259 d = D(4, 5)
1260 self.assertEqual((d.x, d.z), (4, 5))
1261
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001262 def test_classvar_default_factory(self):
1263 # It's an error for a ClassVar to have a factory function.
1264 with self.assertRaisesRegex(TypeError,
1265 'cannot have a default factory'):
1266 @dataclass
1267 class C:
1268 x: ClassVar[int] = field(default_factory=int)
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001269
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001270 def test_is_dataclass(self):
1271 class NotDataClass:
1272 pass
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001273
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001274 self.assertFalse(is_dataclass(0))
1275 self.assertFalse(is_dataclass(int))
1276 self.assertFalse(is_dataclass(NotDataClass))
1277 self.assertFalse(is_dataclass(NotDataClass()))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001278
1279 @dataclass
1280 class C:
1281 x: int
1282
Eric V. Smithe7ba0132018-01-06 12:41:53 -05001283 @dataclass
1284 class D:
1285 d: C
1286 e: int
1287
1288 c = C(10)
1289 d = D(c, 4)
1290
1291 self.assertTrue(is_dataclass(C))
1292 self.assertTrue(is_dataclass(c))
1293 self.assertFalse(is_dataclass(c.x))
1294 self.assertTrue(is_dataclass(d.d))
1295 self.assertFalse(is_dataclass(d.e))
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001296
1297 def test_helper_fields_with_class_instance(self):
1298 # Check that we can call fields() on either a class or instance,
1299 # and get back the same thing.
1300 @dataclass
1301 class C:
1302 x: int
1303 y: float
1304
1305 self.assertEqual(fields(C), fields(C(0, 0.0)))
1306
1307 def test_helper_fields_exception(self):
1308 # Check that TypeError is raised if not passed a dataclass or
1309 # instance.
1310 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1311 fields(0)
1312
1313 class C: pass
1314 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1315 fields(C)
1316 with self.assertRaisesRegex(TypeError, 'dataclass type or instance'):
1317 fields(C())
1318
1319 def test_helper_asdict(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001320 # Basic tests for asdict(), it should return a new dictionary.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001321 @dataclass
1322 class C:
1323 x: int
1324 y: int
1325 c = C(1, 2)
1326
1327 self.assertEqual(asdict(c), {'x': 1, 'y': 2})
1328 self.assertEqual(asdict(c), asdict(c))
1329 self.assertIsNot(asdict(c), asdict(c))
1330 c.x = 42
1331 self.assertEqual(asdict(c), {'x': 42, 'y': 2})
1332 self.assertIs(type(asdict(c)), dict)
1333
1334 def test_helper_asdict_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001335 # asdict() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001336 @dataclass
1337 class C:
1338 x: int
1339 y: int
1340 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1341 asdict(C)
1342 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1343 asdict(int)
1344
1345 def test_helper_asdict_copy_values(self):
1346 @dataclass
1347 class C:
1348 x: int
1349 y: List[int] = field(default_factory=list)
1350 initial = []
1351 c = C(1, initial)
1352 d = asdict(c)
1353 self.assertEqual(d['y'], initial)
1354 self.assertIsNot(d['y'], initial)
1355 c = C(1)
1356 d = asdict(c)
1357 d['y'].append(1)
1358 self.assertEqual(c.y, [])
1359
1360 def test_helper_asdict_nested(self):
1361 @dataclass
1362 class UserId:
1363 token: int
1364 group: int
1365 @dataclass
1366 class User:
1367 name: str
1368 id: UserId
1369 u = User('Joe', UserId(123, 1))
1370 d = asdict(u)
1371 self.assertEqual(d, {'name': 'Joe', 'id': {'token': 123, 'group': 1}})
1372 self.assertIsNot(asdict(u), asdict(u))
1373 u.id.group = 2
1374 self.assertEqual(asdict(u), {'name': 'Joe',
1375 'id': {'token': 123, 'group': 2}})
1376
1377 def test_helper_asdict_builtin_containers(self):
1378 @dataclass
1379 class User:
1380 name: str
1381 id: int
1382 @dataclass
1383 class GroupList:
1384 id: int
1385 users: List[User]
1386 @dataclass
1387 class GroupTuple:
1388 id: int
1389 users: Tuple[User, ...]
1390 @dataclass
1391 class GroupDict:
1392 id: int
1393 users: Dict[str, User]
1394 a = User('Alice', 1)
1395 b = User('Bob', 2)
1396 gl = GroupList(0, [a, b])
1397 gt = GroupTuple(0, (a, b))
1398 gd = GroupDict(0, {'first': a, 'second': b})
1399 self.assertEqual(asdict(gl), {'id': 0, 'users': [{'name': 'Alice', 'id': 1},
1400 {'name': 'Bob', 'id': 2}]})
1401 self.assertEqual(asdict(gt), {'id': 0, 'users': ({'name': 'Alice', 'id': 1},
1402 {'name': 'Bob', 'id': 2})})
1403 self.assertEqual(asdict(gd), {'id': 0, 'users': {'first': {'name': 'Alice', 'id': 1},
1404 'second': {'name': 'Bob', 'id': 2}}})
1405
1406 def test_helper_asdict_builtin_containers(self):
1407 @dataclass
1408 class Child:
1409 d: object
1410
1411 @dataclass
1412 class Parent:
1413 child: Child
1414
1415 self.assertEqual(asdict(Parent(Child([1]))), {'child': {'d': [1]}})
1416 self.assertEqual(asdict(Parent(Child({1: 2}))), {'child': {'d': {1: 2}}})
1417
1418 def test_helper_asdict_factory(self):
1419 @dataclass
1420 class C:
1421 x: int
1422 y: int
1423 c = C(1, 2)
1424 d = asdict(c, dict_factory=OrderedDict)
1425 self.assertEqual(d, OrderedDict([('x', 1), ('y', 2)]))
1426 self.assertIsNot(d, asdict(c, dict_factory=OrderedDict))
1427 c.x = 42
1428 d = asdict(c, dict_factory=OrderedDict)
1429 self.assertEqual(d, OrderedDict([('x', 42), ('y', 2)]))
1430 self.assertIs(type(d), OrderedDict)
1431
1432 def test_helper_astuple(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001433 # Basic tests for astuple(), it should return a new tuple.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001434 @dataclass
1435 class C:
1436 x: int
1437 y: int = 0
1438 c = C(1)
1439
1440 self.assertEqual(astuple(c), (1, 0))
1441 self.assertEqual(astuple(c), astuple(c))
1442 self.assertIsNot(astuple(c), astuple(c))
1443 c.y = 42
1444 self.assertEqual(astuple(c), (1, 42))
1445 self.assertIs(type(astuple(c)), tuple)
1446
1447 def test_helper_astuple_raises_on_classes(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001448 # astuple() should raise on a class object.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001449 @dataclass
1450 class C:
1451 x: int
1452 y: int
1453 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1454 astuple(C)
1455 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
1456 astuple(int)
1457
1458 def test_helper_astuple_copy_values(self):
1459 @dataclass
1460 class C:
1461 x: int
1462 y: List[int] = field(default_factory=list)
1463 initial = []
1464 c = C(1, initial)
1465 t = astuple(c)
1466 self.assertEqual(t[1], initial)
1467 self.assertIsNot(t[1], initial)
1468 c = C(1)
1469 t = astuple(c)
1470 t[1].append(1)
1471 self.assertEqual(c.y, [])
1472
1473 def test_helper_astuple_nested(self):
1474 @dataclass
1475 class UserId:
1476 token: int
1477 group: int
1478 @dataclass
1479 class User:
1480 name: str
1481 id: UserId
1482 u = User('Joe', UserId(123, 1))
1483 t = astuple(u)
1484 self.assertEqual(t, ('Joe', (123, 1)))
1485 self.assertIsNot(astuple(u), astuple(u))
1486 u.id.group = 2
1487 self.assertEqual(astuple(u), ('Joe', (123, 2)))
1488
1489 def test_helper_astuple_builtin_containers(self):
1490 @dataclass
1491 class User:
1492 name: str
1493 id: int
1494 @dataclass
1495 class GroupList:
1496 id: int
1497 users: List[User]
1498 @dataclass
1499 class GroupTuple:
1500 id: int
1501 users: Tuple[User, ...]
1502 @dataclass
1503 class GroupDict:
1504 id: int
1505 users: Dict[str, User]
1506 a = User('Alice', 1)
1507 b = User('Bob', 2)
1508 gl = GroupList(0, [a, b])
1509 gt = GroupTuple(0, (a, b))
1510 gd = GroupDict(0, {'first': a, 'second': b})
1511 self.assertEqual(astuple(gl), (0, [('Alice', 1), ('Bob', 2)]))
1512 self.assertEqual(astuple(gt), (0, (('Alice', 1), ('Bob', 2))))
1513 self.assertEqual(astuple(gd), (0, {'first': ('Alice', 1), 'second': ('Bob', 2)}))
1514
1515 def test_helper_astuple_builtin_containers(self):
1516 @dataclass
1517 class Child:
1518 d: object
1519
1520 @dataclass
1521 class Parent:
1522 child: Child
1523
1524 self.assertEqual(astuple(Parent(Child([1]))), (([1],),))
1525 self.assertEqual(astuple(Parent(Child({1: 2}))), (({1: 2},),))
1526
1527 def test_helper_astuple_factory(self):
1528 @dataclass
1529 class C:
1530 x: int
1531 y: int
1532 NT = namedtuple('NT', 'x y')
1533 def nt(lst):
1534 return NT(*lst)
1535 c = C(1, 2)
1536 t = astuple(c, tuple_factory=nt)
1537 self.assertEqual(t, NT(1, 2))
1538 self.assertIsNot(t, astuple(c, tuple_factory=nt))
1539 c.x = 42
1540 t = astuple(c, tuple_factory=nt)
1541 self.assertEqual(t, NT(42, 2))
1542 self.assertIs(type(t), NT)
1543
1544 def test_dynamic_class_creation(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001545 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001546 }
1547
1548 # Create the class.
1549 cls = type('C', (), cls_dict)
1550
1551 # Make it a dataclass.
1552 cls1 = dataclass(cls)
1553
1554 self.assertEqual(cls1, cls)
1555 self.assertEqual(asdict(cls(1, 2)), {'x': 1, 'y': 2})
1556
1557 def test_dynamic_class_creation_using_field(self):
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001558 cls_dict = {'__annotations__': {'x': int, 'y': int},
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001559 'y': field(default=5),
1560 }
1561
1562 # Create the class.
1563 cls = type('C', (), cls_dict)
1564
1565 # Make it a dataclass.
1566 cls1 = dataclass(cls)
1567
1568 self.assertEqual(cls1, cls)
1569 self.assertEqual(asdict(cls1(1)), {'x': 1, 'y': 5})
1570
1571 def test_init_in_order(self):
1572 @dataclass
1573 class C:
1574 a: int
1575 b: int = field()
1576 c: list = field(default_factory=list, init=False)
1577 d: list = field(default_factory=list)
1578 e: int = field(default=4, init=False)
1579 f: int = 4
1580
1581 calls = []
1582 def setattr(self, name, value):
1583 calls.append((name, value))
1584
1585 C.__setattr__ = setattr
1586 c = C(0, 1)
1587 self.assertEqual(('a', 0), calls[0])
1588 self.assertEqual(('b', 1), calls[1])
1589 self.assertEqual(('c', []), calls[2])
1590 self.assertEqual(('d', []), calls[3])
1591 self.assertNotIn(('e', 4), calls)
1592 self.assertEqual(('f', 4), calls[4])
1593
1594 def test_items_in_dicts(self):
1595 @dataclass
1596 class C:
1597 a: int
1598 b: list = field(default_factory=list, init=False)
1599 c: list = field(default_factory=list)
1600 d: int = field(default=4, init=False)
1601 e: int = 0
1602
1603 c = C(0)
1604 # Class dict
1605 self.assertNotIn('a', C.__dict__)
1606 self.assertNotIn('b', C.__dict__)
1607 self.assertNotIn('c', C.__dict__)
1608 self.assertIn('d', C.__dict__)
1609 self.assertEqual(C.d, 4)
1610 self.assertIn('e', C.__dict__)
1611 self.assertEqual(C.e, 0)
1612 # Instance dict
1613 self.assertIn('a', c.__dict__)
1614 self.assertEqual(c.a, 0)
1615 self.assertIn('b', c.__dict__)
1616 self.assertEqual(c.b, [])
1617 self.assertIn('c', c.__dict__)
1618 self.assertEqual(c.c, [])
1619 self.assertNotIn('d', c.__dict__)
1620 self.assertIn('e', c.__dict__)
1621 self.assertEqual(c.e, 0)
1622
1623 def test_alternate_classmethod_constructor(self):
1624 # Since __post_init__ can't take params, use a classmethod
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001625 # alternate constructor. This is mostly an example to show
1626 # how to use this technique.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001627 @dataclass
1628 class C:
1629 x: int
1630 @classmethod
1631 def from_file(cls, filename):
1632 # In a real example, create a new instance
1633 # and populate 'x' from contents of a file.
1634 value_in_file = 20
1635 return cls(value_in_file)
1636
1637 self.assertEqual(C.from_file('filename').x, 20)
1638
1639 def test_field_metadata_default(self):
1640 # Make sure the default metadata is read-only and of
1641 # zero length.
1642 @dataclass
1643 class C:
1644 i: int
1645
1646 self.assertFalse(fields(C)[0].metadata)
1647 self.assertEqual(len(fields(C)[0].metadata), 0)
1648 with self.assertRaisesRegex(TypeError,
1649 'does not support item assignment'):
1650 fields(C)[0].metadata['test'] = 3
1651
1652 def test_field_metadata_mapping(self):
1653 # Make sure only a mapping can be passed as metadata
1654 # zero length.
1655 with self.assertRaises(TypeError):
1656 @dataclass
1657 class C:
1658 i: int = field(metadata=0)
1659
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001660 # Make sure an empty dict works.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001661 @dataclass
1662 class C:
1663 i: int = field(metadata={})
1664 self.assertFalse(fields(C)[0].metadata)
1665 self.assertEqual(len(fields(C)[0].metadata), 0)
1666 with self.assertRaisesRegex(TypeError,
1667 'does not support item assignment'):
1668 fields(C)[0].metadata['test'] = 3
1669
1670 # Make sure a non-empty dict works.
1671 @dataclass
1672 class C:
1673 i: int = field(metadata={'test': 10, 'bar': '42', 3: 'three'})
1674 self.assertEqual(len(fields(C)[0].metadata), 3)
1675 self.assertEqual(fields(C)[0].metadata['test'], 10)
1676 self.assertEqual(fields(C)[0].metadata['bar'], '42')
1677 self.assertEqual(fields(C)[0].metadata[3], 'three')
1678 with self.assertRaises(KeyError):
1679 # Non-existent key.
1680 fields(C)[0].metadata['baz']
1681 with self.assertRaisesRegex(TypeError,
1682 'does not support item assignment'):
1683 fields(C)[0].metadata['test'] = 3
1684
1685 def test_field_metadata_custom_mapping(self):
1686 # Try a custom mapping.
1687 class SimpleNameSpace:
1688 def __init__(self, **kw):
1689 self.__dict__.update(kw)
1690
1691 def __getitem__(self, item):
1692 if item == 'xyzzy':
1693 return 'plugh'
1694 return getattr(self, item)
1695
1696 def __len__(self):
1697 return self.__dict__.__len__()
1698
1699 @dataclass
1700 class C:
1701 i: int = field(metadata=SimpleNameSpace(a=10))
1702
1703 self.assertEqual(len(fields(C)[0].metadata), 1)
1704 self.assertEqual(fields(C)[0].metadata['a'], 10)
1705 with self.assertRaises(AttributeError):
1706 fields(C)[0].metadata['b']
1707 # Make sure we're still talking to our custom mapping.
1708 self.assertEqual(fields(C)[0].metadata['xyzzy'], 'plugh')
1709
1710 def test_generic_dataclasses(self):
1711 T = TypeVar('T')
1712
1713 @dataclass
1714 class LabeledBox(Generic[T]):
1715 content: T
1716 label: str = '<unknown>'
1717
1718 box = LabeledBox(42)
1719 self.assertEqual(box.content, 42)
1720 self.assertEqual(box.label, '<unknown>')
1721
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001722 # Subscripting the resulting class should work, etc.
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001723 Alias = List[LabeledBox[int]]
1724
1725 def test_generic_extending(self):
1726 S = TypeVar('S')
1727 T = TypeVar('T')
1728
1729 @dataclass
1730 class Base(Generic[T, S]):
1731 x: T
1732 y: S
1733
1734 @dataclass
1735 class DataDerived(Base[int, T]):
1736 new_field: str
1737 Alias = DataDerived[str]
1738 c = Alias(0, 'test1', 'test2')
1739 self.assertEqual(astuple(c), (0, 'test1', 'test2'))
1740
1741 class NonDataDerived(Base[int, T]):
1742 def new_method(self):
1743 return self.y
1744 Alias = NonDataDerived[float]
1745 c = Alias(10, 1.0)
1746 self.assertEqual(c.new_method(), 1.0)
1747
Miss Islington (bot)d063ad82018-04-01 04:33:13 -07001748 def test_generic_dynamic(self):
1749 T = TypeVar('T')
1750
1751 @dataclass
1752 class Parent(Generic[T]):
1753 x: T
1754 Child = make_dataclass('Child', [('y', T), ('z', Optional[T], None)],
1755 bases=(Parent[int], Generic[T]), namespace={'other': 42})
1756 self.assertIs(Child[int](1, 2).z, None)
1757 self.assertEqual(Child[int](1, 2, 3).z, 3)
1758 self.assertEqual(Child[int](1, 2, 3).other, 42)
1759 # Check that type aliases work correctly.
1760 Alias = Child[T]
1761 self.assertEqual(Alias[int](1, 2).x, 1)
1762 # Check MRO resolution.
1763 self.assertEqual(Child.__mro__, (Child, Parent, Generic, object))
1764
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001765 def test_dataclassses_pickleable(self):
1766 global P, Q, R
1767 @dataclass
1768 class P:
1769 x: int
1770 y: int = 0
1771 @dataclass
1772 class Q:
1773 x: int
1774 y: int = field(default=0, init=False)
1775 @dataclass
1776 class R:
1777 x: int
1778 y: List[int] = field(default_factory=list)
1779 q = Q(1)
1780 q.y = 2
1781 samples = [P(1), P(1, 2), Q(1), q, R(1), R(1, [2, 3, 4])]
1782 for sample in samples:
1783 for proto in range(pickle.HIGHEST_PROTOCOL + 1):
1784 with self.subTest(sample=sample, proto=proto):
1785 new_sample = pickle.loads(pickle.dumps(sample, proto))
1786 self.assertEqual(sample.x, new_sample.x)
1787 self.assertEqual(sample.y, new_sample.y)
1788 self.assertIsNot(sample, new_sample)
1789 new_sample.x = 42
1790 another_new_sample = pickle.loads(pickle.dumps(new_sample, proto))
1791 self.assertEqual(new_sample.x, another_new_sample.x)
1792 self.assertEqual(sample.y, another_new_sample.y)
1793
Eric V. Smithea8fc522018-01-27 19:07:40 -05001794
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001795class TestFieldNoAnnotation(unittest.TestCase):
1796 def test_field_without_annotation(self):
1797 with self.assertRaisesRegex(TypeError,
1798 "'f' is a field but has no type annotation"):
1799 @dataclass
1800 class C:
1801 f = field()
1802
1803 def test_field_without_annotation_but_annotation_in_base(self):
1804 @dataclass
1805 class B:
1806 f: int
1807
1808 with self.assertRaisesRegex(TypeError,
1809 "'f' is a field but has no type annotation"):
1810 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001811 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001812 @dataclass
1813 class C(B):
1814 f = field()
1815
1816 def test_field_without_annotation_but_annotation_in_base_not_dataclass(self):
1817 # Same test, but with the base class not a dataclass.
1818 class B:
1819 f: int
1820
1821 with self.assertRaisesRegex(TypeError,
1822 "'f' is a field but has no type annotation"):
1823 # This is still an error: make sure we don't pick up the
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07001824 # type annotation in the base class.
Miss Islington (bot)3b4c6b12018-03-22 13:58:59 -07001825 @dataclass
1826 class C(B):
1827 f = field()
1828
1829
Eric V. Smithf0db54a2017-12-04 16:58:55 -05001830class TestDocString(unittest.TestCase):
1831 def assertDocStrEqual(self, a, b):
1832 # Because 3.6 and 3.7 differ in how inspect.signature work
1833 # (see bpo #32108), for the time being just compare them with
1834 # whitespace stripped.
1835 self.assertEqual(a.replace(' ', ''), b.replace(' ', ''))
1836
1837 def test_existing_docstring_not_overridden(self):
1838 @dataclass
1839 class C:
1840 """Lorem ipsum"""
1841 x: int
1842
1843 self.assertEqual(C.__doc__, "Lorem ipsum")
1844
1845 def test_docstring_no_fields(self):
1846 @dataclass
1847 class C:
1848 pass
1849
1850 self.assertDocStrEqual(C.__doc__, "C()")
1851
1852 def test_docstring_one_field(self):
1853 @dataclass
1854 class C:
1855 x: int
1856
1857 self.assertDocStrEqual(C.__doc__, "C(x:int)")
1858
1859 def test_docstring_two_fields(self):
1860 @dataclass
1861 class C:
1862 x: int
1863 y: int
1864
1865 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int)")
1866
1867 def test_docstring_three_fields(self):
1868 @dataclass
1869 class C:
1870 x: int
1871 y: int
1872 z: str
1873
1874 self.assertDocStrEqual(C.__doc__, "C(x:int, y:int, z:str)")
1875
1876 def test_docstring_one_field_with_default(self):
1877 @dataclass
1878 class C:
1879 x: int = 3
1880
1881 self.assertDocStrEqual(C.__doc__, "C(x:int=3)")
1882
1883 def test_docstring_one_field_with_default_none(self):
1884 @dataclass
1885 class C:
1886 x: Union[int, type(None)] = None
1887
1888 self.assertDocStrEqual(C.__doc__, "C(x:Union[int, NoneType]=None)")
1889
1890 def test_docstring_list_field(self):
1891 @dataclass
1892 class C:
1893 x: List[int]
1894
1895 self.assertDocStrEqual(C.__doc__, "C(x:List[int])")
1896
1897 def test_docstring_list_field_with_default_factory(self):
1898 @dataclass
1899 class C:
1900 x: List[int] = field(default_factory=list)
1901
1902 self.assertDocStrEqual(C.__doc__, "C(x:List[int]=<factory>)")
1903
1904 def test_docstring_deque_field(self):
1905 @dataclass
1906 class C:
1907 x: deque
1908
1909 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque)")
1910
1911 def test_docstring_deque_field_with_default_factory(self):
1912 @dataclass
1913 class C:
1914 x: deque = field(default_factory=deque)
1915
1916 self.assertDocStrEqual(C.__doc__, "C(x:collections.deque=<factory>)")
1917
1918
Eric V. Smithea8fc522018-01-27 19:07:40 -05001919class TestInit(unittest.TestCase):
1920 def test_base_has_init(self):
1921 class B:
1922 def __init__(self):
1923 self.z = 100
1924 pass
1925
1926 # Make sure that declaring this class doesn't raise an error.
1927 # The issue is that we can't override __init__ in our class,
1928 # but it should be okay to add __init__ to us if our base has
1929 # an __init__.
1930 @dataclass
1931 class C(B):
1932 x: int = 0
1933 c = C(10)
1934 self.assertEqual(c.x, 10)
1935 self.assertNotIn('z', vars(c))
1936
1937 # Make sure that if we don't add an init, the base __init__
1938 # gets called.
1939 @dataclass(init=False)
1940 class C(B):
1941 x: int = 10
1942 c = C()
1943 self.assertEqual(c.x, 10)
1944 self.assertEqual(c.z, 100)
1945
1946 def test_no_init(self):
1947 dataclass(init=False)
1948 class C:
1949 i: int = 0
1950 self.assertEqual(C().i, 0)
1951
1952 dataclass(init=False)
1953 class C:
1954 i: int = 2
1955 def __init__(self):
1956 self.i = 3
1957 self.assertEqual(C().i, 3)
1958
1959 def test_overwriting_init(self):
1960 # If the class has __init__, use it no matter the value of
1961 # init=.
1962
1963 @dataclass
1964 class C:
1965 x: int
1966 def __init__(self, x):
1967 self.x = 2 * x
1968 self.assertEqual(C(3).x, 6)
1969
1970 @dataclass(init=True)
1971 class C:
1972 x: int
1973 def __init__(self, x):
1974 self.x = 2 * x
1975 self.assertEqual(C(4).x, 8)
1976
1977 @dataclass(init=False)
1978 class C:
1979 x: int
1980 def __init__(self, x):
1981 self.x = 2 * x
1982 self.assertEqual(C(5).x, 10)
1983
1984
1985class TestRepr(unittest.TestCase):
1986 def test_repr(self):
1987 @dataclass
1988 class B:
1989 x: int
1990
1991 @dataclass
1992 class C(B):
1993 y: int = 10
1994
1995 o = C(4)
1996 self.assertEqual(repr(o), 'TestRepr.test_repr.<locals>.C(x=4, y=10)')
1997
1998 @dataclass
1999 class D(C):
2000 x: int = 20
2001 self.assertEqual(repr(D()), 'TestRepr.test_repr.<locals>.D(x=20, y=10)')
2002
2003 @dataclass
2004 class C:
2005 @dataclass
2006 class D:
2007 i: int
2008 @dataclass
2009 class E:
2010 pass
2011 self.assertEqual(repr(C.D(0)), 'TestRepr.test_repr.<locals>.C.D(i=0)')
2012 self.assertEqual(repr(C.E()), 'TestRepr.test_repr.<locals>.C.E()')
2013
2014 def test_no_repr(self):
2015 # Test a class with no __repr__ and repr=False.
2016 @dataclass(repr=False)
2017 class C:
2018 x: int
Miss Islington (bot)63533822018-07-23 14:25:11 -07002019 self.assertIn(f'{__name__}.TestRepr.test_no_repr.<locals>.C object at',
Eric V. Smithea8fc522018-01-27 19:07:40 -05002020 repr(C(3)))
2021
2022 # Test a class with a __repr__ and repr=False.
2023 @dataclass(repr=False)
2024 class C:
2025 x: int
2026 def __repr__(self):
2027 return 'C-class'
2028 self.assertEqual(repr(C(3)), 'C-class')
2029
2030 def test_overwriting_repr(self):
2031 # If the class has __repr__, use it no matter the value of
2032 # repr=.
2033
2034 @dataclass
2035 class C:
2036 x: int
2037 def __repr__(self):
2038 return 'x'
2039 self.assertEqual(repr(C(0)), 'x')
2040
2041 @dataclass(repr=True)
2042 class C:
2043 x: int
2044 def __repr__(self):
2045 return 'x'
2046 self.assertEqual(repr(C(0)), 'x')
2047
2048 @dataclass(repr=False)
2049 class C:
2050 x: int
2051 def __repr__(self):
2052 return 'x'
2053 self.assertEqual(repr(C(0)), 'x')
2054
2055
Eric V. Smithea8fc522018-01-27 19:07:40 -05002056class TestEq(unittest.TestCase):
2057 def test_no_eq(self):
2058 # Test a class with no __eq__ and eq=False.
2059 @dataclass(eq=False)
2060 class C:
2061 x: int
2062 self.assertNotEqual(C(0), C(0))
2063 c = C(3)
2064 self.assertEqual(c, c)
2065
2066 # Test a class with an __eq__ and eq=False.
2067 @dataclass(eq=False)
2068 class C:
2069 x: int
2070 def __eq__(self, other):
2071 return other == 10
2072 self.assertEqual(C(3), 10)
2073
2074 def test_overwriting_eq(self):
2075 # If the class has __eq__, use it no matter the value of
2076 # eq=.
2077
2078 @dataclass
2079 class C:
2080 x: int
2081 def __eq__(self, other):
2082 return other == 3
2083 self.assertEqual(C(1), 3)
2084 self.assertNotEqual(C(1), 1)
2085
2086 @dataclass(eq=True)
2087 class C:
2088 x: int
2089 def __eq__(self, other):
2090 return other == 4
2091 self.assertEqual(C(1), 4)
2092 self.assertNotEqual(C(1), 1)
2093
2094 @dataclass(eq=False)
2095 class C:
2096 x: int
2097 def __eq__(self, other):
2098 return other == 5
2099 self.assertEqual(C(1), 5)
2100 self.assertNotEqual(C(1), 1)
2101
2102
2103class TestOrdering(unittest.TestCase):
2104 def test_functools_total_ordering(self):
2105 # Test that functools.total_ordering works with this class.
2106 @total_ordering
2107 @dataclass
2108 class C:
2109 x: int
2110 def __lt__(self, other):
2111 # Perform the test "backward", just to make
2112 # sure this is being called.
2113 return self.x >= other
2114
2115 self.assertLess(C(0), -1)
2116 self.assertLessEqual(C(0), -1)
2117 self.assertGreater(C(0), 1)
2118 self.assertGreaterEqual(C(0), 1)
2119
2120 def test_no_order(self):
2121 # Test that no ordering functions are added by default.
2122 @dataclass(order=False)
2123 class C:
2124 x: int
2125 # Make sure no order methods are added.
2126 self.assertNotIn('__le__', C.__dict__)
2127 self.assertNotIn('__lt__', C.__dict__)
2128 self.assertNotIn('__ge__', C.__dict__)
2129 self.assertNotIn('__gt__', C.__dict__)
2130
2131 # Test that __lt__ is still called
2132 @dataclass(order=False)
2133 class C:
2134 x: int
2135 def __lt__(self, other):
2136 return False
2137 # Make sure other methods aren't added.
2138 self.assertNotIn('__le__', C.__dict__)
2139 self.assertNotIn('__ge__', C.__dict__)
2140 self.assertNotIn('__gt__', C.__dict__)
2141
2142 def test_overwriting_order(self):
2143 with self.assertRaisesRegex(TypeError,
2144 'Cannot overwrite attribute __lt__'
2145 '.*using functools.total_ordering'):
2146 @dataclass(order=True)
2147 class C:
2148 x: int
2149 def __lt__(self):
2150 pass
2151
2152 with self.assertRaisesRegex(TypeError,
2153 'Cannot overwrite attribute __le__'
2154 '.*using functools.total_ordering'):
2155 @dataclass(order=True)
2156 class C:
2157 x: int
2158 def __le__(self):
2159 pass
2160
2161 with self.assertRaisesRegex(TypeError,
2162 'Cannot overwrite attribute __gt__'
2163 '.*using functools.total_ordering'):
2164 @dataclass(order=True)
2165 class C:
2166 x: int
2167 def __gt__(self):
2168 pass
2169
2170 with self.assertRaisesRegex(TypeError,
2171 'Cannot overwrite attribute __ge__'
2172 '.*using functools.total_ordering'):
2173 @dataclass(order=True)
2174 class C:
2175 x: int
2176 def __ge__(self):
2177 pass
2178
2179class TestHash(unittest.TestCase):
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002180 def test_unsafe_hash(self):
2181 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002182 class C:
2183 x: int
2184 y: str
2185 self.assertEqual(hash(C(1, 'foo')), hash((1, 'foo')))
2186
Eric V. Smithea8fc522018-01-27 19:07:40 -05002187 def test_hash_rules(self):
2188 def non_bool(value):
2189 # Map to something else that's True, but not a bool.
2190 if value is None:
2191 return None
2192 if value:
2193 return (3,)
2194 return 0
2195
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002196 def test(case, unsafe_hash, eq, frozen, with_hash, result):
2197 with self.subTest(case=case, unsafe_hash=unsafe_hash, eq=eq,
2198 frozen=frozen):
2199 if result != 'exception':
2200 if with_hash:
2201 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2202 class C:
2203 def __hash__(self):
2204 return 0
2205 else:
2206 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2207 class C:
2208 pass
Eric V. Smithea8fc522018-01-27 19:07:40 -05002209
2210 # See if the result matches what's expected.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002211 if result == 'fn':
Eric V. Smithea8fc522018-01-27 19:07:40 -05002212 # __hash__ contains the function we generated.
2213 self.assertIn('__hash__', C.__dict__)
2214 self.assertIsNotNone(C.__dict__['__hash__'])
2215
Eric V. Smithea8fc522018-01-27 19:07:40 -05002216 elif result == '':
2217 # __hash__ is not present in our class.
2218 if not with_hash:
2219 self.assertNotIn('__hash__', C.__dict__)
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002220
Eric V. Smithea8fc522018-01-27 19:07:40 -05002221 elif result == 'none':
2222 # __hash__ is set to None.
2223 self.assertIn('__hash__', C.__dict__)
2224 self.assertIsNone(C.__dict__['__hash__'])
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002225
2226 elif result == 'exception':
2227 # Creating the class should cause an exception.
2228 # This only happens with with_hash==True.
2229 assert(with_hash)
2230 with self.assertRaisesRegex(TypeError, 'Cannot overwrite attribute __hash__'):
2231 @dataclass(unsafe_hash=unsafe_hash, eq=eq, frozen=frozen)
2232 class C:
2233 def __hash__(self):
2234 return 0
2235
Eric V. Smithea8fc522018-01-27 19:07:40 -05002236 else:
2237 assert False, f'unknown result {result!r}'
2238
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002239 # There are 8 cases of:
2240 # unsafe_hash=True/False
Eric V. Smithea8fc522018-01-27 19:07:40 -05002241 # eq=True/False
2242 # frozen=True/False
2243 # And for each of these, a different result if
2244 # __hash__ is defined or not.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002245 for case, (unsafe_hash, eq, frozen, res_no_defined_hash, res_defined_hash) in enumerate([
2246 (False, False, False, '', ''),
2247 (False, False, True, '', ''),
2248 (False, True, False, 'none', ''),
2249 (False, True, True, 'fn', ''),
2250 (True, False, False, 'fn', 'exception'),
2251 (True, False, True, 'fn', 'exception'),
2252 (True, True, False, 'fn', 'exception'),
2253 (True, True, True, 'fn', 'exception'),
2254 ], 1):
2255 test(case, unsafe_hash, eq, frozen, False, res_no_defined_hash)
2256 test(case, unsafe_hash, eq, frozen, True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002257
2258 # Test non-bool truth values, too. This is just to
2259 # make sure the data-driven table in the decorator
2260 # handles non-bool values.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002261 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), False, res_no_defined_hash)
2262 test(case, non_bool(unsafe_hash), non_bool(eq), non_bool(frozen), True, res_defined_hash)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002263
2264
2265 def test_eq_only(self):
2266 # If a class defines __eq__, __hash__ is automatically added
2267 # and set to None. This is normal Python behavior, not
2268 # related to dataclasses. Make sure we don't interfere with
2269 # that (see bpo=32546).
2270
2271 @dataclass
2272 class C:
2273 i: int
2274 def __eq__(self, other):
2275 return self.i == other.i
2276 self.assertEqual(C(1), C(1))
2277 self.assertNotEqual(C(1), C(4))
2278
2279 # And make sure things work in this case if we specify
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002280 # unsafe_hash=True.
2281 @dataclass(unsafe_hash=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002282 class C:
2283 i: int
2284 def __eq__(self, other):
2285 return self.i == other.i
2286 self.assertEqual(C(1), C(1.0))
2287 self.assertEqual(hash(C(1)), hash(C(1.0)))
2288
2289 # And check that the classes __eq__ is being used, despite
2290 # specifying eq=True.
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002291 @dataclass(unsafe_hash=True, eq=True)
Eric V. Smithea8fc522018-01-27 19:07:40 -05002292 class C:
2293 i: int
2294 def __eq__(self, other):
2295 return self.i == 3 and self.i == other.i
2296 self.assertEqual(C(3), C(3))
2297 self.assertNotEqual(C(1), C(1))
2298 self.assertEqual(hash(C(1)), hash(C(1.0)))
2299
Miss Islington (bot)4cffe2f2018-02-26 01:43:35 -08002300 def test_0_field_hash(self):
2301 @dataclass(frozen=True)
2302 class C:
2303 pass
2304 self.assertEqual(hash(C()), hash(()))
2305
2306 @dataclass(unsafe_hash=True)
2307 class C:
2308 pass
2309 self.assertEqual(hash(C()), hash(()))
2310
2311 def test_1_field_hash(self):
2312 @dataclass(frozen=True)
2313 class C:
2314 x: int
2315 self.assertEqual(hash(C(4)), hash((4,)))
2316 self.assertEqual(hash(C(42)), hash((42,)))
2317
2318 @dataclass(unsafe_hash=True)
2319 class C:
2320 x: int
2321 self.assertEqual(hash(C(4)), hash((4,)))
2322 self.assertEqual(hash(C(42)), hash((42,)))
2323
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002324 def test_hash_no_args(self):
2325 # Test dataclasses with no hash= argument. This exists to
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002326 # make sure that if the @dataclass parameter name is changed
2327 # or the non-default hashing behavior changes, the default
2328 # hashability keeps working the same way.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002329
2330 class Base:
2331 def __hash__(self):
2332 return 301
2333
2334 # If frozen or eq is None, then use the default value (do not
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002335 # specify any value in the decorator).
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002336 for frozen, eq, base, expected in [
2337 (None, None, object, 'unhashable'),
2338 (None, None, Base, 'unhashable'),
2339 (None, False, object, 'object'),
2340 (None, False, Base, 'base'),
2341 (None, True, object, 'unhashable'),
2342 (None, True, Base, 'unhashable'),
2343 (False, None, object, 'unhashable'),
2344 (False, None, Base, 'unhashable'),
2345 (False, False, object, 'object'),
2346 (False, False, Base, 'base'),
2347 (False, True, object, 'unhashable'),
2348 (False, True, Base, 'unhashable'),
2349 (True, None, object, 'tuple'),
2350 (True, None, Base, 'tuple'),
2351 (True, False, object, 'object'),
2352 (True, False, Base, 'base'),
2353 (True, True, object, 'tuple'),
2354 (True, True, Base, 'tuple'),
2355 ]:
2356
2357 with self.subTest(frozen=frozen, eq=eq, base=base, expected=expected):
2358 # First, create the class.
2359 if frozen is None and eq is None:
2360 @dataclass
2361 class C(base):
2362 i: int
2363 elif frozen is None:
2364 @dataclass(eq=eq)
2365 class C(base):
2366 i: int
2367 elif eq is None:
2368 @dataclass(frozen=frozen)
2369 class C(base):
2370 i: int
2371 else:
2372 @dataclass(frozen=frozen, eq=eq)
2373 class C(base):
2374 i: int
2375
2376 # Now, make sure it hashes as expected.
2377 if expected == 'unhashable':
2378 c = C(10)
2379 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2380 hash(c)
2381
2382 elif expected == 'base':
2383 self.assertEqual(hash(C(10)), 301)
2384
2385 elif expected == 'object':
2386 # I'm not sure what test to use here. object's
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002387 # hash isn't based on id(), so calling hash()
2388 # won't tell us much. So, just check the
2389 # function used is object's.
Miss Islington (bot)b6b66692018-02-25 08:56:30 -08002390 self.assertIs(C.__hash__, object.__hash__)
2391
2392 elif expected == 'tuple':
2393 self.assertEqual(hash(C(42)), hash((42,)))
2394
2395 else:
2396 assert False, f'unknown value for expected={expected!r}'
2397
Eric V. Smithea8fc522018-01-27 19:07:40 -05002398
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002399class TestFrozen(unittest.TestCase):
2400 def test_frozen(self):
2401 @dataclass(frozen=True)
2402 class C:
2403 i: int
2404
2405 c = C(10)
2406 self.assertEqual(c.i, 10)
2407 with self.assertRaises(FrozenInstanceError):
2408 c.i = 5
2409 self.assertEqual(c.i, 10)
2410
2411 def test_inherit(self):
2412 @dataclass(frozen=True)
2413 class C:
2414 i: int
2415
2416 @dataclass(frozen=True)
2417 class D(C):
2418 j: int
2419
2420 d = D(0, 10)
2421 with self.assertRaises(FrozenInstanceError):
2422 d.i = 5
Miss Islington (bot)45648312018-03-18 18:03:36 -07002423 with self.assertRaises(FrozenInstanceError):
2424 d.j = 6
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002425 self.assertEqual(d.i, 0)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002426 self.assertEqual(d.j, 10)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002427
Miss Islington (bot)45648312018-03-18 18:03:36 -07002428 # Test both ways: with an intermediate normal (non-dataclass)
2429 # class and without an intermediate class.
2430 def test_inherit_nonfrozen_from_frozen(self):
2431 for intermediate_class in [True, False]:
2432 with self.subTest(intermediate_class=intermediate_class):
2433 @dataclass(frozen=True)
2434 class C:
2435 i: int
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002436
Miss Islington (bot)45648312018-03-18 18:03:36 -07002437 if intermediate_class:
2438 class I(C): pass
2439 else:
2440 I = C
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002441
Miss Islington (bot)45648312018-03-18 18:03:36 -07002442 with self.assertRaisesRegex(TypeError,
2443 'cannot inherit non-frozen dataclass from a frozen one'):
2444 @dataclass
2445 class D(I):
2446 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002447
Miss Islington (bot)45648312018-03-18 18:03:36 -07002448 def test_inherit_frozen_from_nonfrozen(self):
2449 for intermediate_class in [True, False]:
2450 with self.subTest(intermediate_class=intermediate_class):
2451 @dataclass
2452 class C:
2453 i: int
2454
2455 if intermediate_class:
2456 class I(C): pass
2457 else:
2458 I = C
2459
2460 with self.assertRaisesRegex(TypeError,
2461 'cannot inherit frozen dataclass from a non-frozen one'):
2462 @dataclass(frozen=True)
2463 class D(I):
2464 pass
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002465
2466 def test_inherit_from_normal_class(self):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002467 for intermediate_class in [True, False]:
2468 with self.subTest(intermediate_class=intermediate_class):
2469 class C:
2470 pass
2471
2472 if intermediate_class:
2473 class I(C): pass
2474 else:
2475 I = C
2476
2477 @dataclass(frozen=True)
2478 class D(I):
2479 i: int
2480
2481 d = D(10)
2482 with self.assertRaises(FrozenInstanceError):
2483 d.i = 5
2484
2485 def test_non_frozen_normal_derived(self):
2486 # See bpo-32953.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002487
2488 @dataclass(frozen=True)
Miss Islington (bot)45648312018-03-18 18:03:36 -07002489 class D:
2490 x: int
2491 y: int = 10
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002492
Miss Islington (bot)45648312018-03-18 18:03:36 -07002493 class S(D):
2494 pass
2495
2496 s = S(3)
2497 self.assertEqual(s.x, 3)
2498 self.assertEqual(s.y, 10)
2499 s.cached = True
2500
2501 # But can't change the frozen attributes.
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002502 with self.assertRaises(FrozenInstanceError):
Miss Islington (bot)45648312018-03-18 18:03:36 -07002503 s.x = 5
2504 with self.assertRaises(FrozenInstanceError):
2505 s.y = 5
2506 self.assertEqual(s.x, 3)
2507 self.assertEqual(s.y, 10)
2508 self.assertEqual(s.cached, True)
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002509
Miss Islington (bot)83f564f2018-04-05 04:12:31 -07002510 def test_overwriting_frozen(self):
2511 # frozen uses __setattr__ and __delattr__.
2512 with self.assertRaisesRegex(TypeError,
2513 'Cannot overwrite attribute __setattr__'):
2514 @dataclass(frozen=True)
2515 class C:
2516 x: int
2517 def __setattr__(self):
2518 pass
2519
2520 with self.assertRaisesRegex(TypeError,
2521 'Cannot overwrite attribute __delattr__'):
2522 @dataclass(frozen=True)
2523 class C:
2524 x: int
2525 def __delattr__(self):
2526 pass
2527
2528 @dataclass(frozen=False)
2529 class C:
2530 x: int
2531 def __setattr__(self, name, value):
2532 self.__dict__['x'] = value * 2
2533 self.assertEqual(C(10).x, 20)
2534
2535 def test_frozen_hash(self):
2536 @dataclass(frozen=True)
2537 class C:
2538 x: Any
2539
2540 # If x is immutable, we can compute the hash. No exception is
2541 # raised.
2542 hash(C(3))
2543
2544 # If x is mutable, computing the hash is an error.
2545 with self.assertRaisesRegex(TypeError, 'unhashable type'):
2546 hash(C({}))
2547
Miss Islington (bot)a93e3dc2018-02-26 17:59:55 -08002548
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002549class TestSlots(unittest.TestCase):
2550 def test_simple(self):
2551 @dataclass
2552 class C:
2553 __slots__ = ('x',)
2554 x: Any
2555
Miss Islington (bot)5fc6fc82018-03-25 18:00:43 -07002556 # There was a bug where a variable in a slot was assumed to
2557 # also have a default value (of type
2558 # types.MemberDescriptorType).
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002559 with self.assertRaisesRegex(TypeError,
Miss Islington (bot)5729b9c2018-03-24 20:23:00 -07002560 r"__init__\(\) missing 1 required positional argument: 'x'"):
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002561 C()
2562
2563 # We can create an instance, and assign to x.
2564 c = C(10)
2565 self.assertEqual(c.x, 10)
2566 c.x = 5
2567 self.assertEqual(c.x, 5)
2568
2569 # We can't assign to anything else.
2570 with self.assertRaisesRegex(AttributeError, "'C' object has no attribute 'y'"):
2571 c.y = 5
2572
2573 def test_derived_added_field(self):
2574 # See bpo-33100.
2575 @dataclass
2576 class Base:
2577 __slots__ = ('x',)
2578 x: Any
2579
2580 @dataclass
2581 class Derived(Base):
2582 x: int
2583 y: int
2584
2585 d = Derived(1, 2)
2586 self.assertEqual((d.x, d.y), (1, 2))
2587
2588 # We can add a new field to the derived instance.
2589 d.z = 10
2590
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002591class TestDescriptors(unittest.TestCase):
2592 def test_set_name(self):
2593 # See bpo-33141.
2594
2595 # Create a descriptor.
2596 class D:
2597 def __set_name__(self, owner, name):
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002598 self.name = name + 'x'
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002599 def __get__(self, instance, owner):
2600 if instance is not None:
2601 return 1
2602 return self
2603
2604 # This is the case of just normal descriptor behavior, no
2605 # dataclass code is involved in initializing the descriptor.
2606 @dataclass
2607 class C:
2608 c: int=D()
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002609 self.assertEqual(C.c.name, 'cx')
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002610
2611 # Now test with a default value and init=False, which is the
2612 # only time this is really meaningful. If not using
2613 # init=False, then the descriptor will be overwritten, anyway.
2614 @dataclass
2615 class C:
2616 c: int=field(default=D(), init=False)
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002617 self.assertEqual(C.c.name, 'cx')
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002618 self.assertEqual(C().c, 1)
2619
2620 def test_non_descriptor(self):
2621 # PEP 487 says __set_name__ should work on non-descriptors.
2622 # Create a descriptor.
2623
2624 class D:
2625 def __set_name__(self, owner, name):
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002626 self.name = name + 'x'
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002627
2628 @dataclass
2629 class C:
2630 c: int=field(default=D(), init=False)
Miss Islington (bot)faa6f5c2018-03-29 08:32:36 -07002631 self.assertEqual(C.c.name, 'cx')
2632
2633 def test_lookup_on_instance(self):
2634 # See bpo-33175.
2635 class D:
2636 pass
2637
2638 d = D()
2639 # Create an attribute on the instance, not type.
2640 d.__set_name__ = Mock()
2641
2642 # Make sure d.__set_name__ is not called.
2643 @dataclass
2644 class C:
2645 i: int=field(default=d, init=False)
2646
2647 self.assertEqual(d.__set_name__.call_count, 0)
2648
2649 def test_lookup_on_class(self):
2650 # See bpo-33175.
2651 class D:
2652 pass
2653 D.__set_name__ = Mock()
2654
2655 # Make sure D.__set_name__ is called.
2656 @dataclass
2657 class C:
2658 i: int=field(default=D(), init=False)
2659
2660 self.assertEqual(D.__set_name__.call_count, 1)
Miss Islington (bot)c6147ac2018-03-26 10:55:13 -07002661
Miss Islington (bot)3d41f482018-03-19 18:31:22 -07002662
Miss Islington (bot)c73268a2018-05-15 21:22:13 -07002663class TestStringAnnotations(unittest.TestCase):
2664 def test_classvar(self):
2665 # Some expressions recognized as ClassVar really aren't. But
2666 # if you're using string annotations, it's not an exact
2667 # science.
2668 # These tests assume that both "import typing" and "from
2669 # typing import *" have been run in this file.
2670 for typestr in ('ClassVar[int]',
2671 'ClassVar [int]'
2672 ' ClassVar [int]',
2673 'ClassVar',
2674 ' ClassVar ',
2675 'typing.ClassVar[int]',
2676 'typing.ClassVar[str]',
2677 ' typing.ClassVar[str]',
2678 'typing .ClassVar[str]',
2679 'typing. ClassVar[str]',
2680 'typing.ClassVar [str]',
2681 'typing.ClassVar [ str]',
2682
2683 # Not syntactically valid, but these will
2684 # be treated as ClassVars.
2685 'typing.ClassVar.[int]',
2686 'typing.ClassVar+',
2687 ):
2688 with self.subTest(typestr=typestr):
2689 @dataclass
2690 class C:
2691 x: typestr
2692
2693 # x is a ClassVar, so C() takes no args.
2694 C()
2695
2696 # And it won't appear in the class's dict because it doesn't
2697 # have a default.
2698 self.assertNotIn('x', C.__dict__)
2699
2700 def test_isnt_classvar(self):
2701 for typestr in ('CV',
2702 't.ClassVar',
2703 't.ClassVar[int]',
2704 'typing..ClassVar[int]',
2705 'Classvar',
2706 'Classvar[int]',
2707 'typing.ClassVarx[int]',
2708 'typong.ClassVar[int]',
2709 'dataclasses.ClassVar[int]',
2710 'typingxClassVar[str]',
2711 ):
2712 with self.subTest(typestr=typestr):
2713 @dataclass
2714 class C:
2715 x: typestr
2716
2717 # x is not a ClassVar, so C() takes one arg.
2718 self.assertEqual(C(10).x, 10)
2719
2720 def test_initvar(self):
2721 # These tests assume that both "import dataclasses" and "from
2722 # dataclasses import *" have been run in this file.
2723 for typestr in ('InitVar[int]',
2724 'InitVar [int]'
2725 ' InitVar [int]',
2726 'InitVar',
2727 ' InitVar ',
2728 'dataclasses.InitVar[int]',
2729 'dataclasses.InitVar[str]',
2730 ' dataclasses.InitVar[str]',
2731 'dataclasses .InitVar[str]',
2732 'dataclasses. InitVar[str]',
2733 'dataclasses.InitVar [str]',
2734 'dataclasses.InitVar [ str]',
2735
2736 # Not syntactically valid, but these will
2737 # be treated as InitVars.
2738 'dataclasses.InitVar.[int]',
2739 'dataclasses.InitVar+',
2740 ):
2741 with self.subTest(typestr=typestr):
2742 @dataclass
2743 class C:
2744 x: typestr
2745
2746 # x is an InitVar, so doesn't create a member.
2747 with self.assertRaisesRegex(AttributeError,
2748 "object has no attribute 'x'"):
2749 C(1).x
2750
2751 def test_isnt_initvar(self):
2752 for typestr in ('IV',
2753 'dc.InitVar',
2754 'xdataclasses.xInitVar',
2755 'typing.xInitVar[int]',
2756 ):
2757 with self.subTest(typestr=typestr):
2758 @dataclass
2759 class C:
2760 x: typestr
2761
2762 # x is not an InitVar, so there will be a member x.
2763 self.assertEqual(C(10).x, 10)
2764
2765 def test_classvar_module_level_import(self):
Miss Islington (bot)63533822018-07-23 14:25:11 -07002766 from test import dataclass_module_1
2767 from test import dataclass_module_1_str
2768 from test import dataclass_module_2
2769 from test import dataclass_module_2_str
Miss Islington (bot)c73268a2018-05-15 21:22:13 -07002770
2771 for m in (dataclass_module_1, dataclass_module_1_str,
2772 dataclass_module_2, dataclass_module_2_str,
2773 ):
2774 with self.subTest(m=m):
2775 # There's a difference in how the ClassVars are
2776 # interpreted when using string annotations or
2777 # not. See the imported modules for details.
2778 if m.USING_STRINGS:
2779 c = m.CV(10)
2780 else:
2781 c = m.CV()
2782 self.assertEqual(c.cv0, 20)
2783
2784
2785 # There's a difference in how the InitVars are
2786 # interpreted when using string annotations or
2787 # not. See the imported modules for details.
2788 c = m.IV(0, 1, 2, 3, 4)
2789
2790 for field_name in ('iv0', 'iv1', 'iv2', 'iv3'):
2791 with self.subTest(field_name=field_name):
2792 with self.assertRaisesRegex(AttributeError, f"object has no attribute '{field_name}'"):
2793 # Since field_name is an InitVar, it's
2794 # not an instance field.
2795 getattr(c, field_name)
2796
2797 if m.USING_STRINGS:
2798 # iv4 is interpreted as a normal field.
2799 self.assertIn('not_iv4', c.__dict__)
2800 self.assertEqual(c.not_iv4, 4)
2801 else:
2802 # iv4 is interpreted as an InitVar, so it
2803 # won't exist on the instance.
2804 self.assertNotIn('not_iv4', c.__dict__)
2805
2806
Miss Islington (bot)6409e752018-05-16 09:28:22 -07002807class TestMakeDataclass(unittest.TestCase):
2808 def test_simple(self):
2809 C = make_dataclass('C',
2810 [('x', int),
2811 ('y', int, field(default=5))],
2812 namespace={'add_one': lambda self: self.x + 1})
2813 c = C(10)
2814 self.assertEqual((c.x, c.y), (10, 5))
2815 self.assertEqual(c.add_one(), 11)
2816
2817
2818 def test_no_mutate_namespace(self):
2819 # Make sure a provided namespace isn't mutated.
2820 ns = {}
2821 C = make_dataclass('C',
2822 [('x', int),
2823 ('y', int, field(default=5))],
2824 namespace=ns)
2825 self.assertEqual(ns, {})
2826
2827 def test_base(self):
2828 class Base1:
2829 pass
2830 class Base2:
2831 pass
2832 C = make_dataclass('C',
2833 [('x', int)],
2834 bases=(Base1, Base2))
2835 c = C(2)
2836 self.assertIsInstance(c, C)
2837 self.assertIsInstance(c, Base1)
2838 self.assertIsInstance(c, Base2)
2839
2840 def test_base_dataclass(self):
2841 @dataclass
2842 class Base1:
2843 x: int
2844 class Base2:
2845 pass
2846 C = make_dataclass('C',
2847 [('y', int)],
2848 bases=(Base1, Base2))
2849 with self.assertRaisesRegex(TypeError, 'required positional'):
2850 c = C(2)
2851 c = C(1, 2)
2852 self.assertIsInstance(c, C)
2853 self.assertIsInstance(c, Base1)
2854 self.assertIsInstance(c, Base2)
2855
2856 self.assertEqual((c.x, c.y), (1, 2))
2857
2858 def test_init_var(self):
2859 def post_init(self, y):
2860 self.x *= y
2861
2862 C = make_dataclass('C',
2863 [('x', int),
2864 ('y', InitVar[int]),
2865 ],
2866 namespace={'__post_init__': post_init},
2867 )
2868 c = C(2, 3)
2869 self.assertEqual(vars(c), {'x': 6})
2870 self.assertEqual(len(fields(c)), 1)
2871
2872 def test_class_var(self):
2873 C = make_dataclass('C',
2874 [('x', int),
2875 ('y', ClassVar[int], 10),
2876 ('z', ClassVar[int], field(default=20)),
2877 ])
2878 c = C(1)
2879 self.assertEqual(vars(c), {'x': 1})
2880 self.assertEqual(len(fields(c)), 1)
2881 self.assertEqual(C.y, 10)
2882 self.assertEqual(C.z, 20)
2883
2884 def test_other_params(self):
2885 C = make_dataclass('C',
2886 [('x', int),
2887 ('y', ClassVar[int], 10),
2888 ('z', ClassVar[int], field(default=20)),
2889 ],
2890 init=False)
2891 # Make sure we have a repr, but no init.
2892 self.assertNotIn('__init__', vars(C))
2893 self.assertIn('__repr__', vars(C))
2894
2895 # Make sure random other params don't work.
2896 with self.assertRaisesRegex(TypeError, 'unexpected keyword argument'):
2897 C = make_dataclass('C',
2898 [],
2899 xxinit=False)
2900
2901 def test_no_types(self):
2902 C = make_dataclass('Point', ['x', 'y', 'z'])
2903 c = C(1, 2, 3)
2904 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2905 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2906 'y': 'typing.Any',
2907 'z': 'typing.Any'})
2908
2909 C = make_dataclass('Point', ['x', ('y', int), 'z'])
2910 c = C(1, 2, 3)
2911 self.assertEqual(vars(c), {'x': 1, 'y': 2, 'z': 3})
2912 self.assertEqual(C.__annotations__, {'x': 'typing.Any',
2913 'y': int,
2914 'z': 'typing.Any'})
2915
2916 def test_invalid_type_specification(self):
2917 for bad_field in [(),
2918 (1, 2, 3, 4),
2919 ]:
2920 with self.subTest(bad_field=bad_field):
2921 with self.assertRaisesRegex(TypeError, r'Invalid field: '):
2922 make_dataclass('C', ['a', bad_field])
2923
2924 # And test for things with no len().
2925 for bad_field in [float,
2926 lambda x:x,
2927 ]:
2928 with self.subTest(bad_field=bad_field):
2929 with self.assertRaisesRegex(TypeError, r'has no len\(\)'):
2930 make_dataclass('C', ['a', bad_field])
2931
2932 def test_duplicate_field_names(self):
2933 for field in ['a', 'ab']:
2934 with self.subTest(field=field):
2935 with self.assertRaisesRegex(TypeError, 'Field name duplicated'):
2936 make_dataclass('C', [field, 'a', field])
2937
2938 def test_keyword_field_names(self):
2939 for field in ['for', 'async', 'await', 'as']:
2940 with self.subTest(field=field):
2941 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2942 make_dataclass('C', ['a', field])
2943 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2944 make_dataclass('C', [field])
2945 with self.assertRaisesRegex(TypeError, 'must not be keywords'):
2946 make_dataclass('C', [field, 'a'])
2947
2948 def test_non_identifier_field_names(self):
2949 for field in ['()', 'x,y', '*', '2@3', '', 'little johnny tables']:
2950 with self.subTest(field=field):
2951 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2952 make_dataclass('C', ['a', field])
2953 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2954 make_dataclass('C', [field])
2955 with self.assertRaisesRegex(TypeError, 'must be valid identifers'):
2956 make_dataclass('C', [field, 'a'])
2957
2958 def test_underscore_field_names(self):
2959 # Unlike namedtuple, it's okay if dataclass field names have
2960 # an underscore.
2961 make_dataclass('C', ['_', '_a', 'a_a', 'a_'])
2962
2963 def test_funny_class_names_names(self):
2964 # No reason to prevent weird class names, since
2965 # types.new_class allows them.
2966 for classname in ['()', 'x,y', '*', '2@3', '']:
2967 with self.subTest(classname=classname):
2968 C = make_dataclass(classname, ['a', 'b'])
2969 self.assertEqual(C.__name__, classname)
2970
Miss Islington (bot)0aee3be2018-06-07 13:15:23 -07002971class TestReplace(unittest.TestCase):
2972 def test(self):
2973 @dataclass(frozen=True)
2974 class C:
2975 x: int
2976 y: int
2977
2978 c = C(1, 2)
2979 c1 = replace(c, x=3)
2980 self.assertEqual(c1.x, 3)
2981 self.assertEqual(c1.y, 2)
2982
2983 def test_frozen(self):
2984 @dataclass(frozen=True)
2985 class C:
2986 x: int
2987 y: int
2988 z: int = field(init=False, default=10)
2989 t: int = field(init=False, default=100)
2990
2991 c = C(1, 2)
2992 c1 = replace(c, x=3)
2993 self.assertEqual((c.x, c.y, c.z, c.t), (1, 2, 10, 100))
2994 self.assertEqual((c1.x, c1.y, c1.z, c1.t), (3, 2, 10, 100))
2995
2996
2997 with self.assertRaisesRegex(ValueError, 'init=False'):
2998 replace(c, x=3, z=20, t=50)
2999 with self.assertRaisesRegex(ValueError, 'init=False'):
3000 replace(c, z=20)
3001 replace(c, x=3, z=20, t=50)
3002
3003 # Make sure the result is still frozen.
3004 with self.assertRaisesRegex(FrozenInstanceError, "cannot assign to field 'x'"):
3005 c1.x = 3
3006
3007 # Make sure we can't replace an attribute that doesn't exist,
3008 # if we're also replacing one that does exist. Test this
3009 # here, because setting attributes on frozen instances is
3010 # handled slightly differently from non-frozen ones.
3011 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3012 "keyword argument 'a'"):
3013 c1 = replace(c, x=20, a=5)
3014
3015 def test_invalid_field_name(self):
3016 @dataclass(frozen=True)
3017 class C:
3018 x: int
3019 y: int
3020
3021 c = C(1, 2)
3022 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an unexpected "
3023 "keyword argument 'z'"):
3024 c1 = replace(c, z=3)
3025
3026 def test_invalid_object(self):
3027 @dataclass(frozen=True)
3028 class C:
3029 x: int
3030 y: int
3031
3032 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3033 replace(C, x=3)
3034
3035 with self.assertRaisesRegex(TypeError, 'dataclass instance'):
3036 replace(0, x=3)
3037
3038 def test_no_init(self):
3039 @dataclass
3040 class C:
3041 x: int
3042 y: int = field(init=False, default=10)
3043
3044 c = C(1)
3045 c.y = 20
3046
3047 # Make sure y gets the default value.
3048 c1 = replace(c, x=5)
3049 self.assertEqual((c1.x, c1.y), (5, 10))
3050
3051 # Trying to replace y is an error.
3052 with self.assertRaisesRegex(ValueError, 'init=False'):
3053 replace(c, x=2, y=30)
3054
3055 with self.assertRaisesRegex(ValueError, 'init=False'):
3056 replace(c, y=30)
3057
3058 def test_classvar(self):
3059 @dataclass
3060 class C:
3061 x: int
3062 y: ClassVar[int] = 1000
3063
3064 c = C(1)
3065 d = C(2)
3066
3067 self.assertIs(c.y, d.y)
3068 self.assertEqual(c.y, 1000)
3069
3070 # Trying to replace y is an error: can't replace ClassVars.
3071 with self.assertRaisesRegex(TypeError, r"__init__\(\) got an "
3072 "unexpected keyword argument 'y'"):
3073 replace(c, y=30)
3074
3075 replace(c, x=5)
3076
Miss Islington (bot)bbef7ab2018-06-23 08:04:01 -07003077 def test_initvar_is_specified(self):
3078 @dataclass
3079 class C:
3080 x: int
3081 y: InitVar[int]
3082
3083 def __post_init__(self, y):
3084 self.x *= y
3085
3086 c = C(1, 10)
3087 self.assertEqual(c.x, 10)
3088 with self.assertRaisesRegex(ValueError, r"InitVar 'y' must be "
3089 "specified with replace()"):
3090 replace(c, x=3)
3091 c = replace(c, x=3, y=5)
3092 self.assertEqual(c.x, 15)
Miss Islington (bot)0aee3be2018-06-07 13:15:23 -07003093 ## def test_initvar(self):
3094 ## @dataclass
3095 ## class C:
3096 ## x: int
3097 ## y: InitVar[int]
3098
3099 ## c = C(1, 10)
3100 ## d = C(2, 20)
3101
3102 ## # In our case, replacing an InitVar is a no-op
3103 ## self.assertEqual(c, replace(c, y=5))
3104
3105 ## replace(c, x=5)
3106
Miss Islington (bot)6409e752018-05-16 09:28:22 -07003107
Eric V. Smithf0db54a2017-12-04 16:58:55 -05003108if __name__ == '__main__':
3109 unittest.main()