blob: fb9e8d73db36bbe62b12908364ab49d8cc97c197 [file] [log] [blame]
Michael Foorde6410c52010-03-29 20:04:23 +00001import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17 pass
18
19class SomeClass(object):
20 def __init__(self, value):
21 self.value = value
22 def __eq__(self, other):
23 if type(other) != type(self):
24 return False
25 return other.value == self.value
26
27 def __ne__(self, other):
28 return not self.__eq__(other)
29
30 def __hash__(self):
31 return hash((SomeClass, self.value))
32
Antoine Pitrouc56bca32012-03-01 16:26:35 +010033class RefCycle(object):
34 def __init__(self):
35 self.cycle = self
36
Michael Foorde6410c52010-03-29 20:04:23 +000037class TestWeakSet(unittest.TestCase):
38
39 def setUp(self):
40 # need to keep references to them
41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
43 self.letters = [SomeClass(c) for c in string.ascii_letters]
Meador Inge104f1892012-03-04 22:02:17 -060044 self.ab_items = [SomeClass(c) for c in 'ab']
45 self.abcde_items = [SomeClass(c) for c in 'abcde']
46 self.def_items = [SomeClass(c) for c in 'def']
47 self.ab_weakset = WeakSet(self.ab_items)
48 self.abcde_weakset = WeakSet(self.abcde_items)
49 self.def_weakset = WeakSet(self.def_items)
Michael Foorde6410c52010-03-29 20:04:23 +000050 self.s = WeakSet(self.items)
51 self.d = dict.fromkeys(self.items)
52 self.obj = SomeClass('F')
53 self.fs = WeakSet([self.obj])
54
55 def test_methods(self):
56 weaksetmethods = dir(WeakSet)
57 for method in dir(set):
58 if method == 'test_c_api' or method.startswith('_'):
59 continue
60 self.assertIn(method, weaksetmethods,
61 "WeakSet missing method " + method)
62
63 def test_new_or_init(self):
64 self.assertRaises(TypeError, WeakSet, [], 2)
65
66 def test_len(self):
67 self.assertEqual(len(self.s), len(self.d))
68 self.assertEqual(len(self.fs), 1)
69 del self.obj
70 self.assertEqual(len(self.fs), 0)
71
72 def test_contains(self):
73 for c in self.letters:
74 self.assertEqual(c in self.s, c in self.d)
Georg Brandl52f83952011-02-25 10:39:23 +000075 # 1 is not weakref'able, but that TypeError is caught by __contains__
76 self.assertNotIn(1, self.s)
Michael Foorde6410c52010-03-29 20:04:23 +000077 self.assertIn(self.obj, self.fs)
78 del self.obj
79 self.assertNotIn(SomeClass('F'), self.fs)
80
81 def test_union(self):
82 u = self.s.union(self.items2)
83 for c in self.letters:
84 self.assertEqual(c in u, c in self.d or c in self.items2)
85 self.assertEqual(self.s, WeakSet(self.items))
86 self.assertEqual(type(u), WeakSet)
87 self.assertRaises(TypeError, self.s.union, [[]])
88 for C in set, frozenset, dict.fromkeys, list, tuple:
89 x = WeakSet(self.items + self.items2)
90 c = C(self.items2)
91 self.assertEqual(self.s.union(c), x)
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +010092 del c
93 self.assertEqual(len(u), len(self.items) + len(self.items2))
94 self.items2.pop()
95 gc.collect()
96 self.assertEqual(len(u), len(self.items) + len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +000097
98 def test_or(self):
99 i = self.s.union(self.items2)
100 self.assertEqual(self.s | set(self.items2), i)
101 self.assertEqual(self.s | frozenset(self.items2), i)
102
103 def test_intersection(self):
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100104 s = WeakSet(self.letters)
105 i = s.intersection(self.items2)
Michael Foorde6410c52010-03-29 20:04:23 +0000106 for c in self.letters:
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100107 self.assertEqual(c in i, c in self.items2 and c in self.letters)
108 self.assertEqual(s, WeakSet(self.letters))
Michael Foorde6410c52010-03-29 20:04:23 +0000109 self.assertEqual(type(i), WeakSet)
110 for C in set, frozenset, dict.fromkeys, list, tuple:
111 x = WeakSet([])
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100112 self.assertEqual(i.intersection(C(self.items)), x)
113 self.assertEqual(len(i), len(self.items2))
114 self.items2.pop()
115 gc.collect()
116 self.assertEqual(len(i), len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +0000117
118 def test_isdisjoint(self):
119 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
120 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
121
122 def test_and(self):
123 i = self.s.intersection(self.items2)
124 self.assertEqual(self.s & set(self.items2), i)
125 self.assertEqual(self.s & frozenset(self.items2), i)
126
127 def test_difference(self):
128 i = self.s.difference(self.items2)
129 for c in self.letters:
130 self.assertEqual(c in i, c in self.d and c not in self.items2)
131 self.assertEqual(self.s, WeakSet(self.items))
132 self.assertEqual(type(i), WeakSet)
133 self.assertRaises(TypeError, self.s.difference, [[]])
134
135 def test_sub(self):
136 i = self.s.difference(self.items2)
137 self.assertEqual(self.s - set(self.items2), i)
138 self.assertEqual(self.s - frozenset(self.items2), i)
139
140 def test_symmetric_difference(self):
141 i = self.s.symmetric_difference(self.items2)
142 for c in self.letters:
143 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
144 self.assertEqual(self.s, WeakSet(self.items))
145 self.assertEqual(type(i), WeakSet)
146 self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100147 self.assertEqual(len(i), len(self.items) + len(self.items2))
148 self.items2.pop()
149 gc.collect()
150 self.assertEqual(len(i), len(self.items) + len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +0000151
152 def test_xor(self):
153 i = self.s.symmetric_difference(self.items2)
154 self.assertEqual(self.s ^ set(self.items2), i)
155 self.assertEqual(self.s ^ frozenset(self.items2), i)
156
157 def test_sub_and_super(self):
Meador Inge104f1892012-03-04 22:02:17 -0600158 self.assertTrue(self.ab_weakset <= self.abcde_weakset)
159 self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
160 self.assertTrue(self.abcde_weakset >= self.ab_weakset)
161 self.assertFalse(self.abcde_weakset <= self.def_weakset)
162 self.assertFalse(self.abcde_weakset >= self.def_weakset)
Michael Foorde6410c52010-03-29 20:04:23 +0000163 self.assertTrue(set('a').issubset('abc'))
164 self.assertTrue(set('abc').issuperset('a'))
165 self.assertFalse(set('a').issubset('cbs'))
166 self.assertFalse(set('cbs').issuperset('a'))
167
Meador Inge104f1892012-03-04 22:02:17 -0600168 def test_lt(self):
169 self.assertTrue(self.ab_weakset < self.abcde_weakset)
170 self.assertFalse(self.abcde_weakset < self.def_weakset)
171 self.assertFalse(self.ab_weakset < self.ab_weakset)
172 self.assertFalse(WeakSet() < WeakSet())
173
174 def test_gt(self):
175 self.assertTrue(self.abcde_weakset > self.ab_weakset)
176 self.assertFalse(self.abcde_weakset > self.def_weakset)
177 self.assertFalse(self.ab_weakset > self.ab_weakset)
178 self.assertFalse(WeakSet() > WeakSet())
179
Michael Foorde6410c52010-03-29 20:04:23 +0000180 def test_gc(self):
181 # Create a nest of cycles to exercise overall ref count check
182 s = WeakSet(Foo() for i in range(1000))
183 for elem in s:
184 elem.cycle = s
185 elem.sub = elem
186 elem.set = WeakSet([elem])
187
188 def test_subclass_with_custom_hash(self):
189 # Bug #1257731
190 class H(WeakSet):
191 def __hash__(self):
192 return int(id(self) & 0x7fffffff)
193 s=H()
194 f=set()
195 f.add(s)
196 self.assertIn(s, f)
197 f.remove(s)
198 f.add(s)
199 f.discard(s)
200
201 def test_init(self):
202 s = WeakSet()
203 s.__init__(self.items)
204 self.assertEqual(s, self.s)
205 s.__init__(self.items2)
206 self.assertEqual(s, WeakSet(self.items2))
207 self.assertRaises(TypeError, s.__init__, s, 2);
208 self.assertRaises(TypeError, s.__init__, 1);
209
210 def test_constructor_identity(self):
211 s = WeakSet(self.items)
212 t = WeakSet(s)
213 self.assertNotEqual(id(s), id(t))
214
215 def test_hash(self):
216 self.assertRaises(TypeError, hash, self.s)
217
218 def test_clear(self):
219 self.s.clear()
220 self.assertEqual(self.s, WeakSet([]))
221 self.assertEqual(len(self.s), 0)
222
223 def test_copy(self):
224 dup = self.s.copy()
225 self.assertEqual(self.s, dup)
226 self.assertNotEqual(id(self.s), id(dup))
227
228 def test_add(self):
229 x = SomeClass('Q')
230 self.s.add(x)
231 self.assertIn(x, self.s)
232 dup = self.s.copy()
233 self.s.add(x)
234 self.assertEqual(self.s, dup)
235 self.assertRaises(TypeError, self.s.add, [])
236 self.fs.add(Foo())
237 self.assertTrue(len(self.fs) == 1)
238 self.fs.add(self.obj)
239 self.assertTrue(len(self.fs) == 1)
240
241 def test_remove(self):
242 x = SomeClass('a')
243 self.s.remove(x)
244 self.assertNotIn(x, self.s)
245 self.assertRaises(KeyError, self.s.remove, x)
246 self.assertRaises(TypeError, self.s.remove, [])
247
248 def test_discard(self):
249 a, q = SomeClass('a'), SomeClass('Q')
250 self.s.discard(a)
251 self.assertNotIn(a, self.s)
252 self.s.discard(q)
253 self.assertRaises(TypeError, self.s.discard, [])
254
255 def test_pop(self):
256 for i in range(len(self.s)):
257 elem = self.s.pop()
258 self.assertNotIn(elem, self.s)
259 self.assertRaises(KeyError, self.s.pop)
260
261 def test_update(self):
262 retval = self.s.update(self.items2)
263 self.assertEqual(retval, None)
264 for c in (self.items + self.items2):
265 self.assertIn(c, self.s)
266 self.assertRaises(TypeError, self.s.update, [[]])
267
268 def test_update_set(self):
269 self.s.update(set(self.items2))
270 for c in (self.items + self.items2):
271 self.assertIn(c, self.s)
272
273 def test_ior(self):
274 self.s |= set(self.items2)
275 for c in (self.items + self.items2):
276 self.assertIn(c, self.s)
277
278 def test_intersection_update(self):
279 retval = self.s.intersection_update(self.items2)
280 self.assertEqual(retval, None)
281 for c in (self.items + self.items2):
282 if c in self.items2 and c in self.items:
283 self.assertIn(c, self.s)
284 else:
285 self.assertNotIn(c, self.s)
286 self.assertRaises(TypeError, self.s.intersection_update, [[]])
287
288 def test_iand(self):
289 self.s &= set(self.items2)
290 for c in (self.items + self.items2):
291 if c in self.items2 and c in self.items:
292 self.assertIn(c, self.s)
293 else:
294 self.assertNotIn(c, self.s)
295
296 def test_difference_update(self):
297 retval = self.s.difference_update(self.items2)
298 self.assertEqual(retval, None)
299 for c in (self.items + self.items2):
300 if c in self.items and c not in self.items2:
301 self.assertIn(c, self.s)
302 else:
303 self.assertNotIn(c, self.s)
304 self.assertRaises(TypeError, self.s.difference_update, [[]])
305 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
306
307 def test_isub(self):
308 self.s -= set(self.items2)
309 for c in (self.items + self.items2):
310 if c in self.items and c not in self.items2:
311 self.assertIn(c, self.s)
312 else:
313 self.assertNotIn(c, self.s)
314
315 def test_symmetric_difference_update(self):
316 retval = self.s.symmetric_difference_update(self.items2)
317 self.assertEqual(retval, None)
318 for c in (self.items + self.items2):
319 if (c in self.items) ^ (c in self.items2):
320 self.assertIn(c, self.s)
321 else:
322 self.assertNotIn(c, self.s)
323 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
324
325 def test_ixor(self):
326 self.s ^= set(self.items2)
327 for c in (self.items + self.items2):
328 if (c in self.items) ^ (c in self.items2):
329 self.assertIn(c, self.s)
330 else:
331 self.assertNotIn(c, self.s)
332
333 def test_inplace_on_self(self):
334 t = self.s.copy()
335 t |= t
336 self.assertEqual(t, self.s)
337 t &= t
338 self.assertEqual(t, self.s)
339 t -= t
340 self.assertEqual(t, WeakSet())
341 t = self.s.copy()
342 t ^= t
343 self.assertEqual(t, WeakSet())
344
345 def test_eq(self):
346 # issue 5964
347 self.assertTrue(self.s == self.s)
348 self.assertTrue(self.s == WeakSet(self.items))
349 self.assertFalse(self.s == set(self.items))
350 self.assertFalse(self.s == list(self.items))
351 self.assertFalse(self.s == tuple(self.items))
352 self.assertFalse(self.s == 1)
353
Benjamin Peterson1cf48b42013-05-22 13:25:41 -0700354 def test_ne(self):
355 self.assertTrue(self.s != set(self.items))
356 s1 = WeakSet()
357 s2 = WeakSet()
358 self.assertFalse(s1 != s2)
359
Michael Foorde6410c52010-03-29 20:04:23 +0000360 def test_weak_destroy_while_iterating(self):
361 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
362 # Create new items to be sure no-one else holds a reference
363 items = [SomeClass(c) for c in ('a', 'b', 'c')]
364 s = WeakSet(items)
365 it = iter(s)
366 next(it) # Trigger internal iteration
367 # Destroy an item
368 del items[-1]
369 gc.collect() # just in case
370 # We have removed either the first consumed items, or another one
371 self.assertIn(len(list(it)), [len(items), len(items) - 1])
372 del it
373 # The removal has been committed
374 self.assertEqual(len(s), len(items))
375
376 def test_weak_destroy_and_mutate_while_iterating(self):
377 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
378 items = [SomeClass(c) for c in string.ascii_letters]
379 s = WeakSet(items)
380 @contextlib.contextmanager
381 def testcontext():
382 try:
383 it = iter(s)
384 next(it)
385 # Schedule an item for removal and recreate it
386 u = SomeClass(str(items.pop()))
387 gc.collect() # just in case
388 yield u
389 finally:
390 it = None # should commit all removals
391
392 with testcontext() as u:
393 self.assertNotIn(u, s)
394 with testcontext() as u:
395 self.assertRaises(KeyError, s.remove, u)
396 self.assertNotIn(u, s)
397 with testcontext() as u:
398 s.add(u)
399 self.assertIn(u, s)
400 t = s.copy()
401 with testcontext() as u:
402 s.update(t)
403 self.assertEqual(len(s), len(t))
404 with testcontext() as u:
405 s.clear()
406 self.assertEqual(len(s), 0)
407
Antoine Pitrouc56bca32012-03-01 16:26:35 +0100408 def test_len_cycles(self):
409 N = 20
410 items = [RefCycle() for i in range(N)]
411 s = WeakSet(items)
412 del items
413 it = iter(s)
414 try:
415 next(it)
416 except StopIteration:
417 pass
418 gc.collect()
419 n1 = len(s)
420 del it
421 gc.collect()
422 n2 = len(s)
423 # one item may be kept alive inside the iterator
424 self.assertIn(n1, (0, 1))
425 self.assertEqual(n2, 0)
426
427 def test_len_race(self):
428 # Extended sanity checks for len() in the face of cyclic collection
429 self.addCleanup(gc.set_threshold, *gc.get_threshold())
430 for th in range(1, 100):
431 N = 20
432 gc.collect(0)
433 gc.set_threshold(th, th, th)
434 items = [RefCycle() for i in range(N)]
435 s = WeakSet(items)
436 del items
437 # All items will be collected at next garbage collection pass
438 it = iter(s)
439 try:
440 next(it)
441 except StopIteration:
442 pass
443 n1 = len(s)
444 del it
445 n2 = len(s)
446 self.assertGreaterEqual(n1, 0)
447 self.assertLessEqual(n1, N)
448 self.assertGreaterEqual(n2, 0)
449 self.assertLessEqual(n2, n1)
450
Michael Foorde6410c52010-03-29 20:04:23 +0000451
452def test_main(verbose=None):
453 test_support.run_unittest(TestWeakSet)
454
455if __name__ == "__main__":
456 test_main(verbose=True)