blob: 944b17dcc71a1d92fd81f8daaba27ec09176f17f [file] [log] [blame]
Guido van Rossum0b191782002-08-02 18:29:53 +00001"""Unittests for heapq."""
2
Raymond Hettinger33ecffb2004-06-10 05:03:17 +00003from heapq import heappush, heappop, heapify, heapreplace, nlargest, nsmallest
Guido van Rossum0b191782002-08-02 18:29:53 +00004import random
Raymond Hettingerbce036b2004-06-10 05:07:18 +00005import unittest
6from test import test_support
Guido van Rossum0b191782002-08-02 18:29:53 +00007
Guido van Rossum0b191782002-08-02 18:29:53 +00008
Raymond Hettingerbce036b2004-06-10 05:07:18 +00009def heapiter(heap):
10 # An iterator returning a heap's elements, smallest-first.
11 try:
12 while 1:
13 yield heappop(heap)
14 except IndexError:
15 pass
Tim Petersaa7d2432002-08-03 02:11:26 +000016
Raymond Hettingerbce036b2004-06-10 05:07:18 +000017class TestHeap(unittest.TestCase):
Tim Petersaa7d2432002-08-03 02:11:26 +000018
Raymond Hettingerbce036b2004-06-10 05:07:18 +000019 def test_push_pop(self):
20 # 1) Push 256 random numbers and pop them off, verifying all's OK.
21 heap = []
22 data = []
23 self.check_invariant(heap)
24 for i in range(256):
25 item = random.random()
26 data.append(item)
27 heappush(heap, item)
28 self.check_invariant(heap)
29 results = []
30 while heap:
31 item = heappop(heap)
32 self.check_invariant(heap)
33 results.append(item)
34 data_sorted = data[:]
35 data_sorted.sort()
36 self.assertEqual(data_sorted, results)
37 # 2) Check that the invariant holds for a sorted array
38 self.check_invariant(results)
39
40 def check_invariant(self, heap):
41 # Check the heap invariant.
42 for pos, item in enumerate(heap):
43 if pos: # pos 0 has no parent
44 parentpos = (pos-1) >> 1
45 self.assert_(heap[parentpos] <= item)
46
47 def test_heapify(self):
48 for size in range(30):
49 heap = [random.random() for dummy in range(size)]
50 heapify(heap)
51 self.check_invariant(heap)
52
53 def test_naive_nbest(self):
54 data = [random.randrange(2000) for i in range(1000)]
55 heap = []
56 for item in data:
57 heappush(heap, item)
58 if len(heap) > 10:
59 heappop(heap)
60 heap.sort()
61 self.assertEqual(heap, sorted(data)[-10:])
62
63 def test_nbest(self):
64 # Less-naive "N-best" algorithm, much faster (if len(data) is big
65 # enough <wink>) than sorting all of data. However, if we had a max
66 # heap instead of a min heap, it could go faster still via
67 # heapify'ing all of data (linear time), then doing 10 heappops
68 # (10 log-time steps).
69 data = [random.randrange(2000) for i in range(1000)]
70 heap = data[:10]
71 heapify(heap)
72 for item in data[10:]:
73 if item > heap[0]: # this gets rarer the longer we run
74 heapreplace(heap, item)
75 self.assertEqual(list(heapiter(heap)), sorted(data)[-10:])
76
77 def test_heapsort(self):
78 # Exercise everything with repeated heapsort checks
79 for trial in xrange(100):
80 size = random.randrange(50)
81 data = [random.randrange(25) for i in range(size)]
82 if trial & 1: # Half of the time, use heapify
83 heap = data[:]
84 heapify(heap)
85 else: # The rest of the time, use heappush
86 heap = []
87 for item in data:
88 heappush(heap, item)
89 heap_sorted = [heappop(heap) for i in range(size)]
90 self.assertEqual(heap_sorted, sorted(data))
91
92 def test_nsmallest(self):
93 data = [random.randrange(2000) for i in range(1000)]
94 self.assertEqual(nsmallest(data, 400), sorted(data)[:400])
95
96 def test_largest(self):
97 data = [random.randrange(2000) for i in range(1000)]
98 self.assertEqual(nlargest(data, 400), sorted(data, reverse=True)[:400])
Tim Petersaa7d2432002-08-03 02:11:26 +000099
Guido van Rossum0b191782002-08-02 18:29:53 +0000100def test_main():
Raymond Hettingerbce036b2004-06-10 05:07:18 +0000101 test_support.run_unittest(TestHeap)
Guido van Rossum0b191782002-08-02 18:29:53 +0000102
103if __name__ == "__main__":
104 test_main()
Raymond Hettingerbce036b2004-06-10 05:07:18 +0000105