blob: f981bddb19337bfc15100be1b93ea57d8d10c493 [file] [log] [blame]
Michael Foorde6410c52010-03-29 20:04:23 +00001import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17 pass
18
19class SomeClass(object):
20 def __init__(self, value):
21 self.value = value
22 def __eq__(self, other):
23 if type(other) != type(self):
24 return False
25 return other.value == self.value
26
27 def __ne__(self, other):
28 return not self.__eq__(other)
29
30 def __hash__(self):
31 return hash((SomeClass, self.value))
32
Antoine Pitrouc56bca32012-03-01 16:26:35 +010033class RefCycle(object):
34 def __init__(self):
35 self.cycle = self
36
Michael Foorde6410c52010-03-29 20:04:23 +000037class TestWeakSet(unittest.TestCase):
38
39 def setUp(self):
40 # need to keep references to them
41 self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
42 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
43 self.letters = [SomeClass(c) for c in string.ascii_letters]
44 self.s = WeakSet(self.items)
45 self.d = dict.fromkeys(self.items)
46 self.obj = SomeClass('F')
47 self.fs = WeakSet([self.obj])
48
49 def test_methods(self):
50 weaksetmethods = dir(WeakSet)
51 for method in dir(set):
52 if method == 'test_c_api' or method.startswith('_'):
53 continue
54 self.assertIn(method, weaksetmethods,
55 "WeakSet missing method " + method)
56
57 def test_new_or_init(self):
58 self.assertRaises(TypeError, WeakSet, [], 2)
59
60 def test_len(self):
61 self.assertEqual(len(self.s), len(self.d))
62 self.assertEqual(len(self.fs), 1)
63 del self.obj
64 self.assertEqual(len(self.fs), 0)
65
66 def test_contains(self):
67 for c in self.letters:
68 self.assertEqual(c in self.s, c in self.d)
Georg Brandl52f83952011-02-25 10:39:23 +000069 # 1 is not weakref'able, but that TypeError is caught by __contains__
70 self.assertNotIn(1, self.s)
Michael Foorde6410c52010-03-29 20:04:23 +000071 self.assertIn(self.obj, self.fs)
72 del self.obj
73 self.assertNotIn(SomeClass('F'), self.fs)
74
75 def test_union(self):
76 u = self.s.union(self.items2)
77 for c in self.letters:
78 self.assertEqual(c in u, c in self.d or c in self.items2)
79 self.assertEqual(self.s, WeakSet(self.items))
80 self.assertEqual(type(u), WeakSet)
81 self.assertRaises(TypeError, self.s.union, [[]])
82 for C in set, frozenset, dict.fromkeys, list, tuple:
83 x = WeakSet(self.items + self.items2)
84 c = C(self.items2)
85 self.assertEqual(self.s.union(c), x)
86
87 def test_or(self):
88 i = self.s.union(self.items2)
89 self.assertEqual(self.s | set(self.items2), i)
90 self.assertEqual(self.s | frozenset(self.items2), i)
91
92 def test_intersection(self):
93 i = self.s.intersection(self.items2)
94 for c in self.letters:
95 self.assertEqual(c in i, c in self.d and c in self.items2)
96 self.assertEqual(self.s, WeakSet(self.items))
97 self.assertEqual(type(i), WeakSet)
98 for C in set, frozenset, dict.fromkeys, list, tuple:
99 x = WeakSet([])
100 self.assertEqual(self.s.intersection(C(self.items2)), x)
101
102 def test_isdisjoint(self):
103 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
104 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
105
106 def test_and(self):
107 i = self.s.intersection(self.items2)
108 self.assertEqual(self.s & set(self.items2), i)
109 self.assertEqual(self.s & frozenset(self.items2), i)
110
111 def test_difference(self):
112 i = self.s.difference(self.items2)
113 for c in self.letters:
114 self.assertEqual(c in i, c in self.d and c not in self.items2)
115 self.assertEqual(self.s, WeakSet(self.items))
116 self.assertEqual(type(i), WeakSet)
117 self.assertRaises(TypeError, self.s.difference, [[]])
118
119 def test_sub(self):
120 i = self.s.difference(self.items2)
121 self.assertEqual(self.s - set(self.items2), i)
122 self.assertEqual(self.s - frozenset(self.items2), i)
123
124 def test_symmetric_difference(self):
125 i = self.s.symmetric_difference(self.items2)
126 for c in self.letters:
127 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
128 self.assertEqual(self.s, WeakSet(self.items))
129 self.assertEqual(type(i), WeakSet)
130 self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
131
132 def test_xor(self):
133 i = self.s.symmetric_difference(self.items2)
134 self.assertEqual(self.s ^ set(self.items2), i)
135 self.assertEqual(self.s ^ frozenset(self.items2), i)
136
137 def test_sub_and_super(self):
138 pl, ql, rl = map(lambda s: [SomeClass(c) for c in s], ['ab', 'abcde', 'def'])
139 p, q, r = map(WeakSet, (pl, ql, rl))
140 self.assertTrue(p < q)
141 self.assertTrue(p <= q)
142 self.assertTrue(q <= q)
143 self.assertTrue(q > p)
144 self.assertTrue(q >= p)
145 self.assertFalse(q < r)
146 self.assertFalse(q <= r)
147 self.assertFalse(q > r)
148 self.assertFalse(q >= r)
149 self.assertTrue(set('a').issubset('abc'))
150 self.assertTrue(set('abc').issuperset('a'))
151 self.assertFalse(set('a').issubset('cbs'))
152 self.assertFalse(set('cbs').issuperset('a'))
153
154 def test_gc(self):
155 # Create a nest of cycles to exercise overall ref count check
156 s = WeakSet(Foo() for i in range(1000))
157 for elem in s:
158 elem.cycle = s
159 elem.sub = elem
160 elem.set = WeakSet([elem])
161
162 def test_subclass_with_custom_hash(self):
163 # Bug #1257731
164 class H(WeakSet):
165 def __hash__(self):
166 return int(id(self) & 0x7fffffff)
167 s=H()
168 f=set()
169 f.add(s)
170 self.assertIn(s, f)
171 f.remove(s)
172 f.add(s)
173 f.discard(s)
174
175 def test_init(self):
176 s = WeakSet()
177 s.__init__(self.items)
178 self.assertEqual(s, self.s)
179 s.__init__(self.items2)
180 self.assertEqual(s, WeakSet(self.items2))
181 self.assertRaises(TypeError, s.__init__, s, 2);
182 self.assertRaises(TypeError, s.__init__, 1);
183
184 def test_constructor_identity(self):
185 s = WeakSet(self.items)
186 t = WeakSet(s)
187 self.assertNotEqual(id(s), id(t))
188
189 def test_hash(self):
190 self.assertRaises(TypeError, hash, self.s)
191
192 def test_clear(self):
193 self.s.clear()
194 self.assertEqual(self.s, WeakSet([]))
195 self.assertEqual(len(self.s), 0)
196
197 def test_copy(self):
198 dup = self.s.copy()
199 self.assertEqual(self.s, dup)
200 self.assertNotEqual(id(self.s), id(dup))
201
202 def test_add(self):
203 x = SomeClass('Q')
204 self.s.add(x)
205 self.assertIn(x, self.s)
206 dup = self.s.copy()
207 self.s.add(x)
208 self.assertEqual(self.s, dup)
209 self.assertRaises(TypeError, self.s.add, [])
210 self.fs.add(Foo())
211 self.assertTrue(len(self.fs) == 1)
212 self.fs.add(self.obj)
213 self.assertTrue(len(self.fs) == 1)
214
215 def test_remove(self):
216 x = SomeClass('a')
217 self.s.remove(x)
218 self.assertNotIn(x, self.s)
219 self.assertRaises(KeyError, self.s.remove, x)
220 self.assertRaises(TypeError, self.s.remove, [])
221
222 def test_discard(self):
223 a, q = SomeClass('a'), SomeClass('Q')
224 self.s.discard(a)
225 self.assertNotIn(a, self.s)
226 self.s.discard(q)
227 self.assertRaises(TypeError, self.s.discard, [])
228
229 def test_pop(self):
230 for i in range(len(self.s)):
231 elem = self.s.pop()
232 self.assertNotIn(elem, self.s)
233 self.assertRaises(KeyError, self.s.pop)
234
235 def test_update(self):
236 retval = self.s.update(self.items2)
237 self.assertEqual(retval, None)
238 for c in (self.items + self.items2):
239 self.assertIn(c, self.s)
240 self.assertRaises(TypeError, self.s.update, [[]])
241
242 def test_update_set(self):
243 self.s.update(set(self.items2))
244 for c in (self.items + self.items2):
245 self.assertIn(c, self.s)
246
247 def test_ior(self):
248 self.s |= set(self.items2)
249 for c in (self.items + self.items2):
250 self.assertIn(c, self.s)
251
252 def test_intersection_update(self):
253 retval = self.s.intersection_update(self.items2)
254 self.assertEqual(retval, None)
255 for c in (self.items + self.items2):
256 if c in self.items2 and c in self.items:
257 self.assertIn(c, self.s)
258 else:
259 self.assertNotIn(c, self.s)
260 self.assertRaises(TypeError, self.s.intersection_update, [[]])
261
262 def test_iand(self):
263 self.s &= set(self.items2)
264 for c in (self.items + self.items2):
265 if c in self.items2 and c in self.items:
266 self.assertIn(c, self.s)
267 else:
268 self.assertNotIn(c, self.s)
269
270 def test_difference_update(self):
271 retval = self.s.difference_update(self.items2)
272 self.assertEqual(retval, None)
273 for c in (self.items + self.items2):
274 if c in self.items and c not in self.items2:
275 self.assertIn(c, self.s)
276 else:
277 self.assertNotIn(c, self.s)
278 self.assertRaises(TypeError, self.s.difference_update, [[]])
279 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
280
281 def test_isub(self):
282 self.s -= set(self.items2)
283 for c in (self.items + self.items2):
284 if c in self.items and c not in self.items2:
285 self.assertIn(c, self.s)
286 else:
287 self.assertNotIn(c, self.s)
288
289 def test_symmetric_difference_update(self):
290 retval = self.s.symmetric_difference_update(self.items2)
291 self.assertEqual(retval, None)
292 for c in (self.items + self.items2):
293 if (c in self.items) ^ (c in self.items2):
294 self.assertIn(c, self.s)
295 else:
296 self.assertNotIn(c, self.s)
297 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
298
299 def test_ixor(self):
300 self.s ^= set(self.items2)
301 for c in (self.items + self.items2):
302 if (c in self.items) ^ (c in self.items2):
303 self.assertIn(c, self.s)
304 else:
305 self.assertNotIn(c, self.s)
306
307 def test_inplace_on_self(self):
308 t = self.s.copy()
309 t |= t
310 self.assertEqual(t, self.s)
311 t &= t
312 self.assertEqual(t, self.s)
313 t -= t
314 self.assertEqual(t, WeakSet())
315 t = self.s.copy()
316 t ^= t
317 self.assertEqual(t, WeakSet())
318
319 def test_eq(self):
320 # issue 5964
321 self.assertTrue(self.s == self.s)
322 self.assertTrue(self.s == WeakSet(self.items))
323 self.assertFalse(self.s == set(self.items))
324 self.assertFalse(self.s == list(self.items))
325 self.assertFalse(self.s == tuple(self.items))
326 self.assertFalse(self.s == 1)
327
328 def test_weak_destroy_while_iterating(self):
329 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
330 # Create new items to be sure no-one else holds a reference
331 items = [SomeClass(c) for c in ('a', 'b', 'c')]
332 s = WeakSet(items)
333 it = iter(s)
334 next(it) # Trigger internal iteration
335 # Destroy an item
336 del items[-1]
337 gc.collect() # just in case
338 # We have removed either the first consumed items, or another one
339 self.assertIn(len(list(it)), [len(items), len(items) - 1])
340 del it
341 # The removal has been committed
342 self.assertEqual(len(s), len(items))
343
344 def test_weak_destroy_and_mutate_while_iterating(self):
345 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
346 items = [SomeClass(c) for c in string.ascii_letters]
347 s = WeakSet(items)
348 @contextlib.contextmanager
349 def testcontext():
350 try:
351 it = iter(s)
352 next(it)
353 # Schedule an item for removal and recreate it
354 u = SomeClass(str(items.pop()))
355 gc.collect() # just in case
356 yield u
357 finally:
358 it = None # should commit all removals
359
360 with testcontext() as u:
361 self.assertNotIn(u, s)
362 with testcontext() as u:
363 self.assertRaises(KeyError, s.remove, u)
364 self.assertNotIn(u, s)
365 with testcontext() as u:
366 s.add(u)
367 self.assertIn(u, s)
368 t = s.copy()
369 with testcontext() as u:
370 s.update(t)
371 self.assertEqual(len(s), len(t))
372 with testcontext() as u:
373 s.clear()
374 self.assertEqual(len(s), 0)
375
Antoine Pitrouc56bca32012-03-01 16:26:35 +0100376 def test_len_cycles(self):
377 N = 20
378 items = [RefCycle() for i in range(N)]
379 s = WeakSet(items)
380 del items
381 it = iter(s)
382 try:
383 next(it)
384 except StopIteration:
385 pass
386 gc.collect()
387 n1 = len(s)
388 del it
389 gc.collect()
390 n2 = len(s)
391 # one item may be kept alive inside the iterator
392 self.assertIn(n1, (0, 1))
393 self.assertEqual(n2, 0)
394
395 def test_len_race(self):
396 # Extended sanity checks for len() in the face of cyclic collection
397 self.addCleanup(gc.set_threshold, *gc.get_threshold())
398 for th in range(1, 100):
399 N = 20
400 gc.collect(0)
401 gc.set_threshold(th, th, th)
402 items = [RefCycle() for i in range(N)]
403 s = WeakSet(items)
404 del items
405 # All items will be collected at next garbage collection pass
406 it = iter(s)
407 try:
408 next(it)
409 except StopIteration:
410 pass
411 n1 = len(s)
412 del it
413 n2 = len(s)
414 self.assertGreaterEqual(n1, 0)
415 self.assertLessEqual(n1, N)
416 self.assertGreaterEqual(n2, 0)
417 self.assertLessEqual(n2, n1)
418
Michael Foorde6410c52010-03-29 20:04:23 +0000419
420def test_main(verbose=None):
421 test_support.run_unittest(TestWeakSet)
422
423if __name__ == "__main__":
424 test_main(verbose=True)