blob: 55845879a581c62b6aa43994375b0ffc03990805 [file] [log] [blame]
Guido van Rossum8b48cf92001-04-21 13:33:54 +00001# Test iterators.
2
3import unittest
4from test_support import run_unittest, TESTFN, unlink
5
6# Test result of triple loop (too big to inline)
7TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
8 (0, 1, 0), (0, 1, 1), (0, 1, 2),
9 (0, 2, 0), (0, 2, 1), (0, 2, 2),
10
11 (1, 0, 0), (1, 0, 1), (1, 0, 2),
12 (1, 1, 0), (1, 1, 1), (1, 1, 2),
13 (1, 2, 0), (1, 2, 1), (1, 2, 2),
14
15 (2, 0, 0), (2, 0, 1), (2, 0, 2),
16 (2, 1, 0), (2, 1, 1), (2, 1, 2),
17 (2, 2, 0), (2, 2, 1), (2, 2, 2)]
18
19# Helper classes
20
21class BasicIterClass:
22 def __init__(self, n):
23 self.n = n
24 self.i = 0
25 def next(self):
26 res = self.i
27 if res >= self.n:
28 raise StopIteration
29 self.i = res + 1
30 return res
31
32class IteratingSequenceClass:
33 def __init__(self, n):
34 self.n = n
35 def __iter__(self):
36 return BasicIterClass(self.n)
37
38class SequenceClass:
39 def __init__(self, n):
40 self.n = n
41 def __getitem__(self, i):
42 if 0 <= i < self.n:
43 return i
44 else:
45 raise IndexError
46
47# Main test suite
48
49class TestCase(unittest.TestCase):
50
51 # Helper to check that an iterator returns a given sequence
52 def check_iterator(self, it, seq):
53 res = []
54 while 1:
55 try:
56 val = it.next()
57 except StopIteration:
58 break
59 res.append(val)
60 self.assertEqual(res, seq)
61
62 # Helper to check that a for loop generates a given sequence
63 def check_for_loop(self, expr, seq):
64 res = []
65 for val in expr:
66 res.append(val)
67 self.assertEqual(res, seq)
68
69 # Test basic use of iter() function
70 def test_iter_basic(self):
71 self.check_iterator(iter(range(10)), range(10))
72
73 # Test that iter(iter(x)) is the same as iter(x)
74 def test_iter_idempotency(self):
75 seq = range(10)
76 it = iter(seq)
77 it2 = iter(it)
78 self.assert_(it is it2)
79
80 # Test that for loops over iterators work
81 def test_iter_for_loop(self):
82 self.check_for_loop(iter(range(10)), range(10))
83
84 # Test several independent iterators over the same list
85 def test_iter_independence(self):
86 seq = range(3)
87 res = []
88 for i in iter(seq):
89 for j in iter(seq):
90 for k in iter(seq):
91 res.append((i, j, k))
92 self.assertEqual(res, TRIPLETS)
93
94 # Test triple list comprehension using iterators
95 def test_nested_comprehensions_iter(self):
96 seq = range(3)
97 res = [(i, j, k)
98 for i in iter(seq) for j in iter(seq) for k in iter(seq)]
99 self.assertEqual(res, TRIPLETS)
100
101 # Test triple list comprehension without iterators
102 def test_nested_comprehensions_for(self):
103 seq = range(3)
104 res = [(i, j, k) for i in seq for j in seq for k in seq]
105 self.assertEqual(res, TRIPLETS)
106
107 # Test a class with __iter__ in a for loop
108 def test_iter_class_for(self):
109 self.check_for_loop(IteratingSequenceClass(10), range(10))
110
111 # Test a class with __iter__ with explicit iter()
112 def test_iter_class_iter(self):
113 self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
114
115 # Test for loop on a sequence class without __iter__
116 def test_seq_class_for(self):
117 self.check_for_loop(SequenceClass(10), range(10))
118
119 # Test iter() on a sequence class without __iter__
120 def test_seq_class_iter(self):
121 self.check_iterator(iter(SequenceClass(10)), range(10))
122
123 # Test two-argument iter() with callable instance
124 def test_iter_callable(self):
125 class C:
126 def __init__(self):
127 self.i = 0
128 def __call__(self):
129 i = self.i
130 self.i = i + 1
131 if i > 100:
132 raise IndexError # Emergency stop
133 return i
134 self.check_iterator(iter(C(), 10), range(10))
135
136 # Test two-argument iter() with function
137 def test_iter_function(self):
138 def spam(state=[0]):
139 i = state[0]
140 state[0] = i+1
141 return i
142 self.check_iterator(iter(spam, 10), range(10))
143
144 # Test two-argument iter() with function that raises StopIteration
145 def test_iter_function_stop(self):
146 def spam(state=[0]):
147 i = state[0]
148 if i == 10:
149 raise StopIteration
150 state[0] = i+1
151 return i
152 self.check_iterator(iter(spam, 20), range(10))
153
154 # Test exception propagation through function iterator
155 def test_exception_function(self):
156 def spam(state=[0]):
157 i = state[0]
158 state[0] = i+1
159 if i == 10:
160 raise RuntimeError
161 return i
162 res = []
163 try:
164 for x in iter(spam, 20):
165 res.append(x)
166 except RuntimeError:
167 self.assertEqual(res, range(10))
168 else:
169 self.fail("should have raised RuntimeError")
170
171 # Test exception propagation through sequence iterator
172 def test_exception_sequence(self):
173 class MySequenceClass(SequenceClass):
174 def __getitem__(self, i):
175 if i == 10:
176 raise RuntimeError
177 return SequenceClass.__getitem__(self, i)
178 res = []
179 try:
180 for x in MySequenceClass(20):
181 res.append(x)
182 except RuntimeError:
183 self.assertEqual(res, range(10))
184 else:
185 self.fail("should have raised RuntimeError")
186
187 # Test for StopIteration from __getitem__
188 def test_stop_sequence(self):
189 class MySequenceClass(SequenceClass):
190 def __getitem__(self, i):
191 if i == 10:
192 raise StopIteration
193 return SequenceClass.__getitem__(self, i)
194 self.check_for_loop(MySequenceClass(20), range(10))
195
196 # Test a big range
197 def test_iter_big_range(self):
198 self.check_for_loop(iter(range(10000)), range(10000))
199
200 # Test an empty list
201 def test_iter_empty(self):
202 self.check_for_loop(iter([]), [])
203
204 # Test a tuple
205 def test_iter_tuple(self):
206 self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
207
208 # Test an xrange
209 def test_iter_xrange(self):
210 self.check_for_loop(iter(xrange(10)), range(10))
211
212 # Test a string
213 def test_iter_string(self):
214 self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
215
216 # Test a Unicode string
217 def test_iter_unicode(self):
218 self.check_for_loop(iter(u"abcde"), [u"a", u"b", u"c", u"d", u"e"])
219
220 # Test a directory
221 def test_iter_dict(self):
222 dict = {}
223 for i in range(10):
224 dict[i] = None
225 self.check_for_loop(dict, dict.keys())
226
227 # Test a file
228 def test_iter_file(self):
229 f = open(TESTFN, "w")
230 try:
231 for i in range(5):
232 f.write("%d\n" % i)
233 finally:
234 f.close()
235 f = open(TESTFN, "r")
236 try:
237 self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
238 self.check_for_loop(f, [])
239 finally:
240 f.close()
241 try:
242 unlink(TESTFN)
243 except OSError:
244 pass
245
Tim Petersf553f892001-05-01 20:45:31 +0000246 # Test list()'s use of iterators.
247 def test_builtin_list(self):
248 self.assertEqual(list(SequenceClass(5)), range(5))
249 self.assertEqual(list(SequenceClass(0)), [])
250 self.assertEqual(list(()), [])
251 self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
252
253 d = {"one": 1, "two": 2, "three": 3}
254 self.assertEqual(list(d), d.keys())
255
256 self.assertRaises(TypeError, list, list)
257 self.assertRaises(TypeError, list, 42)
258
259 f = open(TESTFN, "w")
260 try:
261 for i in range(5):
262 f.write("%d\n" % i)
263 finally:
264 f.close()
265 f = open(TESTFN, "r")
266 try:
267 self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
268 f.seek(0, 0)
269 self.assertEqual(list(f.xreadlines()),
270 ["0\n", "1\n", "2\n", "3\n", "4\n"])
271 finally:
272 f.close()
273 try:
274 unlink(TESTFN)
275 except OSError:
276 pass
277
Tim Peters0e57abf2001-05-02 07:39:38 +0000278 # Test filter()'s use of iterators.
279 def test_builtin_filter(self):
280 self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
281 self.assertEqual(filter(None, SequenceClass(0)), [])
282 self.assertEqual(filter(None, ()), ())
283 self.assertEqual(filter(None, "abc"), "abc")
284
285 d = {"one": 1, "two": 2, "three": 3}
286 self.assertEqual(filter(None, d), d.keys())
287
288 self.assertRaises(TypeError, filter, None, list)
289 self.assertRaises(TypeError, filter, None, 42)
290
291 class Boolean:
292 def __init__(self, truth):
293 self.truth = truth
294 def __nonzero__(self):
295 return self.truth
296 True = Boolean(1)
297 False = Boolean(0)
298
299 class Seq:
300 def __init__(self, *args):
301 self.vals = args
302 def __iter__(self):
303 class SeqIter:
304 def __init__(self, vals):
305 self.vals = vals
306 self.i = 0
307 def __iter__(self):
308 return self
309 def next(self):
310 i = self.i
311 self.i = i + 1
312 if i < len(self.vals):
313 return self.vals[i]
314 else:
315 raise StopIteration
316 return SeqIter(self.vals)
317
318 seq = Seq(*([True, False] * 25))
319 self.assertEqual(filter(lambda x: not x, seq), [False]*25)
320 self.assertEqual(filter(lambda x: not x, iter(seq)), [False]*25)
321
Tim Petersc3074532001-05-03 07:00:32 +0000322 # Test max() and min()'s use of iterators.
323 def test_builtin_max_min(self):
324 self.assertEqual(max(SequenceClass(5)), 4)
325 self.assertEqual(min(SequenceClass(5)), 0)
326 self.assertEqual(max(8, -1), 8)
327 self.assertEqual(min(8, -1), -1)
328
329 d = {"one": 1, "two": 2, "three": 3}
330 self.assertEqual(max(d), "two")
331 self.assertEqual(min(d), "one")
332 self.assertEqual(max(d.itervalues()), 3)
333 self.assertEqual(min(iter(d.itervalues())), 1)
334
Tim Petersc3074532001-05-03 07:00:32 +0000335 f = open(TESTFN, "w")
336 try:
337 f.write("medium line\n")
338 f.write("xtra large line\n")
339 f.write("itty-bitty line\n")
340 finally:
341 f.close()
342 f = open(TESTFN, "r")
343 try:
344 self.assertEqual(min(f), "itty-bitty line\n")
345 f.seek(0, 0)
346 self.assertEqual(max(f), "xtra large line\n")
347 finally:
348 f.close()
349 try:
350 unlink(TESTFN)
351 except OSError:
352 pass
353
Tim Peters4e9afdc2001-05-03 23:54:49 +0000354 # Test map()'s use of iterators.
355 def test_builtin_map(self):
356 self.assertEqual(map(None, SequenceClass(5)), range(5))
357 self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
358
359 d = {"one": 1, "two": 2, "three": 3}
360 self.assertEqual(map(None, d), d.keys())
361 self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
362 dkeys = d.keys()
363 expected = [(i < len(d) and dkeys[i] or None,
364 i,
365 i < len(d) and dkeys[i] or None)
366 for i in range(5)]
367 self.assertEqual(map(None, d,
368 SequenceClass(5),
369 iter(d.iterkeys())),
Tim Peters8bc10b02001-05-03 23:58:47 +0000370 expected)
Tim Peters4e9afdc2001-05-03 23:54:49 +0000371
372 f = open(TESTFN, "w")
373 try:
374 for i in range(10):
375 f.write("xy" * i + "\n") # line i has len 2*i+1
376 finally:
377 f.close()
378 f = open(TESTFN, "r")
379 try:
380 self.assertEqual(map(len, f), range(1, 21, 2))
Tim Peters4e9afdc2001-05-03 23:54:49 +0000381 finally:
382 f.close()
383 try:
384 unlink(TESTFN)
385 except OSError:
386 pass
387
Tim Peters15d81ef2001-05-04 04:39:21 +0000388 # Test reduces()'s use of iterators.
389 def test_builtin_reduce(self):
390 from operator import add
391 self.assertEqual(reduce(add, SequenceClass(5)), 10)
392 self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
393 self.assertRaises(TypeError, reduce, add, SequenceClass(0))
394 self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
395 self.assertEqual(reduce(add, SequenceClass(1)), 0)
396 self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
397
398 d = {"one": 1, "two": 2, "three": 3}
399 self.assertEqual(reduce(add, d), "".join(d.keys()))
400
Guido van Rossum8b48cf92001-04-21 13:33:54 +0000401run_unittest(TestCase)