blob: b0d442e0f901f5475b8ebe179b0bc22908dfde18 [file] [log] [blame]
Guido van Rossum7dab2422002-04-26 19:40:56 +00001from __future__ import generators
2import unittest
3
4import test_support
5
6seq, res = 'abc', [(0,'a'), (1,'b'), (2,'c')]
7
8class G:
9 'Sequence using __getitem__'
10 def __init__(self, seqn):
11 self.seqn = seqn
12 def __getitem__(self, i):
13 return self.seqn[i]
14
15class I:
16 'Sequence using iterator protocol'
17 def __init__(self, seqn):
18 self.seqn = seqn
19 self.i = 0
20 def __iter__(self):
21 return self
22 def next(self):
23 if self.i >= len(self.seqn): raise StopIteration
24 v = self.seqn[self.i]
25 self.i += 1
26 return v
27
28class Ig:
29 'Sequence using iterator protocol defined with a generator'
30 def __init__(self, seqn):
31 self.seqn = seqn
32 self.i = 0
33 def __iter__(self):
34 for val in self.seqn:
35 yield val
36
37class X:
38 'Missing __getitem__ and __iter__'
39 def __init__(self, seqn):
40 self.seqn = seqn
41 self.i = 0
42 def next(self):
43 if self.i >= len(self.seqn): raise StopIteration
44 v = self.seqn[self.i]
45 self.i += 1
46 return v
47
48class E:
49 'Test propagation of exceptions'
50 def __init__(self, seqn):
51 self.seqn = seqn
52 self.i = 0
53 def __iter__(self):
54 return self
55 def next(self):
56 3/0
57
58class N:
59 'Iterator missing next()'
60 def __init__(self, seqn):
61 self.seqn = seqn
62 self.i = 0
63 def __iter__(self):
64 return self
65
66class EnumerateTestCase(unittest.TestCase):
67
68 enum = enumerate
69
70 def test_basicfunction(self):
71 self.assertEqual(type(self.enum(seq)), self.enum)
72 e = self.enum(seq)
73 self.assertEqual(iter(e), e)
74 self.assertEqual(list(self.enum(seq)), res)
75 self.enum.__doc__
76
77 def test_getitemseqn(self):
78 self.assertEqual(list(self.enum(G(seq))), res)
79 e = self.enum(G(''))
80 self.assertRaises(StopIteration, e.next)
81
82 def test_iteratorseqn(self):
83 self.assertEqual(list(self.enum(I(seq))), res)
84 e = self.enum(I(''))
85 self.assertRaises(StopIteration, e.next)
86
87 def test_iteratorgenerator(self):
88 self.assertEqual(list(self.enum(Ig(seq))), res)
89 e = self.enum(Ig(''))
90 self.assertRaises(StopIteration, e.next)
91
92 def test_noniterable(self):
93 self.assertRaises(TypeError, self.enum, X(seq))
94
95 def test_illformediterable(self):
96 self.assertRaises(TypeError, list, self.enum(N(seq)))
97
98 def test_exception_propagation(self):
99 self.assertRaises(ZeroDivisionError, list, self.enum(E(seq)))
100
101class MyEnum(enumerate):
102 pass
103
104class SubclassTestCase(EnumerateTestCase):
105
106 enum = MyEnum
107
108def suite():
109 suite = unittest.TestSuite()
110 suite.addTest(unittest.makeSuite(EnumerateTestCase))
111 suite.addTest(unittest.makeSuite(SubclassTestCase))
112 return suite
113
114def test_main():
115 test_support.run_suite(suite())
116
117if __name__ == "__main__":
118 test_main()