blob: 6350717be71f7ea47b6b6284013158cd8143f926 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00003
4import os
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00005import tempfile
6import unittest
7import threading
8from contextlib import * # Tests __all__
Collin Winterc2898c52007-04-25 17:29:52 +00009from test import test_support
Raymond Hettinger822b87f2009-05-29 01:46:48 +000010import warnings
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000011
12class ContextManagerTestCase(unittest.TestCase):
13
Nick Coghlanafd5e632006-05-03 13:02:47 +000014 def test_contextmanager_plain(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000015 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000016 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000017 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
Nick Coghlanafd5e632006-05-03 13:02:47 +000027 def test_contextmanager_finally(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000028 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000029 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000030 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
Nick Coghlanafd5e632006-05-03 13:02:47 +000048 def test_contextmanager_no_reraise(self):
49 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000050 def whee():
51 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000052 ctx = whee()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000053 ctx.__enter__()
54 # Calling __exit__ should not result in an exception
Benjamin Peterson5c8da862009-06-30 22:57:08 +000055 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Phillip J. Eby6edd2582006-03-25 00:28:24 +000056
Nick Coghlanafd5e632006-05-03 13:02:47 +000057 def test_contextmanager_trap_yield_after_throw(self):
58 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000059 def whoo():
60 try:
61 yield
62 except:
63 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000064 ctx = whoo()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000065 ctx.__enter__()
66 self.assertRaises(
67 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
68 )
69
Nick Coghlanafd5e632006-05-03 13:02:47 +000070 def test_contextmanager_except(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000071 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000072 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000073 def woohoo():
74 state.append(1)
75 try:
76 yield 42
77 except ZeroDivisionError, e:
78 state.append(e.args[0])
79 self.assertEqual(state, [1, 42, 999])
80 with woohoo() as x:
81 self.assertEqual(state, [1])
82 self.assertEqual(x, 42)
83 state.append(x)
84 raise ZeroDivisionError(999)
85 self.assertEqual(state, [1, 42, 999])
86
Nick Coghlanafd5e632006-05-03 13:02:47 +000087 def test_contextmanager_attribs(self):
Phillip J. Eby35fd1422006-03-28 00:07:24 +000088 def attribs(**kw):
89 def decorate(func):
90 for k,v in kw.items():
91 setattr(func,k,v)
92 return func
93 return decorate
Nick Coghlanafd5e632006-05-03 13:02:47 +000094 @contextmanager
Phillip J. Eby35fd1422006-03-28 00:07:24 +000095 @attribs(foo='bar')
96 def baz(spam):
97 """Whee!"""
98 self.assertEqual(baz.__name__,'baz')
99 self.assertEqual(baz.foo, 'bar')
100 self.assertEqual(baz.__doc__, "Whee!")
101
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000102class NestedTestCase(unittest.TestCase):
103
104 # XXX This needs more work
105
106 def test_nested(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000107 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000108 def a():
109 yield 1
Nick Coghlanafd5e632006-05-03 13:02:47 +0000110 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000111 def b():
112 yield 2
Nick Coghlanafd5e632006-05-03 13:02:47 +0000113 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000114 def c():
115 yield 3
116 with nested(a(), b(), c()) as (x, y, z):
117 self.assertEqual(x, 1)
118 self.assertEqual(y, 2)
119 self.assertEqual(z, 3)
120
121 def test_nested_cleanup(self):
122 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000123 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000124 def a():
125 state.append(1)
126 try:
127 yield 2
128 finally:
129 state.append(3)
Nick Coghlanafd5e632006-05-03 13:02:47 +0000130 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000131 def b():
132 state.append(4)
133 try:
134 yield 5
135 finally:
136 state.append(6)
137 try:
138 with nested(a(), b()) as (x, y):
139 state.append(x)
140 state.append(y)
141 1/0
142 except ZeroDivisionError:
143 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
144 else:
145 self.fail("Didn't raise ZeroDivisionError")
146
Nick Coghlanda2268f2006-04-24 04:37:15 +0000147 def test_nested_right_exception(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000148 @contextmanager
Nick Coghlanda2268f2006-04-24 04:37:15 +0000149 def a():
150 yield 1
151 class b(object):
152 def __enter__(self):
153 return 2
154 def __exit__(self, *exc_info):
155 try:
156 raise Exception()
157 except:
158 pass
159 try:
160 with nested(a(), b()) as (x, y):
161 1/0
162 except ZeroDivisionError:
163 self.assertEqual((x, y), (1, 2))
164 except Exception:
165 self.fail("Reraised wrong exception")
166 else:
167 self.fail("Didn't raise ZeroDivisionError")
168
Guido van Rossuma9f06872006-03-01 17:10:01 +0000169 def test_nested_b_swallows(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000170 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000171 def a():
172 yield
Nick Coghlanafd5e632006-05-03 13:02:47 +0000173 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000174 def b():
175 try:
176 yield
177 except:
178 # Swallow the exception
179 pass
180 try:
181 with nested(a(), b()):
182 1/0
183 except ZeroDivisionError:
184 self.fail("Didn't swallow ZeroDivisionError")
185
186 def test_nested_break(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000187 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000188 def a():
189 yield
190 state = 0
191 while True:
192 state += 1
193 with nested(a(), a()):
194 break
195 state += 10
196 self.assertEqual(state, 1)
197
198 def test_nested_continue(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000199 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000200 def a():
201 yield
202 state = 0
203 while state < 3:
204 state += 1
205 with nested(a(), a()):
206 continue
207 state += 10
208 self.assertEqual(state, 3)
209
210 def test_nested_return(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000211 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000212 def a():
213 try:
214 yield
215 except:
216 pass
217 def foo():
218 with nested(a(), a()):
219 return 1
220 return 10
221 self.assertEqual(foo(), 1)
222
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000223class ClosingTestCase(unittest.TestCase):
224
225 # XXX This needs more work
226
227 def test_closing(self):
228 state = []
229 class C:
230 def close(self):
231 state.append(1)
232 x = C()
233 self.assertEqual(state, [])
234 with closing(x) as y:
235 self.assertEqual(x, y)
236 self.assertEqual(state, [1])
237
238 def test_closing_error(self):
239 state = []
240 class C:
241 def close(self):
242 state.append(1)
243 x = C()
244 self.assertEqual(state, [])
245 try:
246 with closing(x) as y:
247 self.assertEqual(x, y)
248 1/0
249 except ZeroDivisionError:
250 self.assertEqual(state, [1])
251 else:
252 self.fail("Didn't raise ZeroDivisionError")
253
254class FileContextTestCase(unittest.TestCase):
255
256 def testWithOpen(self):
257 tfn = tempfile.mktemp()
258 try:
259 f = None
260 with open(tfn, "w") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000261 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000262 f.write("Booh\n")
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000263 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000264 f = None
265 try:
266 with open(tfn, "r") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000267 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000268 self.assertEqual(f.read(), "Booh\n")
269 1/0
270 except ZeroDivisionError:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000271 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000272 else:
273 self.fail("Didn't raise ZeroDivisionError")
274 finally:
275 try:
276 os.remove(tfn)
277 except os.error:
278 pass
279
280class LockContextTestCase(unittest.TestCase):
281
282 def boilerPlate(self, lock, locked):
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000283 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000284 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000285 self.assertTrue(locked())
286 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000287 try:
288 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000289 self.assertTrue(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000290 1/0
291 except ZeroDivisionError:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000292 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000293 else:
294 self.fail("Didn't raise ZeroDivisionError")
295
296 def testWithLock(self):
297 lock = threading.Lock()
298 self.boilerPlate(lock, lock.locked)
299
300 def testWithRLock(self):
301 lock = threading.RLock()
302 self.boilerPlate(lock, lock._is_owned)
303
304 def testWithCondition(self):
305 lock = threading.Condition()
306 def locked():
307 return lock._is_owned()
308 self.boilerPlate(lock, locked)
309
310 def testWithSemaphore(self):
311 lock = threading.Semaphore()
312 def locked():
313 if lock.acquire(False):
314 lock.release()
315 return False
316 else:
317 return True
318 self.boilerPlate(lock, locked)
319
320 def testWithBoundedSemaphore(self):
321 lock = threading.BoundedSemaphore()
322 def locked():
323 if lock.acquire(False):
324 lock.release()
325 return False
326 else:
327 return True
328 self.boilerPlate(lock, locked)
329
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000330# This is needed to make the test actually run under regrtest.py!
331def test_main():
Raymond Hettinger822b87f2009-05-29 01:46:48 +0000332 with warnings.catch_warnings():
333 warnings.simplefilter('ignore')
334 test_support.run_unittest(__name__)
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000335
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000336if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000337 test_main()