blob: 89c2822b6ee2b5801e3300eb642ca539ac37bed3 [file] [log] [blame]
Michael Foorde6410c52010-03-29 20:04:23 +00001import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17 pass
18
19class SomeClass(object):
20 def __init__(self, value):
21 self.value = value
22 def __eq__(self, other):
23 if type(other) != type(self):
24 return False
25 return other.value == self.value
26
27 def __ne__(self, other):
28 return not self.__eq__(other)
29
30 def __hash__(self):
31 return hash((SomeClass, self.value))
32
33class TestWeakSet(unittest.TestCase):
34
35 def setUp(self):
36 # need to keep references to them
37 self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
38 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
39 self.letters = [SomeClass(c) for c in string.ascii_letters]
40 self.s = WeakSet(self.items)
41 self.d = dict.fromkeys(self.items)
42 self.obj = SomeClass('F')
43 self.fs = WeakSet([self.obj])
44
45 def test_methods(self):
46 weaksetmethods = dir(WeakSet)
47 for method in dir(set):
48 if method == 'test_c_api' or method.startswith('_'):
49 continue
50 self.assertIn(method, weaksetmethods,
51 "WeakSet missing method " + method)
52
53 def test_new_or_init(self):
54 self.assertRaises(TypeError, WeakSet, [], 2)
55
56 def test_len(self):
57 self.assertEqual(len(self.s), len(self.d))
58 self.assertEqual(len(self.fs), 1)
59 del self.obj
60 self.assertEqual(len(self.fs), 0)
61
62 def test_contains(self):
63 for c in self.letters:
64 self.assertEqual(c in self.s, c in self.d)
Georg Brandl52f83952011-02-25 10:39:23 +000065 # 1 is not weakref'able, but that TypeError is caught by __contains__
66 self.assertNotIn(1, self.s)
Michael Foorde6410c52010-03-29 20:04:23 +000067 self.assertIn(self.obj, self.fs)
68 del self.obj
69 self.assertNotIn(SomeClass('F'), self.fs)
70
71 def test_union(self):
72 u = self.s.union(self.items2)
73 for c in self.letters:
74 self.assertEqual(c in u, c in self.d or c in self.items2)
75 self.assertEqual(self.s, WeakSet(self.items))
76 self.assertEqual(type(u), WeakSet)
77 self.assertRaises(TypeError, self.s.union, [[]])
78 for C in set, frozenset, dict.fromkeys, list, tuple:
79 x = WeakSet(self.items + self.items2)
80 c = C(self.items2)
81 self.assertEqual(self.s.union(c), x)
82
83 def test_or(self):
84 i = self.s.union(self.items2)
85 self.assertEqual(self.s | set(self.items2), i)
86 self.assertEqual(self.s | frozenset(self.items2), i)
87
88 def test_intersection(self):
89 i = self.s.intersection(self.items2)
90 for c in self.letters:
91 self.assertEqual(c in i, c in self.d and c in self.items2)
92 self.assertEqual(self.s, WeakSet(self.items))
93 self.assertEqual(type(i), WeakSet)
94 for C in set, frozenset, dict.fromkeys, list, tuple:
95 x = WeakSet([])
96 self.assertEqual(self.s.intersection(C(self.items2)), x)
97
98 def test_isdisjoint(self):
99 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
100 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
101
102 def test_and(self):
103 i = self.s.intersection(self.items2)
104 self.assertEqual(self.s & set(self.items2), i)
105 self.assertEqual(self.s & frozenset(self.items2), i)
106
107 def test_difference(self):
108 i = self.s.difference(self.items2)
109 for c in self.letters:
110 self.assertEqual(c in i, c in self.d and c not in self.items2)
111 self.assertEqual(self.s, WeakSet(self.items))
112 self.assertEqual(type(i), WeakSet)
113 self.assertRaises(TypeError, self.s.difference, [[]])
114
115 def test_sub(self):
116 i = self.s.difference(self.items2)
117 self.assertEqual(self.s - set(self.items2), i)
118 self.assertEqual(self.s - frozenset(self.items2), i)
119
120 def test_symmetric_difference(self):
121 i = self.s.symmetric_difference(self.items2)
122 for c in self.letters:
123 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
124 self.assertEqual(self.s, WeakSet(self.items))
125 self.assertEqual(type(i), WeakSet)
126 self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
127
128 def test_xor(self):
129 i = self.s.symmetric_difference(self.items2)
130 self.assertEqual(self.s ^ set(self.items2), i)
131 self.assertEqual(self.s ^ frozenset(self.items2), i)
132
133 def test_sub_and_super(self):
134 pl, ql, rl = map(lambda s: [SomeClass(c) for c in s], ['ab', 'abcde', 'def'])
135 p, q, r = map(WeakSet, (pl, ql, rl))
136 self.assertTrue(p < q)
137 self.assertTrue(p <= q)
138 self.assertTrue(q <= q)
139 self.assertTrue(q > p)
140 self.assertTrue(q >= p)
141 self.assertFalse(q < r)
142 self.assertFalse(q <= r)
143 self.assertFalse(q > r)
144 self.assertFalse(q >= r)
145 self.assertTrue(set('a').issubset('abc'))
146 self.assertTrue(set('abc').issuperset('a'))
147 self.assertFalse(set('a').issubset('cbs'))
148 self.assertFalse(set('cbs').issuperset('a'))
149
150 def test_gc(self):
151 # Create a nest of cycles to exercise overall ref count check
152 s = WeakSet(Foo() for i in range(1000))
153 for elem in s:
154 elem.cycle = s
155 elem.sub = elem
156 elem.set = WeakSet([elem])
157
158 def test_subclass_with_custom_hash(self):
159 # Bug #1257731
160 class H(WeakSet):
161 def __hash__(self):
162 return int(id(self) & 0x7fffffff)
163 s=H()
164 f=set()
165 f.add(s)
166 self.assertIn(s, f)
167 f.remove(s)
168 f.add(s)
169 f.discard(s)
170
171 def test_init(self):
172 s = WeakSet()
173 s.__init__(self.items)
174 self.assertEqual(s, self.s)
175 s.__init__(self.items2)
176 self.assertEqual(s, WeakSet(self.items2))
177 self.assertRaises(TypeError, s.__init__, s, 2);
178 self.assertRaises(TypeError, s.__init__, 1);
179
180 def test_constructor_identity(self):
181 s = WeakSet(self.items)
182 t = WeakSet(s)
183 self.assertNotEqual(id(s), id(t))
184
185 def test_hash(self):
186 self.assertRaises(TypeError, hash, self.s)
187
188 def test_clear(self):
189 self.s.clear()
190 self.assertEqual(self.s, WeakSet([]))
191 self.assertEqual(len(self.s), 0)
192
193 def test_copy(self):
194 dup = self.s.copy()
195 self.assertEqual(self.s, dup)
196 self.assertNotEqual(id(self.s), id(dup))
197
198 def test_add(self):
199 x = SomeClass('Q')
200 self.s.add(x)
201 self.assertIn(x, self.s)
202 dup = self.s.copy()
203 self.s.add(x)
204 self.assertEqual(self.s, dup)
205 self.assertRaises(TypeError, self.s.add, [])
206 self.fs.add(Foo())
207 self.assertTrue(len(self.fs) == 1)
208 self.fs.add(self.obj)
209 self.assertTrue(len(self.fs) == 1)
210
211 def test_remove(self):
212 x = SomeClass('a')
213 self.s.remove(x)
214 self.assertNotIn(x, self.s)
215 self.assertRaises(KeyError, self.s.remove, x)
216 self.assertRaises(TypeError, self.s.remove, [])
217
218 def test_discard(self):
219 a, q = SomeClass('a'), SomeClass('Q')
220 self.s.discard(a)
221 self.assertNotIn(a, self.s)
222 self.s.discard(q)
223 self.assertRaises(TypeError, self.s.discard, [])
224
225 def test_pop(self):
226 for i in range(len(self.s)):
227 elem = self.s.pop()
228 self.assertNotIn(elem, self.s)
229 self.assertRaises(KeyError, self.s.pop)
230
231 def test_update(self):
232 retval = self.s.update(self.items2)
233 self.assertEqual(retval, None)
234 for c in (self.items + self.items2):
235 self.assertIn(c, self.s)
236 self.assertRaises(TypeError, self.s.update, [[]])
237
238 def test_update_set(self):
239 self.s.update(set(self.items2))
240 for c in (self.items + self.items2):
241 self.assertIn(c, self.s)
242
243 def test_ior(self):
244 self.s |= set(self.items2)
245 for c in (self.items + self.items2):
246 self.assertIn(c, self.s)
247
248 def test_intersection_update(self):
249 retval = self.s.intersection_update(self.items2)
250 self.assertEqual(retval, None)
251 for c in (self.items + self.items2):
252 if c in self.items2 and c in self.items:
253 self.assertIn(c, self.s)
254 else:
255 self.assertNotIn(c, self.s)
256 self.assertRaises(TypeError, self.s.intersection_update, [[]])
257
258 def test_iand(self):
259 self.s &= set(self.items2)
260 for c in (self.items + self.items2):
261 if c in self.items2 and c in self.items:
262 self.assertIn(c, self.s)
263 else:
264 self.assertNotIn(c, self.s)
265
266 def test_difference_update(self):
267 retval = self.s.difference_update(self.items2)
268 self.assertEqual(retval, None)
269 for c in (self.items + self.items2):
270 if c in self.items and c not in self.items2:
271 self.assertIn(c, self.s)
272 else:
273 self.assertNotIn(c, self.s)
274 self.assertRaises(TypeError, self.s.difference_update, [[]])
275 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
276
277 def test_isub(self):
278 self.s -= set(self.items2)
279 for c in (self.items + self.items2):
280 if c in self.items and c not in self.items2:
281 self.assertIn(c, self.s)
282 else:
283 self.assertNotIn(c, self.s)
284
285 def test_symmetric_difference_update(self):
286 retval = self.s.symmetric_difference_update(self.items2)
287 self.assertEqual(retval, None)
288 for c in (self.items + self.items2):
289 if (c in self.items) ^ (c in self.items2):
290 self.assertIn(c, self.s)
291 else:
292 self.assertNotIn(c, self.s)
293 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
294
295 def test_ixor(self):
296 self.s ^= set(self.items2)
297 for c in (self.items + self.items2):
298 if (c in self.items) ^ (c in self.items2):
299 self.assertIn(c, self.s)
300 else:
301 self.assertNotIn(c, self.s)
302
303 def test_inplace_on_self(self):
304 t = self.s.copy()
305 t |= t
306 self.assertEqual(t, self.s)
307 t &= t
308 self.assertEqual(t, self.s)
309 t -= t
310 self.assertEqual(t, WeakSet())
311 t = self.s.copy()
312 t ^= t
313 self.assertEqual(t, WeakSet())
314
315 def test_eq(self):
316 # issue 5964
317 self.assertTrue(self.s == self.s)
318 self.assertTrue(self.s == WeakSet(self.items))
319 self.assertFalse(self.s == set(self.items))
320 self.assertFalse(self.s == list(self.items))
321 self.assertFalse(self.s == tuple(self.items))
322 self.assertFalse(self.s == 1)
323
324 def test_weak_destroy_while_iterating(self):
325 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
326 # Create new items to be sure no-one else holds a reference
327 items = [SomeClass(c) for c in ('a', 'b', 'c')]
328 s = WeakSet(items)
329 it = iter(s)
330 next(it) # Trigger internal iteration
331 # Destroy an item
332 del items[-1]
333 gc.collect() # just in case
334 # We have removed either the first consumed items, or another one
335 self.assertIn(len(list(it)), [len(items), len(items) - 1])
336 del it
337 # The removal has been committed
338 self.assertEqual(len(s), len(items))
339
340 def test_weak_destroy_and_mutate_while_iterating(self):
341 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
342 items = [SomeClass(c) for c in string.ascii_letters]
343 s = WeakSet(items)
344 @contextlib.contextmanager
345 def testcontext():
346 try:
347 it = iter(s)
348 next(it)
349 # Schedule an item for removal and recreate it
350 u = SomeClass(str(items.pop()))
351 gc.collect() # just in case
352 yield u
353 finally:
354 it = None # should commit all removals
355
356 with testcontext() as u:
357 self.assertNotIn(u, s)
358 with testcontext() as u:
359 self.assertRaises(KeyError, s.remove, u)
360 self.assertNotIn(u, s)
361 with testcontext() as u:
362 s.add(u)
363 self.assertIn(u, s)
364 t = s.copy()
365 with testcontext() as u:
366 s.update(t)
367 self.assertEqual(len(s), len(t))
368 with testcontext() as u:
369 s.clear()
370 self.assertEqual(len(s), 0)
371
372
373def test_main(verbose=None):
374 test_support.run_unittest(TestWeakSet)
375
376if __name__ == "__main__":
377 test_main(verbose=True)