blob: a3e9b071b2ec7ce7f1d7ce7b1b4374c241354bab [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00004import tempfile
5import unittest
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006from contextlib import * # Tests __all__
Benjamin Petersonee8712c2008-05-20 21:35:26 +00007from test import support
Victor Stinner45df8202010-04-28 22:31:17 +00008try:
9 import threading
10except ImportError:
11 threading = None
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000012
Florent Xicluna41fe6152010-04-02 18:52:12 +000013
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000014class ContextManagerTestCase(unittest.TestCase):
15
16 def test_contextmanager_plain(self):
17 state = []
18 @contextmanager
19 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
29 def test_contextmanager_finally(self):
30 state = []
31 @contextmanager
32 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
Florent Xicluna41fe6152010-04-02 18:52:12 +000038 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000039 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000044 self.assertEqual(state, [1, 42, 999])
45
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000046 def test_contextmanager_no_reraise(self):
47 @contextmanager
48 def whee():
49 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000050 ctx = whee()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000051 ctx.__enter__()
52 # Calling __exit__ should not result in an exception
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000053 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000054
55 def test_contextmanager_trap_yield_after_throw(self):
56 @contextmanager
57 def whoo():
58 try:
59 yield
60 except:
61 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000062 ctx = whoo()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000063 ctx.__enter__()
64 self.assertRaises(
65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
66 )
67
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000068 def test_contextmanager_except(self):
69 state = []
70 @contextmanager
71 def woohoo():
72 state.append(1)
73 try:
74 yield 42
Guido van Rossumb940e112007-01-10 16:19:56 +000075 except ZeroDivisionError as e:
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000076 state.append(e.args[0])
77 self.assertEqual(state, [1, 42, 999])
78 with woohoo() as x:
79 self.assertEqual(state, [1])
80 self.assertEqual(x, 42)
81 state.append(x)
82 raise ZeroDivisionError(999)
83 self.assertEqual(state, [1, 42, 999])
84
R. David Murray378c0cf2010-02-24 01:46:21 +000085 def _create_contextmanager_attribs(self):
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000086 def attribs(**kw):
87 def decorate(func):
88 for k,v in kw.items():
89 setattr(func,k,v)
90 return func
91 return decorate
92 @contextmanager
93 @attribs(foo='bar')
94 def baz(spam):
95 """Whee!"""
R. David Murray378c0cf2010-02-24 01:46:21 +000096 return baz
97
98 def test_contextmanager_attribs(self):
99 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
R. David Murray378c0cf2010-02-24 01:46:21 +0000102
103 @unittest.skipIf(sys.flags.optimize >= 2,
104 "Docstrings are omitted with -O2 and above")
105 def test_contextmanager_doc_attrib(self):
106 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000107 self.assertEqual(baz.__doc__, "Whee!")
108
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000109class ClosingTestCase(unittest.TestCase):
110
111 # XXX This needs more work
112
113 def test_closing(self):
114 state = []
115 class C:
116 def close(self):
117 state.append(1)
118 x = C()
119 self.assertEqual(state, [])
120 with closing(x) as y:
121 self.assertEqual(x, y)
122 self.assertEqual(state, [1])
123
124 def test_closing_error(self):
125 state = []
126 class C:
127 def close(self):
128 state.append(1)
129 x = C()
130 self.assertEqual(state, [])
Florent Xicluna41fe6152010-04-02 18:52:12 +0000131 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000132 with closing(x) as y:
133 self.assertEqual(x, y)
Florent Xicluna41fe6152010-04-02 18:52:12 +0000134 1 / 0
135 self.assertEqual(state, [1])
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000136
137class FileContextTestCase(unittest.TestCase):
138
139 def testWithOpen(self):
140 tfn = tempfile.mktemp()
141 try:
142 f = None
143 with open(tfn, "w") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000145 f.write("Booh\n")
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000146 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000147 f = None
Florent Xicluna41fe6152010-04-02 18:52:12 +0000148 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000149 with open(tfn, "r") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000150 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000151 self.assertEqual(f.read(), "Booh\n")
Florent Xicluna41fe6152010-04-02 18:52:12 +0000152 1 / 0
153 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000154 finally:
Florent Xicluna41fe6152010-04-02 18:52:12 +0000155 support.unlink(tfn)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000156
Victor Stinner45df8202010-04-28 22:31:17 +0000157@unittest.skipUnless(threading, 'Threading required for this test.')
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000158class LockContextTestCase(unittest.TestCase):
159
160 def boilerPlate(self, lock, locked):
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000161 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000162 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000163 self.assertTrue(locked())
164 self.assertFalse(locked())
Florent Xicluna41fe6152010-04-02 18:52:12 +0000165 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000166 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000167 self.assertTrue(locked())
Florent Xicluna41fe6152010-04-02 18:52:12 +0000168 1 / 0
169 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000170
171 def testWithLock(self):
172 lock = threading.Lock()
173 self.boilerPlate(lock, lock.locked)
174
175 def testWithRLock(self):
176 lock = threading.RLock()
177 self.boilerPlate(lock, lock._is_owned)
178
179 def testWithCondition(self):
180 lock = threading.Condition()
181 def locked():
182 return lock._is_owned()
183 self.boilerPlate(lock, locked)
184
185 def testWithSemaphore(self):
186 lock = threading.Semaphore()
187 def locked():
188 if lock.acquire(False):
189 lock.release()
190 return False
191 else:
192 return True
193 self.boilerPlate(lock, locked)
194
195 def testWithBoundedSemaphore(self):
196 lock = threading.BoundedSemaphore()
197 def locked():
198 if lock.acquire(False):
199 lock.release()
200 return False
201 else:
202 return True
203 self.boilerPlate(lock, locked)
204
Michael Foordb3a89842010-06-30 12:17:50 +0000205
206class mycontext(ContextDecorator):
207 started = False
208 exc = None
209 catch = False
210
211 def __enter__(self):
212 self.started = True
213 return self
214
215 def __exit__(self, *exc):
216 self.exc = exc
217 return self.catch
218
219
220class TestContextDecorator(unittest.TestCase):
221
222 def test_contextdecorator(self):
223 context = mycontext()
224 with context as result:
225 self.assertIs(result, context)
226 self.assertTrue(context.started)
227
228 self.assertEqual(context.exc, (None, None, None))
229
230
231 def test_contextdecorator_with_exception(self):
232 context = mycontext()
233
234 with self.assertRaisesRegexp(NameError, 'foo'):
235 with context:
236 raise NameError('foo')
237 self.assertIsNotNone(context.exc)
238 self.assertIs(context.exc[0], NameError)
239
240 context = mycontext()
241 context.catch = True
242 with context:
243 raise NameError('foo')
244 self.assertIsNotNone(context.exc)
245 self.assertIs(context.exc[0], NameError)
246
247
248 def test_decorator(self):
249 context = mycontext()
250
251 @context
252 def test():
253 self.assertIsNone(context.exc)
254 self.assertTrue(context.started)
255 test()
256 self.assertEqual(context.exc, (None, None, None))
257
258
259 def test_decorator_with_exception(self):
260 context = mycontext()
261
262 @context
263 def test():
264 self.assertIsNone(context.exc)
265 self.assertTrue(context.started)
266 raise NameError('foo')
267
268 with self.assertRaisesRegexp(NameError, 'foo'):
269 test()
270 self.assertIsNotNone(context.exc)
271 self.assertIs(context.exc[0], NameError)
272
273
274 def test_decorating_method(self):
275 context = mycontext()
276
277 class Test(object):
278
279 @context
280 def method(self, a, b, c=None):
281 self.a = a
282 self.b = b
283 self.c = c
284
285 # these tests are for argument passing when used as a decorator
286 test = Test()
287 test.method(1, 2)
288 self.assertEqual(test.a, 1)
289 self.assertEqual(test.b, 2)
290 self.assertEqual(test.c, None)
291
292 test = Test()
293 test.method('a', 'b', 'c')
294 self.assertEqual(test.a, 'a')
295 self.assertEqual(test.b, 'b')
296 self.assertEqual(test.c, 'c')
297
298 test = Test()
299 test.method(a=1, b=2)
300 self.assertEqual(test.a, 1)
301 self.assertEqual(test.b, 2)
302
303
304 def test_typo_enter(self):
305 class mycontext(ContextDecorator):
306 def __unter__(self):
307 pass
308 def __exit__(self, *exc):
309 pass
310
311 with self.assertRaises(AttributeError):
312 with mycontext():
313 pass
314
315
316 def test_typo_exit(self):
317 class mycontext(ContextDecorator):
318 def __enter__(self):
319 pass
320 def __uxit__(self, *exc):
321 pass
322
323 with self.assertRaises(AttributeError):
324 with mycontext():
325 pass
326
327
328 def test_contextdecorator_as_mixin(self):
329 class somecontext(object):
330 started = False
331 exc = None
332
333 def __enter__(self):
334 self.started = True
335 return self
336
337 def __exit__(self, *exc):
338 self.exc = exc
339
340 class mycontext(somecontext, ContextDecorator):
341 pass
342
343 context = mycontext()
344 @context
345 def test():
346 self.assertIsNone(context.exc)
347 self.assertTrue(context.started)
348 test()
349 self.assertEqual(context.exc, (None, None, None))
350
351
352 def test_contextmanager_as_decorator(self):
353 state = []
354 @contextmanager
355 def woohoo(y):
356 state.append(y)
357 yield
358 state.append(999)
359
360 @woohoo(1)
361 def test(x):
362 self.assertEqual(state, [1])
363 state.append(x)
364 test('something')
365 self.assertEqual(state, [1, 'something', 999])
366
367
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000368# This is needed to make the test actually run under regrtest.py!
369def test_main():
Benjamin Petersonc8c0d782009-07-01 01:39:51 +0000370 support.run_unittest(__name__)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000371
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000372if __name__ == "__main__":
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000373 test_main()