blob: 22a7c4460d4b102c66045eef1ee8649fa6902ff9 [file] [log] [blame]
Guido van Rossum8b48cf92001-04-21 13:33:54 +00001# Test iterators.
2
3import unittest
4from test_support import run_unittest, TESTFN, unlink
5
6# Test result of triple loop (too big to inline)
7TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
8 (0, 1, 0), (0, 1, 1), (0, 1, 2),
9 (0, 2, 0), (0, 2, 1), (0, 2, 2),
10
11 (1, 0, 0), (1, 0, 1), (1, 0, 2),
12 (1, 1, 0), (1, 1, 1), (1, 1, 2),
13 (1, 2, 0), (1, 2, 1), (1, 2, 2),
14
15 (2, 0, 0), (2, 0, 1), (2, 0, 2),
16 (2, 1, 0), (2, 1, 1), (2, 1, 2),
17 (2, 2, 0), (2, 2, 1), (2, 2, 2)]
18
19# Helper classes
20
21class BasicIterClass:
22 def __init__(self, n):
23 self.n = n
24 self.i = 0
25 def next(self):
26 res = self.i
27 if res >= self.n:
28 raise StopIteration
29 self.i = res + 1
30 return res
31
32class IteratingSequenceClass:
33 def __init__(self, n):
34 self.n = n
35 def __iter__(self):
36 return BasicIterClass(self.n)
37
38class SequenceClass:
39 def __init__(self, n):
40 self.n = n
41 def __getitem__(self, i):
42 if 0 <= i < self.n:
43 return i
44 else:
45 raise IndexError
46
47# Main test suite
48
49class TestCase(unittest.TestCase):
50
51 # Helper to check that an iterator returns a given sequence
52 def check_iterator(self, it, seq):
53 res = []
54 while 1:
55 try:
56 val = it.next()
57 except StopIteration:
58 break
59 res.append(val)
60 self.assertEqual(res, seq)
61
62 # Helper to check that a for loop generates a given sequence
63 def check_for_loop(self, expr, seq):
64 res = []
65 for val in expr:
66 res.append(val)
67 self.assertEqual(res, seq)
68
69 # Test basic use of iter() function
70 def test_iter_basic(self):
71 self.check_iterator(iter(range(10)), range(10))
72
73 # Test that iter(iter(x)) is the same as iter(x)
74 def test_iter_idempotency(self):
75 seq = range(10)
76 it = iter(seq)
77 it2 = iter(it)
78 self.assert_(it is it2)
79
80 # Test that for loops over iterators work
81 def test_iter_for_loop(self):
82 self.check_for_loop(iter(range(10)), range(10))
83
84 # Test several independent iterators over the same list
85 def test_iter_independence(self):
86 seq = range(3)
87 res = []
88 for i in iter(seq):
89 for j in iter(seq):
90 for k in iter(seq):
91 res.append((i, j, k))
92 self.assertEqual(res, TRIPLETS)
93
94 # Test triple list comprehension using iterators
95 def test_nested_comprehensions_iter(self):
96 seq = range(3)
97 res = [(i, j, k)
98 for i in iter(seq) for j in iter(seq) for k in iter(seq)]
99 self.assertEqual(res, TRIPLETS)
100
101 # Test triple list comprehension without iterators
102 def test_nested_comprehensions_for(self):
103 seq = range(3)
104 res = [(i, j, k) for i in seq for j in seq for k in seq]
105 self.assertEqual(res, TRIPLETS)
106
107 # Test a class with __iter__ in a for loop
108 def test_iter_class_for(self):
109 self.check_for_loop(IteratingSequenceClass(10), range(10))
110
111 # Test a class with __iter__ with explicit iter()
112 def test_iter_class_iter(self):
113 self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
114
115 # Test for loop on a sequence class without __iter__
116 def test_seq_class_for(self):
117 self.check_for_loop(SequenceClass(10), range(10))
118
119 # Test iter() on a sequence class without __iter__
120 def test_seq_class_iter(self):
121 self.check_iterator(iter(SequenceClass(10)), range(10))
122
123 # Test two-argument iter() with callable instance
124 def test_iter_callable(self):
125 class C:
126 def __init__(self):
127 self.i = 0
128 def __call__(self):
129 i = self.i
130 self.i = i + 1
131 if i > 100:
132 raise IndexError # Emergency stop
133 return i
134 self.check_iterator(iter(C(), 10), range(10))
135
136 # Test two-argument iter() with function
137 def test_iter_function(self):
138 def spam(state=[0]):
139 i = state[0]
140 state[0] = i+1
141 return i
142 self.check_iterator(iter(spam, 10), range(10))
143
144 # Test two-argument iter() with function that raises StopIteration
145 def test_iter_function_stop(self):
146 def spam(state=[0]):
147 i = state[0]
148 if i == 10:
149 raise StopIteration
150 state[0] = i+1
151 return i
152 self.check_iterator(iter(spam, 20), range(10))
153
154 # Test exception propagation through function iterator
155 def test_exception_function(self):
156 def spam(state=[0]):
157 i = state[0]
158 state[0] = i+1
159 if i == 10:
160 raise RuntimeError
161 return i
162 res = []
163 try:
164 for x in iter(spam, 20):
165 res.append(x)
166 except RuntimeError:
167 self.assertEqual(res, range(10))
168 else:
169 self.fail("should have raised RuntimeError")
170
171 # Test exception propagation through sequence iterator
172 def test_exception_sequence(self):
173 class MySequenceClass(SequenceClass):
174 def __getitem__(self, i):
175 if i == 10:
176 raise RuntimeError
177 return SequenceClass.__getitem__(self, i)
178 res = []
179 try:
180 for x in MySequenceClass(20):
181 res.append(x)
182 except RuntimeError:
183 self.assertEqual(res, range(10))
184 else:
185 self.fail("should have raised RuntimeError")
186
187 # Test for StopIteration from __getitem__
188 def test_stop_sequence(self):
189 class MySequenceClass(SequenceClass):
190 def __getitem__(self, i):
191 if i == 10:
192 raise StopIteration
193 return SequenceClass.__getitem__(self, i)
194 self.check_for_loop(MySequenceClass(20), range(10))
195
196 # Test a big range
197 def test_iter_big_range(self):
198 self.check_for_loop(iter(range(10000)), range(10000))
199
200 # Test an empty list
201 def test_iter_empty(self):
202 self.check_for_loop(iter([]), [])
203
204 # Test a tuple
205 def test_iter_tuple(self):
206 self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
207
208 # Test an xrange
209 def test_iter_xrange(self):
210 self.check_for_loop(iter(xrange(10)), range(10))
211
212 # Test a string
213 def test_iter_string(self):
214 self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
215
216 # Test a Unicode string
217 def test_iter_unicode(self):
218 self.check_for_loop(iter(u"abcde"), [u"a", u"b", u"c", u"d", u"e"])
219
220 # Test a directory
221 def test_iter_dict(self):
222 dict = {}
223 for i in range(10):
224 dict[i] = None
225 self.check_for_loop(dict, dict.keys())
226
227 # Test a file
228 def test_iter_file(self):
229 f = open(TESTFN, "w")
230 try:
231 for i in range(5):
232 f.write("%d\n" % i)
233 finally:
234 f.close()
235 f = open(TESTFN, "r")
236 try:
237 self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
238 self.check_for_loop(f, [])
239 finally:
240 f.close()
241 try:
242 unlink(TESTFN)
243 except OSError:
244 pass
245
Tim Petersf553f892001-05-01 20:45:31 +0000246 # Test list()'s use of iterators.
247 def test_builtin_list(self):
248 self.assertEqual(list(SequenceClass(5)), range(5))
249 self.assertEqual(list(SequenceClass(0)), [])
250 self.assertEqual(list(()), [])
251 self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
252
253 d = {"one": 1, "two": 2, "three": 3}
254 self.assertEqual(list(d), d.keys())
255
256 self.assertRaises(TypeError, list, list)
257 self.assertRaises(TypeError, list, 42)
258
259 f = open(TESTFN, "w")
260 try:
261 for i in range(5):
262 f.write("%d\n" % i)
263 finally:
264 f.close()
265 f = open(TESTFN, "r")
266 try:
267 self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
268 f.seek(0, 0)
269 self.assertEqual(list(f.xreadlines()),
270 ["0\n", "1\n", "2\n", "3\n", "4\n"])
271 finally:
272 f.close()
273 try:
274 unlink(TESTFN)
275 except OSError:
276 pass
277
Tim Peters6912d4d2001-05-05 03:56:37 +0000278 # Test tuples()'s use of iterators.
279 def test_builtin_tuple(self):
280 self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
281 self.assertEqual(tuple(SequenceClass(0)), ())
282 self.assertEqual(tuple([]), ())
283 self.assertEqual(tuple(()), ())
284 self.assertEqual(tuple("abc"), ("a", "b", "c"))
285
286 d = {"one": 1, "two": 2, "three": 3}
287 self.assertEqual(tuple(d), tuple(d.keys()))
288
289 self.assertRaises(TypeError, tuple, list)
290 self.assertRaises(TypeError, tuple, 42)
291
292 f = open(TESTFN, "w")
293 try:
294 for i in range(5):
295 f.write("%d\n" % i)
296 finally:
297 f.close()
298 f = open(TESTFN, "r")
299 try:
300 self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
301 f.seek(0, 0)
302 self.assertEqual(tuple(f.xreadlines()),
303 ("0\n", "1\n", "2\n", "3\n", "4\n"))
304 finally:
305 f.close()
306 try:
307 unlink(TESTFN)
308 except OSError:
309 pass
310
Tim Peters0e57abf2001-05-02 07:39:38 +0000311 # Test filter()'s use of iterators.
312 def test_builtin_filter(self):
313 self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
314 self.assertEqual(filter(None, SequenceClass(0)), [])
315 self.assertEqual(filter(None, ()), ())
316 self.assertEqual(filter(None, "abc"), "abc")
317
318 d = {"one": 1, "two": 2, "three": 3}
319 self.assertEqual(filter(None, d), d.keys())
320
321 self.assertRaises(TypeError, filter, None, list)
322 self.assertRaises(TypeError, filter, None, 42)
323
324 class Boolean:
325 def __init__(self, truth):
326 self.truth = truth
327 def __nonzero__(self):
328 return self.truth
329 True = Boolean(1)
330 False = Boolean(0)
331
332 class Seq:
333 def __init__(self, *args):
334 self.vals = args
335 def __iter__(self):
336 class SeqIter:
337 def __init__(self, vals):
338 self.vals = vals
339 self.i = 0
340 def __iter__(self):
341 return self
342 def next(self):
343 i = self.i
344 self.i = i + 1
345 if i < len(self.vals):
346 return self.vals[i]
347 else:
348 raise StopIteration
349 return SeqIter(self.vals)
350
351 seq = Seq(*([True, False] * 25))
352 self.assertEqual(filter(lambda x: not x, seq), [False]*25)
353 self.assertEqual(filter(lambda x: not x, iter(seq)), [False]*25)
354
Tim Petersc3074532001-05-03 07:00:32 +0000355 # Test max() and min()'s use of iterators.
356 def test_builtin_max_min(self):
357 self.assertEqual(max(SequenceClass(5)), 4)
358 self.assertEqual(min(SequenceClass(5)), 0)
359 self.assertEqual(max(8, -1), 8)
360 self.assertEqual(min(8, -1), -1)
361
362 d = {"one": 1, "two": 2, "three": 3}
363 self.assertEqual(max(d), "two")
364 self.assertEqual(min(d), "one")
365 self.assertEqual(max(d.itervalues()), 3)
366 self.assertEqual(min(iter(d.itervalues())), 1)
367
Tim Petersc3074532001-05-03 07:00:32 +0000368 f = open(TESTFN, "w")
369 try:
370 f.write("medium line\n")
371 f.write("xtra large line\n")
372 f.write("itty-bitty line\n")
373 finally:
374 f.close()
375 f = open(TESTFN, "r")
376 try:
377 self.assertEqual(min(f), "itty-bitty line\n")
378 f.seek(0, 0)
379 self.assertEqual(max(f), "xtra large line\n")
380 finally:
381 f.close()
382 try:
383 unlink(TESTFN)
384 except OSError:
385 pass
386
Tim Peters4e9afdc2001-05-03 23:54:49 +0000387 # Test map()'s use of iterators.
388 def test_builtin_map(self):
389 self.assertEqual(map(None, SequenceClass(5)), range(5))
390 self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
391
392 d = {"one": 1, "two": 2, "three": 3}
393 self.assertEqual(map(None, d), d.keys())
394 self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
395 dkeys = d.keys()
396 expected = [(i < len(d) and dkeys[i] or None,
397 i,
398 i < len(d) and dkeys[i] or None)
399 for i in range(5)]
400 self.assertEqual(map(None, d,
401 SequenceClass(5),
402 iter(d.iterkeys())),
Tim Peters8bc10b02001-05-03 23:58:47 +0000403 expected)
Tim Peters4e9afdc2001-05-03 23:54:49 +0000404
405 f = open(TESTFN, "w")
406 try:
407 for i in range(10):
408 f.write("xy" * i + "\n") # line i has len 2*i+1
409 finally:
410 f.close()
411 f = open(TESTFN, "r")
412 try:
413 self.assertEqual(map(len, f), range(1, 21, 2))
Tim Peters4e9afdc2001-05-03 23:54:49 +0000414 finally:
415 f.close()
416 try:
417 unlink(TESTFN)
418 except OSError:
419 pass
420
Tim Peters15d81ef2001-05-04 04:39:21 +0000421 # Test reduces()'s use of iterators.
422 def test_builtin_reduce(self):
423 from operator import add
424 self.assertEqual(reduce(add, SequenceClass(5)), 10)
425 self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
426 self.assertRaises(TypeError, reduce, add, SequenceClass(0))
427 self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
428 self.assertEqual(reduce(add, SequenceClass(1)), 0)
429 self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
430
431 d = {"one": 1, "two": 2, "three": 3}
432 self.assertEqual(reduce(add, d), "".join(d.keys()))
433
Tim Peters2cfe3682001-05-05 05:36:48 +0000434 def test_unicode_join_endcase(self):
435
436 # This class inserts a Unicode object into its argument's natural
437 # iteration, in the 3rd position.
438 class OhPhooey:
439 def __init__(self, seq):
440 self.it = iter(seq)
441 self.i = 0
442
443 def __iter__(self):
444 return self
445
446 def next(self):
447 i = self.i
448 self.i = i+1
449 if i == 2:
450 return u"fooled you!"
451 return self.it.next()
452
453 f = open(TESTFN, "w")
454 try:
455 f.write("a\n" + "b\n" + "c\n")
456 finally:
457 f.close()
458
459 f = open(TESTFN, "r")
460 # Nasty: string.join(s) can't know whether unicode.join() is needed
461 # until it's seen all of s's elements. But in this case, f's
462 # iterator cannot be restarted. So what we're testing here is
463 # whether string.join() can manage to remember everything it's seen
464 # and pass that on to unicode.join().
465 try:
466 got = " - ".join(OhPhooey(f))
467 self.assertEqual(got, u"a\n - b\n - fooled you! - c\n")
468 finally:
469 f.close()
470 try:
471 unlink(TESTFN)
472 except OSError:
473 pass
474
Tim Petersde9725f2001-05-05 10:06:17 +0000475 # Test iterators with 'x in y' and 'x not in y'.
476 def test_in_and_not_in(self):
Tim Peterscb8d3682001-05-05 21:05:01 +0000477 for sc5 in IteratingSequenceClass(5), SequenceClass(5):
478 for i in range(5):
479 self.assert_(i in sc5)
480 for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
481 self.assert_(i not in sc5)
482 del sc5
Tim Petersde9725f2001-05-05 10:06:17 +0000483
484 self.assertRaises(TypeError, lambda: 3 in 12)
485 self.assertRaises(TypeError, lambda: 3 not in map)
486
487 d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
488 for k in d:
489 self.assert_(k in d)
490 self.assert_(k not in d.itervalues())
491 for v in d.values():
492 self.assert_(v in d.itervalues())
493 self.assert_(v not in d)
494 for k, v in d.iteritems():
495 self.assert_((k, v) in d.iteritems())
496 self.assert_((v, k) not in d.iteritems())
497 del d
498
499 f = open(TESTFN, "w")
500 try:
501 f.write("a\n" "b\n" "c\n")
502 finally:
503 f.close()
504 f = open(TESTFN, "r")
505 try:
506 for chunk in "abc":
507 f.seek(0, 0)
508 self.assert_(chunk not in f)
509 f.seek(0, 0)
510 self.assert_((chunk + "\n") in f)
511 finally:
512 f.close()
513 try:
514 unlink(TESTFN)
515 except OSError:
516 pass
517
Tim Peters75f8e352001-05-05 11:33:43 +0000518 # Test iterators with operator.countOf (PySequence_Count).
519 def test_countOf(self):
520 from operator import countOf
521 self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
522 self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
523 self.assertEqual(countOf("122325", "2"), 3)
524 self.assertEqual(countOf("122325", "6"), 0)
525
526 self.assertRaises(TypeError, countOf, 42, 1)
527 self.assertRaises(TypeError, countOf, countOf, countOf)
528
529 d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
530 for k in d:
531 self.assertEqual(countOf(d, k), 1)
532 self.assertEqual(countOf(d.itervalues(), 3), 3)
533 self.assertEqual(countOf(d.itervalues(), 2j), 1)
534 self.assertEqual(countOf(d.itervalues(), 1j), 0)
535
536 f = open(TESTFN, "w")
537 try:
538 f.write("a\n" "b\n" "c\n" "b\n")
539 finally:
540 f.close()
541 f = open(TESTFN, "r")
542 try:
543 for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
544 f.seek(0, 0)
545 self.assertEqual(countOf(f, letter + "\n"), count)
546 finally:
547 f.close()
548 try:
549 unlink(TESTFN)
550 except OSError:
551 pass
552
Guido van Rossum8b48cf92001-04-21 13:33:54 +0000553run_unittest(TestCase)