Add key= argument to heapq.nsmallest() and heapq.nlargest().
diff --git a/Lib/heapq.py b/Lib/heapq.py
index b4ebb91..04725cd 100644
--- a/Lib/heapq.py
+++ b/Lib/heapq.py
@@ -129,7 +129,8 @@
__all__ = ['heappush', 'heappop', 'heapify', 'heapreplace', 'nlargest',
'nsmallest']
-from itertools import islice, repeat
+from itertools import islice, repeat, count, imap, izip, tee
+from operator import itemgetter
import bisect
def heappush(heap, item):
@@ -307,6 +308,33 @@
except ImportError:
pass
+# Extend the implementations of nsmallest and nlargest to use a key= argument
+_nsmallest = nsmallest
+def nsmallest(n, iterable, key=None):
+ """Find the n smallest elements in a dataset.
+
+ Equivalent to: sorted(iterable, key=key)[:n]
+ """
+ if key is None:
+ return _nsmallest(n, iterable)
+ in1, in2 = tee(iterable)
+ it = izip(imap(key, in1), count(), in2) # decorate
+ result = _nsmallest(n, it)
+ return map(itemgetter(2), result) # undecorate
+
+_nlargest = nlargest
+def nlargest(n, iterable, key=None):
+ """Find the n largest elements in a dataset.
+
+ Equivalent to: sorted(iterable, key=key, reverse=True)[:n]
+ """
+ if key is None:
+ return _nlargest(n, iterable)
+ in1, in2 = tee(iterable)
+ it = izip(imap(key, in1), count(), in2) # decorate
+ result = _nlargest(n, it)
+ return map(itemgetter(2), result) # undecorate
+
if __name__ == "__main__":
# Simple sanity test
heap = []
diff --git a/Lib/test/test_heapq.py b/Lib/test/test_heapq.py
index 68003e7..2da4f8c 100644
--- a/Lib/test/test_heapq.py
+++ b/Lib/test/test_heapq.py
@@ -105,13 +105,19 @@
def test_nsmallest(self):
data = [random.randrange(2000) for i in range(1000)]
+ f = lambda x: x * 547 % 2000
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(nsmallest(n, data), sorted(data)[:n])
+ self.assertEqual(nsmallest(n, data, key=f),
+ sorted(data, key=f)[:n])
- def test_largest(self):
+ def test_nlargest(self):
data = [random.randrange(2000) for i in range(1000)]
+ f = lambda x: x * 547 % 2000
for n in (0, 1, 2, 10, 100, 400, 999, 1000, 1100):
self.assertEqual(nlargest(n, data), sorted(data, reverse=True)[:n])
+ self.assertEqual(nlargest(n, data, key=f),
+ sorted(data, key=f, reverse=True)[:n])
#==============================================================================