blob: 80ba3e807a7a7ce49e0143360c5893d28549a414 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00003
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +00004import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00005import os
6import decimal
7import tempfile
8import unittest
9import threading
10from contextlib import * # Tests __all__
Collin Winterc2898c52007-04-25 17:29:52 +000011from test import test_support
Raymond Hettinger822b87f2009-05-29 01:46:48 +000012import warnings
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000013
14class ContextManagerTestCase(unittest.TestCase):
15
Nick Coghlanafd5e632006-05-03 13:02:47 +000016 def test_contextmanager_plain(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000017 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000018 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000019 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
Nick Coghlanafd5e632006-05-03 13:02:47 +000029 def test_contextmanager_finally(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000030 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000031 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000032 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
38 try:
39 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
44 except ZeroDivisionError:
45 pass
46 else:
47 self.fail("Expected ZeroDivisionError")
48 self.assertEqual(state, [1, 42, 999])
49
Nick Coghlanafd5e632006-05-03 13:02:47 +000050 def test_contextmanager_no_reraise(self):
51 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000052 def whee():
53 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000054 ctx = whee()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000055 ctx.__enter__()
56 # Calling __exit__ should not result in an exception
Benjamin Peterson5c8da862009-06-30 22:57:08 +000057 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Phillip J. Eby6edd2582006-03-25 00:28:24 +000058
Nick Coghlanafd5e632006-05-03 13:02:47 +000059 def test_contextmanager_trap_yield_after_throw(self):
60 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000061 def whoo():
62 try:
63 yield
64 except:
65 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000066 ctx = whoo()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000067 ctx.__enter__()
68 self.assertRaises(
69 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
70 )
71
Nick Coghlanafd5e632006-05-03 13:02:47 +000072 def test_contextmanager_except(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000073 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000074 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000075 def woohoo():
76 state.append(1)
77 try:
78 yield 42
79 except ZeroDivisionError, e:
80 state.append(e.args[0])
81 self.assertEqual(state, [1, 42, 999])
82 with woohoo() as x:
83 self.assertEqual(state, [1])
84 self.assertEqual(x, 42)
85 state.append(x)
86 raise ZeroDivisionError(999)
87 self.assertEqual(state, [1, 42, 999])
88
Nick Coghlanafd5e632006-05-03 13:02:47 +000089 def test_contextmanager_attribs(self):
Phillip J. Eby35fd1422006-03-28 00:07:24 +000090 def attribs(**kw):
91 def decorate(func):
92 for k,v in kw.items():
93 setattr(func,k,v)
94 return func
95 return decorate
Nick Coghlanafd5e632006-05-03 13:02:47 +000096 @contextmanager
Phillip J. Eby35fd1422006-03-28 00:07:24 +000097 @attribs(foo='bar')
98 def baz(spam):
99 """Whee!"""
100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
102 self.assertEqual(baz.__doc__, "Whee!")
103
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000104class NestedTestCase(unittest.TestCase):
105
106 # XXX This needs more work
107
108 def test_nested(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000109 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000110 def a():
111 yield 1
Nick Coghlanafd5e632006-05-03 13:02:47 +0000112 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000113 def b():
114 yield 2
Nick Coghlanafd5e632006-05-03 13:02:47 +0000115 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000116 def c():
117 yield 3
118 with nested(a(), b(), c()) as (x, y, z):
119 self.assertEqual(x, 1)
120 self.assertEqual(y, 2)
121 self.assertEqual(z, 3)
122
123 def test_nested_cleanup(self):
124 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000125 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000126 def a():
127 state.append(1)
128 try:
129 yield 2
130 finally:
131 state.append(3)
Nick Coghlanafd5e632006-05-03 13:02:47 +0000132 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000133 def b():
134 state.append(4)
135 try:
136 yield 5
137 finally:
138 state.append(6)
139 try:
140 with nested(a(), b()) as (x, y):
141 state.append(x)
142 state.append(y)
143 1/0
144 except ZeroDivisionError:
145 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
146 else:
147 self.fail("Didn't raise ZeroDivisionError")
148
Nick Coghlanda2268f2006-04-24 04:37:15 +0000149 def test_nested_right_exception(self):
150 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000151 @contextmanager
Nick Coghlanda2268f2006-04-24 04:37:15 +0000152 def a():
153 yield 1
154 class b(object):
155 def __enter__(self):
156 return 2
157 def __exit__(self, *exc_info):
158 try:
159 raise Exception()
160 except:
161 pass
162 try:
163 with nested(a(), b()) as (x, y):
164 1/0
165 except ZeroDivisionError:
166 self.assertEqual((x, y), (1, 2))
167 except Exception:
168 self.fail("Reraised wrong exception")
169 else:
170 self.fail("Didn't raise ZeroDivisionError")
171
Guido van Rossuma9f06872006-03-01 17:10:01 +0000172 def test_nested_b_swallows(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000173 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000174 def a():
175 yield
Nick Coghlanafd5e632006-05-03 13:02:47 +0000176 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000177 def b():
178 try:
179 yield
180 except:
181 # Swallow the exception
182 pass
183 try:
184 with nested(a(), b()):
185 1/0
186 except ZeroDivisionError:
187 self.fail("Didn't swallow ZeroDivisionError")
188
189 def test_nested_break(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000190 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000191 def a():
192 yield
193 state = 0
194 while True:
195 state += 1
196 with nested(a(), a()):
197 break
198 state += 10
199 self.assertEqual(state, 1)
200
201 def test_nested_continue(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000202 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000203 def a():
204 yield
205 state = 0
206 while state < 3:
207 state += 1
208 with nested(a(), a()):
209 continue
210 state += 10
211 self.assertEqual(state, 3)
212
213 def test_nested_return(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000214 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000215 def a():
216 try:
217 yield
218 except:
219 pass
220 def foo():
221 with nested(a(), a()):
222 return 1
223 return 10
224 self.assertEqual(foo(), 1)
225
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000226class ClosingTestCase(unittest.TestCase):
227
228 # XXX This needs more work
229
230 def test_closing(self):
231 state = []
232 class C:
233 def close(self):
234 state.append(1)
235 x = C()
236 self.assertEqual(state, [])
237 with closing(x) as y:
238 self.assertEqual(x, y)
239 self.assertEqual(state, [1])
240
241 def test_closing_error(self):
242 state = []
243 class C:
244 def close(self):
245 state.append(1)
246 x = C()
247 self.assertEqual(state, [])
248 try:
249 with closing(x) as y:
250 self.assertEqual(x, y)
251 1/0
252 except ZeroDivisionError:
253 self.assertEqual(state, [1])
254 else:
255 self.fail("Didn't raise ZeroDivisionError")
256
257class FileContextTestCase(unittest.TestCase):
258
259 def testWithOpen(self):
260 tfn = tempfile.mktemp()
261 try:
262 f = None
263 with open(tfn, "w") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000264 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000265 f.write("Booh\n")
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000266 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000267 f = None
268 try:
269 with open(tfn, "r") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000270 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000271 self.assertEqual(f.read(), "Booh\n")
272 1/0
273 except ZeroDivisionError:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000274 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000275 else:
276 self.fail("Didn't raise ZeroDivisionError")
277 finally:
278 try:
279 os.remove(tfn)
280 except os.error:
281 pass
282
283class LockContextTestCase(unittest.TestCase):
284
285 def boilerPlate(self, lock, locked):
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000286 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000287 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000288 self.assertTrue(locked())
289 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000290 try:
291 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000292 self.assertTrue(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000293 1/0
294 except ZeroDivisionError:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000295 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000296 else:
297 self.fail("Didn't raise ZeroDivisionError")
298
299 def testWithLock(self):
300 lock = threading.Lock()
301 self.boilerPlate(lock, lock.locked)
302
303 def testWithRLock(self):
304 lock = threading.RLock()
305 self.boilerPlate(lock, lock._is_owned)
306
307 def testWithCondition(self):
308 lock = threading.Condition()
309 def locked():
310 return lock._is_owned()
311 self.boilerPlate(lock, locked)
312
313 def testWithSemaphore(self):
314 lock = threading.Semaphore()
315 def locked():
316 if lock.acquire(False):
317 lock.release()
318 return False
319 else:
320 return True
321 self.boilerPlate(lock, locked)
322
323 def testWithBoundedSemaphore(self):
324 lock = threading.BoundedSemaphore()
325 def locked():
326 if lock.acquire(False):
327 lock.release()
328 return False
329 else:
330 return True
331 self.boilerPlate(lock, locked)
332
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000333# This is needed to make the test actually run under regrtest.py!
334def test_main():
Raymond Hettinger822b87f2009-05-29 01:46:48 +0000335 with warnings.catch_warnings():
336 warnings.simplefilter('ignore')
337 test_support.run_unittest(__name__)
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000338
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000339if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000340 test_main()