blob: efd7319a23ae0268b2d037aa76174adc343474bc [file] [log] [blame]
Yury Selivanovf23746a2018-01-22 19:11:18 -05001import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9
10try:
11 from _testcapi import hamt
12except ImportError:
13 hamt = None
14
15
16def isolated_context(func):
17 """Needed to make reftracking test mode work."""
18 @functools.wraps(func)
19 def wrapper(*args, **kwargs):
20 ctx = contextvars.Context()
21 return ctx.run(func, *args, **kwargs)
22 return wrapper
23
24
25class ContextTest(unittest.TestCase):
26 def test_context_var_new_1(self):
27 with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
28 contextvars.ContextVar()
29
30 with self.assertRaisesRegex(TypeError, 'must be a str'):
31 contextvars.ContextVar(1)
32
Yury Selivanov41cb0ba2018-06-28 13:20:29 -040033 c = contextvars.ContextVar('aaa')
34 self.assertEqual(c.name, 'aaa')
35
36 with self.assertRaises(AttributeError):
37 c.name = 'bbb'
38
39 self.assertNotEqual(hash(c), hash('aaa'))
Yury Selivanovf23746a2018-01-22 19:11:18 -050040
41 def test_context_var_new_2(self):
42 self.assertIsNone(contextvars.ContextVar[int])
43
44 @isolated_context
45 def test_context_var_repr_1(self):
46 c = contextvars.ContextVar('a')
47 self.assertIn('a', repr(c))
48
49 c = contextvars.ContextVar('a', default=123)
50 self.assertIn('123', repr(c))
51
52 lst = []
53 c = contextvars.ContextVar('a', default=lst)
54 lst.append(c)
55 self.assertIn('...', repr(c))
56 self.assertIn('...', repr(lst))
57
58 t = c.set(1)
59 self.assertIn(repr(c), repr(t))
60 self.assertNotIn(' used ', repr(t))
61 c.reset(t)
62 self.assertIn(' used ', repr(t))
63
64 def test_context_subclassing_1(self):
65 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
66 class MyContextVar(contextvars.ContextVar):
67 # Potentially we might want ContextVars to be subclassable.
68 pass
69
70 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
71 class MyContext(contextvars.Context):
72 pass
73
74 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
75 class MyToken(contextvars.Token):
76 pass
77
78 def test_context_new_1(self):
79 with self.assertRaisesRegex(TypeError, 'any arguments'):
80 contextvars.Context(1)
81 with self.assertRaisesRegex(TypeError, 'any arguments'):
82 contextvars.Context(1, a=1)
83 with self.assertRaisesRegex(TypeError, 'any arguments'):
84 contextvars.Context(a=1)
85 contextvars.Context(**{})
86
87 def test_context_typerrors_1(self):
88 ctx = contextvars.Context()
89
90 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
91 ctx[1]
92 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
93 1 in ctx
94 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
95 ctx.get(1)
96
97 def test_context_get_context_1(self):
98 ctx = contextvars.copy_context()
99 self.assertIsInstance(ctx, contextvars.Context)
100
101 def test_context_run_1(self):
102 ctx = contextvars.Context()
103
104 with self.assertRaisesRegex(TypeError, 'missing 1 required'):
105 ctx.run()
106
107 def test_context_run_2(self):
108 ctx = contextvars.Context()
109
110 def func(*args, **kwargs):
111 kwargs['spam'] = 'foo'
112 args += ('bar',)
113 return args, kwargs
114
115 for f in (func, functools.partial(func)):
116 # partial doesn't support FASTCALL
117
118 self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
119 self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
120
121 self.assertEqual(
122 ctx.run(f, a=2),
123 (('bar',), {'a': 2, 'spam': 'foo'}))
124
125 self.assertEqual(
126 ctx.run(f, 11, a=2),
127 ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
128
129 a = {}
130 self.assertEqual(
131 ctx.run(f, 11, **a),
132 ((11, 'bar'), {'spam': 'foo'}))
133 self.assertEqual(a, {})
134
135 def test_context_run_3(self):
136 ctx = contextvars.Context()
137
138 def func(*args, **kwargs):
139 1 / 0
140
141 with self.assertRaises(ZeroDivisionError):
142 ctx.run(func)
143 with self.assertRaises(ZeroDivisionError):
144 ctx.run(func, 1, 2)
145 with self.assertRaises(ZeroDivisionError):
146 ctx.run(func, 1, 2, a=123)
147
148 @isolated_context
149 def test_context_run_4(self):
150 ctx1 = contextvars.Context()
151 ctx2 = contextvars.Context()
152 var = contextvars.ContextVar('var')
153
154 def func2():
155 self.assertIsNone(var.get(None))
156
157 def func1():
158 self.assertIsNone(var.get(None))
159 var.set('spam')
160 ctx2.run(func2)
161 self.assertEqual(var.get(None), 'spam')
162
163 cur = contextvars.copy_context()
164 self.assertEqual(len(cur), 1)
165 self.assertEqual(cur[var], 'spam')
166 return cur
167
168 returned_ctx = ctx1.run(func1)
169 self.assertEqual(ctx1, returned_ctx)
170 self.assertEqual(returned_ctx[var], 'spam')
171 self.assertIn(var, returned_ctx)
172
173 def test_context_run_5(self):
174 ctx = contextvars.Context()
175 var = contextvars.ContextVar('var')
176
177 def func():
178 self.assertIsNone(var.get(None))
179 var.set('spam')
180 1 / 0
181
182 with self.assertRaises(ZeroDivisionError):
183 ctx.run(func)
184
185 self.assertIsNone(var.get(None))
186
187 def test_context_run_6(self):
188 ctx = contextvars.Context()
189 c = contextvars.ContextVar('a', default=0)
190
191 def fun():
192 self.assertEqual(c.get(), 0)
193 self.assertIsNone(ctx.get(c))
194
195 c.set(42)
196 self.assertEqual(c.get(), 42)
197 self.assertEqual(ctx.get(c), 42)
198
199 ctx.run(fun)
200
201 def test_context_run_7(self):
202 ctx = contextvars.Context()
203
204 def fun():
205 with self.assertRaisesRegex(RuntimeError, 'is already entered'):
206 ctx.run(fun)
207
208 ctx.run(fun)
209
210 @isolated_context
211 def test_context_getset_1(self):
212 c = contextvars.ContextVar('c')
213 with self.assertRaises(LookupError):
214 c.get()
215
216 self.assertIsNone(c.get(None))
217
218 t0 = c.set(42)
219 self.assertEqual(c.get(), 42)
220 self.assertEqual(c.get(None), 42)
221 self.assertIs(t0.old_value, t0.MISSING)
222 self.assertIs(t0.old_value, contextvars.Token.MISSING)
223 self.assertIs(t0.var, c)
224
225 t = c.set('spam')
226 self.assertEqual(c.get(), 'spam')
227 self.assertEqual(c.get(None), 'spam')
228 self.assertEqual(t.old_value, 42)
229 c.reset(t)
230
231 self.assertEqual(c.get(), 42)
232 self.assertEqual(c.get(None), 42)
233
234 c.set('spam2')
235 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
236 c.reset(t)
237 self.assertEqual(c.get(), 'spam2')
238
239 ctx1 = contextvars.copy_context()
240 self.assertIn(c, ctx1)
241
242 c.reset(t0)
243 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
244 c.reset(t0)
245 self.assertIsNone(c.get(None))
246
247 self.assertIn(c, ctx1)
248 self.assertEqual(ctx1[c], 'spam2')
249 self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
250 self.assertEqual(len(ctx1), 1)
251 self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
252 self.assertEqual(list(ctx1.values()), ['spam2'])
253 self.assertEqual(list(ctx1.keys()), [c])
254 self.assertEqual(list(ctx1), [c])
255
256 ctx2 = contextvars.copy_context()
257 self.assertNotIn(c, ctx2)
258 with self.assertRaises(KeyError):
259 ctx2[c]
260 self.assertEqual(ctx2.get(c, 'aa'), 'aa')
261 self.assertEqual(len(ctx2), 0)
262 self.assertEqual(list(ctx2), [])
263
264 @isolated_context
265 def test_context_getset_2(self):
266 v1 = contextvars.ContextVar('v1')
267 v2 = contextvars.ContextVar('v2')
268
269 t1 = v1.set(42)
270 with self.assertRaisesRegex(ValueError, 'by a different'):
271 v2.reset(t1)
272
273 @isolated_context
274 def test_context_getset_3(self):
275 c = contextvars.ContextVar('c', default=42)
276 ctx = contextvars.Context()
277
278 def fun():
279 self.assertEqual(c.get(), 42)
280 with self.assertRaises(KeyError):
281 ctx[c]
282 self.assertIsNone(ctx.get(c))
283 self.assertEqual(ctx.get(c, 'spam'), 'spam')
284 self.assertNotIn(c, ctx)
285 self.assertEqual(list(ctx.keys()), [])
286
287 t = c.set(1)
288 self.assertEqual(list(ctx.keys()), [c])
289 self.assertEqual(ctx[c], 1)
290
291 c.reset(t)
292 self.assertEqual(list(ctx.keys()), [])
293 with self.assertRaises(KeyError):
294 ctx[c]
295
296 ctx.run(fun)
297
298 @isolated_context
299 def test_context_getset_4(self):
300 c = contextvars.ContextVar('c', default=42)
301 ctx = contextvars.Context()
302
303 tok = ctx.run(c.set, 1)
304
305 with self.assertRaisesRegex(ValueError, 'different Context'):
306 c.reset(tok)
307
308 @isolated_context
309 def test_context_getset_5(self):
310 c = contextvars.ContextVar('c', default=42)
311 c.set([])
312
313 def fun():
314 c.set([])
315 c.get().append(42)
316 self.assertEqual(c.get(), [42])
317
318 contextvars.copy_context().run(fun)
319 self.assertEqual(c.get(), [])
320
321 def test_context_copy_1(self):
322 ctx1 = contextvars.Context()
323 c = contextvars.ContextVar('c', default=42)
324
325 def ctx1_fun():
326 c.set(10)
327
328 ctx2 = ctx1.copy()
329 self.assertEqual(ctx2[c], 10)
330
331 c.set(20)
332 self.assertEqual(ctx1[c], 20)
333 self.assertEqual(ctx2[c], 10)
334
335 ctx2.run(ctx2_fun)
336 self.assertEqual(ctx1[c], 20)
337 self.assertEqual(ctx2[c], 30)
338
339 def ctx2_fun():
340 self.assertEqual(c.get(), 10)
341 c.set(30)
342 self.assertEqual(c.get(), 30)
343
344 ctx1.run(ctx1_fun)
345
346 @isolated_context
347 def test_context_threads_1(self):
348 cvar = contextvars.ContextVar('cvar')
349
350 def sub(num):
351 for i in range(10):
352 cvar.set(num + i)
353 time.sleep(random.uniform(0.001, 0.05))
354 self.assertEqual(cvar.get(), num + i)
355 return num
356
357 tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
358 try:
359 results = list(tp.map(sub, range(10)))
360 finally:
361 tp.shutdown()
362 self.assertEqual(results, list(range(10)))
363
364
365# HAMT Tests
366
367
368class HashKey:
369 _crasher = None
370
371 def __init__(self, hash, name, *, error_on_eq_to=None):
372 assert hash != -1
373 self.name = name
374 self.hash = hash
375 self.error_on_eq_to = error_on_eq_to
376
377 def __repr__(self):
378 return f'<Key name:{self.name} hash:{self.hash}>'
379
380 def __hash__(self):
381 if self._crasher is not None and self._crasher.error_on_hash:
382 raise HashingError
383
384 return self.hash
385
386 def __eq__(self, other):
387 if not isinstance(other, HashKey):
388 return NotImplemented
389
390 if self._crasher is not None and self._crasher.error_on_eq:
391 raise EqError
392
393 if self.error_on_eq_to is not None and self.error_on_eq_to is other:
394 raise ValueError(f'cannot compare {self!r} to {other!r}')
395 if other.error_on_eq_to is not None and other.error_on_eq_to is self:
396 raise ValueError(f'cannot compare {other!r} to {self!r}')
397
398 return (self.name, self.hash) == (other.name, other.hash)
399
400
401class KeyStr(str):
402 def __hash__(self):
403 if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
404 raise HashingError
405 return super().__hash__()
406
407 def __eq__(self, other):
408 if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
409 raise EqError
410 return super().__eq__(other)
411
412
413class HaskKeyCrasher:
414 def __init__(self, *, error_on_hash=False, error_on_eq=False):
415 self.error_on_hash = error_on_hash
416 self.error_on_eq = error_on_eq
417
418 def __enter__(self):
419 if HashKey._crasher is not None:
420 raise RuntimeError('cannot nest crashers')
421 HashKey._crasher = self
422
423 def __exit__(self, *exc):
424 HashKey._crasher = None
425
426
427class HashingError(Exception):
428 pass
429
430
431class EqError(Exception):
432 pass
433
434
435@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
436class HamtTest(unittest.TestCase):
437
438 def test_hashkey_helper_1(self):
439 k1 = HashKey(10, 'aaa')
440 k2 = HashKey(10, 'bbb')
441
442 self.assertNotEqual(k1, k2)
443 self.assertEqual(hash(k1), hash(k2))
444
445 d = dict()
446 d[k1] = 'a'
447 d[k2] = 'b'
448
449 self.assertEqual(d[k1], 'a')
450 self.assertEqual(d[k2], 'b')
451
452 def test_hamt_basics_1(self):
453 h = hamt()
454 h = None # NoQA
455
456 def test_hamt_basics_2(self):
457 h = hamt()
458 self.assertEqual(len(h), 0)
459
460 h2 = h.set('a', 'b')
461 self.assertIsNot(h, h2)
462 self.assertEqual(len(h), 0)
463 self.assertEqual(len(h2), 1)
464
465 self.assertIsNone(h.get('a'))
466 self.assertEqual(h.get('a', 42), 42)
467
468 self.assertEqual(h2.get('a'), 'b')
469
470 h3 = h2.set('b', 10)
471 self.assertIsNot(h2, h3)
472 self.assertEqual(len(h), 0)
473 self.assertEqual(len(h2), 1)
474 self.assertEqual(len(h3), 2)
475 self.assertEqual(h3.get('a'), 'b')
476 self.assertEqual(h3.get('b'), 10)
477
478 self.assertIsNone(h.get('b'))
479 self.assertIsNone(h2.get('b'))
480
481 self.assertIsNone(h.get('a'))
482 self.assertEqual(h2.get('a'), 'b')
483
484 h = h2 = h3 = None
485
486 def test_hamt_basics_3(self):
487 h = hamt()
488 o = object()
489 h1 = h.set('1', o)
490 h2 = h1.set('1', o)
491 self.assertIs(h1, h2)
492
493 def test_hamt_basics_4(self):
494 h = hamt()
495 h1 = h.set('key', [])
496 h2 = h1.set('key', [])
497 self.assertIsNot(h1, h2)
498 self.assertEqual(len(h1), 1)
499 self.assertEqual(len(h2), 1)
500 self.assertIsNot(h1.get('key'), h2.get('key'))
501
502 def test_hamt_collision_1(self):
503 k1 = HashKey(10, 'aaa')
504 k2 = HashKey(10, 'bbb')
505 k3 = HashKey(10, 'ccc')
506
507 h = hamt()
508 h2 = h.set(k1, 'a')
509 h3 = h2.set(k2, 'b')
510
511 self.assertEqual(h.get(k1), None)
512 self.assertEqual(h.get(k2), None)
513
514 self.assertEqual(h2.get(k1), 'a')
515 self.assertEqual(h2.get(k2), None)
516
517 self.assertEqual(h3.get(k1), 'a')
518 self.assertEqual(h3.get(k2), 'b')
519
520 h4 = h3.set(k2, 'cc')
521 h5 = h4.set(k3, 'aa')
522
523 self.assertEqual(h3.get(k1), 'a')
524 self.assertEqual(h3.get(k2), 'b')
525 self.assertEqual(h4.get(k1), 'a')
526 self.assertEqual(h4.get(k2), 'cc')
527 self.assertEqual(h4.get(k3), None)
528 self.assertEqual(h5.get(k1), 'a')
529 self.assertEqual(h5.get(k2), 'cc')
530 self.assertEqual(h5.get(k2), 'cc')
531 self.assertEqual(h5.get(k3), 'aa')
532
533 self.assertEqual(len(h), 0)
534 self.assertEqual(len(h2), 1)
535 self.assertEqual(len(h3), 2)
536 self.assertEqual(len(h4), 2)
537 self.assertEqual(len(h5), 3)
538
539 def test_hamt_stress(self):
540 COLLECTION_SIZE = 7000
541 TEST_ITERS_EVERY = 647
542 CRASH_HASH_EVERY = 97
543 CRASH_EQ_EVERY = 11
544 RUN_XTIMES = 3
545
546 for _ in range(RUN_XTIMES):
547 h = hamt()
548 d = dict()
549
550 for i in range(COLLECTION_SIZE):
551 key = KeyStr(i)
552
553 if not (i % CRASH_HASH_EVERY):
554 with HaskKeyCrasher(error_on_hash=True):
555 with self.assertRaises(HashingError):
556 h.set(key, i)
557
558 h = h.set(key, i)
559
560 if not (i % CRASH_EQ_EVERY):
561 with HaskKeyCrasher(error_on_eq=True):
562 with self.assertRaises(EqError):
563 h.get(KeyStr(i)) # really trigger __eq__
564
565 d[key] = i
566 self.assertEqual(len(d), len(h))
567
568 if not (i % TEST_ITERS_EVERY):
569 self.assertEqual(set(h.items()), set(d.items()))
570 self.assertEqual(len(h.items()), len(d.items()))
571
572 self.assertEqual(len(h), COLLECTION_SIZE)
573
574 for key in range(COLLECTION_SIZE):
575 self.assertEqual(h.get(KeyStr(key), 'not found'), key)
576
577 keys_to_delete = list(range(COLLECTION_SIZE))
578 random.shuffle(keys_to_delete)
579 for iter_i, i in enumerate(keys_to_delete):
580 key = KeyStr(i)
581
582 if not (iter_i % CRASH_HASH_EVERY):
583 with HaskKeyCrasher(error_on_hash=True):
584 with self.assertRaises(HashingError):
585 h.delete(key)
586
587 if not (iter_i % CRASH_EQ_EVERY):
588 with HaskKeyCrasher(error_on_eq=True):
589 with self.assertRaises(EqError):
590 h.delete(KeyStr(i))
591
592 h = h.delete(key)
593 self.assertEqual(h.get(key, 'not found'), 'not found')
594 del d[key]
595 self.assertEqual(len(d), len(h))
596
597 if iter_i == COLLECTION_SIZE // 2:
598 hm = h
599 dm = d.copy()
600
601 if not (iter_i % TEST_ITERS_EVERY):
602 self.assertEqual(set(h.keys()), set(d.keys()))
603 self.assertEqual(len(h.keys()), len(d.keys()))
604
605 self.assertEqual(len(d), 0)
606 self.assertEqual(len(h), 0)
607
608 # ============
609
610 for key in dm:
611 self.assertEqual(hm.get(str(key)), dm[key])
612 self.assertEqual(len(dm), len(hm))
613
614 for i, key in enumerate(keys_to_delete):
615 hm = hm.delete(str(key))
616 self.assertEqual(hm.get(str(key), 'not found'), 'not found')
617 dm.pop(str(key), None)
618 self.assertEqual(len(d), len(h))
619
620 if not (i % TEST_ITERS_EVERY):
621 self.assertEqual(set(h.values()), set(d.values()))
622 self.assertEqual(len(h.values()), len(d.values()))
623
624 self.assertEqual(len(d), 0)
625 self.assertEqual(len(h), 0)
626 self.assertEqual(list(h.items()), [])
627
628 def test_hamt_delete_1(self):
629 A = HashKey(100, 'A')
630 B = HashKey(101, 'B')
631 C = HashKey(102, 'C')
632 D = HashKey(103, 'D')
633 E = HashKey(104, 'E')
634 Z = HashKey(-100, 'Z')
635
636 Er = HashKey(103, 'Er', error_on_eq_to=D)
637
638 h = hamt()
639 h = h.set(A, 'a')
640 h = h.set(B, 'b')
641 h = h.set(C, 'c')
642 h = h.set(D, 'd')
643 h = h.set(E, 'e')
644
645 orig_len = len(h)
646
647 # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
648 # <Key name:A hash:100>: 'a'
649 # <Key name:B hash:101>: 'b'
650 # <Key name:C hash:102>: 'c'
651 # <Key name:D hash:103>: 'd'
652 # <Key name:E hash:104>: 'e'
653
654 h = h.delete(C)
655 self.assertEqual(len(h), orig_len - 1)
656
657 with self.assertRaisesRegex(ValueError, 'cannot compare'):
658 h.delete(Er)
659
660 h = h.delete(D)
661 self.assertEqual(len(h), orig_len - 2)
662
663 h2 = h.delete(Z)
664 self.assertIs(h2, h)
665
666 h = h.delete(A)
667 self.assertEqual(len(h), orig_len - 3)
668
669 self.assertEqual(h.get(A, 42), 42)
670 self.assertEqual(h.get(B), 'b')
671 self.assertEqual(h.get(E), 'e')
672
673 def test_hamt_delete_2(self):
674 A = HashKey(100, 'A')
675 B = HashKey(201001, 'B')
676 C = HashKey(101001, 'C')
677 D = HashKey(103, 'D')
678 E = HashKey(104, 'E')
679 Z = HashKey(-100, 'Z')
680
681 Er = HashKey(201001, 'Er', error_on_eq_to=B)
682
683 h = hamt()
684 h = h.set(A, 'a')
685 h = h.set(B, 'b')
686 h = h.set(C, 'c')
687 h = h.set(D, 'd')
688 h = h.set(E, 'e')
689
690 orig_len = len(h)
691
692 # BitmapNode(size=8 bitmap=0b1110010000):
693 # <Key name:A hash:100>: 'a'
694 # <Key name:D hash:103>: 'd'
695 # <Key name:E hash:104>: 'e'
696 # NULL:
697 # BitmapNode(size=4 bitmap=0b100000000001000000000):
698 # <Key name:B hash:201001>: 'b'
699 # <Key name:C hash:101001>: 'c'
700
701 with self.assertRaisesRegex(ValueError, 'cannot compare'):
702 h.delete(Er)
703
704 h = h.delete(Z)
705 self.assertEqual(len(h), orig_len)
706
707 h = h.delete(C)
708 self.assertEqual(len(h), orig_len - 1)
709
710 h = h.delete(B)
711 self.assertEqual(len(h), orig_len - 2)
712
713 h = h.delete(A)
714 self.assertEqual(len(h), orig_len - 3)
715
716 self.assertEqual(h.get(D), 'd')
717 self.assertEqual(h.get(E), 'e')
718
719 h = h.delete(A)
720 h = h.delete(B)
721 h = h.delete(D)
722 h = h.delete(E)
723 self.assertEqual(len(h), 0)
724
725 def test_hamt_delete_3(self):
726 A = HashKey(100, 'A')
727 B = HashKey(101, 'B')
728 C = HashKey(100100, 'C')
729 D = HashKey(100100, 'D')
730 E = HashKey(104, 'E')
731
732 h = hamt()
733 h = h.set(A, 'a')
734 h = h.set(B, 'b')
735 h = h.set(C, 'c')
736 h = h.set(D, 'd')
737 h = h.set(E, 'e')
738
739 orig_len = len(h)
740
741 # BitmapNode(size=6 bitmap=0b100110000):
742 # NULL:
743 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
744 # <Key name:A hash:100>: 'a'
745 # NULL:
746 # CollisionNode(size=4 id=0x108572410):
747 # <Key name:C hash:100100>: 'c'
748 # <Key name:D hash:100100>: 'd'
749 # <Key name:B hash:101>: 'b'
750 # <Key name:E hash:104>: 'e'
751
752 h = h.delete(A)
753 self.assertEqual(len(h), orig_len - 1)
754
755 h = h.delete(E)
756 self.assertEqual(len(h), orig_len - 2)
757
758 self.assertEqual(h.get(C), 'c')
759 self.assertEqual(h.get(B), 'b')
760
761 def test_hamt_delete_4(self):
762 A = HashKey(100, 'A')
763 B = HashKey(101, 'B')
764 C = HashKey(100100, 'C')
765 D = HashKey(100100, 'D')
766 E = HashKey(100100, 'E')
767
768 h = hamt()
769 h = h.set(A, 'a')
770 h = h.set(B, 'b')
771 h = h.set(C, 'c')
772 h = h.set(D, 'd')
773 h = h.set(E, 'e')
774
775 orig_len = len(h)
776
777 # BitmapNode(size=4 bitmap=0b110000):
778 # NULL:
779 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
780 # <Key name:A hash:100>: 'a'
781 # NULL:
782 # CollisionNode(size=6 id=0x10515ef30):
783 # <Key name:C hash:100100>: 'c'
784 # <Key name:D hash:100100>: 'd'
785 # <Key name:E hash:100100>: 'e'
786 # <Key name:B hash:101>: 'b'
787
788 h = h.delete(D)
789 self.assertEqual(len(h), orig_len - 1)
790
791 h = h.delete(E)
792 self.assertEqual(len(h), orig_len - 2)
793
794 h = h.delete(C)
795 self.assertEqual(len(h), orig_len - 3)
796
797 h = h.delete(A)
798 self.assertEqual(len(h), orig_len - 4)
799
800 h = h.delete(B)
801 self.assertEqual(len(h), 0)
802
803 def test_hamt_delete_5(self):
804 h = hamt()
805
806 keys = []
807 for i in range(17):
808 key = HashKey(i, str(i))
809 keys.append(key)
810 h = h.set(key, f'val-{i}')
811
812 collision_key16 = HashKey(16, '18')
813 h = h.set(collision_key16, 'collision')
814
815 # ArrayNode(id=0x10f8b9318):
816 # 0::
817 # BitmapNode(size=2 count=1 bitmap=0b1):
818 # <Key name:0 hash:0>: 'val-0'
819 #
820 # ... 14 more BitmapNodes ...
821 #
822 # 15::
823 # BitmapNode(size=2 count=1 bitmap=0b1):
824 # <Key name:15 hash:15>: 'val-15'
825 #
826 # 16::
827 # BitmapNode(size=2 count=1 bitmap=0b1):
828 # NULL:
829 # CollisionNode(size=4 id=0x10f2f5af8):
830 # <Key name:16 hash:16>: 'val-16'
831 # <Key name:18 hash:16>: 'collision'
832
833 self.assertEqual(len(h), 18)
834
835 h = h.delete(keys[2])
836 self.assertEqual(len(h), 17)
837
838 h = h.delete(collision_key16)
839 self.assertEqual(len(h), 16)
840 h = h.delete(keys[16])
841 self.assertEqual(len(h), 15)
842
843 h = h.delete(keys[1])
844 self.assertEqual(len(h), 14)
845 h = h.delete(keys[1])
846 self.assertEqual(len(h), 14)
847
848 for key in keys:
849 h = h.delete(key)
850 self.assertEqual(len(h), 0)
851
852 def test_hamt_items_1(self):
853 A = HashKey(100, 'A')
854 B = HashKey(201001, 'B')
855 C = HashKey(101001, 'C')
856 D = HashKey(103, 'D')
857 E = HashKey(104, 'E')
858 F = HashKey(110, 'F')
859
860 h = hamt()
861 h = h.set(A, 'a')
862 h = h.set(B, 'b')
863 h = h.set(C, 'c')
864 h = h.set(D, 'd')
865 h = h.set(E, 'e')
866 h = h.set(F, 'f')
867
868 it = h.items()
869 self.assertEqual(
870 set(list(it)),
871 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
872
873 def test_hamt_items_2(self):
874 A = HashKey(100, 'A')
875 B = HashKey(101, 'B')
876 C = HashKey(100100, 'C')
877 D = HashKey(100100, 'D')
878 E = HashKey(100100, 'E')
879 F = HashKey(110, 'F')
880
881 h = hamt()
882 h = h.set(A, 'a')
883 h = h.set(B, 'b')
884 h = h.set(C, 'c')
885 h = h.set(D, 'd')
886 h = h.set(E, 'e')
887 h = h.set(F, 'f')
888
889 it = h.items()
890 self.assertEqual(
891 set(list(it)),
892 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
893
894 def test_hamt_keys_1(self):
895 A = HashKey(100, 'A')
896 B = HashKey(101, 'B')
897 C = HashKey(100100, 'C')
898 D = HashKey(100100, 'D')
899 E = HashKey(100100, 'E')
900 F = HashKey(110, 'F')
901
902 h = hamt()
903 h = h.set(A, 'a')
904 h = h.set(B, 'b')
905 h = h.set(C, 'c')
906 h = h.set(D, 'd')
907 h = h.set(E, 'e')
908 h = h.set(F, 'f')
909
910 self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
911 self.assertEqual(set(list(h)), {A, B, C, D, E, F})
912
913 def test_hamt_items_3(self):
914 h = hamt()
915 self.assertEqual(len(h.items()), 0)
916 self.assertEqual(list(h.items()), [])
917
918 def test_hamt_eq_1(self):
919 A = HashKey(100, 'A')
920 B = HashKey(101, 'B')
921 C = HashKey(100100, 'C')
922 D = HashKey(100100, 'D')
923 E = HashKey(120, 'E')
924
925 h1 = hamt()
926 h1 = h1.set(A, 'a')
927 h1 = h1.set(B, 'b')
928 h1 = h1.set(C, 'c')
929 h1 = h1.set(D, 'd')
930
931 h2 = hamt()
932 h2 = h2.set(A, 'a')
933
934 self.assertFalse(h1 == h2)
935 self.assertTrue(h1 != h2)
936
937 h2 = h2.set(B, 'b')
938 self.assertFalse(h1 == h2)
939 self.assertTrue(h1 != h2)
940
941 h2 = h2.set(C, 'c')
942 self.assertFalse(h1 == h2)
943 self.assertTrue(h1 != h2)
944
945 h2 = h2.set(D, 'd2')
946 self.assertFalse(h1 == h2)
947 self.assertTrue(h1 != h2)
948
949 h2 = h2.set(D, 'd')
950 self.assertTrue(h1 == h2)
951 self.assertFalse(h1 != h2)
952
953 h2 = h2.set(E, 'e')
954 self.assertFalse(h1 == h2)
955 self.assertTrue(h1 != h2)
956
957 h2 = h2.delete(D)
958 self.assertFalse(h1 == h2)
959 self.assertTrue(h1 != h2)
960
961 h2 = h2.set(E, 'd')
962 self.assertFalse(h1 == h2)
963 self.assertTrue(h1 != h2)
964
965 def test_hamt_eq_2(self):
966 A = HashKey(100, 'A')
967 Er = HashKey(100, 'Er', error_on_eq_to=A)
968
969 h1 = hamt()
970 h1 = h1.set(A, 'a')
971
972 h2 = hamt()
973 h2 = h2.set(Er, 'a')
974
975 with self.assertRaisesRegex(ValueError, 'cannot compare'):
976 h1 == h2
977
978 with self.assertRaisesRegex(ValueError, 'cannot compare'):
979 h1 != h2
980
981 def test_hamt_gc_1(self):
982 A = HashKey(100, 'A')
983
984 h = hamt()
985 h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
986 ref = weakref.ref(h)
987
988 a = []
989 a.append(a)
990 a.append(h)
991 b = []
992 a.append(b)
993 b.append(a)
994 h = h.set(A, b)
995
996 del h, a, b
997
998 gc.collect()
999 gc.collect()
1000 gc.collect()
1001
1002 self.assertIsNone(ref())
1003
1004 def test_hamt_gc_2(self):
1005 A = HashKey(100, 'A')
1006 B = HashKey(101, 'B')
1007
1008 h = hamt()
1009 h = h.set(A, 'a')
1010 h = h.set(A, h)
1011
1012 ref = weakref.ref(h)
1013 hi = h.items()
1014 next(hi)
1015
1016 del h, hi
1017
1018 gc.collect()
1019 gc.collect()
1020 gc.collect()
1021
1022 self.assertIsNone(ref())
1023
1024 def test_hamt_in_1(self):
1025 A = HashKey(100, 'A')
1026 AA = HashKey(100, 'A')
1027
1028 B = HashKey(101, 'B')
1029
1030 h = hamt()
1031 h = h.set(A, 1)
1032
1033 self.assertTrue(A in h)
1034 self.assertFalse(B in h)
1035
1036 with self.assertRaises(EqError):
1037 with HaskKeyCrasher(error_on_eq=True):
1038 AA in h
1039
1040 with self.assertRaises(HashingError):
1041 with HaskKeyCrasher(error_on_hash=True):
1042 AA in h
1043
1044 def test_hamt_getitem_1(self):
1045 A = HashKey(100, 'A')
1046 AA = HashKey(100, 'A')
1047
1048 B = HashKey(101, 'B')
1049
1050 h = hamt()
1051 h = h.set(A, 1)
1052
1053 self.assertEqual(h[A], 1)
1054 self.assertEqual(h[AA], 1)
1055
1056 with self.assertRaises(KeyError):
1057 h[B]
1058
1059 with self.assertRaises(EqError):
1060 with HaskKeyCrasher(error_on_eq=True):
1061 h[AA]
1062
1063 with self.assertRaises(HashingError):
1064 with HaskKeyCrasher(error_on_hash=True):
1065 h[AA]
1066
1067
1068if __name__ == "__main__":
1069 unittest.main()