blob: 2d44c286665be0eb14a6c1cba46fca56c5a89bd8 [file] [log] [blame]
Guido van Rossumd6cf3af2002-08-19 16:19:15 +00001"""Classes to represent arbitrary sets (including sets of sets).
2
3This module implements sets using dictionaries whose values are
4ignored. The usual operations (union, intersection, deletion, etc.)
5are provided as both methods and operators.
6
7The following classes are provided:
8
9BaseSet -- All the operations common to both mutable and immutable
10 sets. This is an abstract class, not meant to be directly
11 instantiated.
12
13Set -- Mutable sets, subclass of BaseSet; not hashable.
14
15ImmutableSet -- Immutable sets, subclass of BaseSet; hashable.
16 An iterable argument is mandatory to create an ImmutableSet.
17
18_TemporarilyImmutableSet -- Not a subclass of BaseSet: just a wrapper
19 around a Set, hashable, giving the same hash value as the
20 immutable set equivalent would have. Do not use this class
21 directly.
22
23Only hashable objects can be added to a Set. In particular, you cannot
24really add a Set as an element to another Set; if you try, what is
25actuallly added is an ImmutableSet built from it (it compares equal to
26the one you tried adding).
27
28When you ask if `x in y' where x is a Set and y is a Set or
29ImmutableSet, x is wrapped into a _TemporarilyImmutableSet z, and
30what's tested is actually `z in y'.
31
32"""
33
34# Code history:
35#
36# - Greg V. Wilson wrote the first version, using a different approach
37# to the mutable/immutable problem, and inheriting from dict.
38#
39# - Alex Martelli modified Greg's version to implement the current
40# Set/ImmutableSet approach, and make the data an attribute.
41#
42# - Guido van Rossum rewrote much of the code, made some API changes,
43# and cleaned up the docstrings.
44
45
46__all__ = ['BaseSet', 'Set', 'ImmutableSet']
47
48
49class BaseSet(object):
50 """Common base class for mutable and immutable sets."""
51
52 __slots__ = ['_data']
53
54 # Constructor
55
56 def __init__(self, seq=None):
57 """Construct a set, optionally initializing it from a sequence."""
58 self._data = {}
59 if seq is not None:
60 # I don't know a faster way to do this in pure Python.
61 # Custom code written in C only did it 65% faster,
62 # preallocating the dict to len(seq); without
63 # preallocation it was only 25% faster. So the speed of
64 # this Python code is respectable. Just copying True into
65 # a local variable is responsible for a 7-8% speedup.
66 data = self._data
67 value = True
68 for key in seq:
69 data[key] = value
70
71 # Standard protocols: __len__, __repr__, __str__, __iter__
72
73 def __len__(self):
74 """Return the number of elements of a set."""
75 return len(self._data)
76
77 def __repr__(self):
78 """Return string representation of a set.
79
80 This looks like 'Set([<list of elements>])'.
81 """
82 return self._repr()
83
84 # __str__ is the same as __repr__
85 __str__ = __repr__
86
87 def _repr(self, sorted=False):
88 elements = self._data.keys()
89 if sorted:
90 elements.sort()
91 return '%s(%r)' % (self.__class__.__name__, elements)
92
93 def __iter__(self):
94 """Return an iterator over the elements or a set.
95
96 This is the keys iterator for the underlying dict.
97 """
98 return self._data.iterkeys()
99
100 # Comparisons. Ordering is determined by the ordering of the
101 # underlying dicts (which is consistent though unpredictable).
102
103 def __lt__(self, other):
104 self._binary_sanity_check(other)
105 return self._data < other._data
106
107 def __le__(self, other):
108 self._binary_sanity_check(other)
109 return self._data <= other._data
110
111 def __eq__(self, other):
112 self._binary_sanity_check(other)
113 return self._data == other._data
114
115 def __ne__(self, other):
116 self._binary_sanity_check(other)
117 return self._data != other._data
118
119 def __gt__(self, other):
120 self._binary_sanity_check(other)
121 return self._data > other._data
122
123 def __ge__(self, other):
124 self._binary_sanity_check(other)
125 return self._data >= other._data
126
127 # Copying operations
128
129 def copy(self):
130 """Return a shallow copy of a set."""
131 return self.__class__(self)
132
133 __copy__ = copy # For the copy module
134
135 def __deepcopy__(self, memo):
136 """Return a deep copy of a set; used by copy module."""
137 # This pre-creates the result and inserts it in the memo
138 # early, in case the deep copy recurses into another reference
139 # to this same set. A set can't be an element of itself, but
140 # it can certainly contain an object that has a reference to
141 # itself.
142 from copy import deepcopy
143 result = self.__class__([])
144 memo[id(self)] = result
145 data = result._data
146 value = True
147 for elt in self:
148 data[deepcopy(elt, memo)] = value
149 return result
150
151 # Standard set operations: union, intersection, both differences
152
153 def union(self, other):
154 """Return the union of two sets as a new set.
155
156 (I.e. all elements that are in either set.)
157 """
158 self._binary_sanity_check(other)
159 result = self.__class__(self._data)
160 result._data.update(other._data)
161 return result
162
163 __or__ = union
164
165 def intersection(self, other):
166 """Return the intersection of two sets as a new set.
167
168 (I.e. all elements that are in both sets.)
169 """
170 self._binary_sanity_check(other)
171 if len(self) <= len(other):
172 little, big = self, other
173 else:
174 little, big = other, self
175 result = self.__class__([])
176 data = result._data
177 value = True
178 for elt in little:
179 if elt in big:
180 data[elt] = value
181 return result
182
183 __and__ = intersection
184
185 def symmetric_difference(self, other):
186 """Return the symmetric difference of two sets as a new set.
187
188 (I.e. all elements that are in exactly one of the sets.)
189 """
190 self._binary_sanity_check(other)
191 result = self.__class__([])
192 data = result._data
193 value = True
194 for elt in self:
195 if elt not in other:
196 data[elt] = value
197 for elt in other:
198 if elt not in self:
199 data[elt] = value
200 return result
201
202 __xor__ = symmetric_difference
203
204 def difference(self, other):
205 """Return the difference of two sets as a new Set.
206
207 (I.e. all elements that are in this set and not in the other.)
208 """
209 self._binary_sanity_check(other)
210 result = self.__class__([])
211 data = result._data
212 value = True
213 for elt in self:
214 if elt not in other:
215 data[elt] = value
216 return result
217
218 __sub__ = difference
219
220 # Membership test
221
222 def __contains__(self, element):
223 """Report whether an element is a member of a set.
224
225 (Called in response to the expression `element in self'.)
226 """
227 try:
228 transform = element._as_temporarily_immutable
229 except AttributeError:
230 pass
231 else:
232 element = transform()
233 return element in self._data
234
235 # Subset and superset test
236
237 def issubset(self, other):
238 """Report whether another set contains this set."""
239 self._binary_sanity_check(other)
240 for elt in self:
241 if elt not in other:
242 return False
243 return True
244
245 def issuperset(self, other):
246 """Report whether this set contains another set."""
247 self._binary_sanity_check(other)
248 for elt in other:
249 if elt not in self:
250 return False
251 return True
252
253 # Assorted helpers
254
255 def _binary_sanity_check(self, other):
256 # Check that the other argument to a binary operation is also
257 # a set, raising a TypeError otherwise.
258 if not isinstance(other, BaseSet):
259 raise TypeError, "Binary operation only permitted between sets"
260
261 def _compute_hash(self):
262 # Calculate hash code for a set by xor'ing the hash codes of
263 # the elements. This algorithm ensures that the hash code
264 # does not depend on the order in which elements are added to
265 # the code. This is not called __hash__ because a BaseSet
266 # should not be hashable; only an ImmutableSet is hashable.
267 result = 0
268 for elt in self:
269 result ^= hash(elt)
270 return result
271
272
273class ImmutableSet(BaseSet):
274 """Immutable set class."""
275
Guido van Rossum0b650d72002-08-19 16:29:58 +0000276 __slots__ = ['_hashcode']
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000277
278 # BaseSet + hashing
279
280 def __init__(self, seq):
281 """Construct an immutable set from a sequence."""
282 # Override the constructor to make 'seq' a required argument
283 BaseSet.__init__(self, seq)
284 self._hashcode = None
285
286 def __hash__(self):
287 if self._hashcode is None:
288 self._hashcode = self._compute_hash()
289 return self._hashcode
290
291
292class Set(BaseSet):
293 """ Mutable set class."""
294
295 __slots__ = []
296
297 # BaseSet + operations requiring mutability; no hashing
298
299 # In-place union, intersection, differences
300
301 def union_update(self, other):
302 """Update a set with the union of itself and another."""
303 self._binary_sanity_check(other)
304 self._data.update(other._data)
305 return self
306
307 __ior__ = union_update
308
309 def intersection_update(self, other):
310 """Update a set with the intersection of itself and another."""
311 self._binary_sanity_check(other)
312 for elt in self._data.keys():
313 if elt not in other:
314 del self._data[elt]
315 return self
316
317 __iand__ = intersection_update
318
319 def symmetric_difference_update(self, other):
320 """Update a set with the symmetric difference of itself and another."""
321 self._binary_sanity_check(other)
322 data = self._data
323 value = True
324 for elt in other:
325 if elt in data:
326 del data[elt]
327 else:
328 data[elt] = value
329 return self
330
331 __ixor__ = symmetric_difference_update
332
333 def difference_update(self, other):
334 """Remove all elements of another set from this set."""
335 self._binary_sanity_check(other)
336 data = self._data
337 for elt in other:
338 if elt in data:
339 del data[elt]
340 return self
341
342 __isub__ = difference_update
343
344 # Python dict-like mass mutations: update, clear
345
346 def update(self, iterable):
347 """Add all values from an iterable (such as a list or file)."""
348 data = self._data
349 value = True
350 for elt in iterable:
351 try:
352 transform = elt._as_immutable
353 except AttributeError:
354 pass
355 else:
356 elt = transform()
357 data[elt] = value
358
359 def clear(self):
360 """Remove all elements from this set."""
361 self._data.clear()
362
363 # Single-element mutations: add, remove, discard
364
365 def add(self, element):
366 """Add an element to a set.
367
368 This has no effect if the element is already present.
369 """
370 try:
371 transform = element._as_immutable
372 except AttributeError:
373 pass
374 else:
375 element = transform()
376 self._data[element] = True
377
378 def remove(self, element):
379 """Remove an element from a set; it must be a member.
380
381 If the element is not a member, raise a KeyError.
382 """
383 try:
384 transform = element._as_temporarily_immutable
385 except AttributeError:
386 pass
387 else:
388 element = transform()
389 del self._data[element]
390
391 def discard(self, element):
392 """Remove an element from a set if it is a member.
393
394 If the element is not a member, do nothing.
395 """
396 try:
397 del self._data[element]
398 except KeyError:
399 pass
400
401 def popitem(self):
402 """Remove and return a randomly-chosen set element."""
403 return self._data.popitem()[0]
404
405 def _as_immutable(self):
406 # Return a copy of self as an immutable set
407 return ImmutableSet(self)
408
409 def _as_temporarily_immutable(self):
410 # Return self wrapped in a temporarily immutable set
411 return _TemporarilyImmutableSet(self)
412
413
414class _TemporarilyImmutableSet(object):
415 # Wrap a mutable set as if it was temporarily immutable.
416 # This only supplies hashing and equality comparisons.
417
418 _hashcode = None
419
420 def __init__(self, set):
421 self._set = set
422
423 def __hash__(self):
424 if self._hashcode is None:
425 self._hashcode = self._set._compute_hash()
426 return self._hashcode
427
428 def __eq__(self, other):
429 return self._set == other
430
431 def __ne__(self, other):
432 return self._set != other
433
434
435# Rudimentary self-tests
436
437def _test():
438
439 # Empty set
440 red = Set()
441 assert `red` == "Set([])", "Empty set: %s" % `red`
442
443 # Unit set
444 green = Set((0,))
445 assert `green` == "Set([0])", "Unit set: %s" % `green`
446
447 # 3-element set
448 blue = Set([0, 1, 2])
449 assert blue._repr(True) == "Set([0, 1, 2])", "3-element set: %s" % `blue`
450
451 # 2-element set with other values
452 black = Set([0, 5])
453 assert black._repr(True) == "Set([0, 5])", "2-element set: %s" % `black`
454
455 # All elements from all sets
456 white = Set([0, 1, 2, 5])
457 assert white._repr(True) == "Set([0, 1, 2, 5])", "4-element set: %s" % `white`
458
459 # Add element to empty set
460 red.add(9)
461 assert `red` == "Set([9])", "Add to empty set: %s" % `red`
462
463 # Remove element from unit set
464 red.remove(9)
465 assert `red` == "Set([])", "Remove from unit set: %s" % `red`
466
467 # Remove element from empty set
468 try:
469 red.remove(0)
470 assert 0, "Remove element from empty set: %s" % `red`
471 except LookupError:
472 pass
473
474 # Length
475 assert len(red) == 0, "Length of empty set"
476 assert len(green) == 1, "Length of unit set"
477 assert len(blue) == 3, "Length of 3-element set"
478
479 # Compare
480 assert green == Set([0]), "Equality failed"
481 assert green != Set([1]), "Inequality failed"
482
483 # Union
484 assert blue | red == blue, "Union non-empty with empty"
485 assert red | blue == blue, "Union empty with non-empty"
486 assert green | blue == blue, "Union non-empty with non-empty"
487 assert blue | black == white, "Enclosing union"
488
489 # Intersection
490 assert blue & red == red, "Intersect non-empty with empty"
491 assert red & blue == red, "Intersect empty with non-empty"
492 assert green & blue == green, "Intersect non-empty with non-empty"
493 assert blue & black == green, "Enclosing intersection"
494
495 # Symmetric difference
496 assert red ^ green == green, "Empty symdiff non-empty"
497 assert green ^ blue == Set([1, 2]), "Non-empty symdiff"
498 assert white ^ white == red, "Self symdiff"
499
500 # Difference
501 assert red - green == red, "Empty - non-empty"
502 assert blue - red == blue, "Non-empty - empty"
503 assert white - black == Set([1, 2]), "Non-empty - non-empty"
504
505 # In-place union
506 orange = Set([])
507 orange |= Set([1])
508 assert orange == Set([1]), "In-place union"
509
510 # In-place intersection
511 orange = Set([1, 2])
512 orange &= Set([2])
513 assert orange == Set([2]), "In-place intersection"
514
515 # In-place difference
516 orange = Set([1, 2, 3])
517 orange -= Set([2, 4])
518 assert orange == Set([1, 3]), "In-place difference"
519
520 # In-place symmetric difference
521 orange = Set([1, 2, 3])
522 orange ^= Set([3, 4])
523 assert orange == Set([1, 2, 4]), "In-place symmetric difference"
524
525 print "All tests passed"
526
527
528if __name__ == "__main__":
529 _test()