blob: ddc58a7cfc727a4aa2cadc74c512de4511ebdaa6 [file] [log] [blame]
Guido van Rossum8b48cf92001-04-21 13:33:54 +00001# Test iterators.
2
3import unittest
4from test_support import run_unittest, TESTFN, unlink
5
6# Test result of triple loop (too big to inline)
7TRIPLETS = [(0, 0, 0), (0, 0, 1), (0, 0, 2),
8 (0, 1, 0), (0, 1, 1), (0, 1, 2),
9 (0, 2, 0), (0, 2, 1), (0, 2, 2),
10
11 (1, 0, 0), (1, 0, 1), (1, 0, 2),
12 (1, 1, 0), (1, 1, 1), (1, 1, 2),
13 (1, 2, 0), (1, 2, 1), (1, 2, 2),
14
15 (2, 0, 0), (2, 0, 1), (2, 0, 2),
16 (2, 1, 0), (2, 1, 1), (2, 1, 2),
17 (2, 2, 0), (2, 2, 1), (2, 2, 2)]
18
19# Helper classes
20
21class BasicIterClass:
22 def __init__(self, n):
23 self.n = n
24 self.i = 0
25 def next(self):
26 res = self.i
27 if res >= self.n:
28 raise StopIteration
29 self.i = res + 1
30 return res
31
32class IteratingSequenceClass:
33 def __init__(self, n):
34 self.n = n
35 def __iter__(self):
36 return BasicIterClass(self.n)
37
38class SequenceClass:
39 def __init__(self, n):
40 self.n = n
41 def __getitem__(self, i):
42 if 0 <= i < self.n:
43 return i
44 else:
45 raise IndexError
46
47# Main test suite
48
49class TestCase(unittest.TestCase):
50
51 # Helper to check that an iterator returns a given sequence
52 def check_iterator(self, it, seq):
53 res = []
54 while 1:
55 try:
56 val = it.next()
57 except StopIteration:
58 break
59 res.append(val)
60 self.assertEqual(res, seq)
61
62 # Helper to check that a for loop generates a given sequence
63 def check_for_loop(self, expr, seq):
64 res = []
65 for val in expr:
66 res.append(val)
67 self.assertEqual(res, seq)
68
69 # Test basic use of iter() function
70 def test_iter_basic(self):
71 self.check_iterator(iter(range(10)), range(10))
72
73 # Test that iter(iter(x)) is the same as iter(x)
74 def test_iter_idempotency(self):
75 seq = range(10)
76 it = iter(seq)
77 it2 = iter(it)
78 self.assert_(it is it2)
79
80 # Test that for loops over iterators work
81 def test_iter_for_loop(self):
82 self.check_for_loop(iter(range(10)), range(10))
83
84 # Test several independent iterators over the same list
85 def test_iter_independence(self):
86 seq = range(3)
87 res = []
88 for i in iter(seq):
89 for j in iter(seq):
90 for k in iter(seq):
91 res.append((i, j, k))
92 self.assertEqual(res, TRIPLETS)
93
94 # Test triple list comprehension using iterators
95 def test_nested_comprehensions_iter(self):
96 seq = range(3)
97 res = [(i, j, k)
98 for i in iter(seq) for j in iter(seq) for k in iter(seq)]
99 self.assertEqual(res, TRIPLETS)
100
101 # Test triple list comprehension without iterators
102 def test_nested_comprehensions_for(self):
103 seq = range(3)
104 res = [(i, j, k) for i in seq for j in seq for k in seq]
105 self.assertEqual(res, TRIPLETS)
106
107 # Test a class with __iter__ in a for loop
108 def test_iter_class_for(self):
109 self.check_for_loop(IteratingSequenceClass(10), range(10))
110
111 # Test a class with __iter__ with explicit iter()
112 def test_iter_class_iter(self):
113 self.check_iterator(iter(IteratingSequenceClass(10)), range(10))
114
115 # Test for loop on a sequence class without __iter__
116 def test_seq_class_for(self):
117 self.check_for_loop(SequenceClass(10), range(10))
118
119 # Test iter() on a sequence class without __iter__
120 def test_seq_class_iter(self):
121 self.check_iterator(iter(SequenceClass(10)), range(10))
122
123 # Test two-argument iter() with callable instance
124 def test_iter_callable(self):
125 class C:
126 def __init__(self):
127 self.i = 0
128 def __call__(self):
129 i = self.i
130 self.i = i + 1
131 if i > 100:
132 raise IndexError # Emergency stop
133 return i
134 self.check_iterator(iter(C(), 10), range(10))
135
136 # Test two-argument iter() with function
137 def test_iter_function(self):
138 def spam(state=[0]):
139 i = state[0]
140 state[0] = i+1
141 return i
142 self.check_iterator(iter(spam, 10), range(10))
143
144 # Test two-argument iter() with function that raises StopIteration
145 def test_iter_function_stop(self):
146 def spam(state=[0]):
147 i = state[0]
148 if i == 10:
149 raise StopIteration
150 state[0] = i+1
151 return i
152 self.check_iterator(iter(spam, 20), range(10))
153
154 # Test exception propagation through function iterator
155 def test_exception_function(self):
156 def spam(state=[0]):
157 i = state[0]
158 state[0] = i+1
159 if i == 10:
160 raise RuntimeError
161 return i
162 res = []
163 try:
164 for x in iter(spam, 20):
165 res.append(x)
166 except RuntimeError:
167 self.assertEqual(res, range(10))
168 else:
169 self.fail("should have raised RuntimeError")
170
171 # Test exception propagation through sequence iterator
172 def test_exception_sequence(self):
173 class MySequenceClass(SequenceClass):
174 def __getitem__(self, i):
175 if i == 10:
176 raise RuntimeError
177 return SequenceClass.__getitem__(self, i)
178 res = []
179 try:
180 for x in MySequenceClass(20):
181 res.append(x)
182 except RuntimeError:
183 self.assertEqual(res, range(10))
184 else:
185 self.fail("should have raised RuntimeError")
186
187 # Test for StopIteration from __getitem__
188 def test_stop_sequence(self):
189 class MySequenceClass(SequenceClass):
190 def __getitem__(self, i):
191 if i == 10:
192 raise StopIteration
193 return SequenceClass.__getitem__(self, i)
194 self.check_for_loop(MySequenceClass(20), range(10))
195
196 # Test a big range
197 def test_iter_big_range(self):
198 self.check_for_loop(iter(range(10000)), range(10000))
199
200 # Test an empty list
201 def test_iter_empty(self):
202 self.check_for_loop(iter([]), [])
203
204 # Test a tuple
205 def test_iter_tuple(self):
206 self.check_for_loop(iter((0,1,2,3,4,5,6,7,8,9)), range(10))
207
208 # Test an xrange
209 def test_iter_xrange(self):
210 self.check_for_loop(iter(xrange(10)), range(10))
211
212 # Test a string
213 def test_iter_string(self):
214 self.check_for_loop(iter("abcde"), ["a", "b", "c", "d", "e"])
215
216 # Test a Unicode string
217 def test_iter_unicode(self):
218 self.check_for_loop(iter(u"abcde"), [u"a", u"b", u"c", u"d", u"e"])
219
220 # Test a directory
221 def test_iter_dict(self):
222 dict = {}
223 for i in range(10):
224 dict[i] = None
225 self.check_for_loop(dict, dict.keys())
226
227 # Test a file
228 def test_iter_file(self):
229 f = open(TESTFN, "w")
230 try:
231 for i in range(5):
232 f.write("%d\n" % i)
233 finally:
234 f.close()
235 f = open(TESTFN, "r")
236 try:
237 self.check_for_loop(f, ["0\n", "1\n", "2\n", "3\n", "4\n"])
238 self.check_for_loop(f, [])
239 finally:
240 f.close()
241 try:
242 unlink(TESTFN)
243 except OSError:
244 pass
245
Tim Petersf553f892001-05-01 20:45:31 +0000246 # Test list()'s use of iterators.
247 def test_builtin_list(self):
248 self.assertEqual(list(SequenceClass(5)), range(5))
249 self.assertEqual(list(SequenceClass(0)), [])
250 self.assertEqual(list(()), [])
251 self.assertEqual(list(range(10, -1, -1)), range(10, -1, -1))
252
253 d = {"one": 1, "two": 2, "three": 3}
254 self.assertEqual(list(d), d.keys())
255
256 self.assertRaises(TypeError, list, list)
257 self.assertRaises(TypeError, list, 42)
258
259 f = open(TESTFN, "w")
260 try:
261 for i in range(5):
262 f.write("%d\n" % i)
263 finally:
264 f.close()
265 f = open(TESTFN, "r")
266 try:
267 self.assertEqual(list(f), ["0\n", "1\n", "2\n", "3\n", "4\n"])
268 f.seek(0, 0)
269 self.assertEqual(list(f.xreadlines()),
270 ["0\n", "1\n", "2\n", "3\n", "4\n"])
271 finally:
272 f.close()
273 try:
274 unlink(TESTFN)
275 except OSError:
276 pass
277
Tim Peters6912d4d2001-05-05 03:56:37 +0000278 # Test tuples()'s use of iterators.
279 def test_builtin_tuple(self):
280 self.assertEqual(tuple(SequenceClass(5)), (0, 1, 2, 3, 4))
281 self.assertEqual(tuple(SequenceClass(0)), ())
282 self.assertEqual(tuple([]), ())
283 self.assertEqual(tuple(()), ())
284 self.assertEqual(tuple("abc"), ("a", "b", "c"))
285
286 d = {"one": 1, "two": 2, "three": 3}
287 self.assertEqual(tuple(d), tuple(d.keys()))
288
289 self.assertRaises(TypeError, tuple, list)
290 self.assertRaises(TypeError, tuple, 42)
291
292 f = open(TESTFN, "w")
293 try:
294 for i in range(5):
295 f.write("%d\n" % i)
296 finally:
297 f.close()
298 f = open(TESTFN, "r")
299 try:
300 self.assertEqual(tuple(f), ("0\n", "1\n", "2\n", "3\n", "4\n"))
301 f.seek(0, 0)
302 self.assertEqual(tuple(f.xreadlines()),
303 ("0\n", "1\n", "2\n", "3\n", "4\n"))
304 finally:
305 f.close()
306 try:
307 unlink(TESTFN)
308 except OSError:
309 pass
310
Tim Peters0e57abf2001-05-02 07:39:38 +0000311 # Test filter()'s use of iterators.
312 def test_builtin_filter(self):
313 self.assertEqual(filter(None, SequenceClass(5)), range(1, 5))
314 self.assertEqual(filter(None, SequenceClass(0)), [])
315 self.assertEqual(filter(None, ()), ())
316 self.assertEqual(filter(None, "abc"), "abc")
317
318 d = {"one": 1, "two": 2, "three": 3}
319 self.assertEqual(filter(None, d), d.keys())
320
321 self.assertRaises(TypeError, filter, None, list)
322 self.assertRaises(TypeError, filter, None, 42)
323
324 class Boolean:
325 def __init__(self, truth):
326 self.truth = truth
327 def __nonzero__(self):
328 return self.truth
329 True = Boolean(1)
330 False = Boolean(0)
331
332 class Seq:
333 def __init__(self, *args):
334 self.vals = args
335 def __iter__(self):
336 class SeqIter:
337 def __init__(self, vals):
338 self.vals = vals
339 self.i = 0
340 def __iter__(self):
341 return self
342 def next(self):
343 i = self.i
344 self.i = i + 1
345 if i < len(self.vals):
346 return self.vals[i]
347 else:
348 raise StopIteration
349 return SeqIter(self.vals)
350
351 seq = Seq(*([True, False] * 25))
352 self.assertEqual(filter(lambda x: not x, seq), [False]*25)
353 self.assertEqual(filter(lambda x: not x, iter(seq)), [False]*25)
354
Tim Petersc3074532001-05-03 07:00:32 +0000355 # Test max() and min()'s use of iterators.
356 def test_builtin_max_min(self):
357 self.assertEqual(max(SequenceClass(5)), 4)
358 self.assertEqual(min(SequenceClass(5)), 0)
359 self.assertEqual(max(8, -1), 8)
360 self.assertEqual(min(8, -1), -1)
361
362 d = {"one": 1, "two": 2, "three": 3}
363 self.assertEqual(max(d), "two")
364 self.assertEqual(min(d), "one")
365 self.assertEqual(max(d.itervalues()), 3)
366 self.assertEqual(min(iter(d.itervalues())), 1)
367
Tim Petersc3074532001-05-03 07:00:32 +0000368 f = open(TESTFN, "w")
369 try:
370 f.write("medium line\n")
371 f.write("xtra large line\n")
372 f.write("itty-bitty line\n")
373 finally:
374 f.close()
375 f = open(TESTFN, "r")
376 try:
377 self.assertEqual(min(f), "itty-bitty line\n")
378 f.seek(0, 0)
379 self.assertEqual(max(f), "xtra large line\n")
380 finally:
381 f.close()
382 try:
383 unlink(TESTFN)
384 except OSError:
385 pass
386
Tim Peters4e9afdc2001-05-03 23:54:49 +0000387 # Test map()'s use of iterators.
388 def test_builtin_map(self):
389 self.assertEqual(map(None, SequenceClass(5)), range(5))
390 self.assertEqual(map(lambda x: x+1, SequenceClass(5)), range(1, 6))
391
392 d = {"one": 1, "two": 2, "three": 3}
393 self.assertEqual(map(None, d), d.keys())
394 self.assertEqual(map(lambda k, d=d: (k, d[k]), d), d.items())
395 dkeys = d.keys()
396 expected = [(i < len(d) and dkeys[i] or None,
397 i,
398 i < len(d) and dkeys[i] or None)
399 for i in range(5)]
400 self.assertEqual(map(None, d,
401 SequenceClass(5),
402 iter(d.iterkeys())),
Tim Peters8bc10b02001-05-03 23:58:47 +0000403 expected)
Tim Peters4e9afdc2001-05-03 23:54:49 +0000404
405 f = open(TESTFN, "w")
406 try:
407 for i in range(10):
408 f.write("xy" * i + "\n") # line i has len 2*i+1
409 finally:
410 f.close()
411 f = open(TESTFN, "r")
412 try:
413 self.assertEqual(map(len, f), range(1, 21, 2))
Tim Peters4e9afdc2001-05-03 23:54:49 +0000414 finally:
415 f.close()
416 try:
417 unlink(TESTFN)
418 except OSError:
419 pass
420
Tim Peters8572b4f2001-05-06 01:05:02 +0000421 # Test zip()'s use of iterators.
422 def test_builtin_zip(self):
423 self.assertRaises(TypeError, zip)
424 self.assertRaises(TypeError, zip, None)
425 self.assertRaises(TypeError, zip, range(10), 42)
426 self.assertRaises(TypeError, zip, range(10), zip)
427
428 self.assertEqual(zip(IteratingSequenceClass(3)),
429 [(0,), (1,), (2,)])
430 self.assertEqual(zip(SequenceClass(3)),
431 [(0,), (1,), (2,)])
432
433 d = {"one": 1, "two": 2, "three": 3}
434 self.assertEqual(d.items(), zip(d, d.itervalues()))
435
436 # Generate all ints starting at constructor arg.
437 class IntsFrom:
438 def __init__(self, start):
439 self.i = start
440
441 def __iter__(self):
442 return self
443
444 def next(self):
445 i = self.i
446 self.i = i+1
447 return i
448
449 f = open(TESTFN, "w")
450 try:
451 f.write("a\n" "bbb\n" "cc\n")
452 finally:
453 f.close()
454 f = open(TESTFN, "r")
455 try:
456 self.assertEqual(zip(IntsFrom(0), f, IntsFrom(-100)),
457 [(0, "a\n", -100),
458 (1, "bbb\n", -99),
459 (2, "cc\n", -98)])
460 finally:
461 f.close()
462 try:
463 unlink(TESTFN)
464 except OSError:
465 pass
466
Tim Peters15d81ef2001-05-04 04:39:21 +0000467 # Test reduces()'s use of iterators.
468 def test_builtin_reduce(self):
469 from operator import add
470 self.assertEqual(reduce(add, SequenceClass(5)), 10)
471 self.assertEqual(reduce(add, SequenceClass(5), 42), 52)
472 self.assertRaises(TypeError, reduce, add, SequenceClass(0))
473 self.assertEqual(reduce(add, SequenceClass(0), 42), 42)
474 self.assertEqual(reduce(add, SequenceClass(1)), 0)
475 self.assertEqual(reduce(add, SequenceClass(1), 42), 42)
476
477 d = {"one": 1, "two": 2, "three": 3}
478 self.assertEqual(reduce(add, d), "".join(d.keys()))
479
Tim Peters2cfe3682001-05-05 05:36:48 +0000480 def test_unicode_join_endcase(self):
481
482 # This class inserts a Unicode object into its argument's natural
483 # iteration, in the 3rd position.
484 class OhPhooey:
485 def __init__(self, seq):
486 self.it = iter(seq)
487 self.i = 0
488
489 def __iter__(self):
490 return self
491
492 def next(self):
493 i = self.i
494 self.i = i+1
495 if i == 2:
496 return u"fooled you!"
497 return self.it.next()
498
499 f = open(TESTFN, "w")
500 try:
501 f.write("a\n" + "b\n" + "c\n")
502 finally:
503 f.close()
504
505 f = open(TESTFN, "r")
506 # Nasty: string.join(s) can't know whether unicode.join() is needed
507 # until it's seen all of s's elements. But in this case, f's
508 # iterator cannot be restarted. So what we're testing here is
509 # whether string.join() can manage to remember everything it's seen
510 # and pass that on to unicode.join().
511 try:
512 got = " - ".join(OhPhooey(f))
513 self.assertEqual(got, u"a\n - b\n - fooled you! - c\n")
514 finally:
515 f.close()
516 try:
517 unlink(TESTFN)
518 except OSError:
519 pass
520
Tim Petersde9725f2001-05-05 10:06:17 +0000521 # Test iterators with 'x in y' and 'x not in y'.
522 def test_in_and_not_in(self):
Tim Peterscb8d3682001-05-05 21:05:01 +0000523 for sc5 in IteratingSequenceClass(5), SequenceClass(5):
524 for i in range(5):
525 self.assert_(i in sc5)
526 for i in "abc", -1, 5, 42.42, (3, 4), [], {1: 1}, 3-12j, sc5:
527 self.assert_(i not in sc5)
Tim Petersde9725f2001-05-05 10:06:17 +0000528
529 self.assertRaises(TypeError, lambda: 3 in 12)
530 self.assertRaises(TypeError, lambda: 3 not in map)
531
532 d = {"one": 1, "two": 2, "three": 3, 1j: 2j}
533 for k in d:
534 self.assert_(k in d)
535 self.assert_(k not in d.itervalues())
536 for v in d.values():
537 self.assert_(v in d.itervalues())
538 self.assert_(v not in d)
539 for k, v in d.iteritems():
540 self.assert_((k, v) in d.iteritems())
541 self.assert_((v, k) not in d.iteritems())
Tim Petersde9725f2001-05-05 10:06:17 +0000542
543 f = open(TESTFN, "w")
544 try:
545 f.write("a\n" "b\n" "c\n")
546 finally:
547 f.close()
548 f = open(TESTFN, "r")
549 try:
550 for chunk in "abc":
551 f.seek(0, 0)
552 self.assert_(chunk not in f)
553 f.seek(0, 0)
554 self.assert_((chunk + "\n") in f)
555 finally:
556 f.close()
557 try:
558 unlink(TESTFN)
559 except OSError:
560 pass
561
Tim Peters75f8e352001-05-05 11:33:43 +0000562 # Test iterators with operator.countOf (PySequence_Count).
563 def test_countOf(self):
564 from operator import countOf
565 self.assertEqual(countOf([1,2,2,3,2,5], 2), 3)
566 self.assertEqual(countOf((1,2,2,3,2,5), 2), 3)
567 self.assertEqual(countOf("122325", "2"), 3)
568 self.assertEqual(countOf("122325", "6"), 0)
569
570 self.assertRaises(TypeError, countOf, 42, 1)
571 self.assertRaises(TypeError, countOf, countOf, countOf)
572
573 d = {"one": 3, "two": 3, "three": 3, 1j: 2j}
574 for k in d:
575 self.assertEqual(countOf(d, k), 1)
576 self.assertEqual(countOf(d.itervalues(), 3), 3)
577 self.assertEqual(countOf(d.itervalues(), 2j), 1)
578 self.assertEqual(countOf(d.itervalues(), 1j), 0)
579
580 f = open(TESTFN, "w")
581 try:
582 f.write("a\n" "b\n" "c\n" "b\n")
583 finally:
584 f.close()
585 f = open(TESTFN, "r")
586 try:
587 for letter, count in ("a", 1), ("b", 2), ("c", 1), ("d", 0):
588 f.seek(0, 0)
589 self.assertEqual(countOf(f, letter + "\n"), count)
590 finally:
591 f.close()
592 try:
593 unlink(TESTFN)
594 except OSError:
595 pass
596
Guido van Rossum8b48cf92001-04-21 13:33:54 +0000597run_unittest(TestCase)