blob: b4774ceaec1237e9020773d3464021783c7a4316 [file] [log] [blame]
Michael Foorde6410c52010-03-29 20:04:23 +00001import unittest
2from test import test_support
3from weakref import proxy, ref, WeakSet
4import operator
5import copy
6import string
7import os
8from random import randrange, shuffle
9import sys
10import warnings
11import collections
12import gc
13import contextlib
14
15
16class Foo:
17 pass
18
19class SomeClass(object):
20 def __init__(self, value):
21 self.value = value
22 def __eq__(self, other):
23 if type(other) != type(self):
24 return False
25 return other.value == self.value
26
27 def __ne__(self, other):
28 return not self.__eq__(other)
29
30 def __hash__(self):
31 return hash((SomeClass, self.value))
32
33class TestWeakSet(unittest.TestCase):
34
35 def setUp(self):
36 # need to keep references to them
37 self.items = [SomeClass(c) for c in ('a', 'b', 'c')]
38 self.items2 = [SomeClass(c) for c in ('x', 'y', 'z')]
39 self.letters = [SomeClass(c) for c in string.ascii_letters]
40 self.s = WeakSet(self.items)
41 self.d = dict.fromkeys(self.items)
42 self.obj = SomeClass('F')
43 self.fs = WeakSet([self.obj])
44
45 def test_methods(self):
46 weaksetmethods = dir(WeakSet)
47 for method in dir(set):
48 if method == 'test_c_api' or method.startswith('_'):
49 continue
50 self.assertIn(method, weaksetmethods,
51 "WeakSet missing method " + method)
52
53 def test_new_or_init(self):
54 self.assertRaises(TypeError, WeakSet, [], 2)
55
56 def test_len(self):
57 self.assertEqual(len(self.s), len(self.d))
58 self.assertEqual(len(self.fs), 1)
59 del self.obj
60 self.assertEqual(len(self.fs), 0)
61
62 def test_contains(self):
63 for c in self.letters:
64 self.assertEqual(c in self.s, c in self.d)
65 self.assertRaises(TypeError, self.s.__contains__, [[]])
66 self.assertIn(self.obj, self.fs)
67 del self.obj
68 self.assertNotIn(SomeClass('F'), self.fs)
69
70 def test_union(self):
71 u = self.s.union(self.items2)
72 for c in self.letters:
73 self.assertEqual(c in u, c in self.d or c in self.items2)
74 self.assertEqual(self.s, WeakSet(self.items))
75 self.assertEqual(type(u), WeakSet)
76 self.assertRaises(TypeError, self.s.union, [[]])
77 for C in set, frozenset, dict.fromkeys, list, tuple:
78 x = WeakSet(self.items + self.items2)
79 c = C(self.items2)
80 self.assertEqual(self.s.union(c), x)
81
82 def test_or(self):
83 i = self.s.union(self.items2)
84 self.assertEqual(self.s | set(self.items2), i)
85 self.assertEqual(self.s | frozenset(self.items2), i)
86
87 def test_intersection(self):
88 i = self.s.intersection(self.items2)
89 for c in self.letters:
90 self.assertEqual(c in i, c in self.d and c in self.items2)
91 self.assertEqual(self.s, WeakSet(self.items))
92 self.assertEqual(type(i), WeakSet)
93 for C in set, frozenset, dict.fromkeys, list, tuple:
94 x = WeakSet([])
95 self.assertEqual(self.s.intersection(C(self.items2)), x)
96
97 def test_isdisjoint(self):
98 self.assertTrue(self.s.isdisjoint(WeakSet(self.items2)))
99 self.assertTrue(not self.s.isdisjoint(WeakSet(self.letters)))
100
101 def test_and(self):
102 i = self.s.intersection(self.items2)
103 self.assertEqual(self.s & set(self.items2), i)
104 self.assertEqual(self.s & frozenset(self.items2), i)
105
106 def test_difference(self):
107 i = self.s.difference(self.items2)
108 for c in self.letters:
109 self.assertEqual(c in i, c in self.d and c not in self.items2)
110 self.assertEqual(self.s, WeakSet(self.items))
111 self.assertEqual(type(i), WeakSet)
112 self.assertRaises(TypeError, self.s.difference, [[]])
113
114 def test_sub(self):
115 i = self.s.difference(self.items2)
116 self.assertEqual(self.s - set(self.items2), i)
117 self.assertEqual(self.s - frozenset(self.items2), i)
118
119 def test_symmetric_difference(self):
120 i = self.s.symmetric_difference(self.items2)
121 for c in self.letters:
122 self.assertEqual(c in i, (c in self.d) ^ (c in self.items2))
123 self.assertEqual(self.s, WeakSet(self.items))
124 self.assertEqual(type(i), WeakSet)
125 self.assertRaises(TypeError, self.s.symmetric_difference, [[]])
126
127 def test_xor(self):
128 i = self.s.symmetric_difference(self.items2)
129 self.assertEqual(self.s ^ set(self.items2), i)
130 self.assertEqual(self.s ^ frozenset(self.items2), i)
131
132 def test_sub_and_super(self):
133 pl, ql, rl = map(lambda s: [SomeClass(c) for c in s], ['ab', 'abcde', 'def'])
134 p, q, r = map(WeakSet, (pl, ql, rl))
135 self.assertTrue(p < q)
136 self.assertTrue(p <= q)
137 self.assertTrue(q <= q)
138 self.assertTrue(q > p)
139 self.assertTrue(q >= p)
140 self.assertFalse(q < r)
141 self.assertFalse(q <= r)
142 self.assertFalse(q > r)
143 self.assertFalse(q >= r)
144 self.assertTrue(set('a').issubset('abc'))
145 self.assertTrue(set('abc').issuperset('a'))
146 self.assertFalse(set('a').issubset('cbs'))
147 self.assertFalse(set('cbs').issuperset('a'))
148
149 def test_gc(self):
150 # Create a nest of cycles to exercise overall ref count check
151 s = WeakSet(Foo() for i in range(1000))
152 for elem in s:
153 elem.cycle = s
154 elem.sub = elem
155 elem.set = WeakSet([elem])
156
157 def test_subclass_with_custom_hash(self):
158 # Bug #1257731
159 class H(WeakSet):
160 def __hash__(self):
161 return int(id(self) & 0x7fffffff)
162 s=H()
163 f=set()
164 f.add(s)
165 self.assertIn(s, f)
166 f.remove(s)
167 f.add(s)
168 f.discard(s)
169
170 def test_init(self):
171 s = WeakSet()
172 s.__init__(self.items)
173 self.assertEqual(s, self.s)
174 s.__init__(self.items2)
175 self.assertEqual(s, WeakSet(self.items2))
176 self.assertRaises(TypeError, s.__init__, s, 2);
177 self.assertRaises(TypeError, s.__init__, 1);
178
179 def test_constructor_identity(self):
180 s = WeakSet(self.items)
181 t = WeakSet(s)
182 self.assertNotEqual(id(s), id(t))
183
184 def test_hash(self):
185 self.assertRaises(TypeError, hash, self.s)
186
187 def test_clear(self):
188 self.s.clear()
189 self.assertEqual(self.s, WeakSet([]))
190 self.assertEqual(len(self.s), 0)
191
192 def test_copy(self):
193 dup = self.s.copy()
194 self.assertEqual(self.s, dup)
195 self.assertNotEqual(id(self.s), id(dup))
196
197 def test_add(self):
198 x = SomeClass('Q')
199 self.s.add(x)
200 self.assertIn(x, self.s)
201 dup = self.s.copy()
202 self.s.add(x)
203 self.assertEqual(self.s, dup)
204 self.assertRaises(TypeError, self.s.add, [])
205 self.fs.add(Foo())
206 self.assertTrue(len(self.fs) == 1)
207 self.fs.add(self.obj)
208 self.assertTrue(len(self.fs) == 1)
209
210 def test_remove(self):
211 x = SomeClass('a')
212 self.s.remove(x)
213 self.assertNotIn(x, self.s)
214 self.assertRaises(KeyError, self.s.remove, x)
215 self.assertRaises(TypeError, self.s.remove, [])
216
217 def test_discard(self):
218 a, q = SomeClass('a'), SomeClass('Q')
219 self.s.discard(a)
220 self.assertNotIn(a, self.s)
221 self.s.discard(q)
222 self.assertRaises(TypeError, self.s.discard, [])
223
224 def test_pop(self):
225 for i in range(len(self.s)):
226 elem = self.s.pop()
227 self.assertNotIn(elem, self.s)
228 self.assertRaises(KeyError, self.s.pop)
229
230 def test_update(self):
231 retval = self.s.update(self.items2)
232 self.assertEqual(retval, None)
233 for c in (self.items + self.items2):
234 self.assertIn(c, self.s)
235 self.assertRaises(TypeError, self.s.update, [[]])
236
237 def test_update_set(self):
238 self.s.update(set(self.items2))
239 for c in (self.items + self.items2):
240 self.assertIn(c, self.s)
241
242 def test_ior(self):
243 self.s |= set(self.items2)
244 for c in (self.items + self.items2):
245 self.assertIn(c, self.s)
246
247 def test_intersection_update(self):
248 retval = self.s.intersection_update(self.items2)
249 self.assertEqual(retval, None)
250 for c in (self.items + self.items2):
251 if c in self.items2 and c in self.items:
252 self.assertIn(c, self.s)
253 else:
254 self.assertNotIn(c, self.s)
255 self.assertRaises(TypeError, self.s.intersection_update, [[]])
256
257 def test_iand(self):
258 self.s &= set(self.items2)
259 for c in (self.items + self.items2):
260 if c in self.items2 and c in self.items:
261 self.assertIn(c, self.s)
262 else:
263 self.assertNotIn(c, self.s)
264
265 def test_difference_update(self):
266 retval = self.s.difference_update(self.items2)
267 self.assertEqual(retval, None)
268 for c in (self.items + self.items2):
269 if c in self.items and c not in self.items2:
270 self.assertIn(c, self.s)
271 else:
272 self.assertNotIn(c, self.s)
273 self.assertRaises(TypeError, self.s.difference_update, [[]])
274 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
275
276 def test_isub(self):
277 self.s -= set(self.items2)
278 for c in (self.items + self.items2):
279 if c in self.items and c not in self.items2:
280 self.assertIn(c, self.s)
281 else:
282 self.assertNotIn(c, self.s)
283
284 def test_symmetric_difference_update(self):
285 retval = self.s.symmetric_difference_update(self.items2)
286 self.assertEqual(retval, None)
287 for c in (self.items + self.items2):
288 if (c in self.items) ^ (c in self.items2):
289 self.assertIn(c, self.s)
290 else:
291 self.assertNotIn(c, self.s)
292 self.assertRaises(TypeError, self.s.symmetric_difference_update, [[]])
293
294 def test_ixor(self):
295 self.s ^= set(self.items2)
296 for c in (self.items + self.items2):
297 if (c in self.items) ^ (c in self.items2):
298 self.assertIn(c, self.s)
299 else:
300 self.assertNotIn(c, self.s)
301
302 def test_inplace_on_self(self):
303 t = self.s.copy()
304 t |= t
305 self.assertEqual(t, self.s)
306 t &= t
307 self.assertEqual(t, self.s)
308 t -= t
309 self.assertEqual(t, WeakSet())
310 t = self.s.copy()
311 t ^= t
312 self.assertEqual(t, WeakSet())
313
314 def test_eq(self):
315 # issue 5964
316 self.assertTrue(self.s == self.s)
317 self.assertTrue(self.s == WeakSet(self.items))
318 self.assertFalse(self.s == set(self.items))
319 self.assertFalse(self.s == list(self.items))
320 self.assertFalse(self.s == tuple(self.items))
321 self.assertFalse(self.s == 1)
322
323 def test_weak_destroy_while_iterating(self):
324 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
325 # Create new items to be sure no-one else holds a reference
326 items = [SomeClass(c) for c in ('a', 'b', 'c')]
327 s = WeakSet(items)
328 it = iter(s)
329 next(it) # Trigger internal iteration
330 # Destroy an item
331 del items[-1]
332 gc.collect() # just in case
333 # We have removed either the first consumed items, or another one
334 self.assertIn(len(list(it)), [len(items), len(items) - 1])
335 del it
336 # The removal has been committed
337 self.assertEqual(len(s), len(items))
338
339 def test_weak_destroy_and_mutate_while_iterating(self):
340 # Issue #7105: iterators shouldn't crash when a key is implicitly removed
341 items = [SomeClass(c) for c in string.ascii_letters]
342 s = WeakSet(items)
343 @contextlib.contextmanager
344 def testcontext():
345 try:
346 it = iter(s)
347 next(it)
348 # Schedule an item for removal and recreate it
349 u = SomeClass(str(items.pop()))
350 gc.collect() # just in case
351 yield u
352 finally:
353 it = None # should commit all removals
354
355 with testcontext() as u:
356 self.assertNotIn(u, s)
357 with testcontext() as u:
358 self.assertRaises(KeyError, s.remove, u)
359 self.assertNotIn(u, s)
360 with testcontext() as u:
361 s.add(u)
362 self.assertIn(u, s)
363 t = s.copy()
364 with testcontext() as u:
365 s.update(t)
366 self.assertEqual(len(s), len(t))
367 with testcontext() as u:
368 s.clear()
369 self.assertEqual(len(s), 0)
370
371
372def test_main(verbose=None):
373 test_support.run_unittest(TestWeakSet)
374
375if __name__ == "__main__":
376 test_main(verbose=True)