blob: bb4428064d1084858369b614177015cb1a3cb214 [file] [log] [blame]
Guido van Rossumd6cf3af2002-08-19 16:19:15 +00001"""Classes to represent arbitrary sets (including sets of sets).
2
3This module implements sets using dictionaries whose values are
4ignored. The usual operations (union, intersection, deletion, etc.)
5are provided as both methods and operators.
6
Guido van Rossum290f1872002-08-20 20:05:23 +00007Important: sets are not sequences! While they support 'x in s',
8'len(s)', and 'for x in s', none of those operations are unique for
9sequences; for example, mappings support all three as well. The
10characteristic operation for sequences is subscripting with small
11integers: s[i], for i in range(len(s)). Sets don't support
12subscripting at all. Also, sequences allow multiple occurrences and
13their elements have a definite order; sets on the other hand don't
14record multiple occurrences and don't remember the order of element
15insertion (which is why they don't support s[i]).
16
Guido van Rossumd6cf3af2002-08-19 16:19:15 +000017The following classes are provided:
18
19BaseSet -- All the operations common to both mutable and immutable
20 sets. This is an abstract class, not meant to be directly
21 instantiated.
22
23Set -- Mutable sets, subclass of BaseSet; not hashable.
24
25ImmutableSet -- Immutable sets, subclass of BaseSet; hashable.
26 An iterable argument is mandatory to create an ImmutableSet.
27
28_TemporarilyImmutableSet -- Not a subclass of BaseSet: just a wrapper
29 around a Set, hashable, giving the same hash value as the
30 immutable set equivalent would have. Do not use this class
31 directly.
32
33Only hashable objects can be added to a Set. In particular, you cannot
34really add a Set as an element to another Set; if you try, what is
Raymond Hettingerede3a0d2002-08-20 23:34:01 +000035actually added is an ImmutableSet built from it (it compares equal to
Guido van Rossumd6cf3af2002-08-19 16:19:15 +000036the one you tried adding).
37
38When you ask if `x in y' where x is a Set and y is a Set or
39ImmutableSet, x is wrapped into a _TemporarilyImmutableSet z, and
40what's tested is actually `z in y'.
41
42"""
43
44# Code history:
45#
46# - Greg V. Wilson wrote the first version, using a different approach
47# to the mutable/immutable problem, and inheriting from dict.
48#
49# - Alex Martelli modified Greg's version to implement the current
50# Set/ImmutableSet approach, and make the data an attribute.
51#
52# - Guido van Rossum rewrote much of the code, made some API changes,
53# and cleaned up the docstrings.
Guido van Rossum26588222002-08-21 02:44:04 +000054#
Guido van Rossum9f872932002-08-21 03:20:44 +000055# - Raymond Hettinger added a number of speedups and other
56# bugs^H^H^H^Himprovements.
Guido van Rossumd6cf3af2002-08-19 16:19:15 +000057
58
59__all__ = ['BaseSet', 'Set', 'ImmutableSet']
60
61
62class BaseSet(object):
63 """Common base class for mutable and immutable sets."""
64
65 __slots__ = ['_data']
66
67 # Constructor
68
Guido van Rossum5033b362002-08-20 21:38:37 +000069 def __init__(self):
70 """This is an abstract class."""
71 # Don't call this from a concrete subclass!
72 if self.__class__ is BaseSet:
Guido van Rossum9f872932002-08-21 03:20:44 +000073 raise TypeError, ("BaseSet is an abstract class. "
74 "Use Set or ImmutableSet.")
Guido van Rossumd6cf3af2002-08-19 16:19:15 +000075
76 # Standard protocols: __len__, __repr__, __str__, __iter__
77
78 def __len__(self):
79 """Return the number of elements of a set."""
80 return len(self._data)
81
82 def __repr__(self):
83 """Return string representation of a set.
84
85 This looks like 'Set([<list of elements>])'.
86 """
87 return self._repr()
88
89 # __str__ is the same as __repr__
90 __str__ = __repr__
91
92 def _repr(self, sorted=False):
93 elements = self._data.keys()
94 if sorted:
95 elements.sort()
96 return '%s(%r)' % (self.__class__.__name__, elements)
97
98 def __iter__(self):
99 """Return an iterator over the elements or a set.
100
101 This is the keys iterator for the underlying dict.
102 """
103 return self._data.iterkeys()
104
105 # Comparisons. Ordering is determined by the ordering of the
106 # underlying dicts (which is consistent though unpredictable).
107
108 def __lt__(self, other):
109 self._binary_sanity_check(other)
110 return self._data < other._data
111
112 def __le__(self, other):
113 self._binary_sanity_check(other)
114 return self._data <= other._data
115
116 def __eq__(self, other):
117 self._binary_sanity_check(other)
118 return self._data == other._data
119
120 def __ne__(self, other):
121 self._binary_sanity_check(other)
122 return self._data != other._data
123
124 def __gt__(self, other):
125 self._binary_sanity_check(other)
126 return self._data > other._data
127
128 def __ge__(self, other):
129 self._binary_sanity_check(other)
130 return self._data >= other._data
131
132 # Copying operations
133
134 def copy(self):
135 """Return a shallow copy of a set."""
136 return self.__class__(self)
137
138 __copy__ = copy # For the copy module
139
140 def __deepcopy__(self, memo):
141 """Return a deep copy of a set; used by copy module."""
142 # This pre-creates the result and inserts it in the memo
143 # early, in case the deep copy recurses into another reference
144 # to this same set. A set can't be an element of itself, but
145 # it can certainly contain an object that has a reference to
146 # itself.
147 from copy import deepcopy
148 result = self.__class__([])
149 memo[id(self)] = result
150 data = result._data
151 value = True
152 for elt in self:
153 data[deepcopy(elt, memo)] = value
154 return result
155
156 # Standard set operations: union, intersection, both differences
157
158 def union(self, other):
159 """Return the union of two sets as a new set.
160
161 (I.e. all elements that are in either set.)
162 """
163 self._binary_sanity_check(other)
164 result = self.__class__(self._data)
165 result._data.update(other._data)
166 return result
167
168 __or__ = union
169
170 def intersection(self, other):
171 """Return the intersection of two sets as a new set.
172
173 (I.e. all elements that are in both sets.)
174 """
175 self._binary_sanity_check(other)
176 if len(self) <= len(other):
177 little, big = self, other
178 else:
179 little, big = other, self
180 result = self.__class__([])
181 data = result._data
182 value = True
183 for elt in little:
184 if elt in big:
185 data[elt] = value
186 return result
187
188 __and__ = intersection
189
190 def symmetric_difference(self, other):
191 """Return the symmetric difference of two sets as a new set.
192
193 (I.e. all elements that are in exactly one of the sets.)
194 """
195 self._binary_sanity_check(other)
196 result = self.__class__([])
197 data = result._data
198 value = True
199 for elt in self:
200 if elt not in other:
201 data[elt] = value
202 for elt in other:
203 if elt not in self:
204 data[elt] = value
205 return result
206
207 __xor__ = symmetric_difference
208
209 def difference(self, other):
210 """Return the difference of two sets as a new Set.
211
212 (I.e. all elements that are in this set and not in the other.)
213 """
214 self._binary_sanity_check(other)
215 result = self.__class__([])
216 data = result._data
217 value = True
218 for elt in self:
219 if elt not in other:
220 data[elt] = value
221 return result
222
223 __sub__ = difference
224
225 # Membership test
226
227 def __contains__(self, element):
228 """Report whether an element is a member of a set.
229
230 (Called in response to the expression `element in self'.)
231 """
232 try:
Raymond Hettingerde6d6972002-08-21 01:35:29 +0000233 return element in self._data
234 except TypeError:
Guido van Rossum9f872932002-08-21 03:20:44 +0000235 transform = getattr(element, "_as_temporarily_immutable", None)
Raymond Hettingerde6d6972002-08-21 01:35:29 +0000236 if transform is None:
237 raise # re-raise the TypeError exception we caught
238 return transform() in self._data
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000239
240 # Subset and superset test
241
242 def issubset(self, other):
243 """Report whether another set contains this set."""
244 self._binary_sanity_check(other)
Raymond Hettinger43db0d62002-08-21 02:22:08 +0000245 if len(self) > len(other): # Fast check for obvious cases
246 return False
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000247 for elt in self:
248 if elt not in other:
249 return False
250 return True
251
252 def issuperset(self, other):
253 """Report whether this set contains another set."""
254 self._binary_sanity_check(other)
Raymond Hettinger43db0d62002-08-21 02:22:08 +0000255 if len(self) < len(other): # Fast check for obvious cases
256 return False
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000257 for elt in other:
258 if elt not in self:
259 return False
260 return True
261
262 # Assorted helpers
263
264 def _binary_sanity_check(self, other):
265 # Check that the other argument to a binary operation is also
266 # a set, raising a TypeError otherwise.
267 if not isinstance(other, BaseSet):
268 raise TypeError, "Binary operation only permitted between sets"
269
270 def _compute_hash(self):
271 # Calculate hash code for a set by xor'ing the hash codes of
272 # the elements. This algorithm ensures that the hash code
273 # does not depend on the order in which elements are added to
274 # the code. This is not called __hash__ because a BaseSet
275 # should not be hashable; only an ImmutableSet is hashable.
276 result = 0
277 for elt in self:
278 result ^= hash(elt)
279 return result
280
Guido van Rossum9f872932002-08-21 03:20:44 +0000281 def _update(self, iterable):
282 # The main loop for update() and the subclass __init__() methods.
Guido van Rossum9f872932002-08-21 03:20:44 +0000283 data = self._data
284 value = True
Raymond Hettinger80d21af2002-08-21 04:12:03 +0000285 it = iter(iterable)
286 while True:
Guido van Rossum9f872932002-08-21 03:20:44 +0000287 try:
Raymond Hettinger80d21af2002-08-21 04:12:03 +0000288 for element in it:
289 data[element] = value
290 return
Guido van Rossum9f872932002-08-21 03:20:44 +0000291 except TypeError:
292 transform = getattr(element, "_as_immutable", None)
293 if transform is None:
294 raise # re-raise the TypeError exception we caught
295 data[transform()] = value
296
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000297
298class ImmutableSet(BaseSet):
299 """Immutable set class."""
300
Guido van Rossum0b650d72002-08-19 16:29:58 +0000301 __slots__ = ['_hashcode']
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000302
303 # BaseSet + hashing
304
Guido van Rossum9f872932002-08-21 03:20:44 +0000305 def __init__(self, iterable=None):
306 """Construct an immutable set from an optional iterable."""
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000307 self._hashcode = None
Guido van Rossum9f872932002-08-21 03:20:44 +0000308 self._data = {}
309 if iterable is not None:
310 self._update(iterable)
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000311
312 def __hash__(self):
313 if self._hashcode is None:
314 self._hashcode = self._compute_hash()
315 return self._hashcode
316
317
318class Set(BaseSet):
319 """ Mutable set class."""
320
321 __slots__ = []
322
323 # BaseSet + operations requiring mutability; no hashing
324
Guido van Rossum9f872932002-08-21 03:20:44 +0000325 def __init__(self, iterable=None):
326 """Construct a set from an optional iterable."""
327 self._data = {}
328 if iterable is not None:
329 self._update(iterable)
330
331 def __hash__(self):
332 """A Set cannot be hashed."""
333 # We inherit object.__hash__, so we must deny this explicitly
334 raise TypeError, "Can't hash a Set, only an ImmutableSet."
Guido van Rossum5033b362002-08-20 21:38:37 +0000335
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000336 # In-place union, intersection, differences
337
338 def union_update(self, other):
339 """Update a set with the union of itself and another."""
340 self._binary_sanity_check(other)
341 self._data.update(other._data)
342 return self
343
344 __ior__ = union_update
345
346 def intersection_update(self, other):
347 """Update a set with the intersection of itself and another."""
348 self._binary_sanity_check(other)
349 for elt in self._data.keys():
350 if elt not in other:
351 del self._data[elt]
352 return self
353
354 __iand__ = intersection_update
355
356 def symmetric_difference_update(self, other):
357 """Update a set with the symmetric difference of itself and another."""
358 self._binary_sanity_check(other)
359 data = self._data
360 value = True
361 for elt in other:
362 if elt in data:
363 del data[elt]
364 else:
365 data[elt] = value
366 return self
367
368 __ixor__ = symmetric_difference_update
369
370 def difference_update(self, other):
371 """Remove all elements of another set from this set."""
372 self._binary_sanity_check(other)
373 data = self._data
374 for elt in other:
375 if elt in data:
376 del data[elt]
377 return self
378
379 __isub__ = difference_update
380
381 # Python dict-like mass mutations: update, clear
382
383 def update(self, iterable):
384 """Add all values from an iterable (such as a list or file)."""
Guido van Rossum9f872932002-08-21 03:20:44 +0000385 self._update(iterable)
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000386
387 def clear(self):
388 """Remove all elements from this set."""
389 self._data.clear()
390
391 # Single-element mutations: add, remove, discard
392
393 def add(self, element):
394 """Add an element to a set.
395
396 This has no effect if the element is already present.
397 """
398 try:
Raymond Hettingerde6d6972002-08-21 01:35:29 +0000399 self._data[element] = True
400 except TypeError:
Guido van Rossum9f872932002-08-21 03:20:44 +0000401 transform = getattr(element, "_as_immutable", None)
Raymond Hettingerde6d6972002-08-21 01:35:29 +0000402 if transform is None:
403 raise # re-raise the TypeError exception we caught
404 self._data[transform()] = True
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000405
406 def remove(self, element):
407 """Remove an element from a set; it must be a member.
408
409 If the element is not a member, raise a KeyError.
410 """
411 try:
Raymond Hettingerde6d6972002-08-21 01:35:29 +0000412 del self._data[element]
413 except TypeError:
Guido van Rossum9f872932002-08-21 03:20:44 +0000414 transform = getattr(element, "_as_temporarily_immutable", None)
Raymond Hettingerde6d6972002-08-21 01:35:29 +0000415 if transform is None:
416 raise # re-raise the TypeError exception we caught
417 del self._data[transform()]
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000418
419 def discard(self, element):
420 """Remove an element from a set if it is a member.
421
422 If the element is not a member, do nothing.
423 """
424 try:
425 del self._data[element]
426 except KeyError:
427 pass
428
Guido van Rossumc9196bc2002-08-20 21:51:59 +0000429 def pop(self):
Guido van Rossumd6cf3af2002-08-19 16:19:15 +0000430 """Remove and return a randomly-chosen set element."""
431 return self._data.popitem()[0]
432
433 def _as_immutable(self):
434 # Return a copy of self as an immutable set
435 return ImmutableSet(self)
436
437 def _as_temporarily_immutable(self):
438 # Return self wrapped in a temporarily immutable set
439 return _TemporarilyImmutableSet(self)
440
441
442class _TemporarilyImmutableSet(object):
443 # Wrap a mutable set as if it was temporarily immutable.
444 # This only supplies hashing and equality comparisons.
445
446 _hashcode = None
447
448 def __init__(self, set):
449 self._set = set
450
451 def __hash__(self):
452 if self._hashcode is None:
453 self._hashcode = self._set._compute_hash()
454 return self._hashcode
455
456 def __eq__(self, other):
457 return self._set == other
458
459 def __ne__(self, other):
460 return self._set != other
461
462
463# Rudimentary self-tests
464
465def _test():
466
467 # Empty set
468 red = Set()
469 assert `red` == "Set([])", "Empty set: %s" % `red`
470
471 # Unit set
472 green = Set((0,))
473 assert `green` == "Set([0])", "Unit set: %s" % `green`
474
475 # 3-element set
476 blue = Set([0, 1, 2])
477 assert blue._repr(True) == "Set([0, 1, 2])", "3-element set: %s" % `blue`
478
479 # 2-element set with other values
480 black = Set([0, 5])
481 assert black._repr(True) == "Set([0, 5])", "2-element set: %s" % `black`
482
483 # All elements from all sets
484 white = Set([0, 1, 2, 5])
485 assert white._repr(True) == "Set([0, 1, 2, 5])", "4-element set: %s" % `white`
486
487 # Add element to empty set
488 red.add(9)
489 assert `red` == "Set([9])", "Add to empty set: %s" % `red`
490
491 # Remove element from unit set
492 red.remove(9)
493 assert `red` == "Set([])", "Remove from unit set: %s" % `red`
494
495 # Remove element from empty set
496 try:
497 red.remove(0)
498 assert 0, "Remove element from empty set: %s" % `red`
499 except LookupError:
500 pass
501
502 # Length
503 assert len(red) == 0, "Length of empty set"
504 assert len(green) == 1, "Length of unit set"
505 assert len(blue) == 3, "Length of 3-element set"
506
507 # Compare
508 assert green == Set([0]), "Equality failed"
509 assert green != Set([1]), "Inequality failed"
510
511 # Union
512 assert blue | red == blue, "Union non-empty with empty"
513 assert red | blue == blue, "Union empty with non-empty"
514 assert green | blue == blue, "Union non-empty with non-empty"
515 assert blue | black == white, "Enclosing union"
516
517 # Intersection
518 assert blue & red == red, "Intersect non-empty with empty"
519 assert red & blue == red, "Intersect empty with non-empty"
520 assert green & blue == green, "Intersect non-empty with non-empty"
521 assert blue & black == green, "Enclosing intersection"
522
523 # Symmetric difference
524 assert red ^ green == green, "Empty symdiff non-empty"
525 assert green ^ blue == Set([1, 2]), "Non-empty symdiff"
526 assert white ^ white == red, "Self symdiff"
527
528 # Difference
529 assert red - green == red, "Empty - non-empty"
530 assert blue - red == blue, "Non-empty - empty"
531 assert white - black == Set([1, 2]), "Non-empty - non-empty"
532
533 # In-place union
534 orange = Set([])
535 orange |= Set([1])
536 assert orange == Set([1]), "In-place union"
537
538 # In-place intersection
539 orange = Set([1, 2])
540 orange &= Set([2])
541 assert orange == Set([2]), "In-place intersection"
542
543 # In-place difference
544 orange = Set([1, 2, 3])
545 orange -= Set([2, 4])
546 assert orange == Set([1, 3]), "In-place difference"
547
548 # In-place symmetric difference
549 orange = Set([1, 2, 3])
550 orange ^= Set([3, 4])
551 assert orange == Set([1, 2, 4]), "In-place symmetric difference"
552
553 print "All tests passed"
554
555
556if __name__ == "__main__":
557 _test()