blob: 2d8b63a1f59581e1c471ca82dc8c13fb8fda3e71 [file] [log] [blame]
Yury Selivanovf23746a2018-01-22 19:11:18 -05001import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9
10try:
11 from _testcapi import hamt
12except ImportError:
13 hamt = None
14
15
16def isolated_context(func):
17 """Needed to make reftracking test mode work."""
18 @functools.wraps(func)
19 def wrapper(*args, **kwargs):
20 ctx = contextvars.Context()
21 return ctx.run(func, *args, **kwargs)
22 return wrapper
23
24
25class ContextTest(unittest.TestCase):
26 def test_context_var_new_1(self):
27 with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
28 contextvars.ContextVar()
29
30 with self.assertRaisesRegex(TypeError, 'must be a str'):
31 contextvars.ContextVar(1)
32
Yury Selivanov41cb0ba2018-06-28 13:20:29 -040033 c = contextvars.ContextVar('aaa')
34 self.assertEqual(c.name, 'aaa')
35
36 with self.assertRaises(AttributeError):
37 c.name = 'bbb'
38
39 self.assertNotEqual(hash(c), hash('aaa'))
Yury Selivanovf23746a2018-01-22 19:11:18 -050040
Yury Selivanovf23746a2018-01-22 19:11:18 -050041 @isolated_context
42 def test_context_var_repr_1(self):
43 c = contextvars.ContextVar('a')
44 self.assertIn('a', repr(c))
45
46 c = contextvars.ContextVar('a', default=123)
47 self.assertIn('123', repr(c))
48
49 lst = []
50 c = contextvars.ContextVar('a', default=lst)
51 lst.append(c)
52 self.assertIn('...', repr(c))
53 self.assertIn('...', repr(lst))
54
55 t = c.set(1)
56 self.assertIn(repr(c), repr(t))
57 self.assertNotIn(' used ', repr(t))
58 c.reset(t)
59 self.assertIn(' used ', repr(t))
60
61 def test_context_subclassing_1(self):
62 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
63 class MyContextVar(contextvars.ContextVar):
64 # Potentially we might want ContextVars to be subclassable.
65 pass
66
67 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
68 class MyContext(contextvars.Context):
69 pass
70
71 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
72 class MyToken(contextvars.Token):
73 pass
74
75 def test_context_new_1(self):
76 with self.assertRaisesRegex(TypeError, 'any arguments'):
77 contextvars.Context(1)
78 with self.assertRaisesRegex(TypeError, 'any arguments'):
79 contextvars.Context(1, a=1)
80 with self.assertRaisesRegex(TypeError, 'any arguments'):
81 contextvars.Context(a=1)
82 contextvars.Context(**{})
83
84 def test_context_typerrors_1(self):
85 ctx = contextvars.Context()
86
87 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
88 ctx[1]
89 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
90 1 in ctx
91 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
92 ctx.get(1)
93
94 def test_context_get_context_1(self):
95 ctx = contextvars.copy_context()
96 self.assertIsInstance(ctx, contextvars.Context)
97
98 def test_context_run_1(self):
99 ctx = contextvars.Context()
100
101 with self.assertRaisesRegex(TypeError, 'missing 1 required'):
102 ctx.run()
103
104 def test_context_run_2(self):
105 ctx = contextvars.Context()
106
107 def func(*args, **kwargs):
108 kwargs['spam'] = 'foo'
109 args += ('bar',)
110 return args, kwargs
111
112 for f in (func, functools.partial(func)):
113 # partial doesn't support FASTCALL
114
115 self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
116 self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
117
118 self.assertEqual(
119 ctx.run(f, a=2),
120 (('bar',), {'a': 2, 'spam': 'foo'}))
121
122 self.assertEqual(
123 ctx.run(f, 11, a=2),
124 ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
125
126 a = {}
127 self.assertEqual(
128 ctx.run(f, 11, **a),
129 ((11, 'bar'), {'spam': 'foo'}))
130 self.assertEqual(a, {})
131
132 def test_context_run_3(self):
133 ctx = contextvars.Context()
134
135 def func(*args, **kwargs):
136 1 / 0
137
138 with self.assertRaises(ZeroDivisionError):
139 ctx.run(func)
140 with self.assertRaises(ZeroDivisionError):
141 ctx.run(func, 1, 2)
142 with self.assertRaises(ZeroDivisionError):
143 ctx.run(func, 1, 2, a=123)
144
145 @isolated_context
146 def test_context_run_4(self):
147 ctx1 = contextvars.Context()
148 ctx2 = contextvars.Context()
149 var = contextvars.ContextVar('var')
150
151 def func2():
152 self.assertIsNone(var.get(None))
153
154 def func1():
155 self.assertIsNone(var.get(None))
156 var.set('spam')
157 ctx2.run(func2)
158 self.assertEqual(var.get(None), 'spam')
159
160 cur = contextvars.copy_context()
161 self.assertEqual(len(cur), 1)
162 self.assertEqual(cur[var], 'spam')
163 return cur
164
165 returned_ctx = ctx1.run(func1)
166 self.assertEqual(ctx1, returned_ctx)
167 self.assertEqual(returned_ctx[var], 'spam')
168 self.assertIn(var, returned_ctx)
169
170 def test_context_run_5(self):
171 ctx = contextvars.Context()
172 var = contextvars.ContextVar('var')
173
174 def func():
175 self.assertIsNone(var.get(None))
176 var.set('spam')
177 1 / 0
178
179 with self.assertRaises(ZeroDivisionError):
180 ctx.run(func)
181
182 self.assertIsNone(var.get(None))
183
184 def test_context_run_6(self):
185 ctx = contextvars.Context()
186 c = contextvars.ContextVar('a', default=0)
187
188 def fun():
189 self.assertEqual(c.get(), 0)
190 self.assertIsNone(ctx.get(c))
191
192 c.set(42)
193 self.assertEqual(c.get(), 42)
194 self.assertEqual(ctx.get(c), 42)
195
196 ctx.run(fun)
197
198 def test_context_run_7(self):
199 ctx = contextvars.Context()
200
201 def fun():
202 with self.assertRaisesRegex(RuntimeError, 'is already entered'):
203 ctx.run(fun)
204
205 ctx.run(fun)
206
207 @isolated_context
208 def test_context_getset_1(self):
209 c = contextvars.ContextVar('c')
210 with self.assertRaises(LookupError):
211 c.get()
212
213 self.assertIsNone(c.get(None))
214
215 t0 = c.set(42)
216 self.assertEqual(c.get(), 42)
217 self.assertEqual(c.get(None), 42)
218 self.assertIs(t0.old_value, t0.MISSING)
219 self.assertIs(t0.old_value, contextvars.Token.MISSING)
220 self.assertIs(t0.var, c)
221
222 t = c.set('spam')
223 self.assertEqual(c.get(), 'spam')
224 self.assertEqual(c.get(None), 'spam')
225 self.assertEqual(t.old_value, 42)
226 c.reset(t)
227
228 self.assertEqual(c.get(), 42)
229 self.assertEqual(c.get(None), 42)
230
231 c.set('spam2')
232 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
233 c.reset(t)
234 self.assertEqual(c.get(), 'spam2')
235
236 ctx1 = contextvars.copy_context()
237 self.assertIn(c, ctx1)
238
239 c.reset(t0)
240 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
241 c.reset(t0)
242 self.assertIsNone(c.get(None))
243
244 self.assertIn(c, ctx1)
245 self.assertEqual(ctx1[c], 'spam2')
246 self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
247 self.assertEqual(len(ctx1), 1)
248 self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
249 self.assertEqual(list(ctx1.values()), ['spam2'])
250 self.assertEqual(list(ctx1.keys()), [c])
251 self.assertEqual(list(ctx1), [c])
252
253 ctx2 = contextvars.copy_context()
254 self.assertNotIn(c, ctx2)
255 with self.assertRaises(KeyError):
256 ctx2[c]
257 self.assertEqual(ctx2.get(c, 'aa'), 'aa')
258 self.assertEqual(len(ctx2), 0)
259 self.assertEqual(list(ctx2), [])
260
261 @isolated_context
262 def test_context_getset_2(self):
263 v1 = contextvars.ContextVar('v1')
264 v2 = contextvars.ContextVar('v2')
265
266 t1 = v1.set(42)
267 with self.assertRaisesRegex(ValueError, 'by a different'):
268 v2.reset(t1)
269
270 @isolated_context
271 def test_context_getset_3(self):
272 c = contextvars.ContextVar('c', default=42)
273 ctx = contextvars.Context()
274
275 def fun():
276 self.assertEqual(c.get(), 42)
277 with self.assertRaises(KeyError):
278 ctx[c]
279 self.assertIsNone(ctx.get(c))
280 self.assertEqual(ctx.get(c, 'spam'), 'spam')
281 self.assertNotIn(c, ctx)
282 self.assertEqual(list(ctx.keys()), [])
283
284 t = c.set(1)
285 self.assertEqual(list(ctx.keys()), [c])
286 self.assertEqual(ctx[c], 1)
287
288 c.reset(t)
289 self.assertEqual(list(ctx.keys()), [])
290 with self.assertRaises(KeyError):
291 ctx[c]
292
293 ctx.run(fun)
294
295 @isolated_context
296 def test_context_getset_4(self):
297 c = contextvars.ContextVar('c', default=42)
298 ctx = contextvars.Context()
299
300 tok = ctx.run(c.set, 1)
301
302 with self.assertRaisesRegex(ValueError, 'different Context'):
303 c.reset(tok)
304
305 @isolated_context
306 def test_context_getset_5(self):
307 c = contextvars.ContextVar('c', default=42)
308 c.set([])
309
310 def fun():
311 c.set([])
312 c.get().append(42)
313 self.assertEqual(c.get(), [42])
314
315 contextvars.copy_context().run(fun)
316 self.assertEqual(c.get(), [])
317
318 def test_context_copy_1(self):
319 ctx1 = contextvars.Context()
320 c = contextvars.ContextVar('c', default=42)
321
322 def ctx1_fun():
323 c.set(10)
324
325 ctx2 = ctx1.copy()
326 self.assertEqual(ctx2[c], 10)
327
328 c.set(20)
329 self.assertEqual(ctx1[c], 20)
330 self.assertEqual(ctx2[c], 10)
331
332 ctx2.run(ctx2_fun)
333 self.assertEqual(ctx1[c], 20)
334 self.assertEqual(ctx2[c], 30)
335
336 def ctx2_fun():
337 self.assertEqual(c.get(), 10)
338 c.set(30)
339 self.assertEqual(c.get(), 30)
340
341 ctx1.run(ctx1_fun)
342
343 @isolated_context
344 def test_context_threads_1(self):
345 cvar = contextvars.ContextVar('cvar')
346
347 def sub(num):
348 for i in range(10):
349 cvar.set(num + i)
350 time.sleep(random.uniform(0.001, 0.05))
351 self.assertEqual(cvar.get(), num + i)
352 return num
353
354 tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
355 try:
356 results = list(tp.map(sub, range(10)))
357 finally:
358 tp.shutdown()
359 self.assertEqual(results, list(range(10)))
360
Yury Selivanovf23746a2018-01-22 19:11:18 -0500361
362# HAMT Tests
363
364
365class HashKey:
366 _crasher = None
367
368 def __init__(self, hash, name, *, error_on_eq_to=None):
369 assert hash != -1
370 self.name = name
371 self.hash = hash
372 self.error_on_eq_to = error_on_eq_to
373
374 def __repr__(self):
375 return f'<Key name:{self.name} hash:{self.hash}>'
376
377 def __hash__(self):
378 if self._crasher is not None and self._crasher.error_on_hash:
379 raise HashingError
380
381 return self.hash
382
383 def __eq__(self, other):
384 if not isinstance(other, HashKey):
385 return NotImplemented
386
387 if self._crasher is not None and self._crasher.error_on_eq:
388 raise EqError
389
390 if self.error_on_eq_to is not None and self.error_on_eq_to is other:
391 raise ValueError(f'cannot compare {self!r} to {other!r}')
392 if other.error_on_eq_to is not None and other.error_on_eq_to is self:
393 raise ValueError(f'cannot compare {other!r} to {self!r}')
394
395 return (self.name, self.hash) == (other.name, other.hash)
396
397
398class KeyStr(str):
399 def __hash__(self):
400 if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
401 raise HashingError
402 return super().__hash__()
403
404 def __eq__(self, other):
405 if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
406 raise EqError
407 return super().__eq__(other)
408
409
410class HaskKeyCrasher:
411 def __init__(self, *, error_on_hash=False, error_on_eq=False):
412 self.error_on_hash = error_on_hash
413 self.error_on_eq = error_on_eq
414
415 def __enter__(self):
416 if HashKey._crasher is not None:
417 raise RuntimeError('cannot nest crashers')
418 HashKey._crasher = self
419
420 def __exit__(self, *exc):
421 HashKey._crasher = None
422
423
424class HashingError(Exception):
425 pass
426
427
428class EqError(Exception):
429 pass
430
431
432@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
433class HamtTest(unittest.TestCase):
434
435 def test_hashkey_helper_1(self):
436 k1 = HashKey(10, 'aaa')
437 k2 = HashKey(10, 'bbb')
438
439 self.assertNotEqual(k1, k2)
440 self.assertEqual(hash(k1), hash(k2))
441
442 d = dict()
443 d[k1] = 'a'
444 d[k2] = 'b'
445
446 self.assertEqual(d[k1], 'a')
447 self.assertEqual(d[k2], 'b')
448
449 def test_hamt_basics_1(self):
450 h = hamt()
451 h = None # NoQA
452
453 def test_hamt_basics_2(self):
454 h = hamt()
455 self.assertEqual(len(h), 0)
456
457 h2 = h.set('a', 'b')
458 self.assertIsNot(h, h2)
459 self.assertEqual(len(h), 0)
460 self.assertEqual(len(h2), 1)
461
462 self.assertIsNone(h.get('a'))
463 self.assertEqual(h.get('a', 42), 42)
464
465 self.assertEqual(h2.get('a'), 'b')
466
467 h3 = h2.set('b', 10)
468 self.assertIsNot(h2, h3)
469 self.assertEqual(len(h), 0)
470 self.assertEqual(len(h2), 1)
471 self.assertEqual(len(h3), 2)
472 self.assertEqual(h3.get('a'), 'b')
473 self.assertEqual(h3.get('b'), 10)
474
475 self.assertIsNone(h.get('b'))
476 self.assertIsNone(h2.get('b'))
477
478 self.assertIsNone(h.get('a'))
479 self.assertEqual(h2.get('a'), 'b')
480
481 h = h2 = h3 = None
482
483 def test_hamt_basics_3(self):
484 h = hamt()
485 o = object()
486 h1 = h.set('1', o)
487 h2 = h1.set('1', o)
488 self.assertIs(h1, h2)
489
490 def test_hamt_basics_4(self):
491 h = hamt()
492 h1 = h.set('key', [])
493 h2 = h1.set('key', [])
494 self.assertIsNot(h1, h2)
495 self.assertEqual(len(h1), 1)
496 self.assertEqual(len(h2), 1)
497 self.assertIsNot(h1.get('key'), h2.get('key'))
498
499 def test_hamt_collision_1(self):
500 k1 = HashKey(10, 'aaa')
501 k2 = HashKey(10, 'bbb')
502 k3 = HashKey(10, 'ccc')
503
504 h = hamt()
505 h2 = h.set(k1, 'a')
506 h3 = h2.set(k2, 'b')
507
508 self.assertEqual(h.get(k1), None)
509 self.assertEqual(h.get(k2), None)
510
511 self.assertEqual(h2.get(k1), 'a')
512 self.assertEqual(h2.get(k2), None)
513
514 self.assertEqual(h3.get(k1), 'a')
515 self.assertEqual(h3.get(k2), 'b')
516
517 h4 = h3.set(k2, 'cc')
518 h5 = h4.set(k3, 'aa')
519
520 self.assertEqual(h3.get(k1), 'a')
521 self.assertEqual(h3.get(k2), 'b')
522 self.assertEqual(h4.get(k1), 'a')
523 self.assertEqual(h4.get(k2), 'cc')
524 self.assertEqual(h4.get(k3), None)
525 self.assertEqual(h5.get(k1), 'a')
526 self.assertEqual(h5.get(k2), 'cc')
527 self.assertEqual(h5.get(k2), 'cc')
528 self.assertEqual(h5.get(k3), 'aa')
529
530 self.assertEqual(len(h), 0)
531 self.assertEqual(len(h2), 1)
532 self.assertEqual(len(h3), 2)
533 self.assertEqual(len(h4), 2)
534 self.assertEqual(len(h5), 3)
535
536 def test_hamt_stress(self):
537 COLLECTION_SIZE = 7000
538 TEST_ITERS_EVERY = 647
539 CRASH_HASH_EVERY = 97
540 CRASH_EQ_EVERY = 11
541 RUN_XTIMES = 3
542
543 for _ in range(RUN_XTIMES):
544 h = hamt()
545 d = dict()
546
547 for i in range(COLLECTION_SIZE):
548 key = KeyStr(i)
549
550 if not (i % CRASH_HASH_EVERY):
551 with HaskKeyCrasher(error_on_hash=True):
552 with self.assertRaises(HashingError):
553 h.set(key, i)
554
555 h = h.set(key, i)
556
557 if not (i % CRASH_EQ_EVERY):
558 with HaskKeyCrasher(error_on_eq=True):
559 with self.assertRaises(EqError):
560 h.get(KeyStr(i)) # really trigger __eq__
561
562 d[key] = i
563 self.assertEqual(len(d), len(h))
564
565 if not (i % TEST_ITERS_EVERY):
566 self.assertEqual(set(h.items()), set(d.items()))
567 self.assertEqual(len(h.items()), len(d.items()))
568
569 self.assertEqual(len(h), COLLECTION_SIZE)
570
571 for key in range(COLLECTION_SIZE):
572 self.assertEqual(h.get(KeyStr(key), 'not found'), key)
573
574 keys_to_delete = list(range(COLLECTION_SIZE))
575 random.shuffle(keys_to_delete)
576 for iter_i, i in enumerate(keys_to_delete):
577 key = KeyStr(i)
578
579 if not (iter_i % CRASH_HASH_EVERY):
580 with HaskKeyCrasher(error_on_hash=True):
581 with self.assertRaises(HashingError):
582 h.delete(key)
583
584 if not (iter_i % CRASH_EQ_EVERY):
585 with HaskKeyCrasher(error_on_eq=True):
586 with self.assertRaises(EqError):
587 h.delete(KeyStr(i))
588
589 h = h.delete(key)
590 self.assertEqual(h.get(key, 'not found'), 'not found')
591 del d[key]
592 self.assertEqual(len(d), len(h))
593
594 if iter_i == COLLECTION_SIZE // 2:
595 hm = h
596 dm = d.copy()
597
598 if not (iter_i % TEST_ITERS_EVERY):
599 self.assertEqual(set(h.keys()), set(d.keys()))
600 self.assertEqual(len(h.keys()), len(d.keys()))
601
602 self.assertEqual(len(d), 0)
603 self.assertEqual(len(h), 0)
604
605 # ============
606
607 for key in dm:
608 self.assertEqual(hm.get(str(key)), dm[key])
609 self.assertEqual(len(dm), len(hm))
610
611 for i, key in enumerate(keys_to_delete):
612 hm = hm.delete(str(key))
613 self.assertEqual(hm.get(str(key), 'not found'), 'not found')
614 dm.pop(str(key), None)
615 self.assertEqual(len(d), len(h))
616
617 if not (i % TEST_ITERS_EVERY):
618 self.assertEqual(set(h.values()), set(d.values()))
619 self.assertEqual(len(h.values()), len(d.values()))
620
621 self.assertEqual(len(d), 0)
622 self.assertEqual(len(h), 0)
623 self.assertEqual(list(h.items()), [])
624
625 def test_hamt_delete_1(self):
626 A = HashKey(100, 'A')
627 B = HashKey(101, 'B')
628 C = HashKey(102, 'C')
629 D = HashKey(103, 'D')
630 E = HashKey(104, 'E')
631 Z = HashKey(-100, 'Z')
632
633 Er = HashKey(103, 'Er', error_on_eq_to=D)
634
635 h = hamt()
636 h = h.set(A, 'a')
637 h = h.set(B, 'b')
638 h = h.set(C, 'c')
639 h = h.set(D, 'd')
640 h = h.set(E, 'e')
641
642 orig_len = len(h)
643
644 # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
645 # <Key name:A hash:100>: 'a'
646 # <Key name:B hash:101>: 'b'
647 # <Key name:C hash:102>: 'c'
648 # <Key name:D hash:103>: 'd'
649 # <Key name:E hash:104>: 'e'
650
651 h = h.delete(C)
652 self.assertEqual(len(h), orig_len - 1)
653
654 with self.assertRaisesRegex(ValueError, 'cannot compare'):
655 h.delete(Er)
656
657 h = h.delete(D)
658 self.assertEqual(len(h), orig_len - 2)
659
660 h2 = h.delete(Z)
661 self.assertIs(h2, h)
662
663 h = h.delete(A)
664 self.assertEqual(len(h), orig_len - 3)
665
666 self.assertEqual(h.get(A, 42), 42)
667 self.assertEqual(h.get(B), 'b')
668 self.assertEqual(h.get(E), 'e')
669
670 def test_hamt_delete_2(self):
671 A = HashKey(100, 'A')
672 B = HashKey(201001, 'B')
673 C = HashKey(101001, 'C')
674 D = HashKey(103, 'D')
675 E = HashKey(104, 'E')
676 Z = HashKey(-100, 'Z')
677
678 Er = HashKey(201001, 'Er', error_on_eq_to=B)
679
680 h = hamt()
681 h = h.set(A, 'a')
682 h = h.set(B, 'b')
683 h = h.set(C, 'c')
684 h = h.set(D, 'd')
685 h = h.set(E, 'e')
686
687 orig_len = len(h)
688
689 # BitmapNode(size=8 bitmap=0b1110010000):
690 # <Key name:A hash:100>: 'a'
691 # <Key name:D hash:103>: 'd'
692 # <Key name:E hash:104>: 'e'
693 # NULL:
694 # BitmapNode(size=4 bitmap=0b100000000001000000000):
695 # <Key name:B hash:201001>: 'b'
696 # <Key name:C hash:101001>: 'c'
697
698 with self.assertRaisesRegex(ValueError, 'cannot compare'):
699 h.delete(Er)
700
701 h = h.delete(Z)
702 self.assertEqual(len(h), orig_len)
703
704 h = h.delete(C)
705 self.assertEqual(len(h), orig_len - 1)
706
707 h = h.delete(B)
708 self.assertEqual(len(h), orig_len - 2)
709
710 h = h.delete(A)
711 self.assertEqual(len(h), orig_len - 3)
712
713 self.assertEqual(h.get(D), 'd')
714 self.assertEqual(h.get(E), 'e')
715
716 h = h.delete(A)
717 h = h.delete(B)
718 h = h.delete(D)
719 h = h.delete(E)
720 self.assertEqual(len(h), 0)
721
722 def test_hamt_delete_3(self):
723 A = HashKey(100, 'A')
724 B = HashKey(101, 'B')
725 C = HashKey(100100, 'C')
726 D = HashKey(100100, 'D')
727 E = HashKey(104, 'E')
728
729 h = hamt()
730 h = h.set(A, 'a')
731 h = h.set(B, 'b')
732 h = h.set(C, 'c')
733 h = h.set(D, 'd')
734 h = h.set(E, 'e')
735
736 orig_len = len(h)
737
738 # BitmapNode(size=6 bitmap=0b100110000):
739 # NULL:
740 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
741 # <Key name:A hash:100>: 'a'
742 # NULL:
743 # CollisionNode(size=4 id=0x108572410):
744 # <Key name:C hash:100100>: 'c'
745 # <Key name:D hash:100100>: 'd'
746 # <Key name:B hash:101>: 'b'
747 # <Key name:E hash:104>: 'e'
748
749 h = h.delete(A)
750 self.assertEqual(len(h), orig_len - 1)
751
752 h = h.delete(E)
753 self.assertEqual(len(h), orig_len - 2)
754
755 self.assertEqual(h.get(C), 'c')
756 self.assertEqual(h.get(B), 'b')
757
758 def test_hamt_delete_4(self):
759 A = HashKey(100, 'A')
760 B = HashKey(101, 'B')
761 C = HashKey(100100, 'C')
762 D = HashKey(100100, 'D')
763 E = HashKey(100100, 'E')
764
765 h = hamt()
766 h = h.set(A, 'a')
767 h = h.set(B, 'b')
768 h = h.set(C, 'c')
769 h = h.set(D, 'd')
770 h = h.set(E, 'e')
771
772 orig_len = len(h)
773
774 # BitmapNode(size=4 bitmap=0b110000):
775 # NULL:
776 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
777 # <Key name:A hash:100>: 'a'
778 # NULL:
779 # CollisionNode(size=6 id=0x10515ef30):
780 # <Key name:C hash:100100>: 'c'
781 # <Key name:D hash:100100>: 'd'
782 # <Key name:E hash:100100>: 'e'
783 # <Key name:B hash:101>: 'b'
784
785 h = h.delete(D)
786 self.assertEqual(len(h), orig_len - 1)
787
788 h = h.delete(E)
789 self.assertEqual(len(h), orig_len - 2)
790
791 h = h.delete(C)
792 self.assertEqual(len(h), orig_len - 3)
793
794 h = h.delete(A)
795 self.assertEqual(len(h), orig_len - 4)
796
797 h = h.delete(B)
798 self.assertEqual(len(h), 0)
799
800 def test_hamt_delete_5(self):
801 h = hamt()
802
803 keys = []
804 for i in range(17):
805 key = HashKey(i, str(i))
806 keys.append(key)
807 h = h.set(key, f'val-{i}')
808
809 collision_key16 = HashKey(16, '18')
810 h = h.set(collision_key16, 'collision')
811
812 # ArrayNode(id=0x10f8b9318):
813 # 0::
814 # BitmapNode(size=2 count=1 bitmap=0b1):
815 # <Key name:0 hash:0>: 'val-0'
816 #
817 # ... 14 more BitmapNodes ...
818 #
819 # 15::
820 # BitmapNode(size=2 count=1 bitmap=0b1):
821 # <Key name:15 hash:15>: 'val-15'
822 #
823 # 16::
824 # BitmapNode(size=2 count=1 bitmap=0b1):
825 # NULL:
826 # CollisionNode(size=4 id=0x10f2f5af8):
827 # <Key name:16 hash:16>: 'val-16'
828 # <Key name:18 hash:16>: 'collision'
829
830 self.assertEqual(len(h), 18)
831
832 h = h.delete(keys[2])
833 self.assertEqual(len(h), 17)
834
835 h = h.delete(collision_key16)
836 self.assertEqual(len(h), 16)
837 h = h.delete(keys[16])
838 self.assertEqual(len(h), 15)
839
840 h = h.delete(keys[1])
841 self.assertEqual(len(h), 14)
842 h = h.delete(keys[1])
843 self.assertEqual(len(h), 14)
844
845 for key in keys:
846 h = h.delete(key)
847 self.assertEqual(len(h), 0)
848
849 def test_hamt_items_1(self):
850 A = HashKey(100, 'A')
851 B = HashKey(201001, 'B')
852 C = HashKey(101001, 'C')
853 D = HashKey(103, 'D')
854 E = HashKey(104, 'E')
855 F = HashKey(110, 'F')
856
857 h = hamt()
858 h = h.set(A, 'a')
859 h = h.set(B, 'b')
860 h = h.set(C, 'c')
861 h = h.set(D, 'd')
862 h = h.set(E, 'e')
863 h = h.set(F, 'f')
864
865 it = h.items()
866 self.assertEqual(
867 set(list(it)),
868 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
869
870 def test_hamt_items_2(self):
871 A = HashKey(100, 'A')
872 B = HashKey(101, 'B')
873 C = HashKey(100100, 'C')
874 D = HashKey(100100, 'D')
875 E = HashKey(100100, 'E')
876 F = HashKey(110, 'F')
877
878 h = hamt()
879 h = h.set(A, 'a')
880 h = h.set(B, 'b')
881 h = h.set(C, 'c')
882 h = h.set(D, 'd')
883 h = h.set(E, 'e')
884 h = h.set(F, 'f')
885
886 it = h.items()
887 self.assertEqual(
888 set(list(it)),
889 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
890
891 def test_hamt_keys_1(self):
892 A = HashKey(100, 'A')
893 B = HashKey(101, 'B')
894 C = HashKey(100100, 'C')
895 D = HashKey(100100, 'D')
896 E = HashKey(100100, 'E')
897 F = HashKey(110, 'F')
898
899 h = hamt()
900 h = h.set(A, 'a')
901 h = h.set(B, 'b')
902 h = h.set(C, 'c')
903 h = h.set(D, 'd')
904 h = h.set(E, 'e')
905 h = h.set(F, 'f')
906
907 self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
908 self.assertEqual(set(list(h)), {A, B, C, D, E, F})
909
910 def test_hamt_items_3(self):
911 h = hamt()
912 self.assertEqual(len(h.items()), 0)
913 self.assertEqual(list(h.items()), [])
914
915 def test_hamt_eq_1(self):
916 A = HashKey(100, 'A')
917 B = HashKey(101, 'B')
918 C = HashKey(100100, 'C')
919 D = HashKey(100100, 'D')
920 E = HashKey(120, 'E')
921
922 h1 = hamt()
923 h1 = h1.set(A, 'a')
924 h1 = h1.set(B, 'b')
925 h1 = h1.set(C, 'c')
926 h1 = h1.set(D, 'd')
927
928 h2 = hamt()
929 h2 = h2.set(A, 'a')
930
931 self.assertFalse(h1 == h2)
932 self.assertTrue(h1 != h2)
933
934 h2 = h2.set(B, 'b')
935 self.assertFalse(h1 == h2)
936 self.assertTrue(h1 != h2)
937
938 h2 = h2.set(C, 'c')
939 self.assertFalse(h1 == h2)
940 self.assertTrue(h1 != h2)
941
942 h2 = h2.set(D, 'd2')
943 self.assertFalse(h1 == h2)
944 self.assertTrue(h1 != h2)
945
946 h2 = h2.set(D, 'd')
947 self.assertTrue(h1 == h2)
948 self.assertFalse(h1 != h2)
949
950 h2 = h2.set(E, 'e')
951 self.assertFalse(h1 == h2)
952 self.assertTrue(h1 != h2)
953
954 h2 = h2.delete(D)
955 self.assertFalse(h1 == h2)
956 self.assertTrue(h1 != h2)
957
958 h2 = h2.set(E, 'd')
959 self.assertFalse(h1 == h2)
960 self.assertTrue(h1 != h2)
961
962 def test_hamt_eq_2(self):
963 A = HashKey(100, 'A')
964 Er = HashKey(100, 'Er', error_on_eq_to=A)
965
966 h1 = hamt()
967 h1 = h1.set(A, 'a')
968
969 h2 = hamt()
970 h2 = h2.set(Er, 'a')
971
972 with self.assertRaisesRegex(ValueError, 'cannot compare'):
973 h1 == h2
974
975 with self.assertRaisesRegex(ValueError, 'cannot compare'):
976 h1 != h2
977
978 def test_hamt_gc_1(self):
979 A = HashKey(100, 'A')
980
981 h = hamt()
982 h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
983 ref = weakref.ref(h)
984
985 a = []
986 a.append(a)
987 a.append(h)
988 b = []
989 a.append(b)
990 b.append(a)
991 h = h.set(A, b)
992
993 del h, a, b
994
995 gc.collect()
996 gc.collect()
997 gc.collect()
998
999 self.assertIsNone(ref())
1000
1001 def test_hamt_gc_2(self):
1002 A = HashKey(100, 'A')
1003 B = HashKey(101, 'B')
1004
1005 h = hamt()
1006 h = h.set(A, 'a')
1007 h = h.set(A, h)
1008
1009 ref = weakref.ref(h)
1010 hi = h.items()
1011 next(hi)
1012
1013 del h, hi
1014
1015 gc.collect()
1016 gc.collect()
1017 gc.collect()
1018
1019 self.assertIsNone(ref())
1020
1021 def test_hamt_in_1(self):
1022 A = HashKey(100, 'A')
1023 AA = HashKey(100, 'A')
1024
1025 B = HashKey(101, 'B')
1026
1027 h = hamt()
1028 h = h.set(A, 1)
1029
1030 self.assertTrue(A in h)
1031 self.assertFalse(B in h)
1032
1033 with self.assertRaises(EqError):
1034 with HaskKeyCrasher(error_on_eq=True):
1035 AA in h
1036
1037 with self.assertRaises(HashingError):
1038 with HaskKeyCrasher(error_on_hash=True):
1039 AA in h
1040
1041 def test_hamt_getitem_1(self):
1042 A = HashKey(100, 'A')
1043 AA = HashKey(100, 'A')
1044
1045 B = HashKey(101, 'B')
1046
1047 h = hamt()
1048 h = h.set(A, 1)
1049
1050 self.assertEqual(h[A], 1)
1051 self.assertEqual(h[AA], 1)
1052
1053 with self.assertRaises(KeyError):
1054 h[B]
1055
1056 with self.assertRaises(EqError):
1057 with HaskKeyCrasher(error_on_eq=True):
1058 h[AA]
1059
1060 with self.assertRaises(HashingError):
1061 with HaskKeyCrasher(error_on_hash=True):
1062 h[AA]
1063
1064
1065if __name__ == "__main__":
1066 unittest.main()