blob: d9e091ed02f86e12f50dda2da2c6b91fd7d22974 [file] [log] [blame]
Michael Foorde6410c52010-03-29 20:04:23 +00001import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17 pass
18
19class SomeClass(object):
20 def __init__(self, value):
21 self.value = value
22 def __eq__(self, other):
23 if type(other) != type(self):
24 return False
25 return other.value == self.value
26
27 def __ne__(self, other):
28 return not self.__eq__(other)
29
30 def __hash__(self):
31 return hash((SomeClass, self.value))
32
Antoine Pitrouc56bca32012-03-01 16:26:35 +010033class RefCycle(object):
34 def __init__(self):
35 self.cycle = self
36
Michael Foorde6410c52010-03-29 20:04:23 +000037class TestWeakSet(unittest.TestCase):
38
39 def setUp(self):
40 # need to keep references to them
41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
43 self.letters = [SomeClass(c) for c in string.ascii_letters]
Meador Inge104f1892012-03-04 22:02:17 -060044 self.ab_items = [SomeClass(c) for c in 'ab']
45 self.abcde_items = [SomeClass(c) for c in 'abcde']
46 self.def_items = [SomeClass(c) for c in 'def']
47 self.ab_weakset = WeakSet(self.ab_items)
48 self.abcde_weakset = WeakSet(self.abcde_items)
49 self.def_weakset = WeakSet(self.def_items)
Michael Foorde6410c52010-03-29 20:04:23 +000050 self.s = WeakSet(self.items)
51 self.d = dict.fromkeys(self.items)
52 self.obj = SomeClass('F')
53 self.fs = WeakSet([self.obj])
54
55 def test_methods(self):
56 weaksetmethods = dir(WeakSet)
57 for method in dir(set):
58 if method == 'test_c_api' or method.startswith('_'):
59 continue
60 self.assertIn(method, weaksetmethods,
61 "WeakSet missing method " + method)
62
63 def test_new_or_init(self):
64 self.assertRaises(TypeError, WeakSet, [], 2)
65
66 def test_len(self):
67 self.assertEqual(len(self.s), len(self.d))
68 self.assertEqual(len(self.fs), 1)
69 del self.obj
70 self.assertEqual(len(self.fs), 0)
71
72 def test_contains(self):
73 for c in self.letters:
74 self.assertEqual(c in self.s, c in self.d)
Georg Brandl52f83952011-02-25 10:39:23 +000075 # 1 is not weakref'able, but that TypeError is caught by __contains__
76 self.assertNotIn(1, self.s)
Michael Foorde6410c52010-03-29 20:04:23 +000077 self.assertIn(self.obj, self.fs)
78 del self.obj
79 self.assertNotIn(SomeClass('F'), self.fs)
80
81 def test_union(self):
82 u = self.s.union(self.items2)
83 for c in self.letters:
84 self.assertEqual(c in u, c in self.d or c in self.items2)
85 self.assertEqual(self.s, WeakSet(self.items))
86 self.assertEqual(type(u), WeakSet)
87 self.assertRaises(TypeError, self.s.union, [[]])
88 for C in set, frozenset, dict.fromkeys, list, tuple:
89 x = WeakSet(self.items + self.items2)
90 c = C(self.items2)
91 self.assertEqual(self.s.union(c), x)
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +010092 del c
93 self.assertEqual(len(u), len(self.items) + len(self.items2))
94 self.items2.pop()
95 gc.collect()
96 self.assertEqual(len(u), len(self.items) + len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +000097
98 def test_or(self):
99 i = self.s.union(self.items2)
100 self.assertEqual(self.s | set(self.items2), i)
101 self.assertEqual(self.s | frozenset(self.items2), i)
102
103 def test_intersection(self):
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100104 s = WeakSet(self.letters)
105 i = s.intersection(self.items2)
Michael Foorde6410c52010-03-29 20:04:23 +0000106 for c in self.letters:
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100107 self.assertEqual(c in i, c in self.items2 and c in self.letters)
108 self.assertEqual(s, WeakSet(self.letters))
Michael Foorde6410c52010-03-29 20:04:23 +0000109 self.assertEqual(type(i), WeakSet)
110 for C in set, frozenset, dict.fromkeys, list, tuple:
111 x = WeakSet([])
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100112 self.assertEqual(i.intersection(C(self.items)), x)
113 self.assertEqual(len(i), len(self.items2))
114 self.items2.pop()
115 gc.collect()
116 self.assertEqual(len(i), len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +0000117
118 def test_isdisjoint(self):
119 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
120 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
121
122 def test_and(self):
123 i = self.s.intersection(self.items2)
124 self.assertEqual(self.s & set(self.items2), i)
125 self.assertEqual(self.s & frozenset(self.items2), i)
126
127 def test_difference(self):
128 i = self.s.difference(self.items2)
129 for c in self.letters:
130 self.assertEqual(c in i, c in self.d and c not in self.items2)
131 self.assertEqual(self.s, WeakSet(self.items))
132 self.assertEqual(type(i), WeakSet)
133 self.assertRaises(TypeError, self.s.difference, [[]])
134
135 def test_sub(self):
136 i = self.s.difference(self.items2)
137 self.assertEqual(self.s - set(self.items2), i)
138 self.assertEqual(self.s - frozenset(self.items2), i)
139
140 def test_symmetric_difference(self):
141 i = self.s.symmetric_difference(self.items2)
142 for c in self.letters:
143 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
144 self.assertEqual(self.s, WeakSet(self.items))
145 self.assertEqual(type(i), WeakSet)
146 self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100147 self.assertEqual(len(i), len(self.items) + len(self.items2))
148 self.items2.pop()
149 gc.collect()
150 self.assertEqual(len(i), len(self.items) + len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +0000151
152 def test_xor(self):
153 i = self.s.symmetric_difference(self.items2)
154 self.assertEqual(self.s ^ set(self.items2), i)
155 self.assertEqual(self.s ^ frozenset(self.items2), i)
156
157 def test_sub_and_super(self):
Meador Inge104f1892012-03-04 22:02:17 -0600158 self.assertTrue(self.ab_weakset <= self.abcde_weakset)
159 self.assertTrue(self.abcde_weakset <= self.abcde_weakset)
160 self.assertTrue(self.abcde_weakset >= self.ab_weakset)
161 self.assertFalse(self.abcde_weakset <= self.def_weakset)
162 self.assertFalse(self.abcde_weakset >= self.def_weakset)
Michael Foorde6410c52010-03-29 20:04:23 +0000163 self.assertTrue(set('a').issubset('abc'))
164 self.assertTrue(set('abc').issuperset('a'))
165 self.assertFalse(set('a').issubset('cbs'))
166 self.assertFalse(set('cbs').issuperset('a'))
167
Meador Inge104f1892012-03-04 22:02:17 -0600168 def test_lt(self):
169 self.assertTrue(self.ab_weakset < self.abcde_weakset)
170 self.assertFalse(self.abcde_weakset < self.def_weakset)
171 self.assertFalse(self.ab_weakset < self.ab_weakset)
172 self.assertFalse(WeakSet() < WeakSet())
173
174 def test_gt(self):
175 self.assertTrue(self.abcde_weakset > self.ab_weakset)
176 self.assertFalse(self.abcde_weakset > self.def_weakset)
177 self.assertFalse(self.ab_weakset > self.ab_weakset)
178 self.assertFalse(WeakSet() > WeakSet())
179
Michael Foorde6410c52010-03-29 20:04:23 +0000180 def test_gc(self):
181 # Create a nest of cycles to exercise overall ref count check
182 s = WeakSet(Foo() for i in range(1000))
183 for elem in s:
184 elem.cycle = s
185 elem.sub = elem
186 elem.set = WeakSet([elem])
187
188 def test_subclass_with_custom_hash(self):
189 # Bug #1257731
190 class H(WeakSet):
191 def __hash__(self):
192 return int(id(self) & 0x7fffffff)
193 s=H()
194 f=set()
195 f.add(s)
196 self.assertIn(s, f)
197 f.remove(s)
198 f.add(s)
199 f.discard(s)
200
201 def test_init(self):
202 s = WeakSet()
203 s.__init__(self.items)
204 self.assertEqual(s, self.s)
205 s.__init__(self.items2)
206 self.assertEqual(s, WeakSet(self.items2))
207 self.assertRaises(TypeError, s.__init__, s, 2);
208 self.assertRaises(TypeError, s.__init__, 1);
209
210 def test_constructor_identity(self):
211 s = WeakSet(self.items)
212 t = WeakSet(s)
213 self.assertNotEqual(id(s), id(t))
214
215 def test_hash(self):
216 self.assertRaises(TypeError, hash, self.s)
217
218 def test_clear(self):
219 self.s.clear()
220 self.assertEqual(self.s, WeakSet([]))
221 self.assertEqual(len(self.s), 0)
222
223 def test_copy(self):
224 dup = self.s.copy()
225 self.assertEqual(self.s, dup)
226 self.assertNotEqual(id(self.s), id(dup))
227
228 def test_add(self):
229 x = SomeClass('Q')
230 self.s.add(x)
231 self.assertIn(x, self.s)
232 dup = self.s.copy()
233 self.s.add(x)
234 self.assertEqual(self.s, dup)
235 self.assertRaises(TypeError, self.s.add, [])
236 self.fs.add(Foo())
237 self.assertTrue(len(self.fs) == 1)
238 self.fs.add(self.obj)
239 self.assertTrue(len(self.fs) == 1)
240
241 def test_remove(self):
242 x = SomeClass('a')
243 self.s.remove(x)
244 self.assertNotIn(x, self.s)
245 self.assertRaises(KeyError, self.s.remove, x)
246 self.assertRaises(TypeError, self.s.remove, [])
247
248 def test_discard(self):
249 a, q = SomeClass('a'), SomeClass('Q')
250 self.s.discard(a)
251 self.assertNotIn(a, self.s)
252 self.s.discard(q)
253 self.assertRaises(TypeError, self.s.discard, [])
254
255 def test_pop(self):
256 for i in range(len(self.s)):
257 elem = self.s.pop()
258 self.assertNotIn(elem, self.s)
259 self.assertRaises(KeyError, self.s.pop)
260
261 def test_update(self):
262 retval = self.s.update(self.items2)
263 self.assertEqual(retval, None)
264 for c in (self.items + self.items2):
265 self.assertIn(c, self.s)
266 self.assertRaises(TypeError, self.s.update, [[]])
267
268 def test_update_set(self):
269 self.s.update(set(self.items2))
270 for c in (self.items + self.items2):
271 self.assertIn(c, self.s)
272
273 def test_ior(self):
274 self.s |= set(self.items2)
275 for c in (self.items + self.items2):
276 self.assertIn(c, self.s)
277
278 def test_intersection_update(self):
279 retval = self.s.intersection_update(self.items2)
280 self.assertEqual(retval, None)
281 for c in (self.items + self.items2):
282 if c in self.items2 and c in self.items:
283 self.assertIn(c, self.s)
284 else:
285 self.assertNotIn(c, self.s)
286 self.assertRaises(TypeError, self.s.intersection_update, [[]])
287
288 def test_iand(self):
289 self.s &= set(self.items2)
290 for c in (self.items + self.items2):
291 if c in self.items2 and c in self.items:
292 self.assertIn(c, self.s)
293 else:
294 self.assertNotIn(c, self.s)
295
296 def test_difference_update(self):
297 retval = self.s.difference_update(self.items2)
298 self.assertEqual(retval, None)
299 for c in (self.items + self.items2):
300 if c in self.items and c not in self.items2:
301 self.assertIn(c, self.s)
302 else:
303 self.assertNotIn(c, self.s)
304 self.assertRaises(TypeError, self.s.difference_update, [[]])
305 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
306
307 def test_isub(self):
308 self.s -= set(self.items2)
309 for c in (self.items + self.items2):
310 if c in self.items and c not in self.items2:
311 self.assertIn(c, self.s)
312 else:
313 self.assertNotIn(c, self.s)
314
315 def test_symmetric_difference_update(self):
316 retval = self.s.symmetric_difference_update(self.items2)
317 self.assertEqual(retval, None)
318 for c in (self.items + self.items2):
319 if (c in self.items) ^ (c in self.items2):
320 self.assertIn(c, self.s)
321 else:
322 self.assertNotIn(c, self.s)
323 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
324
325 def test_ixor(self):
326 self.s ^= set(self.items2)
327 for c in (self.items + self.items2):
328 if (c in self.items) ^ (c in self.items2):
329 self.assertIn(c, self.s)
330 else:
331 self.assertNotIn(c, self.s)
332
333 def test_inplace_on_self(self):
334 t = self.s.copy()
335 t |= t
336 self.assertEqual(t, self.s)
337 t &= t
338 self.assertEqual(t, self.s)
339 t -= t
340 self.assertEqual(t, WeakSet())
341 t = self.s.copy()
342 t ^= t
343 self.assertEqual(t, WeakSet())
344
345 def test_eq(self):
346 # issue 5964
347 self.assertTrue(self.s == self.s)
348 self.assertTrue(self.s == WeakSet(self.items))
349 self.assertFalse(self.s == set(self.items))
350 self.assertFalse(self.s == list(self.items))
351 self.assertFalse(self.s == tuple(self.items))
352 self.assertFalse(self.s == 1)
353
354 def test_weak_destroy_while_iterating(self):
355 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
356 # Create new items to be sure no-one else holds a reference
357 items = [SomeClass(c) for c in ('a', 'b', 'c')]
358 s = WeakSet(items)
359 it = iter(s)
360 next(it) # Trigger internal iteration
361 # Destroy an item
362 del items[-1]
363 gc.collect() # just in case
364 # We have removed either the first consumed items, or another one
365 self.assertIn(len(list(it)), [len(items), len(items) - 1])
366 del it
367 # The removal has been committed
368 self.assertEqual(len(s), len(items))
369
370 def test_weak_destroy_and_mutate_while_iterating(self):
371 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
372 items = [SomeClass(c) for c in string.ascii_letters]
373 s = WeakSet(items)
374 @contextlib.contextmanager
375 def testcontext():
376 try:
377 it = iter(s)
378 next(it)
379 # Schedule an item for removal and recreate it
380 u = SomeClass(str(items.pop()))
381 gc.collect() # just in case
382 yield u
383 finally:
384 it = None # should commit all removals
385
386 with testcontext() as u:
387 self.assertNotIn(u, s)
388 with testcontext() as u:
389 self.assertRaises(KeyError, s.remove, u)
390 self.assertNotIn(u, s)
391 with testcontext() as u:
392 s.add(u)
393 self.assertIn(u, s)
394 t = s.copy()
395 with testcontext() as u:
396 s.update(t)
397 self.assertEqual(len(s), len(t))
398 with testcontext() as u:
399 s.clear()
400 self.assertEqual(len(s), 0)
401
Antoine Pitrouc56bca32012-03-01 16:26:35 +0100402 def test_len_cycles(self):
403 N = 20
404 items = [RefCycle() for i in range(N)]
405 s = WeakSet(items)
406 del items
407 it = iter(s)
408 try:
409 next(it)
410 except StopIteration:
411 pass
412 gc.collect()
413 n1 = len(s)
414 del it
415 gc.collect()
416 n2 = len(s)
417 # one item may be kept alive inside the iterator
418 self.assertIn(n1, (0, 1))
419 self.assertEqual(n2, 0)
420
421 def test_len_race(self):
422 # Extended sanity checks for len() in the face of cyclic collection
423 self.addCleanup(gc.set_threshold, *gc.get_threshold())
424 for th in range(1, 100):
425 N = 20
426 gc.collect(0)
427 gc.set_threshold(th, th, th)
428 items = [RefCycle() for i in range(N)]
429 s = WeakSet(items)
430 del items
431 # All items will be collected at next garbage collection pass
432 it = iter(s)
433 try:
434 next(it)
435 except StopIteration:
436 pass
437 n1 = len(s)
438 del it
439 n2 = len(s)
440 self.assertGreaterEqual(n1, 0)
441 self.assertLessEqual(n1, N)
442 self.assertGreaterEqual(n2, 0)
443 self.assertLessEqual(n2, n1)
444
Michael Foorde6410c52010-03-29 20:04:23 +0000445
446def test_main(verbose=None):
447 test_support.run_unittest(TestWeakSet)
448
449if __name__ == "__main__":
450 test_main(verbose=True)