blob: 19f88c7f0210b96e5a3c7fd26e94d2167cdc597d [file] [log] [blame]
Thomas Wouters73e5a5b2006-06-08 15:35:45 +00001"""functools.py - Tools for working with functions and callable objects
Thomas Wouters4d70c3d2006-06-08 14:42:34 +00002"""
3# Python module wrapper for _functools C module
4# to allow utilities written in Python to be added
5# to the functools module.
Łukasz Langa6f692512013-06-05 12:20:24 +02006# Written by Nick Coghlan <ncoghlan at gmail.com>,
7# Raymond Hettinger <python at rcn.com>,
8# and Łukasz Langa <lukasz at langa.pl>.
9# Copyright (C) 2006-2013 Python Software Foundation.
Thomas Wouters73e5a5b2006-06-08 15:35:45 +000010# See C source code for _functools credits/copyright
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000011
Georg Brandl2e7346a2010-07-31 18:09:23 +000012__all__ = ['update_wrapper', 'wraps', 'WRAPPER_ASSIGNMENTS', 'WRAPPER_UPDATES',
Łukasz Langa6f692512013-06-05 12:20:24 +020013 'total_ordering', 'cmp_to_key', 'lru_cache', 'reduce', 'partial',
14 'singledispatch']
Georg Brandl2e7346a2010-07-31 18:09:23 +000015
Antoine Pitroub5b37142012-11-13 21:35:40 +010016try:
17 from _functools import reduce
Brett Cannoncd171c82013-07-04 17:43:24 -040018except ImportError:
Antoine Pitroub5b37142012-11-13 21:35:40 +010019 pass
Łukasz Langa6f692512013-06-05 12:20:24 +020020from abc import get_cache_token
Raymond Hettingerec0e9102012-03-16 01:16:31 -070021from collections import namedtuple
Łukasz Langa6f692512013-06-05 12:20:24 +020022from types import MappingProxyType
23from weakref import WeakKeyDictionary
Raymond Hettingercbe88132010-08-14 22:22:10 +000024try:
Raymond Hettingerfd541172013-03-01 03:47:57 -080025 from _thread import RLock
Raymond Hettingercbe88132010-08-14 22:22:10 +000026except:
Raymond Hettinger409f6632013-03-01 23:20:13 -080027 class RLock:
Raymond Hettingerf96b2b02013-03-08 21:11:55 -070028 'Dummy reentrant lock for builds without threads'
Raymond Hettinger409f6632013-03-01 23:20:13 -080029 def __enter__(self): pass
30 def __exit__(self, exctype, excinst, exctb): pass
Thomas Wouters4d70c3d2006-06-08 14:42:34 +000031
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -070032
33################################################################################
34### update_wrapper() and wraps() decorator
35################################################################################
36
Thomas Wouters73e5a5b2006-06-08 15:35:45 +000037# update_wrapper() and wraps() are tools to help write
38# wrapper functions that can handle naive introspection
39
Meador Ingeff7f64c2011-12-11 22:37:31 -060040WRAPPER_ASSIGNMENTS = ('__module__', '__name__', '__qualname__', '__doc__',
41 '__annotations__')
Thomas Wouters73e5a5b2006-06-08 15:35:45 +000042WRAPPER_UPDATES = ('__dict__',)
43def update_wrapper(wrapper,
44 wrapped,
45 assigned = WRAPPER_ASSIGNMENTS,
46 updated = WRAPPER_UPDATES):
47 """Update a wrapper function to look like the wrapped function
48
49 wrapper is the function to be updated
50 wrapped is the original function
51 assigned is a tuple naming the attributes assigned directly
52 from the wrapped function to the wrapper function (defaults to
53 functools.WRAPPER_ASSIGNMENTS)
Thomas Wouters89f507f2006-12-13 04:49:30 +000054 updated is a tuple naming the attributes of the wrapper that
Thomas Wouters73e5a5b2006-06-08 15:35:45 +000055 are updated with the corresponding attribute from the wrapped
56 function (defaults to functools.WRAPPER_UPDATES)
57 """
58 for attr in assigned:
Nick Coghlan98876832010-08-17 06:17:18 +000059 try:
60 value = getattr(wrapped, attr)
61 except AttributeError:
62 pass
63 else:
64 setattr(wrapper, attr, value)
Thomas Wouters73e5a5b2006-06-08 15:35:45 +000065 for attr in updated:
Thomas Wouters89f507f2006-12-13 04:49:30 +000066 getattr(wrapper, attr).update(getattr(wrapped, attr, {}))
Nick Coghlan24c05bc2013-07-15 21:13:08 +100067 # Issue #17482: set __wrapped__ last so we don't inadvertently copy it
68 # from the wrapped function when updating __dict__
69 wrapper.__wrapped__ = wrapped
Thomas Wouters73e5a5b2006-06-08 15:35:45 +000070 # Return the wrapper so this can be used as a decorator via partial()
71 return wrapper
72
73def wraps(wrapped,
74 assigned = WRAPPER_ASSIGNMENTS,
75 updated = WRAPPER_UPDATES):
76 """Decorator factory to apply update_wrapper() to a wrapper function
77
78 Returns a decorator that invokes update_wrapper() with the decorated
79 function as the wrapper argument and the arguments to wraps() as the
80 remaining arguments. Default arguments are as for update_wrapper().
81 This is a convenience function to simplify applying partial() to
82 update_wrapper().
83 """
84 return partial(update_wrapper, wrapped=wrapped,
85 assigned=assigned, updated=updated)
Raymond Hettingerc50846a2010-04-05 18:56:31 +000086
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -070087
88################################################################################
89### total_ordering class decorator
90################################################################################
91
Raymond Hettingerc50846a2010-04-05 18:56:31 +000092def total_ordering(cls):
Georg Brandle5a26732010-05-19 21:06:36 +000093 """Class decorator that fills in missing ordering methods"""
Raymond Hettingerc50846a2010-04-05 18:56:31 +000094 convert = {
Raymond Hettinger23f9fc32011-01-08 07:01:56 +000095 '__lt__': [('__gt__', lambda self, other: not (self < other or self == other)),
96 ('__le__', lambda self, other: self < other or self == other),
Raymond Hettingerc50846a2010-04-05 18:56:31 +000097 ('__ge__', lambda self, other: not self < other)],
Raymond Hettinger23f9fc32011-01-08 07:01:56 +000098 '__le__': [('__ge__', lambda self, other: not self <= other or self == other),
99 ('__lt__', lambda self, other: self <= other and not self == other),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000100 ('__gt__', lambda self, other: not self <= other)],
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000101 '__gt__': [('__lt__', lambda self, other: not (self > other or self == other)),
102 ('__ge__', lambda self, other: self > other or self == other),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000103 ('__le__', lambda self, other: not self > other)],
Raymond Hettinger23f9fc32011-01-08 07:01:56 +0000104 '__ge__': [('__le__', lambda self, other: (not self >= other) or self == other),
105 ('__gt__', lambda self, other: self >= other and not self == other),
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000106 ('__lt__', lambda self, other: not self >= other)]
107 }
Raymond Hettinger3255c632010-09-16 00:31:21 +0000108 # Find user-defined comparisons (not those inherited from object).
Raymond Hettinger1006bd42010-09-14 22:55:13 +0000109 roots = [op for op in convert if getattr(cls, op, None) is not getattr(object, op, None)]
Raymond Hettinger56de7e22010-04-10 16:59:03 +0000110 if not roots:
111 raise ValueError('must define at least one ordering operation: < > <= >=')
112 root = max(roots) # prefer __lt__ to __le__ to __gt__ to __ge__
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000113 for opname, opfunc in convert[root]:
114 if opname not in roots:
115 opfunc.__name__ = opname
116 opfunc.__doc__ = getattr(int, opname).__doc__
117 setattr(cls, opname, opfunc)
118 return cls
119
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -0700120
121################################################################################
122### cmp_to_key() function converter
123################################################################################
124
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000125def cmp_to_key(mycmp):
Georg Brandle5a26732010-05-19 21:06:36 +0000126 """Convert a cmp= function into a key= function"""
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000127 class K(object):
Raymond Hettingera0d1d962011-03-21 17:50:28 -0700128 __slots__ = ['obj']
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700129 def __init__(self, obj):
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000130 self.obj = obj
131 def __lt__(self, other):
132 return mycmp(self.obj, other.obj) < 0
133 def __gt__(self, other):
134 return mycmp(self.obj, other.obj) > 0
135 def __eq__(self, other):
136 return mycmp(self.obj, other.obj) == 0
137 def __le__(self, other):
138 return mycmp(self.obj, other.obj) <= 0
139 def __ge__(self, other):
140 return mycmp(self.obj, other.obj) >= 0
141 def __ne__(self, other):
142 return mycmp(self.obj, other.obj) != 0
Raymond Hettinger003be522011-05-03 11:01:32 -0700143 __hash__ = None
Raymond Hettingerc50846a2010-04-05 18:56:31 +0000144 return K
Georg Brandl2e7346a2010-07-31 18:09:23 +0000145
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700146try:
147 from _functools import cmp_to_key
Brett Cannoncd171c82013-07-04 17:43:24 -0400148except ImportError:
Raymond Hettinger7ab9e222011-04-05 02:33:54 -0700149 pass
150
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -0700151
152################################################################################
Antoine Pitroub5b37142012-11-13 21:35:40 +0100153### partial() argument application
154################################################################################
155
156def partial(func, *args, **keywords):
157 """new function with partial application of the given arguments
158 and keywords.
159 """
160 def newfunc(*fargs, **fkeywords):
161 newkeywords = keywords.copy()
162 newkeywords.update(fkeywords)
163 return func(*(args + fargs), **newkeywords)
164 newfunc.func = func
165 newfunc.args = args
166 newfunc.keywords = keywords
167 return newfunc
168
169try:
170 from _functools import partial
Brett Cannoncd171c82013-07-04 17:43:24 -0400171except ImportError:
Antoine Pitroub5b37142012-11-13 21:35:40 +0100172 pass
173
174
175################################################################################
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -0700176### LRU Cache function decorator
177################################################################################
178
Raymond Hettingerdce583e2012-03-16 22:12:20 -0700179_CacheInfo = namedtuple("CacheInfo", ["hits", "misses", "maxsize", "currsize"])
Nick Coghlan234515a2010-11-30 06:19:46 +0000180
Raymond Hettinger0c9050c2012-06-04 00:21:14 -0700181class _HashedSeq(list):
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700182 """ This class guarantees that hash() will be called no more than once
183 per element. This is important because the lru_cache() will hash
184 the key multiple times on a cache miss.
185
186 """
187
Raymond Hettinger9acbb602012-04-30 22:32:16 -0700188 __slots__ = 'hashvalue'
189
Raymond Hettinger0c9050c2012-06-04 00:21:14 -0700190 def __init__(self, tup, hash=hash):
191 self[:] = tup
192 self.hashvalue = hash(tup)
Raymond Hettinger9acbb602012-04-30 22:32:16 -0700193
194 def __hash__(self):
195 return self.hashvalue
196
Raymond Hettinger0c9050c2012-06-04 00:21:14 -0700197def _make_key(args, kwds, typed,
198 kwd_mark = (object(),),
199 fasttypes = {int, str, frozenset, type(None)},
200 sorted=sorted, tuple=tuple, type=type, len=len):
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700201 """Make a cache key from optionally typed positional and keyword arguments
202
203 The key is constructed in a way that is flat as possible rather than
204 as a nested structure that would take more memory.
205
206 If there is only a single argument and its data type is known to cache
207 its hash value, then that argument is returned without a wrapper. This
208 saves space and improves lookup speed.
209
210 """
Raymond Hettinger0c9050c2012-06-04 00:21:14 -0700211 key = args
212 if kwds:
213 sorted_items = sorted(kwds.items())
214 key += kwd_mark
215 for item in sorted_items:
216 key += item
217 if typed:
218 key += tuple(type(v) for v in args)
219 if kwds:
220 key += tuple(type(v) for k, v in sorted_items)
221 elif len(key) == 1 and type(key[0]) in fasttypes:
222 return key[0]
223 return _HashedSeq(key)
224
Raymond Hettinger010ce322012-05-19 21:20:48 -0700225def lru_cache(maxsize=128, typed=False):
Benjamin Peterson1f594ad2010-08-08 13:17:07 +0000226 """Least-recently-used cache decorator.
Georg Brandl2e7346a2010-07-31 18:09:23 +0000227
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000228 If *maxsize* is set to None, the LRU features are disabled and the cache
229 can grow without bound.
230
Raymond Hettingercd9fdfd2011-10-20 08:57:45 -0700231 If *typed* is True, arguments of different types will be cached separately.
232 For example, f(3.0) and f(3) will be treated as distinct calls with
233 distinct results.
234
Georg Brandl2e7346a2010-07-31 18:09:23 +0000235 Arguments to the cached function must be hashable.
Raymond Hettinger5fa40c02010-11-25 08:11:57 +0000236
Raymond Hettinger7f7a5a72012-03-30 21:50:40 -0700237 View the cache statistics named tuple (hits, misses, maxsize, currsize)
238 with f.cache_info(). Clear the cache and statistics with f.cache_clear().
Raymond Hettinger00f2f972010-12-01 00:47:56 +0000239 Access the underlying function with f.__wrapped__.
Raymond Hettinger5fa40c02010-11-25 08:11:57 +0000240
241 See: http://en.wikipedia.org/wiki/Cache_algorithms#Least_Recently_Used
Georg Brandl2e7346a2010-07-31 18:09:23 +0000242
Benjamin Peterson1f594ad2010-08-08 13:17:07 +0000243 """
Raymond Hettinger1ff50df2012-03-30 13:15:48 -0700244
Raymond Hettinger5fa40c02010-11-25 08:11:57 +0000245 # Users should only access the lru_cache through its public API:
Raymond Hettinger5e20bab2010-11-30 07:13:04 +0000246 # cache_info, cache_clear, and f.__wrapped__
Raymond Hettinger5fa40c02010-11-25 08:11:57 +0000247 # The internals of the lru_cache are encapsulated for thread safety and
248 # to allow the implementation to change (including a possible C version).
249
Raymond Hettinger9f0ab9f2012-04-29 14:55:27 -0700250 # Constants shared by all lru cache instances:
Raymond Hettingerb6b98c02012-04-29 18:09:02 -0700251 sentinel = object() # unique object used to signal cache misses
Raymond Hettinger0c9050c2012-06-04 00:21:14 -0700252 make_key = _make_key # build a key from the function arguments
Raymond Hettinger9f0ab9f2012-04-29 14:55:27 -0700253 PREV, NEXT, KEY, RESULT = 0, 1, 2, 3 # names for the link fields
254
Raymond Hettinger6e8c8172012-03-16 16:53:05 -0700255 def decorating_function(user_function):
Raymond Hettinger7f7a5a72012-03-30 21:50:40 -0700256 cache = {}
Raymond Hettinger832edde2013-02-17 00:08:45 -0800257 hits = misses = 0
Raymond Hettinger018b4fb2012-04-30 20:48:55 -0700258 full = False
Raymond Hettingerc6897852012-03-31 02:19:06 -0700259 cache_get = cache.get # bound method to lookup a key or return None
Raymond Hettingerfd541172013-03-01 03:47:57 -0800260 lock = RLock() # because linkedlist updates aren't threadsafe
Raymond Hettinger7f7a5a72012-03-30 21:50:40 -0700261 root = [] # root of the circular doubly linked list
262 root[:] = [root, root, None, None] # initialize by pointing to self
Raymond Hettinger6e8c8172012-03-16 16:53:05 -0700263
Raymond Hettinger7e0c5812012-03-17 15:10:24 -0700264 if maxsize == 0:
265
Raymond Hettinger7e0c5812012-03-17 15:10:24 -0700266 def wrapper(*args, **kwds):
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700267 # No caching -- just a statistics update after a successful call
Raymond Hettinger7e0c5812012-03-17 15:10:24 -0700268 nonlocal misses
Raymond Hettinger7dabfed2012-03-17 15:11:09 -0700269 result = user_function(*args, **kwds)
Raymond Hettinger7e0c5812012-03-17 15:10:24 -0700270 misses += 1
271 return result
272
273 elif maxsize is None:
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -0700274
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000275 def wrapper(*args, **kwds):
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700276 # Simple caching without ordering or size limit
Raymond Hettinger832edde2013-02-17 00:08:45 -0800277 nonlocal hits, misses
Raymond Hettinger9acbb602012-04-30 22:32:16 -0700278 key = make_key(args, kwds, typed)
Raymond Hettinger7f7a5a72012-03-30 21:50:40 -0700279 result = cache_get(key, sentinel)
280 if result is not sentinel:
Nick Coghlan234515a2010-11-30 06:19:46 +0000281 hits += 1
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700282 return result
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700283 result = user_function(*args, **kwds)
284 cache[key] = result
285 misses += 1
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000286 return result
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -0700287
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000288 else:
Raymond Hettingerbc8e81d2012-03-17 00:24:09 -0700289
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000290 def wrapper(*args, **kwds):
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700291 # Size limited caching that tracks accesses by recency
Raymond Hettinger832edde2013-02-17 00:08:45 -0800292 nonlocal root, hits, misses, full
Raymond Hettinger9acbb602012-04-30 22:32:16 -0700293 key = make_key(args, kwds, typed)
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700294 with lock:
Raymond Hettingerec0e9102012-03-16 01:16:31 -0700295 link = cache_get(key)
296 if link is not None:
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700297 # Move the link to the front of the circular queue
298 link_prev, link_next, _key, result = link
Raymond Hettingerec0e9102012-03-16 01:16:31 -0700299 link_prev[NEXT] = link_next
300 link_next[PREV] = link_prev
301 last = root[PREV]
302 last[NEXT] = root[PREV] = link
303 link[PREV] = last
304 link[NEXT] = root
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000305 hits += 1
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700306 return result
Raymond Hettinger4b779b32011-10-15 23:50:42 -0700307 result = user_function(*args, **kwds)
308 with lock:
Raymond Hettinger34d94a22012-04-30 14:14:28 -0700309 if key in cache:
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700310 # Getting here means that this same key was added to the
311 # cache while the lock was released. Since the link
Raymond Hettinger34d94a22012-04-30 14:14:28 -0700312 # update is already done, we need only return the
313 # computed result and update the count of misses.
314 pass
Raymond Hettinger018b4fb2012-04-30 20:48:55 -0700315 elif full:
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700316 # Use the old root to store the new key and result.
Raymond Hettingerf2c17a92013-03-04 03:34:09 -0500317 oldroot = root
318 oldroot[KEY] = key
319 oldroot[RESULT] = result
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700320 # Empty the oldest link and make it the new root.
321 # Keep a reference to the old key and old result to
322 # prevent their ref counts from going to zero during the
323 # update. That will prevent potentially arbitrary object
324 # clean-up code (i.e. __del__) from running while we're
325 # still adjusting the links.
Raymond Hettingerf2c17a92013-03-04 03:34:09 -0500326 root = oldroot[NEXT]
327 oldkey = root[KEY]
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700328 oldresult = root[RESULT]
Raymond Hettingerc6897852012-03-31 02:19:06 -0700329 root[KEY] = root[RESULT] = None
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700330 # Now update the cache dictionary.
Raymond Hettingerf2c17a92013-03-04 03:34:09 -0500331 del cache[oldkey]
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700332 # Save the potentially reentrant cache[key] assignment
333 # for last, after the root and links have been put in
334 # a consistent state.
Raymond Hettingerf2c17a92013-03-04 03:34:09 -0500335 cache[key] = oldroot
Raymond Hettinger018b4fb2012-04-30 20:48:55 -0700336 else:
Raymond Hettingerf96b2b02013-03-08 21:11:55 -0700337 # Put result in a new link at the front of the queue.
Raymond Hettinger018b4fb2012-04-30 20:48:55 -0700338 last = root[PREV]
339 link = [last, root, key, result]
Raymond Hettingerf2c17a92013-03-04 03:34:09 -0500340 last[NEXT] = root[PREV] = cache[key] = link
Raymond Hettingerbb5f4802013-03-04 04:20:46 -0500341 full = (len(cache) >= maxsize)
Raymond Hettingerec0e9102012-03-16 01:16:31 -0700342 misses += 1
Raymond Hettingerc79fb0e2010-12-01 03:45:41 +0000343 return result
Georg Brandl2e7346a2010-07-31 18:09:23 +0000344
Nick Coghlan234515a2010-11-30 06:19:46 +0000345 def cache_info():
Raymond Hettinger5e20bab2010-11-30 07:13:04 +0000346 """Report cache statistics"""
Nick Coghlan234515a2010-11-30 06:19:46 +0000347 with lock:
Raymond Hettinger832edde2013-02-17 00:08:45 -0800348 return _CacheInfo(hits, misses, maxsize, len(cache))
Nick Coghlan234515a2010-11-30 06:19:46 +0000349
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000350 def cache_clear():
Benjamin Peterson1f594ad2010-08-08 13:17:07 +0000351 """Clear the cache and cache statistics"""
Raymond Hettinger832edde2013-02-17 00:08:45 -0800352 nonlocal hits, misses, full
Raymond Hettingercbe88132010-08-14 22:22:10 +0000353 with lock:
354 cache.clear()
Benjamin Peterson954cf572012-03-16 18:22:26 -0500355 root[:] = [root, root, None, None]
Raymond Hettinger832edde2013-02-17 00:08:45 -0800356 hits = misses = 0
Raymond Hettinger018b4fb2012-04-30 20:48:55 -0700357 full = False
Georg Brandl2e7346a2010-07-31 18:09:23 +0000358
Nick Coghlan234515a2010-11-30 06:19:46 +0000359 wrapper.cache_info = cache_info
Raymond Hettinger02566ec2010-09-04 22:46:06 +0000360 wrapper.cache_clear = cache_clear
Raymond Hettinger1ff50df2012-03-30 13:15:48 -0700361 return update_wrapper(wrapper, user_function)
Raymond Hettinger5fa40c02010-11-25 08:11:57 +0000362
Georg Brandl2e7346a2010-07-31 18:09:23 +0000363 return decorating_function
Łukasz Langa6f692512013-06-05 12:20:24 +0200364
365
366################################################################################
367### singledispatch() - single-dispatch generic function decorator
368################################################################################
369
Łukasz Langa3720c772013-07-01 16:00:38 +0200370def _c3_merge(sequences):
371 """Merges MROs in *sequences* to a single MRO using the C3 algorithm.
372
373 Adapted from http://www.python.org/download/releases/2.3/mro/.
374
375 """
376 result = []
377 while True:
378 sequences = [s for s in sequences if s] # purge empty sequences
379 if not sequences:
380 return result
381 for s1 in sequences: # find merge candidates among seq heads
382 candidate = s1[0]
383 for s2 in sequences:
384 if candidate in s2[1:]:
385 candidate = None
386 break # reject the current head, it appears later
387 else:
388 break
389 if not candidate:
390 raise RuntimeError("Inconsistent hierarchy")
391 result.append(candidate)
392 # remove the chosen candidate
393 for seq in sequences:
394 if seq[0] == candidate:
395 del seq[0]
396
397def _c3_mro(cls, abcs=None):
398 """Computes the method resolution order using extended C3 linearization.
399
400 If no *abcs* are given, the algorithm works exactly like the built-in C3
401 linearization used for method resolution.
402
403 If given, *abcs* is a list of abstract base classes that should be inserted
404 into the resulting MRO. Unrelated ABCs are ignored and don't end up in the
405 result. The algorithm inserts ABCs where their functionality is introduced,
406 i.e. issubclass(cls, abc) returns True for the class itself but returns
407 False for all its direct base classes. Implicit ABCs for a given class
408 (either registered or inferred from the presence of a special method like
409 __len__) are inserted directly after the last ABC explicitly listed in the
410 MRO of said class. If two implicit ABCs end up next to each other in the
411 resulting MRO, their ordering depends on the order of types in *abcs*.
412
413 """
414 for i, base in enumerate(reversed(cls.__bases__)):
415 if hasattr(base, '__abstractmethods__'):
416 boundary = len(cls.__bases__) - i
417 break # Bases up to the last explicit ABC are considered first.
418 else:
419 boundary = 0
420 abcs = list(abcs) if abcs else []
421 explicit_bases = list(cls.__bases__[:boundary])
422 abstract_bases = []
423 other_bases = list(cls.__bases__[boundary:])
424 for base in abcs:
425 if issubclass(cls, base) and not any(
426 issubclass(b, base) for b in cls.__bases__
427 ):
428 # If *cls* is the class that introduces behaviour described by
429 # an ABC *base*, insert said ABC to its MRO.
430 abstract_bases.append(base)
431 for base in abstract_bases:
432 abcs.remove(base)
433 explicit_c3_mros = [_c3_mro(base, abcs=abcs) for base in explicit_bases]
434 abstract_c3_mros = [_c3_mro(base, abcs=abcs) for base in abstract_bases]
435 other_c3_mros = [_c3_mro(base, abcs=abcs) for base in other_bases]
436 return _c3_merge(
437 [[cls]] +
438 explicit_c3_mros + abstract_c3_mros + other_c3_mros +
439 [explicit_bases] + [abstract_bases] + [other_bases]
440 )
441
442def _compose_mro(cls, types):
443 """Calculates the method resolution order for a given class *cls*.
444
445 Includes relevant abstract base classes (with their respective bases) from
446 the *types* iterable. Uses a modified C3 linearization algorithm.
Łukasz Langa6f692512013-06-05 12:20:24 +0200447
448 """
449 bases = set(cls.__mro__)
Łukasz Langa3720c772013-07-01 16:00:38 +0200450 # Remove entries which are already present in the __mro__ or unrelated.
451 def is_related(typ):
452 return (typ not in bases and hasattr(typ, '__mro__')
453 and issubclass(cls, typ))
454 types = [n for n in types if is_related(n)]
455 # Remove entries which are strict bases of other entries (they will end up
456 # in the MRO anyway.
457 def is_strict_base(typ):
458 for other in types:
459 if typ != other and typ in other.__mro__:
460 return True
461 return False
462 types = [n for n in types if not is_strict_base(n)]
463 # Subclasses of the ABCs in *types* which are also implemented by
464 # *cls* can be used to stabilize ABC ordering.
465 type_set = set(types)
466 mro = []
467 for typ in types:
468 found = []
469 for sub in typ.__subclasses__():
470 if sub not in bases and issubclass(cls, sub):
471 found.append([s for s in sub.__mro__ if s in type_set])
472 if not found:
473 mro.append(typ)
474 continue
475 # Favor subclasses with the biggest number of useful bases
476 found.sort(key=len, reverse=True)
477 for sub in found:
478 for subcls in sub:
479 if subcls not in mro:
480 mro.append(subcls)
481 return _c3_mro(cls, abcs=mro)
Łukasz Langa6f692512013-06-05 12:20:24 +0200482
483def _find_impl(cls, registry):
Łukasz Langa3720c772013-07-01 16:00:38 +0200484 """Returns the best matching implementation from *registry* for type *cls*.
Łukasz Langa6f692512013-06-05 12:20:24 +0200485
Łukasz Langa3720c772013-07-01 16:00:38 +0200486 Where there is no registered implementation for a specific type, its method
487 resolution order is used to find a more generic implementation.
488
489 Note: if *registry* does not contain an implementation for the base
490 *object* type, this function may return None.
Łukasz Langa6f692512013-06-05 12:20:24 +0200491
492 """
493 mro = _compose_mro(cls, registry.keys())
494 match = None
495 for t in mro:
496 if match is not None:
Łukasz Langa3720c772013-07-01 16:00:38 +0200497 # If *match* is an implicit ABC but there is another unrelated,
498 # equally matching implicit ABC, refuse the temptation to guess.
499 if (t in registry and t not in cls.__mro__
500 and match not in cls.__mro__
501 and not issubclass(match, t)):
Łukasz Langa6f692512013-06-05 12:20:24 +0200502 raise RuntimeError("Ambiguous dispatch: {} or {}".format(
503 match, t))
504 break
505 if t in registry:
506 match = t
507 return registry.get(match)
508
509def singledispatch(func):
510 """Single-dispatch generic function decorator.
511
512 Transforms a function into a generic function, which can have different
513 behaviours depending upon the type of its first argument. The decorated
514 function acts as the default implementation, and additional
Łukasz Langa3720c772013-07-01 16:00:38 +0200515 implementations can be registered using the register() attribute of the
516 generic function.
Łukasz Langa6f692512013-06-05 12:20:24 +0200517
518 """
519 registry = {}
520 dispatch_cache = WeakKeyDictionary()
521 cache_token = None
522
Łukasz Langa3720c772013-07-01 16:00:38 +0200523 def dispatch(cls):
524 """generic_func.dispatch(cls) -> <function implementation>
Łukasz Langa6f692512013-06-05 12:20:24 +0200525
526 Runs the dispatch algorithm to return the best available implementation
Łukasz Langa3720c772013-07-01 16:00:38 +0200527 for the given *cls* registered on *generic_func*.
Łukasz Langa6f692512013-06-05 12:20:24 +0200528
529 """
530 nonlocal cache_token
531 if cache_token is not None:
532 current_token = get_cache_token()
533 if cache_token != current_token:
534 dispatch_cache.clear()
535 cache_token = current_token
536 try:
Łukasz Langa3720c772013-07-01 16:00:38 +0200537 impl = dispatch_cache[cls]
Łukasz Langa6f692512013-06-05 12:20:24 +0200538 except KeyError:
539 try:
Łukasz Langa3720c772013-07-01 16:00:38 +0200540 impl = registry[cls]
Łukasz Langa6f692512013-06-05 12:20:24 +0200541 except KeyError:
Łukasz Langa3720c772013-07-01 16:00:38 +0200542 impl = _find_impl(cls, registry)
543 dispatch_cache[cls] = impl
Łukasz Langa6f692512013-06-05 12:20:24 +0200544 return impl
545
Łukasz Langa3720c772013-07-01 16:00:38 +0200546 def register(cls, func=None):
547 """generic_func.register(cls, func) -> func
Łukasz Langa6f692512013-06-05 12:20:24 +0200548
Łukasz Langa3720c772013-07-01 16:00:38 +0200549 Registers a new implementation for the given *cls* on a *generic_func*.
Łukasz Langa6f692512013-06-05 12:20:24 +0200550
551 """
552 nonlocal cache_token
553 if func is None:
Łukasz Langa3720c772013-07-01 16:00:38 +0200554 return lambda f: register(cls, f)
555 registry[cls] = func
556 if cache_token is None and hasattr(cls, '__abstractmethods__'):
Łukasz Langa6f692512013-06-05 12:20:24 +0200557 cache_token = get_cache_token()
558 dispatch_cache.clear()
559 return func
560
561 def wrapper(*args, **kw):
562 return dispatch(args[0].__class__)(*args, **kw)
563
564 registry[object] = func
565 wrapper.register = register
566 wrapper.dispatch = dispatch
567 wrapper.registry = MappingProxyType(registry)
568 wrapper._clear_cache = dispatch_cache.clear
569 update_wrapper(wrapper, func)
570 return wrapper