blob: 958e523b6c701dbc3c2a9889424175bc2127d412 [file] [log] [blame]
Tor Norbye3a2425a2013-11-04 10:16:08 -08001__all__ = ['Counter', 'deque', 'defaultdict', 'namedtuple', 'OrderedDict']
2# For bootstrapping reasons, the collection ABCs are defined in _abcoll.py.
3# They should however be considered an integral part of collections.py.
4from _abcoll import *
5import _abcoll
6__all__ += _abcoll.__all__
7
8from _collections import deque, defaultdict
9from operator import itemgetter as _itemgetter
10from keyword import iskeyword as _iskeyword
11import sys as _sys
12import heapq as _heapq
13from itertools import repeat as _repeat, chain as _chain, starmap as _starmap
14
15try:
16 from thread import get_ident as _get_ident
17except ImportError:
18 from dummy_thread import get_ident as _get_ident
19
20
21################################################################################
22### OrderedDict
23################################################################################
24
25class OrderedDict(dict):
26 'Dictionary that remembers insertion order'
27 # An inherited dict maps keys to values.
28 # The inherited dict provides __getitem__, __len__, __contains__, and get.
29 # The remaining methods are order-aware.
30 # Big-O running times for all methods are the same as regular dictionaries.
31
32 # The internal self.__map dict maps keys to links in a doubly linked list.
33 # The circular doubly linked list starts and ends with a sentinel element.
34 # The sentinel element never gets deleted (this simplifies the algorithm).
35 # Each link is stored as a list of length three: [PREV, NEXT, KEY].
36
37 def __init__(self, *args, **kwds):
38 '''Initialize an ordered dictionary. The signature is the same as
39 regular dictionaries, but keyword arguments are not recommended because
40 their insertion order is arbitrary.
41
42 '''
43 if len(args) > 1:
44 raise TypeError('expected at most 1 arguments, got %d' % len(args))
45 try:
46 self.__root
47 except AttributeError:
48 self.__root = root = [] # sentinel node
49 root[:] = [root, root, None]
50 self.__map = {}
51 self.__update(*args, **kwds)
52
53 def __setitem__(self, key, value, PREV=0, NEXT=1, dict_setitem=dict.__setitem__):
54 'od.__setitem__(i, y) <==> od[i]=y'
55 # Setting a new item creates a new link at the end of the linked list,
56 # and the inherited dictionary is updated with the new key/value pair.
57 if key not in self:
58 root = self.__root
59 last = root[PREV]
60 last[NEXT] = root[PREV] = self.__map[key] = [last, root, key]
61 dict_setitem(self, key, value)
62
63 def __delitem__(self, key, PREV=0, NEXT=1, dict_delitem=dict.__delitem__):
64 'od.__delitem__(y) <==> del od[y]'
65 # Deleting an existing item uses self.__map to find the link which gets
66 # removed by updating the links in the predecessor and successor nodes.
67 dict_delitem(self, key)
68 link_prev, link_next, key = self.__map.pop(key)
69 link_prev[NEXT] = link_next
70 link_next[PREV] = link_prev
71
72 def __iter__(self):
73 'od.__iter__() <==> iter(od)'
74 # Traverse the linked list in order.
75 NEXT, KEY = 1, 2
76 root = self.__root
77 curr = root[NEXT]
78 while curr is not root:
79 yield curr[KEY]
80 curr = curr[NEXT]
81
82 def __reversed__(self):
83 'od.__reversed__() <==> reversed(od)'
84 # Traverse the linked list in reverse order.
85 PREV, KEY = 0, 2
86 root = self.__root
87 curr = root[PREV]
88 while curr is not root:
89 yield curr[KEY]
90 curr = curr[PREV]
91
92 def clear(self):
93 'od.clear() -> None. Remove all items from od.'
94 for node in self.__map.itervalues():
95 del node[:]
96 root = self.__root
97 root[:] = [root, root, None]
98 self.__map.clear()
99 dict.clear(self)
100
101 # -- the following methods do not depend on the internal structure --
102
103 def keys(self):
104 'od.keys() -> list of keys in od'
105 return list(self)
106
107 def values(self):
108 'od.values() -> list of values in od'
109 return [self[key] for key in self]
110
111 def items(self):
112 'od.items() -> list of (key, value) pairs in od'
113 return [(key, self[key]) for key in self]
114
115 def iterkeys(self):
116 'od.iterkeys() -> an iterator over the keys in od'
117 return iter(self)
118
119 def itervalues(self):
120 'od.itervalues -> an iterator over the values in od'
121 for k in self:
122 yield self[k]
123
124 def iteritems(self):
125 'od.iteritems -> an iterator over the (key, value) pairs in od'
126 for k in self:
127 yield (k, self[k])
128
129 update = MutableMapping.update
130
131 __update = update # let subclasses override update without breaking __init__
132
133 __marker = object()
134
135 def pop(self, key, default=__marker):
136 '''od.pop(k[,d]) -> v, remove specified key and return the corresponding
137 value. If key is not found, d is returned if given, otherwise KeyError
138 is raised.
139
140 '''
141 if key in self:
142 result = self[key]
143 del self[key]
144 return result
145 if default is self.__marker:
146 raise KeyError(key)
147 return default
148
149 def setdefault(self, key, default=None):
150 'od.setdefault(k[,d]) -> od.get(k,d), also set od[k]=d if k not in od'
151 if key in self:
152 return self[key]
153 self[key] = default
154 return default
155
156 def popitem(self, last=True):
157 '''od.popitem() -> (k, v), return and remove a (key, value) pair.
158 Pairs are returned in LIFO order if last is true or FIFO order if false.
159
160 '''
161 if not self:
162 raise KeyError('dictionary is empty')
163 key = next(reversed(self) if last else iter(self))
164 value = self.pop(key)
165 return key, value
166
167 def __repr__(self, _repr_running={}):
168 'od.__repr__() <==> repr(od)'
169 call_key = id(self), _get_ident()
170 if call_key in _repr_running:
171 return '...'
172 _repr_running[call_key] = 1
173 try:
174 if not self:
175 return '%s()' % (self.__class__.__name__,)
176 return '%s(%r)' % (self.__class__.__name__, self.items())
177 finally:
178 del _repr_running[call_key]
179
180 def __reduce__(self):
181 'Return state information for pickling'
182 items = [[k, self[k]] for k in self]
183 inst_dict = vars(self).copy()
184 for k in vars(OrderedDict()):
185 inst_dict.pop(k, None)
186 if inst_dict:
187 return (self.__class__, (items,), inst_dict)
188 return self.__class__, (items,)
189
190 def copy(self):
191 'od.copy() -> a shallow copy of od'
192 return self.__class__(self)
193
194 @classmethod
195 def fromkeys(cls, iterable, value=None):
196 '''OD.fromkeys(S[, v]) -> New ordered dictionary with keys from S.
197 If not specified, the value defaults to None.
198
199 '''
200 self = cls()
201 for key in iterable:
202 self[key] = value
203 return self
204
205 def __eq__(self, other):
206 '''od.__eq__(y) <==> od==y. Comparison to another OD is order-sensitive
207 while comparison to a regular mapping is order-insensitive.
208
209 '''
210 if isinstance(other, OrderedDict):
211 return len(self)==len(other) and self.items() == other.items()
212 return dict.__eq__(self, other)
213
214 def __ne__(self, other):
215 'od.__ne__(y) <==> od!=y'
216 return not self == other
217
218 # -- the following methods support python 3.x style dictionary views --
219
220 def viewkeys(self):
221 "od.viewkeys() -> a set-like object providing a view on od's keys"
222 return KeysView(self)
223
224 def viewvalues(self):
225 "od.viewvalues() -> an object providing a view on od's values"
226 return ValuesView(self)
227
228 def viewitems(self):
229 "od.viewitems() -> a set-like object providing a view on od's items"
230 return ItemsView(self)
231
232
233################################################################################
234### namedtuple
235################################################################################
236
237def namedtuple(typename, field_names, verbose=False, rename=False):
238 """Returns a new subclass of tuple with named fields.
239
240 >>> Point = namedtuple('Point', 'x y')
241 >>> Point.__doc__ # docstring for the new class
242 'Point(x, y)'
243 >>> p = Point(11, y=22) # instantiate with positional args or keywords
244 >>> p[0] + p[1] # indexable like a plain tuple
245 33
246 >>> x, y = p # unpack like a regular tuple
247 >>> x, y
248 (11, 22)
249 >>> p.x + p.y # fields also accessable by name
250 33
251 >>> d = p._asdict() # convert to a dictionary
252 >>> d['x']
253 11
254 >>> Point(**d) # convert from a dictionary
255 Point(x=11, y=22)
256 >>> p._replace(x=100) # _replace() is like str.replace() but targets named fields
257 Point(x=100, y=22)
258
259 """
260
261 # Parse and validate the field names. Validation serves two purposes,
262 # generating informative error messages and preventing template injection attacks.
263 if isinstance(field_names, basestring):
264 field_names = field_names.replace(',', ' ').split() # names separated by whitespace and/or commas
265 field_names = tuple(map(str, field_names))
266 if rename:
267 names = list(field_names)
268 seen = set()
269 for i, name in enumerate(names):
270 if (not all(c.isalnum() or c=='_' for c in name) or _iskeyword(name)
271 or not name or name[0].isdigit() or name.startswith('_')
272 or name in seen):
273 names[i] = '_%d' % i
274 seen.add(name)
275 field_names = tuple(names)
276 for name in (typename,) + field_names:
277 if not all(c.isalnum() or c=='_' for c in name):
278 raise ValueError('Type names and field names can only contain alphanumeric characters and underscores: %r' % name)
279 if _iskeyword(name):
280 raise ValueError('Type names and field names cannot be a keyword: %r' % name)
281 if name[0].isdigit():
282 raise ValueError('Type names and field names cannot start with a number: %r' % name)
283 seen_names = set()
284 for name in field_names:
285 if name.startswith('_') and not rename:
286 raise ValueError('Field names cannot start with an underscore: %r' % name)
287 if name in seen_names:
288 raise ValueError('Encountered duplicate field name: %r' % name)
289 seen_names.add(name)
290
291 # Create and fill-in the class template
292 numfields = len(field_names)
293 argtxt = repr(field_names).replace("'", "")[1:-1] # tuple repr without parens or quotes
294 reprtxt = ', '.join('%s=%%r' % name for name in field_names)
295 template = '''class %(typename)s(tuple):
296 '%(typename)s(%(argtxt)s)' \n
297 __slots__ = () \n
298 _fields = %(field_names)r \n
299 def __new__(_cls, %(argtxt)s):
300 'Create new instance of %(typename)s(%(argtxt)s)'
301 return _tuple.__new__(_cls, (%(argtxt)s)) \n
302 @classmethod
303 def _make(cls, iterable, new=tuple.__new__, len=len):
304 'Make a new %(typename)s object from a sequence or iterable'
305 result = new(cls, iterable)
306 if len(result) != %(numfields)d:
307 raise TypeError('Expected %(numfields)d arguments, got %%d' %% len(result))
308 return result \n
309 def __repr__(self):
310 'Return a nicely formatted representation string'
311 return '%(typename)s(%(reprtxt)s)' %% self \n
312 def _asdict(self):
313 'Return a new OrderedDict which maps field names to their values'
314 return OrderedDict(zip(self._fields, self)) \n
315 __dict__ = property(_asdict) \n
316 def _replace(_self, **kwds):
317 'Return a new %(typename)s object replacing specified fields with new values'
318 result = _self._make(map(kwds.pop, %(field_names)r, _self))
319 if kwds:
320 raise ValueError('Got unexpected field names: %%r' %% kwds.keys())
321 return result \n
322 def __getnewargs__(self):
323 'Return self as a plain tuple. Used by copy and pickle.'
324 return tuple(self) \n\n''' % locals()
325 for i, name in enumerate(field_names):
326 template += " %s = _property(_itemgetter(%d), doc='Alias for field number %d')\n" % (name, i, i)
327 if verbose:
328 print template
329
330 # Execute the template string in a temporary namespace and
331 # support tracing utilities by setting a value for frame.f_globals['__name__']
332 namespace = dict(_itemgetter=_itemgetter, __name__='namedtuple_%s' % typename,
333 OrderedDict=OrderedDict, _property=property, _tuple=tuple)
334 try:
335 exec template in namespace
336 except SyntaxError, e:
337 raise SyntaxError(e.message + ':\n' + template)
338 result = namespace[typename]
339
340 # For pickling to work, the __module__ variable needs to be set to the frame
341 # where the named tuple is created. Bypass this step in enviroments where
342 # sys._getframe is not defined (Jython for example) or sys._getframe is not
343 # defined for arguments greater than 0 (IronPython).
344 try:
345 result.__module__ = _sys._getframe(1).f_globals.get('__name__', '__main__')
346 except (AttributeError, ValueError):
347 pass
348
349 return result
350
351
352########################################################################
353### Counter
354########################################################################
355
356class Counter(dict):
357 '''Dict subclass for counting hashable items. Sometimes called a bag
358 or multiset. Elements are stored as dictionary keys and their counts
359 are stored as dictionary values.
360
361 >>> c = Counter('abcdeabcdabcaba') # count elements from a string
362
363 >>> c.most_common(3) # three most common elements
364 [('a', 5), ('b', 4), ('c', 3)]
365 >>> sorted(c) # list all unique elements
366 ['a', 'b', 'c', 'd', 'e']
367 >>> ''.join(sorted(c.elements())) # list elements with repetitions
368 'aaaaabbbbcccdde'
369 >>> sum(c.values()) # total of all counts
370 15
371
372 >>> c['a'] # count of letter 'a'
373 5
374 >>> for elem in 'shazam': # update counts from an iterable
375 ... c[elem] += 1 # by adding 1 to each element's count
376 >>> c['a'] # now there are seven 'a'
377 7
378 >>> del c['b'] # remove all 'b'
379 >>> c['b'] # now there are zero 'b'
380 0
381
382 >>> d = Counter('simsalabim') # make another counter
383 >>> c.update(d) # add in the second counter
384 >>> c['a'] # now there are nine 'a'
385 9
386
387 >>> c.clear() # empty the counter
388 >>> c
389 Counter()
390
391 Note: If a count is set to zero or reduced to zero, it will remain
392 in the counter until the entry is deleted or the counter is cleared:
393
394 >>> c = Counter('aaabbc')
395 >>> c['b'] -= 2 # reduce the count of 'b' by two
396 >>> c.most_common() # 'b' is still in, but its count is zero
397 [('a', 3), ('c', 1), ('b', 0)]
398
399 '''
400 # References:
401 # http://en.wikipedia.org/wiki/Multiset
402 # http://www.gnu.org/software/smalltalk/manual-base/html_node/Bag.html
403 # http://www.demo2s.com/Tutorial/Cpp/0380__set-multiset/Catalog0380__set-multiset.htm
404 # http://code.activestate.com/recipes/259174/
405 # Knuth, TAOCP Vol. II section 4.6.3
406
407 def __init__(self, iterable=None, **kwds):
408 '''Create a new, empty Counter object. And if given, count elements
409 from an input iterable. Or, initialize the count from another mapping
410 of elements to their counts.
411
412 >>> c = Counter() # a new, empty counter
413 >>> c = Counter('gallahad') # a new counter from an iterable
414 >>> c = Counter({'a': 4, 'b': 2}) # a new counter from a mapping
415 >>> c = Counter(a=4, b=2) # a new counter from keyword args
416
417 '''
418 super(Counter, self).__init__()
419 self.update(iterable, **kwds)
420
421 def __missing__(self, key):
422 'The count of elements not in the Counter is zero.'
423 # Needed so that self[missing_item] does not raise KeyError
424 return 0
425
426 def most_common(self, n=None):
427 '''List the n most common elements and their counts from the most
428 common to the least. If n is None, then list all element counts.
429
430 >>> Counter('abcdeabcdabcaba').most_common(3)
431 [('a', 5), ('b', 4), ('c', 3)]
432
433 '''
434 # Emulate Bag.sortedByCount from Smalltalk
435 if n is None:
436 return sorted(self.iteritems(), key=_itemgetter(1), reverse=True)
437 return _heapq.nlargest(n, self.iteritems(), key=_itemgetter(1))
438
439 def elements(self):
440 '''Iterator over elements repeating each as many times as its count.
441
442 >>> c = Counter('ABCABC')
443 >>> sorted(c.elements())
444 ['A', 'A', 'B', 'B', 'C', 'C']
445
446 # Knuth's example for prime factors of 1836: 2**2 * 3**3 * 17**1
447 >>> prime_factors = Counter({2: 2, 3: 3, 17: 1})
448 >>> product = 1
449 >>> for factor in prime_factors.elements(): # loop over factors
450 ... product *= factor # and multiply them
451 >>> product
452 1836
453
454 Note, if an element's count has been set to zero or is a negative
455 number, elements() will ignore it.
456
457 '''
458 # Emulate Bag.do from Smalltalk and Multiset.begin from C++.
459 return _chain.from_iterable(_starmap(_repeat, self.iteritems()))
460
461 # Override dict methods where necessary
462
463 @classmethod
464 def fromkeys(cls, iterable, v=None):
465 # There is no equivalent method for counters because setting v=1
466 # means that no element can have a count greater than one.
467 raise NotImplementedError(
468 'Counter.fromkeys() is undefined. Use Counter(iterable) instead.')
469
470 def update(self, iterable=None, **kwds):
471 '''Like dict.update() but add counts instead of replacing them.
472
473 Source can be an iterable, a dictionary, or another Counter instance.
474
475 >>> c = Counter('which')
476 >>> c.update('witch') # add elements from another iterable
477 >>> d = Counter('watch')
478 >>> c.update(d) # add elements from another counter
479 >>> c['h'] # four 'h' in which, witch, and watch
480 4
481
482 '''
483 # The regular dict.update() operation makes no sense here because the
484 # replace behavior results in the some of original untouched counts
485 # being mixed-in with all of the other counts for a mismash that
486 # doesn't have a straight-forward interpretation in most counting
487 # contexts. Instead, we implement straight-addition. Both the inputs
488 # and outputs are allowed to contain zero and negative counts.
489
490 if iterable is not None:
491 if isinstance(iterable, Mapping):
492 if self:
493 self_get = self.get
494 for elem, count in iterable.iteritems():
495 self[elem] = self_get(elem, 0) + count
496 else:
497 super(Counter, self).update(iterable) # fast path when counter is empty
498 else:
499 self_get = self.get
500 for elem in iterable:
501 self[elem] = self_get(elem, 0) + 1
502 if kwds:
503 self.update(kwds)
504
505 def subtract(self, iterable=None, **kwds):
506 '''Like dict.update() but subtracts counts instead of replacing them.
507 Counts can be reduced below zero. Both the inputs and outputs are
508 allowed to contain zero and negative counts.
509
510 Source can be an iterable, a dictionary, or another Counter instance.
511
512 >>> c = Counter('which')
513 >>> c.subtract('witch') # subtract elements from another iterable
514 >>> c.subtract(Counter('watch')) # subtract elements from another counter
515 >>> c['h'] # 2 in which, minus 1 in witch, minus 1 in watch
516 0
517 >>> c['w'] # 1 in which, minus 1 in witch, minus 1 in watch
518 -1
519
520 '''
521 if iterable is not None:
522 self_get = self.get
523 if isinstance(iterable, Mapping):
524 for elem, count in iterable.items():
525 self[elem] = self_get(elem, 0) - count
526 else:
527 for elem in iterable:
528 self[elem] = self_get(elem, 0) - 1
529 if kwds:
530 self.subtract(kwds)
531
532 def copy(self):
533 'Return a shallow copy.'
534 return self.__class__(self)
535
536 def __reduce__(self):
537 return self.__class__, (dict(self),)
538
539 def __delitem__(self, elem):
540 'Like dict.__delitem__() but does not raise KeyError for missing values.'
541 if elem in self:
542 super(Counter, self).__delitem__(elem)
543
544 def __repr__(self):
545 if not self:
546 return '%s()' % self.__class__.__name__
547 items = ', '.join(map('%r: %r'.__mod__, self.most_common()))
548 return '%s({%s})' % (self.__class__.__name__, items)
549
550 # Multiset-style mathematical operations discussed in:
551 # Knuth TAOCP Volume II section 4.6.3 exercise 19
552 # and at http://en.wikipedia.org/wiki/Multiset
553 #
554 # Outputs guaranteed to only include positive counts.
555 #
556 # To strip negative and zero counts, add-in an empty counter:
557 # c += Counter()
558
559 def __add__(self, other):
560 '''Add counts from two counters.
561
562 >>> Counter('abbb') + Counter('bcc')
563 Counter({'b': 4, 'c': 2, 'a': 1})
564
565 '''
566 if not isinstance(other, Counter):
567 return NotImplemented
568 result = Counter()
569 for elem, count in self.items():
570 newcount = count + other[elem]
571 if newcount > 0:
572 result[elem] = newcount
573 for elem, count in other.items():
574 if elem not in self and count > 0:
575 result[elem] = count
576 return result
577
578 def __sub__(self, other):
579 ''' Subtract count, but keep only results with positive counts.
580
581 >>> Counter('abbbc') - Counter('bccd')
582 Counter({'b': 2, 'a': 1})
583
584 '''
585 if not isinstance(other, Counter):
586 return NotImplemented
587 result = Counter()
588 for elem, count in self.items():
589 newcount = count - other[elem]
590 if newcount > 0:
591 result[elem] = newcount
592 for elem, count in other.items():
593 if elem not in self and count < 0:
594 result[elem] = 0 - count
595 return result
596
597 def __or__(self, other):
598 '''Union is the maximum of value in either of the input counters.
599
600 >>> Counter('abbb') | Counter('bcc')
601 Counter({'b': 3, 'c': 2, 'a': 1})
602
603 '''
604 if not isinstance(other, Counter):
605 return NotImplemented
606 result = Counter()
607 for elem, count in self.items():
608 other_count = other[elem]
609 newcount = other_count if count < other_count else count
610 if newcount > 0:
611 result[elem] = newcount
612 for elem, count in other.items():
613 if elem not in self and count > 0:
614 result[elem] = count
615 return result
616
617 def __and__(self, other):
618 ''' Intersection is the minimum of corresponding counts.
619
620 >>> Counter('abbb') & Counter('bcc')
621 Counter({'b': 1})
622
623 '''
624 if not isinstance(other, Counter):
625 return NotImplemented
626 result = Counter()
627 for elem, count in self.items():
628 other_count = other[elem]
629 newcount = count if count < other_count else other_count
630 if newcount > 0:
631 result[elem] = newcount
632 return result
633
634
635if __name__ == '__main__':
636 # verify that instances can be pickled
637 from cPickle import loads, dumps
638 Point = namedtuple('Point', 'x, y', True)
639 p = Point(x=10, y=20)
640 assert p == loads(dumps(p))
641
642 # test and demonstrate ability to override methods
643 class Point(namedtuple('Point', 'x y')):
644 __slots__ = ()
645 @property
646 def hypot(self):
647 return (self.x ** 2 + self.y ** 2) ** 0.5
648 def __str__(self):
649 return 'Point: x=%6.3f y=%6.3f hypot=%6.3f' % (self.x, self.y, self.hypot)
650
651 for p in Point(3, 4), Point(14, 5/7.):
652 print p
653
654 class Point(namedtuple('Point', 'x y')):
655 'Point class with optimized _make() and _replace() without error-checking'
656 __slots__ = ()
657 _make = classmethod(tuple.__new__)
658 def _replace(self, _map=map, **kwds):
659 return self._make(_map(kwds.get, ('x', 'y'), self))
660
661 print Point(11, 22)._replace(x=100)
662
663 Point3D = namedtuple('Point3D', Point._fields + ('z',))
664 print Point3D.__doc__
665
666 import doctest
667 TestResults = namedtuple('TestResults', 'failed attempted')
668 print TestResults(*doctest.testmod())