blob: c05f37b0ba4e3cb7dc1961d15638cf04d04ab2b5 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00003
Thomas Wouters49fd7fa2006-04-21 10:40:58 +00004import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00005import os
6import decimal
7import tempfile
8import unittest
9import threading
10from contextlib import * # Tests __all__
Benjamin Petersonee8712c2008-05-20 21:35:26 +000011from test import support
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000012
13class ContextManagerTestCase(unittest.TestCase):
14
15 def test_contextmanager_plain(self):
16 state = []
17 @contextmanager
18 def woohoo():
19 state.append(1)
20 yield 42
21 state.append(999)
22 with woohoo() as x:
23 self.assertEqual(state, [1])
24 self.assertEqual(x, 42)
25 state.append(x)
26 self.assertEqual(state, [1, 42, 999])
27
28 def test_contextmanager_finally(self):
29 state = []
30 @contextmanager
31 def woohoo():
32 state.append(1)
33 try:
34 yield 42
35 finally:
36 state.append(999)
37 try:
38 with woohoo() as x:
39 self.assertEqual(state, [1])
40 self.assertEqual(x, 42)
41 state.append(x)
42 raise ZeroDivisionError()
43 except ZeroDivisionError:
44 pass
45 else:
46 self.fail("Expected ZeroDivisionError")
47 self.assertEqual(state, [1, 42, 999])
48
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000049 def test_contextmanager_no_reraise(self):
50 @contextmanager
51 def whee():
52 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000053 ctx = whee()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000054 ctx.__enter__()
55 # Calling __exit__ should not result in an exception
56 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
57
58 def test_contextmanager_trap_yield_after_throw(self):
59 @contextmanager
60 def whoo():
61 try:
62 yield
63 except:
64 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000065 ctx = whoo()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000066 ctx.__enter__()
67 self.assertRaises(
68 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
69 )
70
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000071 def test_contextmanager_except(self):
72 state = []
73 @contextmanager
74 def woohoo():
75 state.append(1)
76 try:
77 yield 42
Guido van Rossumb940e112007-01-10 16:19:56 +000078 except ZeroDivisionError as e:
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000079 state.append(e.args[0])
80 self.assertEqual(state, [1, 42, 999])
81 with woohoo() as x:
82 self.assertEqual(state, [1])
83 self.assertEqual(x, 42)
84 state.append(x)
85 raise ZeroDivisionError(999)
86 self.assertEqual(state, [1, 42, 999])
87
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000088 def test_contextmanager_attribs(self):
89 def attribs(**kw):
90 def decorate(func):
91 for k,v in kw.items():
92 setattr(func,k,v)
93 return func
94 return decorate
95 @contextmanager
96 @attribs(foo='bar')
97 def baz(spam):
98 """Whee!"""
99 self.assertEqual(baz.__name__,'baz')
100 self.assertEqual(baz.foo, 'bar')
101 self.assertEqual(baz.__doc__, "Whee!")
102
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000103class NestedTestCase(unittest.TestCase):
104
105 # XXX This needs more work
106
107 def test_nested(self):
108 @contextmanager
109 def a():
110 yield 1
111 @contextmanager
112 def b():
113 yield 2
114 @contextmanager
115 def c():
116 yield 3
117 with nested(a(), b(), c()) as (x, y, z):
118 self.assertEqual(x, 1)
119 self.assertEqual(y, 2)
120 self.assertEqual(z, 3)
121
122 def test_nested_cleanup(self):
123 state = []
124 @contextmanager
125 def a():
126 state.append(1)
127 try:
128 yield 2
129 finally:
130 state.append(3)
131 @contextmanager
132 def b():
133 state.append(4)
134 try:
135 yield 5
136 finally:
137 state.append(6)
138 try:
139 with nested(a(), b()) as (x, y):
140 state.append(x)
141 state.append(y)
142 1/0
143 except ZeroDivisionError:
144 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
145 else:
146 self.fail("Didn't raise ZeroDivisionError")
147
Thomas Wouters477c8d52006-05-27 19:21:47 +0000148 def test_nested_right_exception(self):
149 state = []
150 @contextmanager
151 def a():
152 yield 1
153 class b(object):
154 def __enter__(self):
155 return 2
156 def __exit__(self, *exc_info):
157 try:
158 raise Exception()
159 except:
160 pass
161 try:
162 with nested(a(), b()) as (x, y):
163 1/0
164 except ZeroDivisionError:
165 self.assertEqual((x, y), (1, 2))
166 except Exception:
167 self.fail("Reraised wrong exception")
168 else:
169 self.fail("Didn't raise ZeroDivisionError")
170
Guido van Rossuma9f06872006-03-01 17:10:01 +0000171 def test_nested_b_swallows(self):
172 @contextmanager
173 def a():
174 yield
175 @contextmanager
176 def b():
177 try:
178 yield
179 except:
180 # Swallow the exception
181 pass
182 try:
183 with nested(a(), b()):
184 1/0
185 except ZeroDivisionError:
186 self.fail("Didn't swallow ZeroDivisionError")
187
188 def test_nested_break(self):
189 @contextmanager
190 def a():
191 yield
192 state = 0
193 while True:
194 state += 1
195 with nested(a(), a()):
196 break
197 state += 10
198 self.assertEqual(state, 1)
199
200 def test_nested_continue(self):
201 @contextmanager
202 def a():
203 yield
204 state = 0
205 while state < 3:
206 state += 1
207 with nested(a(), a()):
208 continue
209 state += 10
210 self.assertEqual(state, 3)
211
212 def test_nested_return(self):
213 @contextmanager
214 def a():
215 try:
216 yield
217 except:
218 pass
219 def foo():
220 with nested(a(), a()):
221 return 1
222 return 10
223 self.assertEqual(foo(), 1)
224
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000225class ClosingTestCase(unittest.TestCase):
226
227 # XXX This needs more work
228
229 def test_closing(self):
230 state = []
231 class C:
232 def close(self):
233 state.append(1)
234 x = C()
235 self.assertEqual(state, [])
236 with closing(x) as y:
237 self.assertEqual(x, y)
238 self.assertEqual(state, [1])
239
240 def test_closing_error(self):
241 state = []
242 class C:
243 def close(self):
244 state.append(1)
245 x = C()
246 self.assertEqual(state, [])
247 try:
248 with closing(x) as y:
249 self.assertEqual(x, y)
250 1/0
251 except ZeroDivisionError:
252 self.assertEqual(state, [1])
253 else:
254 self.fail("Didn't raise ZeroDivisionError")
255
256class FileContextTestCase(unittest.TestCase):
257
258 def testWithOpen(self):
259 tfn = tempfile.mktemp()
260 try:
261 f = None
262 with open(tfn, "w") as f:
263 self.failIf(f.closed)
264 f.write("Booh\n")
265 self.failUnless(f.closed)
266 f = None
267 try:
268 with open(tfn, "r") as f:
269 self.failIf(f.closed)
270 self.assertEqual(f.read(), "Booh\n")
271 1/0
272 except ZeroDivisionError:
273 self.failUnless(f.closed)
274 else:
275 self.fail("Didn't raise ZeroDivisionError")
276 finally:
277 try:
278 os.remove(tfn)
279 except os.error:
280 pass
281
282class LockContextTestCase(unittest.TestCase):
283
284 def boilerPlate(self, lock, locked):
285 self.failIf(locked())
286 with lock:
287 self.failUnless(locked())
288 self.failIf(locked())
289 try:
290 with lock:
291 self.failUnless(locked())
292 1/0
293 except ZeroDivisionError:
294 self.failIf(locked())
295 else:
296 self.fail("Didn't raise ZeroDivisionError")
297
298 def testWithLock(self):
299 lock = threading.Lock()
300 self.boilerPlate(lock, lock.locked)
301
302 def testWithRLock(self):
303 lock = threading.RLock()
304 self.boilerPlate(lock, lock._is_owned)
305
306 def testWithCondition(self):
307 lock = threading.Condition()
308 def locked():
309 return lock._is_owned()
310 self.boilerPlate(lock, locked)
311
312 def testWithSemaphore(self):
313 lock = threading.Semaphore()
314 def locked():
315 if lock.acquire(False):
316 lock.release()
317 return False
318 else:
319 return True
320 self.boilerPlate(lock, locked)
321
322 def testWithBoundedSemaphore(self):
323 lock = threading.BoundedSemaphore()
324 def locked():
325 if lock.acquire(False):
326 lock.release()
327 return False
328 else:
329 return True
330 self.boilerPlate(lock, locked)
331
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000332# This is needed to make the test actually run under regrtest.py!
333def test_main():
Benjamin Petersonee8712c2008-05-20 21:35:26 +0000334 support.run_unittest(__name__)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000335
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000336if __name__ == "__main__":
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000337 test_main()