blob: c8a45916332cd43d0c3e2c0694a2596f39bbc7cb [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +00003import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00004import os
5import decimal
6import tempfile
7import unittest
8import threading
9from contextlib import * # Tests __all__
Collin Winterc2898c52007-04-25 17:29:52 +000010from test import test_support
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000011
12class ContextManagerTestCase(unittest.TestCase):
13
Nick Coghlanafd5e632006-05-03 13:02:47 +000014 def test_contextmanager_plain(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000015 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000016 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000017 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
Nick Coghlanafd5e632006-05-03 13:02:47 +000027 def test_contextmanager_finally(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000028 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000029 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000030 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
Nick Coghlanafd5e632006-05-03 13:02:47 +000048 def test_contextmanager_no_reraise(self):
49 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000050 def whee():
51 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000052 ctx = whee()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000053 ctx.__enter__()
54 # Calling __exit__ should not result in an exception
55 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
56
Nick Coghlanafd5e632006-05-03 13:02:47 +000057 def test_contextmanager_trap_yield_after_throw(self):
58 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000059 def whoo():
60 try:
61 yield
62 except:
63 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000064 ctx = whoo()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000065 ctx.__enter__()
66 self.assertRaises(
67 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
68 )
69
Nick Coghlanafd5e632006-05-03 13:02:47 +000070 def test_contextmanager_except(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000071 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000072 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000073 def woohoo():
74 state.append(1)
75 try:
76 yield 42
77 except ZeroDivisionError, e:
78 state.append(e.args[0])
79 self.assertEqual(state, [1, 42, 999])
80 with woohoo() as x:
81 self.assertEqual(state, [1])
82 self.assertEqual(x, 42)
83 state.append(x)
84 raise ZeroDivisionError(999)
85 self.assertEqual(state, [1, 42, 999])
86
Nick Coghlanafd5e632006-05-03 13:02:47 +000087 def test_contextmanager_attribs(self):
Phillip J. Eby35fd1422006-03-28 00:07:24 +000088 def attribs(**kw):
89 def decorate(func):
90 for k,v in kw.items():
91 setattr(func,k,v)
92 return func
93 return decorate
Nick Coghlanafd5e632006-05-03 13:02:47 +000094 @contextmanager
Phillip J. Eby35fd1422006-03-28 00:07:24 +000095 @attribs(foo='bar')
96 def baz(spam):
97 """Whee!"""
98 self.assertEqual(baz.__name__,'baz')
99 self.assertEqual(baz.foo, 'bar')
100 self.assertEqual(baz.__doc__, "Whee!")
101
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000102class NestedTestCase(unittest.TestCase):
103
104 # XXX This needs more work
105
106 def test_nested(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000107 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000108 def a():
109 yield 1
Nick Coghlanafd5e632006-05-03 13:02:47 +0000110 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000111 def b():
112 yield 2
Nick Coghlanafd5e632006-05-03 13:02:47 +0000113 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000114 def c():
115 yield 3
116 with nested(a(), b(), c()) as (x, y, z):
117 self.assertEqual(x, 1)
118 self.assertEqual(y, 2)
119 self.assertEqual(z, 3)
120
121 def test_nested_cleanup(self):
122 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000123 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000124 def a():
125 state.append(1)
126 try:
127 yield 2
128 finally:
129 state.append(3)
Nick Coghlanafd5e632006-05-03 13:02:47 +0000130 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000131 def b():
132 state.append(4)
133 try:
134 yield 5
135 finally:
136 state.append(6)
137 try:
138 with nested(a(), b()) as (x, y):
139 state.append(x)
140 state.append(y)
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000141 1 // 0
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000142 except ZeroDivisionError:
143 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
144 else:
145 self.fail("Didn't raise ZeroDivisionError")
146
Nick Coghlanda2268f2006-04-24 04:37:15 +0000147 def test_nested_right_exception(self):
148 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000149 @contextmanager
Nick Coghlanda2268f2006-04-24 04:37:15 +0000150 def a():
151 yield 1
152 class b(object):
153 def __enter__(self):
154 return 2
155 def __exit__(self, *exc_info):
156 try:
157 raise Exception()
158 except:
159 pass
160 try:
161 with nested(a(), b()) as (x, y):
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000162 1 // 0
Nick Coghlanda2268f2006-04-24 04:37:15 +0000163 except ZeroDivisionError:
164 self.assertEqual((x, y), (1, 2))
165 except Exception:
166 self.fail("Reraised wrong exception")
167 else:
168 self.fail("Didn't raise ZeroDivisionError")
169
Guido van Rossuma9f06872006-03-01 17:10:01 +0000170 def test_nested_b_swallows(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000171 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000172 def a():
173 yield
Nick Coghlanafd5e632006-05-03 13:02:47 +0000174 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000175 def b():
176 try:
177 yield
178 except:
179 # Swallow the exception
180 pass
181 try:
182 with nested(a(), b()):
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000183 1 // 0
Guido van Rossuma9f06872006-03-01 17:10:01 +0000184 except ZeroDivisionError:
185 self.fail("Didn't swallow ZeroDivisionError")
186
187 def test_nested_break(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000188 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000189 def a():
190 yield
191 state = 0
192 while True:
193 state += 1
194 with nested(a(), a()):
195 break
196 state += 10
197 self.assertEqual(state, 1)
198
199 def test_nested_continue(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000200 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000201 def a():
202 yield
203 state = 0
204 while state < 3:
205 state += 1
206 with nested(a(), a()):
207 continue
208 state += 10
209 self.assertEqual(state, 3)
210
211 def test_nested_return(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000212 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000213 def a():
214 try:
215 yield
216 except:
217 pass
218 def foo():
219 with nested(a(), a()):
220 return 1
221 return 10
222 self.assertEqual(foo(), 1)
223
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000224class ClosingTestCase(unittest.TestCase):
225
226 # XXX This needs more work
227
228 def test_closing(self):
229 state = []
230 class C:
231 def close(self):
232 state.append(1)
233 x = C()
234 self.assertEqual(state, [])
235 with closing(x) as y:
236 self.assertEqual(x, y)
237 self.assertEqual(state, [1])
238
239 def test_closing_error(self):
240 state = []
241 class C:
242 def close(self):
243 state.append(1)
244 x = C()
245 self.assertEqual(state, [])
246 try:
247 with closing(x) as y:
248 self.assertEqual(x, y)
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000249 1 // 0
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000250 except ZeroDivisionError:
251 self.assertEqual(state, [1])
252 else:
253 self.fail("Didn't raise ZeroDivisionError")
254
255class FileContextTestCase(unittest.TestCase):
256
257 def testWithOpen(self):
258 tfn = tempfile.mktemp()
259 try:
260 f = None
261 with open(tfn, "w") as f:
262 self.failIf(f.closed)
263 f.write("Booh\n")
264 self.failUnless(f.closed)
265 f = None
266 try:
267 with open(tfn, "r") as f:
268 self.failIf(f.closed)
269 self.assertEqual(f.read(), "Booh\n")
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000270 1 // 0
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000271 except ZeroDivisionError:
272 self.failUnless(f.closed)
273 else:
274 self.fail("Didn't raise ZeroDivisionError")
275 finally:
276 try:
277 os.remove(tfn)
278 except os.error:
279 pass
280
281class LockContextTestCase(unittest.TestCase):
282
283 def boilerPlate(self, lock, locked):
284 self.failIf(locked())
285 with lock:
286 self.failUnless(locked())
287 self.failIf(locked())
288 try:
289 with lock:
290 self.failUnless(locked())
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000291 1 // 0
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000292 except ZeroDivisionError:
293 self.failIf(locked())
294 else:
295 self.fail("Didn't raise ZeroDivisionError")
296
297 def testWithLock(self):
298 lock = threading.Lock()
299 self.boilerPlate(lock, lock.locked)
300
301 def testWithRLock(self):
302 lock = threading.RLock()
303 self.boilerPlate(lock, lock._is_owned)
304
305 def testWithCondition(self):
306 lock = threading.Condition()
307 def locked():
308 return lock._is_owned()
309 self.boilerPlate(lock, locked)
310
311 def testWithSemaphore(self):
312 lock = threading.Semaphore()
313 def locked():
314 if lock.acquire(False):
315 lock.release()
316 return False
317 else:
318 return True
319 self.boilerPlate(lock, locked)
320
321 def testWithBoundedSemaphore(self):
322 lock = threading.BoundedSemaphore()
323 def locked():
324 if lock.acquire(False):
325 lock.release()
326 return False
327 else:
328 return True
329 self.boilerPlate(lock, locked)
330
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000331# This is needed to make the test actually run under regrtest.py!
332def test_main():
Collin Winterc2898c52007-04-25 17:29:52 +0000333 test_support.run_unittest(__name__)
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000334
Ezio Melotti1d55ec32010-08-02 23:34:49 +0000335
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000336if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000337 test_main()