blob: b9e991a400092915695ef23e971bf0a5c187f4f6 [file] [log] [blame]
Yury Selivanovf23746a2018-01-22 19:11:18 -05001import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9
10try:
11 from _testcapi import hamt
12except ImportError:
13 hamt = None
14
15
16def isolated_context(func):
17 """Needed to make reftracking test mode work."""
18 @functools.wraps(func)
19 def wrapper(*args, **kwargs):
20 ctx = contextvars.Context()
21 return ctx.run(func, *args, **kwargs)
22 return wrapper
23
24
25class ContextTest(unittest.TestCase):
26 def test_context_var_new_1(self):
27 with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
28 contextvars.ContextVar()
29
30 with self.assertRaisesRegex(TypeError, 'must be a str'):
31 contextvars.ContextVar(1)
32
Yury Selivanov41cb0ba2018-06-28 13:20:29 -040033 c = contextvars.ContextVar('aaa')
34 self.assertEqual(c.name, 'aaa')
35
36 with self.assertRaises(AttributeError):
37 c.name = 'bbb'
38
39 self.assertNotEqual(hash(c), hash('aaa'))
Yury Selivanovf23746a2018-01-22 19:11:18 -050040
Yury Selivanovf23746a2018-01-22 19:11:18 -050041 @isolated_context
42 def test_context_var_repr_1(self):
43 c = contextvars.ContextVar('a')
44 self.assertIn('a', repr(c))
45
46 c = contextvars.ContextVar('a', default=123)
47 self.assertIn('123', repr(c))
48
49 lst = []
50 c = contextvars.ContextVar('a', default=lst)
51 lst.append(c)
52 self.assertIn('...', repr(c))
53 self.assertIn('...', repr(lst))
54
55 t = c.set(1)
56 self.assertIn(repr(c), repr(t))
57 self.assertNotIn(' used ', repr(t))
58 c.reset(t)
59 self.assertIn(' used ', repr(t))
60
61 def test_context_subclassing_1(self):
62 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
63 class MyContextVar(contextvars.ContextVar):
64 # Potentially we might want ContextVars to be subclassable.
65 pass
66
67 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
68 class MyContext(contextvars.Context):
69 pass
70
71 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
72 class MyToken(contextvars.Token):
73 pass
74
75 def test_context_new_1(self):
76 with self.assertRaisesRegex(TypeError, 'any arguments'):
77 contextvars.Context(1)
78 with self.assertRaisesRegex(TypeError, 'any arguments'):
79 contextvars.Context(1, a=1)
80 with self.assertRaisesRegex(TypeError, 'any arguments'):
81 contextvars.Context(a=1)
82 contextvars.Context(**{})
83
84 def test_context_typerrors_1(self):
85 ctx = contextvars.Context()
86
87 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
88 ctx[1]
89 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
90 1 in ctx
91 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
92 ctx.get(1)
93
94 def test_context_get_context_1(self):
95 ctx = contextvars.copy_context()
96 self.assertIsInstance(ctx, contextvars.Context)
97
98 def test_context_run_1(self):
99 ctx = contextvars.Context()
100
101 with self.assertRaisesRegex(TypeError, 'missing 1 required'):
102 ctx.run()
103
104 def test_context_run_2(self):
105 ctx = contextvars.Context()
106
107 def func(*args, **kwargs):
108 kwargs['spam'] = 'foo'
109 args += ('bar',)
110 return args, kwargs
111
112 for f in (func, functools.partial(func)):
113 # partial doesn't support FASTCALL
114
115 self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
116 self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
117
118 self.assertEqual(
119 ctx.run(f, a=2),
120 (('bar',), {'a': 2, 'spam': 'foo'}))
121
122 self.assertEqual(
123 ctx.run(f, 11, a=2),
124 ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
125
126 a = {}
127 self.assertEqual(
128 ctx.run(f, 11, **a),
129 ((11, 'bar'), {'spam': 'foo'}))
130 self.assertEqual(a, {})
131
132 def test_context_run_3(self):
133 ctx = contextvars.Context()
134
135 def func(*args, **kwargs):
136 1 / 0
137
138 with self.assertRaises(ZeroDivisionError):
139 ctx.run(func)
140 with self.assertRaises(ZeroDivisionError):
141 ctx.run(func, 1, 2)
142 with self.assertRaises(ZeroDivisionError):
143 ctx.run(func, 1, 2, a=123)
144
145 @isolated_context
146 def test_context_run_4(self):
147 ctx1 = contextvars.Context()
148 ctx2 = contextvars.Context()
149 var = contextvars.ContextVar('var')
150
151 def func2():
152 self.assertIsNone(var.get(None))
153
154 def func1():
155 self.assertIsNone(var.get(None))
156 var.set('spam')
157 ctx2.run(func2)
158 self.assertEqual(var.get(None), 'spam')
159
160 cur = contextvars.copy_context()
161 self.assertEqual(len(cur), 1)
162 self.assertEqual(cur[var], 'spam')
163 return cur
164
165 returned_ctx = ctx1.run(func1)
166 self.assertEqual(ctx1, returned_ctx)
167 self.assertEqual(returned_ctx[var], 'spam')
168 self.assertIn(var, returned_ctx)
169
170 def test_context_run_5(self):
171 ctx = contextvars.Context()
172 var = contextvars.ContextVar('var')
173
174 def func():
175 self.assertIsNone(var.get(None))
176 var.set('spam')
177 1 / 0
178
179 with self.assertRaises(ZeroDivisionError):
180 ctx.run(func)
181
182 self.assertIsNone(var.get(None))
183
184 def test_context_run_6(self):
185 ctx = contextvars.Context()
186 c = contextvars.ContextVar('a', default=0)
187
188 def fun():
189 self.assertEqual(c.get(), 0)
190 self.assertIsNone(ctx.get(c))
191
192 c.set(42)
193 self.assertEqual(c.get(), 42)
194 self.assertEqual(ctx.get(c), 42)
195
196 ctx.run(fun)
197
198 def test_context_run_7(self):
199 ctx = contextvars.Context()
200
201 def fun():
202 with self.assertRaisesRegex(RuntimeError, 'is already entered'):
203 ctx.run(fun)
204
205 ctx.run(fun)
206
207 @isolated_context
208 def test_context_getset_1(self):
209 c = contextvars.ContextVar('c')
210 with self.assertRaises(LookupError):
211 c.get()
212
213 self.assertIsNone(c.get(None))
214
215 t0 = c.set(42)
216 self.assertEqual(c.get(), 42)
217 self.assertEqual(c.get(None), 42)
218 self.assertIs(t0.old_value, t0.MISSING)
219 self.assertIs(t0.old_value, contextvars.Token.MISSING)
220 self.assertIs(t0.var, c)
221
222 t = c.set('spam')
223 self.assertEqual(c.get(), 'spam')
224 self.assertEqual(c.get(None), 'spam')
225 self.assertEqual(t.old_value, 42)
226 c.reset(t)
227
228 self.assertEqual(c.get(), 42)
229 self.assertEqual(c.get(None), 42)
230
231 c.set('spam2')
232 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
233 c.reset(t)
234 self.assertEqual(c.get(), 'spam2')
235
236 ctx1 = contextvars.copy_context()
237 self.assertIn(c, ctx1)
238
239 c.reset(t0)
240 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
241 c.reset(t0)
242 self.assertIsNone(c.get(None))
243
244 self.assertIn(c, ctx1)
245 self.assertEqual(ctx1[c], 'spam2')
246 self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
247 self.assertEqual(len(ctx1), 1)
248 self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
249 self.assertEqual(list(ctx1.values()), ['spam2'])
250 self.assertEqual(list(ctx1.keys()), [c])
251 self.assertEqual(list(ctx1), [c])
252
253 ctx2 = contextvars.copy_context()
254 self.assertNotIn(c, ctx2)
255 with self.assertRaises(KeyError):
256 ctx2[c]
257 self.assertEqual(ctx2.get(c, 'aa'), 'aa')
258 self.assertEqual(len(ctx2), 0)
259 self.assertEqual(list(ctx2), [])
260
261 @isolated_context
262 def test_context_getset_2(self):
263 v1 = contextvars.ContextVar('v1')
264 v2 = contextvars.ContextVar('v2')
265
266 t1 = v1.set(42)
267 with self.assertRaisesRegex(ValueError, 'by a different'):
268 v2.reset(t1)
269
270 @isolated_context
271 def test_context_getset_3(self):
272 c = contextvars.ContextVar('c', default=42)
273 ctx = contextvars.Context()
274
275 def fun():
276 self.assertEqual(c.get(), 42)
277 with self.assertRaises(KeyError):
278 ctx[c]
279 self.assertIsNone(ctx.get(c))
280 self.assertEqual(ctx.get(c, 'spam'), 'spam')
281 self.assertNotIn(c, ctx)
282 self.assertEqual(list(ctx.keys()), [])
283
284 t = c.set(1)
285 self.assertEqual(list(ctx.keys()), [c])
286 self.assertEqual(ctx[c], 1)
287
288 c.reset(t)
289 self.assertEqual(list(ctx.keys()), [])
290 with self.assertRaises(KeyError):
291 ctx[c]
292
293 ctx.run(fun)
294
295 @isolated_context
296 def test_context_getset_4(self):
297 c = contextvars.ContextVar('c', default=42)
298 ctx = contextvars.Context()
299
300 tok = ctx.run(c.set, 1)
301
302 with self.assertRaisesRegex(ValueError, 'different Context'):
303 c.reset(tok)
304
305 @isolated_context
306 def test_context_getset_5(self):
307 c = contextvars.ContextVar('c', default=42)
308 c.set([])
309
310 def fun():
311 c.set([])
312 c.get().append(42)
313 self.assertEqual(c.get(), [42])
314
315 contextvars.copy_context().run(fun)
316 self.assertEqual(c.get(), [])
317
318 def test_context_copy_1(self):
319 ctx1 = contextvars.Context()
320 c = contextvars.ContextVar('c', default=42)
321
322 def ctx1_fun():
323 c.set(10)
324
325 ctx2 = ctx1.copy()
326 self.assertEqual(ctx2[c], 10)
327
328 c.set(20)
329 self.assertEqual(ctx1[c], 20)
330 self.assertEqual(ctx2[c], 10)
331
332 ctx2.run(ctx2_fun)
333 self.assertEqual(ctx1[c], 20)
334 self.assertEqual(ctx2[c], 30)
335
336 def ctx2_fun():
337 self.assertEqual(c.get(), 10)
338 c.set(30)
339 self.assertEqual(c.get(), 30)
340
341 ctx1.run(ctx1_fun)
342
343 @isolated_context
344 def test_context_threads_1(self):
345 cvar = contextvars.ContextVar('cvar')
346
347 def sub(num):
348 for i in range(10):
349 cvar.set(num + i)
350 time.sleep(random.uniform(0.001, 0.05))
351 self.assertEqual(cvar.get(), num + i)
352 return num
353
354 tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
355 try:
356 results = list(tp.map(sub, range(10)))
357 finally:
358 tp.shutdown()
359 self.assertEqual(results, list(range(10)))
360
AMIR28c91632019-12-08 15:05:59 +0330361 def test_contextvar_getitem(self):
362 clss = contextvars.ContextVar
363 self.assertEqual(clss[str], clss)
364
Yury Selivanovf23746a2018-01-22 19:11:18 -0500365
366# HAMT Tests
367
368
369class HashKey:
370 _crasher = None
371
372 def __init__(self, hash, name, *, error_on_eq_to=None):
373 assert hash != -1
374 self.name = name
375 self.hash = hash
376 self.error_on_eq_to = error_on_eq_to
377
378 def __repr__(self):
379 return f'<Key name:{self.name} hash:{self.hash}>'
380
381 def __hash__(self):
382 if self._crasher is not None and self._crasher.error_on_hash:
383 raise HashingError
384
385 return self.hash
386
387 def __eq__(self, other):
388 if not isinstance(other, HashKey):
389 return NotImplemented
390
391 if self._crasher is not None and self._crasher.error_on_eq:
392 raise EqError
393
394 if self.error_on_eq_to is not None and self.error_on_eq_to is other:
395 raise ValueError(f'cannot compare {self!r} to {other!r}')
396 if other.error_on_eq_to is not None and other.error_on_eq_to is self:
397 raise ValueError(f'cannot compare {other!r} to {self!r}')
398
399 return (self.name, self.hash) == (other.name, other.hash)
400
401
402class KeyStr(str):
403 def __hash__(self):
404 if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
405 raise HashingError
406 return super().__hash__()
407
408 def __eq__(self, other):
409 if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
410 raise EqError
411 return super().__eq__(other)
412
413
414class HaskKeyCrasher:
415 def __init__(self, *, error_on_hash=False, error_on_eq=False):
416 self.error_on_hash = error_on_hash
417 self.error_on_eq = error_on_eq
418
419 def __enter__(self):
420 if HashKey._crasher is not None:
421 raise RuntimeError('cannot nest crashers')
422 HashKey._crasher = self
423
424 def __exit__(self, *exc):
425 HashKey._crasher = None
426
427
428class HashingError(Exception):
429 pass
430
431
432class EqError(Exception):
433 pass
434
435
436@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
437class HamtTest(unittest.TestCase):
438
439 def test_hashkey_helper_1(self):
440 k1 = HashKey(10, 'aaa')
441 k2 = HashKey(10, 'bbb')
442
443 self.assertNotEqual(k1, k2)
444 self.assertEqual(hash(k1), hash(k2))
445
446 d = dict()
447 d[k1] = 'a'
448 d[k2] = 'b'
449
450 self.assertEqual(d[k1], 'a')
451 self.assertEqual(d[k2], 'b')
452
453 def test_hamt_basics_1(self):
454 h = hamt()
455 h = None # NoQA
456
457 def test_hamt_basics_2(self):
458 h = hamt()
459 self.assertEqual(len(h), 0)
460
461 h2 = h.set('a', 'b')
462 self.assertIsNot(h, h2)
463 self.assertEqual(len(h), 0)
464 self.assertEqual(len(h2), 1)
465
466 self.assertIsNone(h.get('a'))
467 self.assertEqual(h.get('a', 42), 42)
468
469 self.assertEqual(h2.get('a'), 'b')
470
471 h3 = h2.set('b', 10)
472 self.assertIsNot(h2, h3)
473 self.assertEqual(len(h), 0)
474 self.assertEqual(len(h2), 1)
475 self.assertEqual(len(h3), 2)
476 self.assertEqual(h3.get('a'), 'b')
477 self.assertEqual(h3.get('b'), 10)
478
479 self.assertIsNone(h.get('b'))
480 self.assertIsNone(h2.get('b'))
481
482 self.assertIsNone(h.get('a'))
483 self.assertEqual(h2.get('a'), 'b')
484
485 h = h2 = h3 = None
486
487 def test_hamt_basics_3(self):
488 h = hamt()
489 o = object()
490 h1 = h.set('1', o)
491 h2 = h1.set('1', o)
492 self.assertIs(h1, h2)
493
494 def test_hamt_basics_4(self):
495 h = hamt()
496 h1 = h.set('key', [])
497 h2 = h1.set('key', [])
498 self.assertIsNot(h1, h2)
499 self.assertEqual(len(h1), 1)
500 self.assertEqual(len(h2), 1)
501 self.assertIsNot(h1.get('key'), h2.get('key'))
502
503 def test_hamt_collision_1(self):
504 k1 = HashKey(10, 'aaa')
505 k2 = HashKey(10, 'bbb')
506 k3 = HashKey(10, 'ccc')
507
508 h = hamt()
509 h2 = h.set(k1, 'a')
510 h3 = h2.set(k2, 'b')
511
512 self.assertEqual(h.get(k1), None)
513 self.assertEqual(h.get(k2), None)
514
515 self.assertEqual(h2.get(k1), 'a')
516 self.assertEqual(h2.get(k2), None)
517
518 self.assertEqual(h3.get(k1), 'a')
519 self.assertEqual(h3.get(k2), 'b')
520
521 h4 = h3.set(k2, 'cc')
522 h5 = h4.set(k3, 'aa')
523
524 self.assertEqual(h3.get(k1), 'a')
525 self.assertEqual(h3.get(k2), 'b')
526 self.assertEqual(h4.get(k1), 'a')
527 self.assertEqual(h4.get(k2), 'cc')
528 self.assertEqual(h4.get(k3), None)
529 self.assertEqual(h5.get(k1), 'a')
530 self.assertEqual(h5.get(k2), 'cc')
531 self.assertEqual(h5.get(k2), 'cc')
532 self.assertEqual(h5.get(k3), 'aa')
533
534 self.assertEqual(len(h), 0)
535 self.assertEqual(len(h2), 1)
536 self.assertEqual(len(h3), 2)
537 self.assertEqual(len(h4), 2)
538 self.assertEqual(len(h5), 3)
539
540 def test_hamt_stress(self):
541 COLLECTION_SIZE = 7000
542 TEST_ITERS_EVERY = 647
543 CRASH_HASH_EVERY = 97
544 CRASH_EQ_EVERY = 11
545 RUN_XTIMES = 3
546
547 for _ in range(RUN_XTIMES):
548 h = hamt()
549 d = dict()
550
551 for i in range(COLLECTION_SIZE):
552 key = KeyStr(i)
553
554 if not (i % CRASH_HASH_EVERY):
555 with HaskKeyCrasher(error_on_hash=True):
556 with self.assertRaises(HashingError):
557 h.set(key, i)
558
559 h = h.set(key, i)
560
561 if not (i % CRASH_EQ_EVERY):
562 with HaskKeyCrasher(error_on_eq=True):
563 with self.assertRaises(EqError):
564 h.get(KeyStr(i)) # really trigger __eq__
565
566 d[key] = i
567 self.assertEqual(len(d), len(h))
568
569 if not (i % TEST_ITERS_EVERY):
570 self.assertEqual(set(h.items()), set(d.items()))
571 self.assertEqual(len(h.items()), len(d.items()))
572
573 self.assertEqual(len(h), COLLECTION_SIZE)
574
575 for key in range(COLLECTION_SIZE):
576 self.assertEqual(h.get(KeyStr(key), 'not found'), key)
577
578 keys_to_delete = list(range(COLLECTION_SIZE))
579 random.shuffle(keys_to_delete)
580 for iter_i, i in enumerate(keys_to_delete):
581 key = KeyStr(i)
582
583 if not (iter_i % CRASH_HASH_EVERY):
584 with HaskKeyCrasher(error_on_hash=True):
585 with self.assertRaises(HashingError):
586 h.delete(key)
587
588 if not (iter_i % CRASH_EQ_EVERY):
589 with HaskKeyCrasher(error_on_eq=True):
590 with self.assertRaises(EqError):
591 h.delete(KeyStr(i))
592
593 h = h.delete(key)
594 self.assertEqual(h.get(key, 'not found'), 'not found')
595 del d[key]
596 self.assertEqual(len(d), len(h))
597
598 if iter_i == COLLECTION_SIZE // 2:
599 hm = h
600 dm = d.copy()
601
602 if not (iter_i % TEST_ITERS_EVERY):
603 self.assertEqual(set(h.keys()), set(d.keys()))
604 self.assertEqual(len(h.keys()), len(d.keys()))
605
606 self.assertEqual(len(d), 0)
607 self.assertEqual(len(h), 0)
608
609 # ============
610
611 for key in dm:
612 self.assertEqual(hm.get(str(key)), dm[key])
613 self.assertEqual(len(dm), len(hm))
614
615 for i, key in enumerate(keys_to_delete):
616 hm = hm.delete(str(key))
617 self.assertEqual(hm.get(str(key), 'not found'), 'not found')
618 dm.pop(str(key), None)
619 self.assertEqual(len(d), len(h))
620
621 if not (i % TEST_ITERS_EVERY):
622 self.assertEqual(set(h.values()), set(d.values()))
623 self.assertEqual(len(h.values()), len(d.values()))
624
625 self.assertEqual(len(d), 0)
626 self.assertEqual(len(h), 0)
627 self.assertEqual(list(h.items()), [])
628
629 def test_hamt_delete_1(self):
630 A = HashKey(100, 'A')
631 B = HashKey(101, 'B')
632 C = HashKey(102, 'C')
633 D = HashKey(103, 'D')
634 E = HashKey(104, 'E')
635 Z = HashKey(-100, 'Z')
636
637 Er = HashKey(103, 'Er', error_on_eq_to=D)
638
639 h = hamt()
640 h = h.set(A, 'a')
641 h = h.set(B, 'b')
642 h = h.set(C, 'c')
643 h = h.set(D, 'd')
644 h = h.set(E, 'e')
645
646 orig_len = len(h)
647
648 # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
649 # <Key name:A hash:100>: 'a'
650 # <Key name:B hash:101>: 'b'
651 # <Key name:C hash:102>: 'c'
652 # <Key name:D hash:103>: 'd'
653 # <Key name:E hash:104>: 'e'
654
655 h = h.delete(C)
656 self.assertEqual(len(h), orig_len - 1)
657
658 with self.assertRaisesRegex(ValueError, 'cannot compare'):
659 h.delete(Er)
660
661 h = h.delete(D)
662 self.assertEqual(len(h), orig_len - 2)
663
664 h2 = h.delete(Z)
665 self.assertIs(h2, h)
666
667 h = h.delete(A)
668 self.assertEqual(len(h), orig_len - 3)
669
670 self.assertEqual(h.get(A, 42), 42)
671 self.assertEqual(h.get(B), 'b')
672 self.assertEqual(h.get(E), 'e')
673
674 def test_hamt_delete_2(self):
675 A = HashKey(100, 'A')
676 B = HashKey(201001, 'B')
677 C = HashKey(101001, 'C')
678 D = HashKey(103, 'D')
679 E = HashKey(104, 'E')
680 Z = HashKey(-100, 'Z')
681
682 Er = HashKey(201001, 'Er', error_on_eq_to=B)
683
684 h = hamt()
685 h = h.set(A, 'a')
686 h = h.set(B, 'b')
687 h = h.set(C, 'c')
688 h = h.set(D, 'd')
689 h = h.set(E, 'e')
690
691 orig_len = len(h)
692
693 # BitmapNode(size=8 bitmap=0b1110010000):
694 # <Key name:A hash:100>: 'a'
695 # <Key name:D hash:103>: 'd'
696 # <Key name:E hash:104>: 'e'
697 # NULL:
698 # BitmapNode(size=4 bitmap=0b100000000001000000000):
699 # <Key name:B hash:201001>: 'b'
700 # <Key name:C hash:101001>: 'c'
701
702 with self.assertRaisesRegex(ValueError, 'cannot compare'):
703 h.delete(Er)
704
705 h = h.delete(Z)
706 self.assertEqual(len(h), orig_len)
707
708 h = h.delete(C)
709 self.assertEqual(len(h), orig_len - 1)
710
711 h = h.delete(B)
712 self.assertEqual(len(h), orig_len - 2)
713
714 h = h.delete(A)
715 self.assertEqual(len(h), orig_len - 3)
716
717 self.assertEqual(h.get(D), 'd')
718 self.assertEqual(h.get(E), 'e')
719
720 h = h.delete(A)
721 h = h.delete(B)
722 h = h.delete(D)
723 h = h.delete(E)
724 self.assertEqual(len(h), 0)
725
726 def test_hamt_delete_3(self):
727 A = HashKey(100, 'A')
728 B = HashKey(101, 'B')
729 C = HashKey(100100, 'C')
730 D = HashKey(100100, 'D')
731 E = HashKey(104, 'E')
732
733 h = hamt()
734 h = h.set(A, 'a')
735 h = h.set(B, 'b')
736 h = h.set(C, 'c')
737 h = h.set(D, 'd')
738 h = h.set(E, 'e')
739
740 orig_len = len(h)
741
742 # BitmapNode(size=6 bitmap=0b100110000):
743 # NULL:
744 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
745 # <Key name:A hash:100>: 'a'
746 # NULL:
747 # CollisionNode(size=4 id=0x108572410):
748 # <Key name:C hash:100100>: 'c'
749 # <Key name:D hash:100100>: 'd'
750 # <Key name:B hash:101>: 'b'
751 # <Key name:E hash:104>: 'e'
752
753 h = h.delete(A)
754 self.assertEqual(len(h), orig_len - 1)
755
756 h = h.delete(E)
757 self.assertEqual(len(h), orig_len - 2)
758
759 self.assertEqual(h.get(C), 'c')
760 self.assertEqual(h.get(B), 'b')
761
762 def test_hamt_delete_4(self):
763 A = HashKey(100, 'A')
764 B = HashKey(101, 'B')
765 C = HashKey(100100, 'C')
766 D = HashKey(100100, 'D')
767 E = HashKey(100100, 'E')
768
769 h = hamt()
770 h = h.set(A, 'a')
771 h = h.set(B, 'b')
772 h = h.set(C, 'c')
773 h = h.set(D, 'd')
774 h = h.set(E, 'e')
775
776 orig_len = len(h)
777
778 # BitmapNode(size=4 bitmap=0b110000):
779 # NULL:
780 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
781 # <Key name:A hash:100>: 'a'
782 # NULL:
783 # CollisionNode(size=6 id=0x10515ef30):
784 # <Key name:C hash:100100>: 'c'
785 # <Key name:D hash:100100>: 'd'
786 # <Key name:E hash:100100>: 'e'
787 # <Key name:B hash:101>: 'b'
788
789 h = h.delete(D)
790 self.assertEqual(len(h), orig_len - 1)
791
792 h = h.delete(E)
793 self.assertEqual(len(h), orig_len - 2)
794
795 h = h.delete(C)
796 self.assertEqual(len(h), orig_len - 3)
797
798 h = h.delete(A)
799 self.assertEqual(len(h), orig_len - 4)
800
801 h = h.delete(B)
802 self.assertEqual(len(h), 0)
803
804 def test_hamt_delete_5(self):
805 h = hamt()
806
807 keys = []
808 for i in range(17):
809 key = HashKey(i, str(i))
810 keys.append(key)
811 h = h.set(key, f'val-{i}')
812
813 collision_key16 = HashKey(16, '18')
814 h = h.set(collision_key16, 'collision')
815
816 # ArrayNode(id=0x10f8b9318):
817 # 0::
818 # BitmapNode(size=2 count=1 bitmap=0b1):
819 # <Key name:0 hash:0>: 'val-0'
820 #
821 # ... 14 more BitmapNodes ...
822 #
823 # 15::
824 # BitmapNode(size=2 count=1 bitmap=0b1):
825 # <Key name:15 hash:15>: 'val-15'
826 #
827 # 16::
828 # BitmapNode(size=2 count=1 bitmap=0b1):
829 # NULL:
830 # CollisionNode(size=4 id=0x10f2f5af8):
831 # <Key name:16 hash:16>: 'val-16'
832 # <Key name:18 hash:16>: 'collision'
833
834 self.assertEqual(len(h), 18)
835
836 h = h.delete(keys[2])
837 self.assertEqual(len(h), 17)
838
839 h = h.delete(collision_key16)
840 self.assertEqual(len(h), 16)
841 h = h.delete(keys[16])
842 self.assertEqual(len(h), 15)
843
844 h = h.delete(keys[1])
845 self.assertEqual(len(h), 14)
846 h = h.delete(keys[1])
847 self.assertEqual(len(h), 14)
848
849 for key in keys:
850 h = h.delete(key)
851 self.assertEqual(len(h), 0)
852
853 def test_hamt_items_1(self):
854 A = HashKey(100, 'A')
855 B = HashKey(201001, 'B')
856 C = HashKey(101001, 'C')
857 D = HashKey(103, 'D')
858 E = HashKey(104, 'E')
859 F = HashKey(110, 'F')
860
861 h = hamt()
862 h = h.set(A, 'a')
863 h = h.set(B, 'b')
864 h = h.set(C, 'c')
865 h = h.set(D, 'd')
866 h = h.set(E, 'e')
867 h = h.set(F, 'f')
868
869 it = h.items()
870 self.assertEqual(
871 set(list(it)),
872 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
873
874 def test_hamt_items_2(self):
875 A = HashKey(100, 'A')
876 B = HashKey(101, 'B')
877 C = HashKey(100100, 'C')
878 D = HashKey(100100, 'D')
879 E = HashKey(100100, 'E')
880 F = HashKey(110, 'F')
881
882 h = hamt()
883 h = h.set(A, 'a')
884 h = h.set(B, 'b')
885 h = h.set(C, 'c')
886 h = h.set(D, 'd')
887 h = h.set(E, 'e')
888 h = h.set(F, 'f')
889
890 it = h.items()
891 self.assertEqual(
892 set(list(it)),
893 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
894
895 def test_hamt_keys_1(self):
896 A = HashKey(100, 'A')
897 B = HashKey(101, 'B')
898 C = HashKey(100100, 'C')
899 D = HashKey(100100, 'D')
900 E = HashKey(100100, 'E')
901 F = HashKey(110, 'F')
902
903 h = hamt()
904 h = h.set(A, 'a')
905 h = h.set(B, 'b')
906 h = h.set(C, 'c')
907 h = h.set(D, 'd')
908 h = h.set(E, 'e')
909 h = h.set(F, 'f')
910
911 self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
912 self.assertEqual(set(list(h)), {A, B, C, D, E, F})
913
914 def test_hamt_items_3(self):
915 h = hamt()
916 self.assertEqual(len(h.items()), 0)
917 self.assertEqual(list(h.items()), [])
918
919 def test_hamt_eq_1(self):
920 A = HashKey(100, 'A')
921 B = HashKey(101, 'B')
922 C = HashKey(100100, 'C')
923 D = HashKey(100100, 'D')
924 E = HashKey(120, 'E')
925
926 h1 = hamt()
927 h1 = h1.set(A, 'a')
928 h1 = h1.set(B, 'b')
929 h1 = h1.set(C, 'c')
930 h1 = h1.set(D, 'd')
931
932 h2 = hamt()
933 h2 = h2.set(A, 'a')
934
935 self.assertFalse(h1 == h2)
936 self.assertTrue(h1 != h2)
937
938 h2 = h2.set(B, 'b')
939 self.assertFalse(h1 == h2)
940 self.assertTrue(h1 != h2)
941
942 h2 = h2.set(C, 'c')
943 self.assertFalse(h1 == h2)
944 self.assertTrue(h1 != h2)
945
946 h2 = h2.set(D, 'd2')
947 self.assertFalse(h1 == h2)
948 self.assertTrue(h1 != h2)
949
950 h2 = h2.set(D, 'd')
951 self.assertTrue(h1 == h2)
952 self.assertFalse(h1 != h2)
953
954 h2 = h2.set(E, 'e')
955 self.assertFalse(h1 == h2)
956 self.assertTrue(h1 != h2)
957
958 h2 = h2.delete(D)
959 self.assertFalse(h1 == h2)
960 self.assertTrue(h1 != h2)
961
962 h2 = h2.set(E, 'd')
963 self.assertFalse(h1 == h2)
964 self.assertTrue(h1 != h2)
965
966 def test_hamt_eq_2(self):
967 A = HashKey(100, 'A')
968 Er = HashKey(100, 'Er', error_on_eq_to=A)
969
970 h1 = hamt()
971 h1 = h1.set(A, 'a')
972
973 h2 = hamt()
974 h2 = h2.set(Er, 'a')
975
976 with self.assertRaisesRegex(ValueError, 'cannot compare'):
977 h1 == h2
978
979 with self.assertRaisesRegex(ValueError, 'cannot compare'):
980 h1 != h2
981
982 def test_hamt_gc_1(self):
983 A = HashKey(100, 'A')
984
985 h = hamt()
986 h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
987 ref = weakref.ref(h)
988
989 a = []
990 a.append(a)
991 a.append(h)
992 b = []
993 a.append(b)
994 b.append(a)
995 h = h.set(A, b)
996
997 del h, a, b
998
999 gc.collect()
1000 gc.collect()
1001 gc.collect()
1002
1003 self.assertIsNone(ref())
1004
1005 def test_hamt_gc_2(self):
1006 A = HashKey(100, 'A')
1007 B = HashKey(101, 'B')
1008
1009 h = hamt()
1010 h = h.set(A, 'a')
1011 h = h.set(A, h)
1012
1013 ref = weakref.ref(h)
1014 hi = h.items()
1015 next(hi)
1016
1017 del h, hi
1018
1019 gc.collect()
1020 gc.collect()
1021 gc.collect()
1022
1023 self.assertIsNone(ref())
1024
1025 def test_hamt_in_1(self):
1026 A = HashKey(100, 'A')
1027 AA = HashKey(100, 'A')
1028
1029 B = HashKey(101, 'B')
1030
1031 h = hamt()
1032 h = h.set(A, 1)
1033
1034 self.assertTrue(A in h)
1035 self.assertFalse(B in h)
1036
1037 with self.assertRaises(EqError):
1038 with HaskKeyCrasher(error_on_eq=True):
1039 AA in h
1040
1041 with self.assertRaises(HashingError):
1042 with HaskKeyCrasher(error_on_hash=True):
1043 AA in h
1044
1045 def test_hamt_getitem_1(self):
1046 A = HashKey(100, 'A')
1047 AA = HashKey(100, 'A')
1048
1049 B = HashKey(101, 'B')
1050
1051 h = hamt()
1052 h = h.set(A, 1)
1053
1054 self.assertEqual(h[A], 1)
1055 self.assertEqual(h[AA], 1)
1056
1057 with self.assertRaises(KeyError):
1058 h[B]
1059
1060 with self.assertRaises(EqError):
1061 with HaskKeyCrasher(error_on_eq=True):
1062 h[AA]
1063
1064 with self.assertRaises(HashingError):
1065 with HaskKeyCrasher(error_on_hash=True):
1066 h[AA]
1067
1068
1069if __name__ == "__main__":
1070 unittest.main()