blob: 1f82a7dda4dde81601dd086f21b85671a4a79fe7 [file] [log] [blame]
Michael Foorde6410c52010-03-29 20:04:23 +00001import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17 pass
18
19class SomeClass(object):
20 def __init__(self, value):
21 self.value = value
22 def __eq__(self, other):
23 if type(other) != type(self):
24 return False
25 return other.value == self.value
26
27 def __ne__(self, other):
28 return not self.__eq__(other)
29
30 def __hash__(self):
31 return hash((SomeClass, self.value))
32
Antoine Pitrouc56bca32012-03-01 16:26:35 +010033class RefCycle(object):
34 def __init__(self):
35 self.cycle = self
36
Michael Foorde6410c52010-03-29 20:04:23 +000037class TestWeakSet(unittest.TestCase):
38
39 def setUp(self):
40 # need to keep references to them
41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
43 self.letters = [SomeClass(c) for c in string.ascii_letters]
44 self.s = WeakSet(self.items)
45 self.d = dict.fromkeys(self.items)
46 self.obj = SomeClass('F')
47 self.fs = WeakSet([self.obj])
48
49 def test_methods(self):
50 weaksetmethods = dir(WeakSet)
51 for method in dir(set):
52 if method == 'test_c_api' or method.startswith('_'):
53 continue
54 self.assertIn(method, weaksetmethods,
55 "WeakSet missing method " + method)
56
57 def test_new_or_init(self):
58 self.assertRaises(TypeError, WeakSet, [], 2)
59
60 def test_len(self):
61 self.assertEqual(len(self.s), len(self.d))
62 self.assertEqual(len(self.fs), 1)
63 del self.obj
64 self.assertEqual(len(self.fs), 0)
65
66 def test_contains(self):
67 for c in self.letters:
68 self.assertEqual(c in self.s, c in self.d)
Georg Brandl52f83952011-02-25 10:39:23 +000069 # 1 is not weakref'able, but that TypeError is caught by __contains__
70 self.assertNotIn(1, self.s)
Michael Foorde6410c52010-03-29 20:04:23 +000071 self.assertIn(self.obj, self.fs)
72 del self.obj
73 self.assertNotIn(SomeClass('F'), self.fs)
74
75 def test_union(self):
76 u = self.s.union(self.items2)
77 for c in self.letters:
78 self.assertEqual(c in u, c in self.d or c in self.items2)
79 self.assertEqual(self.s, WeakSet(self.items))
80 self.assertEqual(type(u), WeakSet)
81 self.assertRaises(TypeError, self.s.union, [[]])
82 for C in set, frozenset, dict.fromkeys, list, tuple:
83 x = WeakSet(self.items + self.items2)
84 c = C(self.items2)
85 self.assertEqual(self.s.union(c), x)
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +010086 del c
87 self.assertEqual(len(u), len(self.items) + len(self.items2))
88 self.items2.pop()
89 gc.collect()
90 self.assertEqual(len(u), len(self.items) + len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +000091
92 def test_or(self):
93 i = self.s.union(self.items2)
94 self.assertEqual(self.s | set(self.items2), i)
95 self.assertEqual(self.s | frozenset(self.items2), i)
96
97 def test_intersection(self):
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +010098 s = WeakSet(self.letters)
99 i = s.intersection(self.items2)
Michael Foorde6410c52010-03-29 20:04:23 +0000100 for c in self.letters:
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100101 self.assertEqual(c in i, c in self.items2 and c in self.letters)
102 self.assertEqual(s, WeakSet(self.letters))
Michael Foorde6410c52010-03-29 20:04:23 +0000103 self.assertEqual(type(i), WeakSet)
104 for C in set, frozenset, dict.fromkeys, list, tuple:
105 x = WeakSet([])
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100106 self.assertEqual(i.intersection(C(self.items)), x)
107 self.assertEqual(len(i), len(self.items2))
108 self.items2.pop()
109 gc.collect()
110 self.assertEqual(len(i), len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +0000111
112 def test_isdisjoint(self):
113 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
114 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
115
116 def test_and(self):
117 i = self.s.intersection(self.items2)
118 self.assertEqual(self.s & set(self.items2), i)
119 self.assertEqual(self.s & frozenset(self.items2), i)
120
121 def test_difference(self):
122 i = self.s.difference(self.items2)
123 for c in self.letters:
124 self.assertEqual(c in i, c in self.d and c not in self.items2)
125 self.assertEqual(self.s, WeakSet(self.items))
126 self.assertEqual(type(i), WeakSet)
127 self.assertRaises(TypeError, self.s.difference, [[]])
128
129 def test_sub(self):
130 i = self.s.difference(self.items2)
131 self.assertEqual(self.s - set(self.items2), i)
132 self.assertEqual(self.s - frozenset(self.items2), i)
133
134 def test_symmetric_difference(self):
135 i = self.s.symmetric_difference(self.items2)
136 for c in self.letters:
137 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
138 self.assertEqual(self.s, WeakSet(self.items))
139 self.assertEqual(type(i), WeakSet)
140 self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
Antoine Pitrou94c2d6df52012-03-04 20:47:05 +0100141 self.assertEqual(len(i), len(self.items) + len(self.items2))
142 self.items2.pop()
143 gc.collect()
144 self.assertEqual(len(i), len(self.items) + len(self.items2))
Michael Foorde6410c52010-03-29 20:04:23 +0000145
146 def test_xor(self):
147 i = self.s.symmetric_difference(self.items2)
148 self.assertEqual(self.s ^ set(self.items2), i)
149 self.assertEqual(self.s ^ frozenset(self.items2), i)
150
151 def test_sub_and_super(self):
152 pl, ql, rl = map(lambda s: [SomeClass(c) for c in s], ['ab', 'abcde', 'def'])
153 p, q, r = map(WeakSet, (pl, ql, rl))
154 self.assertTrue(p < q)
155 self.assertTrue(p <= q)
156 self.assertTrue(q <= q)
157 self.assertTrue(q > p)
158 self.assertTrue(q >= p)
159 self.assertFalse(q < r)
160 self.assertFalse(q <= r)
161 self.assertFalse(q > r)
162 self.assertFalse(q >= r)
163 self.assertTrue(set('a').issubset('abc'))
164 self.assertTrue(set('abc').issuperset('a'))
165 self.assertFalse(set('a').issubset('cbs'))
166 self.assertFalse(set('cbs').issuperset('a'))
167
168 def test_gc(self):
169 # Create a nest of cycles to exercise overall ref count check
170 s = WeakSet(Foo() for i in range(1000))
171 for elem in s:
172 elem.cycle = s
173 elem.sub = elem
174 elem.set = WeakSet([elem])
175
176 def test_subclass_with_custom_hash(self):
177 # Bug #1257731
178 class H(WeakSet):
179 def __hash__(self):
180 return int(id(self) & 0x7fffffff)
181 s=H()
182 f=set()
183 f.add(s)
184 self.assertIn(s, f)
185 f.remove(s)
186 f.add(s)
187 f.discard(s)
188
189 def test_init(self):
190 s = WeakSet()
191 s.__init__(self.items)
192 self.assertEqual(s, self.s)
193 s.__init__(self.items2)
194 self.assertEqual(s, WeakSet(self.items2))
195 self.assertRaises(TypeError, s.__init__, s, 2);
196 self.assertRaises(TypeError, s.__init__, 1);
197
198 def test_constructor_identity(self):
199 s = WeakSet(self.items)
200 t = WeakSet(s)
201 self.assertNotEqual(id(s), id(t))
202
203 def test_hash(self):
204 self.assertRaises(TypeError, hash, self.s)
205
206 def test_clear(self):
207 self.s.clear()
208 self.assertEqual(self.s, WeakSet([]))
209 self.assertEqual(len(self.s), 0)
210
211 def test_copy(self):
212 dup = self.s.copy()
213 self.assertEqual(self.s, dup)
214 self.assertNotEqual(id(self.s), id(dup))
215
216 def test_add(self):
217 x = SomeClass('Q')
218 self.s.add(x)
219 self.assertIn(x, self.s)
220 dup = self.s.copy()
221 self.s.add(x)
222 self.assertEqual(self.s, dup)
223 self.assertRaises(TypeError, self.s.add, [])
224 self.fs.add(Foo())
225 self.assertTrue(len(self.fs) == 1)
226 self.fs.add(self.obj)
227 self.assertTrue(len(self.fs) == 1)
228
229 def test_remove(self):
230 x = SomeClass('a')
231 self.s.remove(x)
232 self.assertNotIn(x, self.s)
233 self.assertRaises(KeyError, self.s.remove, x)
234 self.assertRaises(TypeError, self.s.remove, [])
235
236 def test_discard(self):
237 a, q = SomeClass('a'), SomeClass('Q')
238 self.s.discard(a)
239 self.assertNotIn(a, self.s)
240 self.s.discard(q)
241 self.assertRaises(TypeError, self.s.discard, [])
242
243 def test_pop(self):
244 for i in range(len(self.s)):
245 elem = self.s.pop()
246 self.assertNotIn(elem, self.s)
247 self.assertRaises(KeyError, self.s.pop)
248
249 def test_update(self):
250 retval = self.s.update(self.items2)
251 self.assertEqual(retval, None)
252 for c in (self.items + self.items2):
253 self.assertIn(c, self.s)
254 self.assertRaises(TypeError, self.s.update, [[]])
255
256 def test_update_set(self):
257 self.s.update(set(self.items2))
258 for c in (self.items + self.items2):
259 self.assertIn(c, self.s)
260
261 def test_ior(self):
262 self.s |= set(self.items2)
263 for c in (self.items + self.items2):
264 self.assertIn(c, self.s)
265
266 def test_intersection_update(self):
267 retval = self.s.intersection_update(self.items2)
268 self.assertEqual(retval, None)
269 for c in (self.items + self.items2):
270 if c in self.items2 and c in self.items:
271 self.assertIn(c, self.s)
272 else:
273 self.assertNotIn(c, self.s)
274 self.assertRaises(TypeError, self.s.intersection_update, [[]])
275
276 def test_iand(self):
277 self.s &= set(self.items2)
278 for c in (self.items + self.items2):
279 if c in self.items2 and c in self.items:
280 self.assertIn(c, self.s)
281 else:
282 self.assertNotIn(c, self.s)
283
284 def test_difference_update(self):
285 retval = self.s.difference_update(self.items2)
286 self.assertEqual(retval, None)
287 for c in (self.items + self.items2):
288 if c in self.items and c not in self.items2:
289 self.assertIn(c, self.s)
290 else:
291 self.assertNotIn(c, self.s)
292 self.assertRaises(TypeError, self.s.difference_update, [[]])
293 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
294
295 def test_isub(self):
296 self.s -= set(self.items2)
297 for c in (self.items + self.items2):
298 if c in self.items and c not in self.items2:
299 self.assertIn(c, self.s)
300 else:
301 self.assertNotIn(c, self.s)
302
303 def test_symmetric_difference_update(self):
304 retval = self.s.symmetric_difference_update(self.items2)
305 self.assertEqual(retval, None)
306 for c in (self.items + self.items2):
307 if (c in self.items) ^ (c in self.items2):
308 self.assertIn(c, self.s)
309 else:
310 self.assertNotIn(c, self.s)
311 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
312
313 def test_ixor(self):
314 self.s ^= set(self.items2)
315 for c in (self.items + self.items2):
316 if (c in self.items) ^ (c in self.items2):
317 self.assertIn(c, self.s)
318 else:
319 self.assertNotIn(c, self.s)
320
321 def test_inplace_on_self(self):
322 t = self.s.copy()
323 t |= t
324 self.assertEqual(t, self.s)
325 t &= t
326 self.assertEqual(t, self.s)
327 t -= t
328 self.assertEqual(t, WeakSet())
329 t = self.s.copy()
330 t ^= t
331 self.assertEqual(t, WeakSet())
332
333 def test_eq(self):
334 # issue 5964
335 self.assertTrue(self.s == self.s)
336 self.assertTrue(self.s == WeakSet(self.items))
337 self.assertFalse(self.s == set(self.items))
338 self.assertFalse(self.s == list(self.items))
339 self.assertFalse(self.s == tuple(self.items))
340 self.assertFalse(self.s == 1)
341
342 def test_weak_destroy_while_iterating(self):
343 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
344 # Create new items to be sure no-one else holds a reference
345 items = [SomeClass(c) for c in ('a', 'b', 'c')]
346 s = WeakSet(items)
347 it = iter(s)
348 next(it) # Trigger internal iteration
349 # Destroy an item
350 del items[-1]
351 gc.collect() # just in case
352 # We have removed either the first consumed items, or another one
353 self.assertIn(len(list(it)), [len(items), len(items) - 1])
354 del it
355 # The removal has been committed
356 self.assertEqual(len(s), len(items))
357
358 def test_weak_destroy_and_mutate_while_iterating(self):
359 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
360 items = [SomeClass(c) for c in string.ascii_letters]
361 s = WeakSet(items)
362 @contextlib.contextmanager
363 def testcontext():
364 try:
365 it = iter(s)
366 next(it)
367 # Schedule an item for removal and recreate it
368 u = SomeClass(str(items.pop()))
369 gc.collect() # just in case
370 yield u
371 finally:
372 it = None # should commit all removals
373
374 with testcontext() as u:
375 self.assertNotIn(u, s)
376 with testcontext() as u:
377 self.assertRaises(KeyError, s.remove, u)
378 self.assertNotIn(u, s)
379 with testcontext() as u:
380 s.add(u)
381 self.assertIn(u, s)
382 t = s.copy()
383 with testcontext() as u:
384 s.update(t)
385 self.assertEqual(len(s), len(t))
386 with testcontext() as u:
387 s.clear()
388 self.assertEqual(len(s), 0)
389
Antoine Pitrouc56bca32012-03-01 16:26:35 +0100390 def test_len_cycles(self):
391 N = 20
392 items = [RefCycle() for i in range(N)]
393 s = WeakSet(items)
394 del items
395 it = iter(s)
396 try:
397 next(it)
398 except StopIteration:
399 pass
400 gc.collect()
401 n1 = len(s)
402 del it
403 gc.collect()
404 n2 = len(s)
405 # one item may be kept alive inside the iterator
406 self.assertIn(n1, (0, 1))
407 self.assertEqual(n2, 0)
408
409 def test_len_race(self):
410 # Extended sanity checks for len() in the face of cyclic collection
411 self.addCleanup(gc.set_threshold, *gc.get_threshold())
412 for th in range(1, 100):
413 N = 20
414 gc.collect(0)
415 gc.set_threshold(th, th, th)
416 items = [RefCycle() for i in range(N)]
417 s = WeakSet(items)
418 del items
419 # All items will be collected at next garbage collection pass
420 it = iter(s)
421 try:
422 next(it)
423 except StopIteration:
424 pass
425 n1 = len(s)
426 del it
427 n2 = len(s)
428 self.assertGreaterEqual(n1, 0)
429 self.assertLessEqual(n1, N)
430 self.assertGreaterEqual(n2, 0)
431 self.assertLessEqual(n2, n1)
432
Michael Foorde6410c52010-03-29 20:04:23 +0000433
434def test_main(verbose=None):
435 test_support.run_unittest(TestWeakSet)
436
437if __name__ == "__main__":
438 test_main(verbose=True)