blob: 74d05fc5c412cd8455ba4bf51a64c106b418952d [file] [log] [blame]
Yury Selivanovf23746a2018-01-22 19:11:18 -05001import concurrent.futures
2import contextvars
3import functools
4import gc
5import random
6import time
7import unittest
8import weakref
9
10try:
11 from _testcapi import hamt
12except ImportError:
13 hamt = None
14
15
16def isolated_context(func):
17 """Needed to make reftracking test mode work."""
18 @functools.wraps(func)
19 def wrapper(*args, **kwargs):
20 ctx = contextvars.Context()
21 return ctx.run(func, *args, **kwargs)
22 return wrapper
23
24
25class ContextTest(unittest.TestCase):
26 def test_context_var_new_1(self):
27 with self.assertRaisesRegex(TypeError, 'takes exactly 1'):
28 contextvars.ContextVar()
29
30 with self.assertRaisesRegex(TypeError, 'must be a str'):
31 contextvars.ContextVar(1)
32
33 c = contextvars.ContextVar('a')
34 self.assertNotEqual(hash(c), hash('a'))
35
36 def test_context_var_new_2(self):
37 self.assertIsNone(contextvars.ContextVar[int])
38
39 @isolated_context
40 def test_context_var_repr_1(self):
41 c = contextvars.ContextVar('a')
42 self.assertIn('a', repr(c))
43
44 c = contextvars.ContextVar('a', default=123)
45 self.assertIn('123', repr(c))
46
47 lst = []
48 c = contextvars.ContextVar('a', default=lst)
49 lst.append(c)
50 self.assertIn('...', repr(c))
51 self.assertIn('...', repr(lst))
52
53 t = c.set(1)
54 self.assertIn(repr(c), repr(t))
55 self.assertNotIn(' used ', repr(t))
56 c.reset(t)
57 self.assertIn(' used ', repr(t))
58
59 def test_context_subclassing_1(self):
60 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
61 class MyContextVar(contextvars.ContextVar):
62 # Potentially we might want ContextVars to be subclassable.
63 pass
64
65 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
66 class MyContext(contextvars.Context):
67 pass
68
69 with self.assertRaisesRegex(TypeError, 'not an acceptable base type'):
70 class MyToken(contextvars.Token):
71 pass
72
73 def test_context_new_1(self):
74 with self.assertRaisesRegex(TypeError, 'any arguments'):
75 contextvars.Context(1)
76 with self.assertRaisesRegex(TypeError, 'any arguments'):
77 contextvars.Context(1, a=1)
78 with self.assertRaisesRegex(TypeError, 'any arguments'):
79 contextvars.Context(a=1)
80 contextvars.Context(**{})
81
82 def test_context_typerrors_1(self):
83 ctx = contextvars.Context()
84
85 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
86 ctx[1]
87 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
88 1 in ctx
89 with self.assertRaisesRegex(TypeError, 'ContextVar key was expected'):
90 ctx.get(1)
91
92 def test_context_get_context_1(self):
93 ctx = contextvars.copy_context()
94 self.assertIsInstance(ctx, contextvars.Context)
95
96 def test_context_run_1(self):
97 ctx = contextvars.Context()
98
99 with self.assertRaisesRegex(TypeError, 'missing 1 required'):
100 ctx.run()
101
102 def test_context_run_2(self):
103 ctx = contextvars.Context()
104
105 def func(*args, **kwargs):
106 kwargs['spam'] = 'foo'
107 args += ('bar',)
108 return args, kwargs
109
110 for f in (func, functools.partial(func)):
111 # partial doesn't support FASTCALL
112
113 self.assertEqual(ctx.run(f), (('bar',), {'spam': 'foo'}))
114 self.assertEqual(ctx.run(f, 1), ((1, 'bar'), {'spam': 'foo'}))
115
116 self.assertEqual(
117 ctx.run(f, a=2),
118 (('bar',), {'a': 2, 'spam': 'foo'}))
119
120 self.assertEqual(
121 ctx.run(f, 11, a=2),
122 ((11, 'bar'), {'a': 2, 'spam': 'foo'}))
123
124 a = {}
125 self.assertEqual(
126 ctx.run(f, 11, **a),
127 ((11, 'bar'), {'spam': 'foo'}))
128 self.assertEqual(a, {})
129
130 def test_context_run_3(self):
131 ctx = contextvars.Context()
132
133 def func(*args, **kwargs):
134 1 / 0
135
136 with self.assertRaises(ZeroDivisionError):
137 ctx.run(func)
138 with self.assertRaises(ZeroDivisionError):
139 ctx.run(func, 1, 2)
140 with self.assertRaises(ZeroDivisionError):
141 ctx.run(func, 1, 2, a=123)
142
143 @isolated_context
144 def test_context_run_4(self):
145 ctx1 = contextvars.Context()
146 ctx2 = contextvars.Context()
147 var = contextvars.ContextVar('var')
148
149 def func2():
150 self.assertIsNone(var.get(None))
151
152 def func1():
153 self.assertIsNone(var.get(None))
154 var.set('spam')
155 ctx2.run(func2)
156 self.assertEqual(var.get(None), 'spam')
157
158 cur = contextvars.copy_context()
159 self.assertEqual(len(cur), 1)
160 self.assertEqual(cur[var], 'spam')
161 return cur
162
163 returned_ctx = ctx1.run(func1)
164 self.assertEqual(ctx1, returned_ctx)
165 self.assertEqual(returned_ctx[var], 'spam')
166 self.assertIn(var, returned_ctx)
167
168 def test_context_run_5(self):
169 ctx = contextvars.Context()
170 var = contextvars.ContextVar('var')
171
172 def func():
173 self.assertIsNone(var.get(None))
174 var.set('spam')
175 1 / 0
176
177 with self.assertRaises(ZeroDivisionError):
178 ctx.run(func)
179
180 self.assertIsNone(var.get(None))
181
182 def test_context_run_6(self):
183 ctx = contextvars.Context()
184 c = contextvars.ContextVar('a', default=0)
185
186 def fun():
187 self.assertEqual(c.get(), 0)
188 self.assertIsNone(ctx.get(c))
189
190 c.set(42)
191 self.assertEqual(c.get(), 42)
192 self.assertEqual(ctx.get(c), 42)
193
194 ctx.run(fun)
195
196 def test_context_run_7(self):
197 ctx = contextvars.Context()
198
199 def fun():
200 with self.assertRaisesRegex(RuntimeError, 'is already entered'):
201 ctx.run(fun)
202
203 ctx.run(fun)
204
205 @isolated_context
206 def test_context_getset_1(self):
207 c = contextvars.ContextVar('c')
208 with self.assertRaises(LookupError):
209 c.get()
210
211 self.assertIsNone(c.get(None))
212
213 t0 = c.set(42)
214 self.assertEqual(c.get(), 42)
215 self.assertEqual(c.get(None), 42)
216 self.assertIs(t0.old_value, t0.MISSING)
217 self.assertIs(t0.old_value, contextvars.Token.MISSING)
218 self.assertIs(t0.var, c)
219
220 t = c.set('spam')
221 self.assertEqual(c.get(), 'spam')
222 self.assertEqual(c.get(None), 'spam')
223 self.assertEqual(t.old_value, 42)
224 c.reset(t)
225
226 self.assertEqual(c.get(), 42)
227 self.assertEqual(c.get(None), 42)
228
229 c.set('spam2')
230 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
231 c.reset(t)
232 self.assertEqual(c.get(), 'spam2')
233
234 ctx1 = contextvars.copy_context()
235 self.assertIn(c, ctx1)
236
237 c.reset(t0)
238 with self.assertRaisesRegex(RuntimeError, 'has already been used'):
239 c.reset(t0)
240 self.assertIsNone(c.get(None))
241
242 self.assertIn(c, ctx1)
243 self.assertEqual(ctx1[c], 'spam2')
244 self.assertEqual(ctx1.get(c, 'aa'), 'spam2')
245 self.assertEqual(len(ctx1), 1)
246 self.assertEqual(list(ctx1.items()), [(c, 'spam2')])
247 self.assertEqual(list(ctx1.values()), ['spam2'])
248 self.assertEqual(list(ctx1.keys()), [c])
249 self.assertEqual(list(ctx1), [c])
250
251 ctx2 = contextvars.copy_context()
252 self.assertNotIn(c, ctx2)
253 with self.assertRaises(KeyError):
254 ctx2[c]
255 self.assertEqual(ctx2.get(c, 'aa'), 'aa')
256 self.assertEqual(len(ctx2), 0)
257 self.assertEqual(list(ctx2), [])
258
259 @isolated_context
260 def test_context_getset_2(self):
261 v1 = contextvars.ContextVar('v1')
262 v2 = contextvars.ContextVar('v2')
263
264 t1 = v1.set(42)
265 with self.assertRaisesRegex(ValueError, 'by a different'):
266 v2.reset(t1)
267
268 @isolated_context
269 def test_context_getset_3(self):
270 c = contextvars.ContextVar('c', default=42)
271 ctx = contextvars.Context()
272
273 def fun():
274 self.assertEqual(c.get(), 42)
275 with self.assertRaises(KeyError):
276 ctx[c]
277 self.assertIsNone(ctx.get(c))
278 self.assertEqual(ctx.get(c, 'spam'), 'spam')
279 self.assertNotIn(c, ctx)
280 self.assertEqual(list(ctx.keys()), [])
281
282 t = c.set(1)
283 self.assertEqual(list(ctx.keys()), [c])
284 self.assertEqual(ctx[c], 1)
285
286 c.reset(t)
287 self.assertEqual(list(ctx.keys()), [])
288 with self.assertRaises(KeyError):
289 ctx[c]
290
291 ctx.run(fun)
292
293 @isolated_context
294 def test_context_getset_4(self):
295 c = contextvars.ContextVar('c', default=42)
296 ctx = contextvars.Context()
297
298 tok = ctx.run(c.set, 1)
299
300 with self.assertRaisesRegex(ValueError, 'different Context'):
301 c.reset(tok)
302
303 @isolated_context
304 def test_context_getset_5(self):
305 c = contextvars.ContextVar('c', default=42)
306 c.set([])
307
308 def fun():
309 c.set([])
310 c.get().append(42)
311 self.assertEqual(c.get(), [42])
312
313 contextvars.copy_context().run(fun)
314 self.assertEqual(c.get(), [])
315
316 def test_context_copy_1(self):
317 ctx1 = contextvars.Context()
318 c = contextvars.ContextVar('c', default=42)
319
320 def ctx1_fun():
321 c.set(10)
322
323 ctx2 = ctx1.copy()
324 self.assertEqual(ctx2[c], 10)
325
326 c.set(20)
327 self.assertEqual(ctx1[c], 20)
328 self.assertEqual(ctx2[c], 10)
329
330 ctx2.run(ctx2_fun)
331 self.assertEqual(ctx1[c], 20)
332 self.assertEqual(ctx2[c], 30)
333
334 def ctx2_fun():
335 self.assertEqual(c.get(), 10)
336 c.set(30)
337 self.assertEqual(c.get(), 30)
338
339 ctx1.run(ctx1_fun)
340
341 @isolated_context
342 def test_context_threads_1(self):
343 cvar = contextvars.ContextVar('cvar')
344
345 def sub(num):
346 for i in range(10):
347 cvar.set(num + i)
348 time.sleep(random.uniform(0.001, 0.05))
349 self.assertEqual(cvar.get(), num + i)
350 return num
351
352 tp = concurrent.futures.ThreadPoolExecutor(max_workers=10)
353 try:
354 results = list(tp.map(sub, range(10)))
355 finally:
356 tp.shutdown()
357 self.assertEqual(results, list(range(10)))
358
359
360# HAMT Tests
361
362
363class HashKey:
364 _crasher = None
365
366 def __init__(self, hash, name, *, error_on_eq_to=None):
367 assert hash != -1
368 self.name = name
369 self.hash = hash
370 self.error_on_eq_to = error_on_eq_to
371
372 def __repr__(self):
373 return f'<Key name:{self.name} hash:{self.hash}>'
374
375 def __hash__(self):
376 if self._crasher is not None and self._crasher.error_on_hash:
377 raise HashingError
378
379 return self.hash
380
381 def __eq__(self, other):
382 if not isinstance(other, HashKey):
383 return NotImplemented
384
385 if self._crasher is not None and self._crasher.error_on_eq:
386 raise EqError
387
388 if self.error_on_eq_to is not None and self.error_on_eq_to is other:
389 raise ValueError(f'cannot compare {self!r} to {other!r}')
390 if other.error_on_eq_to is not None and other.error_on_eq_to is self:
391 raise ValueError(f'cannot compare {other!r} to {self!r}')
392
393 return (self.name, self.hash) == (other.name, other.hash)
394
395
396class KeyStr(str):
397 def __hash__(self):
398 if HashKey._crasher is not None and HashKey._crasher.error_on_hash:
399 raise HashingError
400 return super().__hash__()
401
402 def __eq__(self, other):
403 if HashKey._crasher is not None and HashKey._crasher.error_on_eq:
404 raise EqError
405 return super().__eq__(other)
406
407
408class HaskKeyCrasher:
409 def __init__(self, *, error_on_hash=False, error_on_eq=False):
410 self.error_on_hash = error_on_hash
411 self.error_on_eq = error_on_eq
412
413 def __enter__(self):
414 if HashKey._crasher is not None:
415 raise RuntimeError('cannot nest crashers')
416 HashKey._crasher = self
417
418 def __exit__(self, *exc):
419 HashKey._crasher = None
420
421
422class HashingError(Exception):
423 pass
424
425
426class EqError(Exception):
427 pass
428
429
430@unittest.skipIf(hamt is None, '_testcapi lacks "hamt()" function')
431class HamtTest(unittest.TestCase):
432
433 def test_hashkey_helper_1(self):
434 k1 = HashKey(10, 'aaa')
435 k2 = HashKey(10, 'bbb')
436
437 self.assertNotEqual(k1, k2)
438 self.assertEqual(hash(k1), hash(k2))
439
440 d = dict()
441 d[k1] = 'a'
442 d[k2] = 'b'
443
444 self.assertEqual(d[k1], 'a')
445 self.assertEqual(d[k2], 'b')
446
447 def test_hamt_basics_1(self):
448 h = hamt()
449 h = None # NoQA
450
451 def test_hamt_basics_2(self):
452 h = hamt()
453 self.assertEqual(len(h), 0)
454
455 h2 = h.set('a', 'b')
456 self.assertIsNot(h, h2)
457 self.assertEqual(len(h), 0)
458 self.assertEqual(len(h2), 1)
459
460 self.assertIsNone(h.get('a'))
461 self.assertEqual(h.get('a', 42), 42)
462
463 self.assertEqual(h2.get('a'), 'b')
464
465 h3 = h2.set('b', 10)
466 self.assertIsNot(h2, h3)
467 self.assertEqual(len(h), 0)
468 self.assertEqual(len(h2), 1)
469 self.assertEqual(len(h3), 2)
470 self.assertEqual(h3.get('a'), 'b')
471 self.assertEqual(h3.get('b'), 10)
472
473 self.assertIsNone(h.get('b'))
474 self.assertIsNone(h2.get('b'))
475
476 self.assertIsNone(h.get('a'))
477 self.assertEqual(h2.get('a'), 'b')
478
479 h = h2 = h3 = None
480
481 def test_hamt_basics_3(self):
482 h = hamt()
483 o = object()
484 h1 = h.set('1', o)
485 h2 = h1.set('1', o)
486 self.assertIs(h1, h2)
487
488 def test_hamt_basics_4(self):
489 h = hamt()
490 h1 = h.set('key', [])
491 h2 = h1.set('key', [])
492 self.assertIsNot(h1, h2)
493 self.assertEqual(len(h1), 1)
494 self.assertEqual(len(h2), 1)
495 self.assertIsNot(h1.get('key'), h2.get('key'))
496
497 def test_hamt_collision_1(self):
498 k1 = HashKey(10, 'aaa')
499 k2 = HashKey(10, 'bbb')
500 k3 = HashKey(10, 'ccc')
501
502 h = hamt()
503 h2 = h.set(k1, 'a')
504 h3 = h2.set(k2, 'b')
505
506 self.assertEqual(h.get(k1), None)
507 self.assertEqual(h.get(k2), None)
508
509 self.assertEqual(h2.get(k1), 'a')
510 self.assertEqual(h2.get(k2), None)
511
512 self.assertEqual(h3.get(k1), 'a')
513 self.assertEqual(h3.get(k2), 'b')
514
515 h4 = h3.set(k2, 'cc')
516 h5 = h4.set(k3, 'aa')
517
518 self.assertEqual(h3.get(k1), 'a')
519 self.assertEqual(h3.get(k2), 'b')
520 self.assertEqual(h4.get(k1), 'a')
521 self.assertEqual(h4.get(k2), 'cc')
522 self.assertEqual(h4.get(k3), None)
523 self.assertEqual(h5.get(k1), 'a')
524 self.assertEqual(h5.get(k2), 'cc')
525 self.assertEqual(h5.get(k2), 'cc')
526 self.assertEqual(h5.get(k3), 'aa')
527
528 self.assertEqual(len(h), 0)
529 self.assertEqual(len(h2), 1)
530 self.assertEqual(len(h3), 2)
531 self.assertEqual(len(h4), 2)
532 self.assertEqual(len(h5), 3)
533
534 def test_hamt_stress(self):
535 COLLECTION_SIZE = 7000
536 TEST_ITERS_EVERY = 647
537 CRASH_HASH_EVERY = 97
538 CRASH_EQ_EVERY = 11
539 RUN_XTIMES = 3
540
541 for _ in range(RUN_XTIMES):
542 h = hamt()
543 d = dict()
544
545 for i in range(COLLECTION_SIZE):
546 key = KeyStr(i)
547
548 if not (i % CRASH_HASH_EVERY):
549 with HaskKeyCrasher(error_on_hash=True):
550 with self.assertRaises(HashingError):
551 h.set(key, i)
552
553 h = h.set(key, i)
554
555 if not (i % CRASH_EQ_EVERY):
556 with HaskKeyCrasher(error_on_eq=True):
557 with self.assertRaises(EqError):
558 h.get(KeyStr(i)) # really trigger __eq__
559
560 d[key] = i
561 self.assertEqual(len(d), len(h))
562
563 if not (i % TEST_ITERS_EVERY):
564 self.assertEqual(set(h.items()), set(d.items()))
565 self.assertEqual(len(h.items()), len(d.items()))
566
567 self.assertEqual(len(h), COLLECTION_SIZE)
568
569 for key in range(COLLECTION_SIZE):
570 self.assertEqual(h.get(KeyStr(key), 'not found'), key)
571
572 keys_to_delete = list(range(COLLECTION_SIZE))
573 random.shuffle(keys_to_delete)
574 for iter_i, i in enumerate(keys_to_delete):
575 key = KeyStr(i)
576
577 if not (iter_i % CRASH_HASH_EVERY):
578 with HaskKeyCrasher(error_on_hash=True):
579 with self.assertRaises(HashingError):
580 h.delete(key)
581
582 if not (iter_i % CRASH_EQ_EVERY):
583 with HaskKeyCrasher(error_on_eq=True):
584 with self.assertRaises(EqError):
585 h.delete(KeyStr(i))
586
587 h = h.delete(key)
588 self.assertEqual(h.get(key, 'not found'), 'not found')
589 del d[key]
590 self.assertEqual(len(d), len(h))
591
592 if iter_i == COLLECTION_SIZE // 2:
593 hm = h
594 dm = d.copy()
595
596 if not (iter_i % TEST_ITERS_EVERY):
597 self.assertEqual(set(h.keys()), set(d.keys()))
598 self.assertEqual(len(h.keys()), len(d.keys()))
599
600 self.assertEqual(len(d), 0)
601 self.assertEqual(len(h), 0)
602
603 # ============
604
605 for key in dm:
606 self.assertEqual(hm.get(str(key)), dm[key])
607 self.assertEqual(len(dm), len(hm))
608
609 for i, key in enumerate(keys_to_delete):
610 hm = hm.delete(str(key))
611 self.assertEqual(hm.get(str(key), 'not found'), 'not found')
612 dm.pop(str(key), None)
613 self.assertEqual(len(d), len(h))
614
615 if not (i % TEST_ITERS_EVERY):
616 self.assertEqual(set(h.values()), set(d.values()))
617 self.assertEqual(len(h.values()), len(d.values()))
618
619 self.assertEqual(len(d), 0)
620 self.assertEqual(len(h), 0)
621 self.assertEqual(list(h.items()), [])
622
623 def test_hamt_delete_1(self):
624 A = HashKey(100, 'A')
625 B = HashKey(101, 'B')
626 C = HashKey(102, 'C')
627 D = HashKey(103, 'D')
628 E = HashKey(104, 'E')
629 Z = HashKey(-100, 'Z')
630
631 Er = HashKey(103, 'Er', error_on_eq_to=D)
632
633 h = hamt()
634 h = h.set(A, 'a')
635 h = h.set(B, 'b')
636 h = h.set(C, 'c')
637 h = h.set(D, 'd')
638 h = h.set(E, 'e')
639
640 orig_len = len(h)
641
642 # BitmapNode(size=10 bitmap=0b111110000 id=0x10eadc618):
643 # <Key name:A hash:100>: 'a'
644 # <Key name:B hash:101>: 'b'
645 # <Key name:C hash:102>: 'c'
646 # <Key name:D hash:103>: 'd'
647 # <Key name:E hash:104>: 'e'
648
649 h = h.delete(C)
650 self.assertEqual(len(h), orig_len - 1)
651
652 with self.assertRaisesRegex(ValueError, 'cannot compare'):
653 h.delete(Er)
654
655 h = h.delete(D)
656 self.assertEqual(len(h), orig_len - 2)
657
658 h2 = h.delete(Z)
659 self.assertIs(h2, h)
660
661 h = h.delete(A)
662 self.assertEqual(len(h), orig_len - 3)
663
664 self.assertEqual(h.get(A, 42), 42)
665 self.assertEqual(h.get(B), 'b')
666 self.assertEqual(h.get(E), 'e')
667
668 def test_hamt_delete_2(self):
669 A = HashKey(100, 'A')
670 B = HashKey(201001, 'B')
671 C = HashKey(101001, 'C')
672 D = HashKey(103, 'D')
673 E = HashKey(104, 'E')
674 Z = HashKey(-100, 'Z')
675
676 Er = HashKey(201001, 'Er', error_on_eq_to=B)
677
678 h = hamt()
679 h = h.set(A, 'a')
680 h = h.set(B, 'b')
681 h = h.set(C, 'c')
682 h = h.set(D, 'd')
683 h = h.set(E, 'e')
684
685 orig_len = len(h)
686
687 # BitmapNode(size=8 bitmap=0b1110010000):
688 # <Key name:A hash:100>: 'a'
689 # <Key name:D hash:103>: 'd'
690 # <Key name:E hash:104>: 'e'
691 # NULL:
692 # BitmapNode(size=4 bitmap=0b100000000001000000000):
693 # <Key name:B hash:201001>: 'b'
694 # <Key name:C hash:101001>: 'c'
695
696 with self.assertRaisesRegex(ValueError, 'cannot compare'):
697 h.delete(Er)
698
699 h = h.delete(Z)
700 self.assertEqual(len(h), orig_len)
701
702 h = h.delete(C)
703 self.assertEqual(len(h), orig_len - 1)
704
705 h = h.delete(B)
706 self.assertEqual(len(h), orig_len - 2)
707
708 h = h.delete(A)
709 self.assertEqual(len(h), orig_len - 3)
710
711 self.assertEqual(h.get(D), 'd')
712 self.assertEqual(h.get(E), 'e')
713
714 h = h.delete(A)
715 h = h.delete(B)
716 h = h.delete(D)
717 h = h.delete(E)
718 self.assertEqual(len(h), 0)
719
720 def test_hamt_delete_3(self):
721 A = HashKey(100, 'A')
722 B = HashKey(101, 'B')
723 C = HashKey(100100, 'C')
724 D = HashKey(100100, 'D')
725 E = HashKey(104, 'E')
726
727 h = hamt()
728 h = h.set(A, 'a')
729 h = h.set(B, 'b')
730 h = h.set(C, 'c')
731 h = h.set(D, 'd')
732 h = h.set(E, 'e')
733
734 orig_len = len(h)
735
736 # BitmapNode(size=6 bitmap=0b100110000):
737 # NULL:
738 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
739 # <Key name:A hash:100>: 'a'
740 # NULL:
741 # CollisionNode(size=4 id=0x108572410):
742 # <Key name:C hash:100100>: 'c'
743 # <Key name:D hash:100100>: 'd'
744 # <Key name:B hash:101>: 'b'
745 # <Key name:E hash:104>: 'e'
746
747 h = h.delete(A)
748 self.assertEqual(len(h), orig_len - 1)
749
750 h = h.delete(E)
751 self.assertEqual(len(h), orig_len - 2)
752
753 self.assertEqual(h.get(C), 'c')
754 self.assertEqual(h.get(B), 'b')
755
756 def test_hamt_delete_4(self):
757 A = HashKey(100, 'A')
758 B = HashKey(101, 'B')
759 C = HashKey(100100, 'C')
760 D = HashKey(100100, 'D')
761 E = HashKey(100100, 'E')
762
763 h = hamt()
764 h = h.set(A, 'a')
765 h = h.set(B, 'b')
766 h = h.set(C, 'c')
767 h = h.set(D, 'd')
768 h = h.set(E, 'e')
769
770 orig_len = len(h)
771
772 # BitmapNode(size=4 bitmap=0b110000):
773 # NULL:
774 # BitmapNode(size=4 bitmap=0b1000000000000000000001000):
775 # <Key name:A hash:100>: 'a'
776 # NULL:
777 # CollisionNode(size=6 id=0x10515ef30):
778 # <Key name:C hash:100100>: 'c'
779 # <Key name:D hash:100100>: 'd'
780 # <Key name:E hash:100100>: 'e'
781 # <Key name:B hash:101>: 'b'
782
783 h = h.delete(D)
784 self.assertEqual(len(h), orig_len - 1)
785
786 h = h.delete(E)
787 self.assertEqual(len(h), orig_len - 2)
788
789 h = h.delete(C)
790 self.assertEqual(len(h), orig_len - 3)
791
792 h = h.delete(A)
793 self.assertEqual(len(h), orig_len - 4)
794
795 h = h.delete(B)
796 self.assertEqual(len(h), 0)
797
798 def test_hamt_delete_5(self):
799 h = hamt()
800
801 keys = []
802 for i in range(17):
803 key = HashKey(i, str(i))
804 keys.append(key)
805 h = h.set(key, f'val-{i}')
806
807 collision_key16 = HashKey(16, '18')
808 h = h.set(collision_key16, 'collision')
809
810 # ArrayNode(id=0x10f8b9318):
811 # 0::
812 # BitmapNode(size=2 count=1 bitmap=0b1):
813 # <Key name:0 hash:0>: 'val-0'
814 #
815 # ... 14 more BitmapNodes ...
816 #
817 # 15::
818 # BitmapNode(size=2 count=1 bitmap=0b1):
819 # <Key name:15 hash:15>: 'val-15'
820 #
821 # 16::
822 # BitmapNode(size=2 count=1 bitmap=0b1):
823 # NULL:
824 # CollisionNode(size=4 id=0x10f2f5af8):
825 # <Key name:16 hash:16>: 'val-16'
826 # <Key name:18 hash:16>: 'collision'
827
828 self.assertEqual(len(h), 18)
829
830 h = h.delete(keys[2])
831 self.assertEqual(len(h), 17)
832
833 h = h.delete(collision_key16)
834 self.assertEqual(len(h), 16)
835 h = h.delete(keys[16])
836 self.assertEqual(len(h), 15)
837
838 h = h.delete(keys[1])
839 self.assertEqual(len(h), 14)
840 h = h.delete(keys[1])
841 self.assertEqual(len(h), 14)
842
843 for key in keys:
844 h = h.delete(key)
845 self.assertEqual(len(h), 0)
846
847 def test_hamt_items_1(self):
848 A = HashKey(100, 'A')
849 B = HashKey(201001, 'B')
850 C = HashKey(101001, 'C')
851 D = HashKey(103, 'D')
852 E = HashKey(104, 'E')
853 F = HashKey(110, 'F')
854
855 h = hamt()
856 h = h.set(A, 'a')
857 h = h.set(B, 'b')
858 h = h.set(C, 'c')
859 h = h.set(D, 'd')
860 h = h.set(E, 'e')
861 h = h.set(F, 'f')
862
863 it = h.items()
864 self.assertEqual(
865 set(list(it)),
866 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
867
868 def test_hamt_items_2(self):
869 A = HashKey(100, 'A')
870 B = HashKey(101, 'B')
871 C = HashKey(100100, 'C')
872 D = HashKey(100100, 'D')
873 E = HashKey(100100, 'E')
874 F = HashKey(110, 'F')
875
876 h = hamt()
877 h = h.set(A, 'a')
878 h = h.set(B, 'b')
879 h = h.set(C, 'c')
880 h = h.set(D, 'd')
881 h = h.set(E, 'e')
882 h = h.set(F, 'f')
883
884 it = h.items()
885 self.assertEqual(
886 set(list(it)),
887 {(A, 'a'), (B, 'b'), (C, 'c'), (D, 'd'), (E, 'e'), (F, 'f')})
888
889 def test_hamt_keys_1(self):
890 A = HashKey(100, 'A')
891 B = HashKey(101, 'B')
892 C = HashKey(100100, 'C')
893 D = HashKey(100100, 'D')
894 E = HashKey(100100, 'E')
895 F = HashKey(110, 'F')
896
897 h = hamt()
898 h = h.set(A, 'a')
899 h = h.set(B, 'b')
900 h = h.set(C, 'c')
901 h = h.set(D, 'd')
902 h = h.set(E, 'e')
903 h = h.set(F, 'f')
904
905 self.assertEqual(set(list(h.keys())), {A, B, C, D, E, F})
906 self.assertEqual(set(list(h)), {A, B, C, D, E, F})
907
908 def test_hamt_items_3(self):
909 h = hamt()
910 self.assertEqual(len(h.items()), 0)
911 self.assertEqual(list(h.items()), [])
912
913 def test_hamt_eq_1(self):
914 A = HashKey(100, 'A')
915 B = HashKey(101, 'B')
916 C = HashKey(100100, 'C')
917 D = HashKey(100100, 'D')
918 E = HashKey(120, 'E')
919
920 h1 = hamt()
921 h1 = h1.set(A, 'a')
922 h1 = h1.set(B, 'b')
923 h1 = h1.set(C, 'c')
924 h1 = h1.set(D, 'd')
925
926 h2 = hamt()
927 h2 = h2.set(A, 'a')
928
929 self.assertFalse(h1 == h2)
930 self.assertTrue(h1 != h2)
931
932 h2 = h2.set(B, 'b')
933 self.assertFalse(h1 == h2)
934 self.assertTrue(h1 != h2)
935
936 h2 = h2.set(C, 'c')
937 self.assertFalse(h1 == h2)
938 self.assertTrue(h1 != h2)
939
940 h2 = h2.set(D, 'd2')
941 self.assertFalse(h1 == h2)
942 self.assertTrue(h1 != h2)
943
944 h2 = h2.set(D, 'd')
945 self.assertTrue(h1 == h2)
946 self.assertFalse(h1 != h2)
947
948 h2 = h2.set(E, 'e')
949 self.assertFalse(h1 == h2)
950 self.assertTrue(h1 != h2)
951
952 h2 = h2.delete(D)
953 self.assertFalse(h1 == h2)
954 self.assertTrue(h1 != h2)
955
956 h2 = h2.set(E, 'd')
957 self.assertFalse(h1 == h2)
958 self.assertTrue(h1 != h2)
959
960 def test_hamt_eq_2(self):
961 A = HashKey(100, 'A')
962 Er = HashKey(100, 'Er', error_on_eq_to=A)
963
964 h1 = hamt()
965 h1 = h1.set(A, 'a')
966
967 h2 = hamt()
968 h2 = h2.set(Er, 'a')
969
970 with self.assertRaisesRegex(ValueError, 'cannot compare'):
971 h1 == h2
972
973 with self.assertRaisesRegex(ValueError, 'cannot compare'):
974 h1 != h2
975
976 def test_hamt_gc_1(self):
977 A = HashKey(100, 'A')
978
979 h = hamt()
980 h = h.set(0, 0) # empty HAMT node is memoized in hamt.c
981 ref = weakref.ref(h)
982
983 a = []
984 a.append(a)
985 a.append(h)
986 b = []
987 a.append(b)
988 b.append(a)
989 h = h.set(A, b)
990
991 del h, a, b
992
993 gc.collect()
994 gc.collect()
995 gc.collect()
996
997 self.assertIsNone(ref())
998
999 def test_hamt_gc_2(self):
1000 A = HashKey(100, 'A')
1001 B = HashKey(101, 'B')
1002
1003 h = hamt()
1004 h = h.set(A, 'a')
1005 h = h.set(A, h)
1006
1007 ref = weakref.ref(h)
1008 hi = h.items()
1009 next(hi)
1010
1011 del h, hi
1012
1013 gc.collect()
1014 gc.collect()
1015 gc.collect()
1016
1017 self.assertIsNone(ref())
1018
1019 def test_hamt_in_1(self):
1020 A = HashKey(100, 'A')
1021 AA = HashKey(100, 'A')
1022
1023 B = HashKey(101, 'B')
1024
1025 h = hamt()
1026 h = h.set(A, 1)
1027
1028 self.assertTrue(A in h)
1029 self.assertFalse(B in h)
1030
1031 with self.assertRaises(EqError):
1032 with HaskKeyCrasher(error_on_eq=True):
1033 AA in h
1034
1035 with self.assertRaises(HashingError):
1036 with HaskKeyCrasher(error_on_hash=True):
1037 AA in h
1038
1039 def test_hamt_getitem_1(self):
1040 A = HashKey(100, 'A')
1041 AA = HashKey(100, 'A')
1042
1043 B = HashKey(101, 'B')
1044
1045 h = hamt()
1046 h = h.set(A, 1)
1047
1048 self.assertEqual(h[A], 1)
1049 self.assertEqual(h[AA], 1)
1050
1051 with self.assertRaises(KeyError):
1052 h[B]
1053
1054 with self.assertRaises(EqError):
1055 with HaskKeyCrasher(error_on_eq=True):
1056 h[AA]
1057
1058 with self.assertRaises(HashingError):
1059 with HaskKeyCrasher(error_on_hash=True):
1060 h[AA]
1061
1062
1063if __name__ == "__main__":
1064 unittest.main()