blob: 4d233dad34a05e0d52579827e27b7a421cc2b2bc [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00003
4import os
R. David Murrayf28fd242010-02-23 00:24:49 +00005import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006import tempfile
7import unittest
8import threading
9from contextlib import * # Tests __all__
Collin Winterc2898c52007-04-25 17:29:52 +000010from test import test_support
Raymond Hettinger822b87f2009-05-29 01:46:48 +000011import warnings
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000012
13class ContextManagerTestCase(unittest.TestCase):
14
Nick Coghlanafd5e632006-05-03 13:02:47 +000015 def test_contextmanager_plain(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000016 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000017 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000018 def woohoo():
19 state.append(1)
20 yield 42
21 state.append(999)
22 with woohoo() as x:
23 self.assertEqual(state, [1])
24 self.assertEqual(x, 42)
25 state.append(x)
26 self.assertEqual(state, [1, 42, 999])
27
Nick Coghlanafd5e632006-05-03 13:02:47 +000028 def test_contextmanager_finally(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000029 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000030 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000031 def woohoo():
32 state.append(1)
33 try:
34 yield 42
35 finally:
36 state.append(999)
37 try:
38 with woohoo() as x:
39 self.assertEqual(state, [1])
40 self.assertEqual(x, 42)
41 state.append(x)
42 raise ZeroDivisionError()
43 except ZeroDivisionError:
44 pass
45 else:
46 self.fail("Expected ZeroDivisionError")
47 self.assertEqual(state, [1, 42, 999])
48
Nick Coghlanafd5e632006-05-03 13:02:47 +000049 def test_contextmanager_no_reraise(self):
50 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000051 def whee():
52 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000053 ctx = whee()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000054 ctx.__enter__()
55 # Calling __exit__ should not result in an exception
Benjamin Peterson5c8da862009-06-30 22:57:08 +000056 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Phillip J. Eby6edd2582006-03-25 00:28:24 +000057
Nick Coghlanafd5e632006-05-03 13:02:47 +000058 def test_contextmanager_trap_yield_after_throw(self):
59 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000060 def whoo():
61 try:
62 yield
63 except:
64 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000065 ctx = whoo()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000066 ctx.__enter__()
67 self.assertRaises(
68 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
69 )
70
Nick Coghlanafd5e632006-05-03 13:02:47 +000071 def test_contextmanager_except(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000072 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000073 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000074 def woohoo():
75 state.append(1)
76 try:
77 yield 42
78 except ZeroDivisionError, e:
79 state.append(e.args[0])
80 self.assertEqual(state, [1, 42, 999])
81 with woohoo() as x:
82 self.assertEqual(state, [1])
83 self.assertEqual(x, 42)
84 state.append(x)
85 raise ZeroDivisionError(999)
86 self.assertEqual(state, [1, 42, 999])
87
R. David Murrayf28fd242010-02-23 00:24:49 +000088 def _create_contextmanager_attribs(self):
Phillip J. Eby35fd1422006-03-28 00:07:24 +000089 def attribs(**kw):
90 def decorate(func):
91 for k,v in kw.items():
92 setattr(func,k,v)
93 return func
94 return decorate
Nick Coghlanafd5e632006-05-03 13:02:47 +000095 @contextmanager
Phillip J. Eby35fd1422006-03-28 00:07:24 +000096 @attribs(foo='bar')
97 def baz(spam):
98 """Whee!"""
R. David Murrayf28fd242010-02-23 00:24:49 +000099 return baz
100
101 def test_contextmanager_attribs(self):
102 baz = self._create_contextmanager_attribs()
Phillip J. Eby35fd1422006-03-28 00:07:24 +0000103 self.assertEqual(baz.__name__,'baz')
104 self.assertEqual(baz.foo, 'bar')
R. David Murrayf28fd242010-02-23 00:24:49 +0000105
106 @unittest.skipIf(sys.flags.optimize >= 2,
107 "Docstrings are omitted with -O2 and above")
108 def test_contextmanager_doc_attrib(self):
109 baz = self._create_contextmanager_attribs()
Phillip J. Eby35fd1422006-03-28 00:07:24 +0000110 self.assertEqual(baz.__doc__, "Whee!")
111
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000112class NestedTestCase(unittest.TestCase):
113
114 # XXX This needs more work
115
116 def test_nested(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000117 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000118 def a():
119 yield 1
Nick Coghlanafd5e632006-05-03 13:02:47 +0000120 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000121 def b():
122 yield 2
Nick Coghlanafd5e632006-05-03 13:02:47 +0000123 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000124 def c():
125 yield 3
126 with nested(a(), b(), c()) as (x, y, z):
127 self.assertEqual(x, 1)
128 self.assertEqual(y, 2)
129 self.assertEqual(z, 3)
130
131 def test_nested_cleanup(self):
132 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000133 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000134 def a():
135 state.append(1)
136 try:
137 yield 2
138 finally:
139 state.append(3)
Nick Coghlanafd5e632006-05-03 13:02:47 +0000140 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000141 def b():
142 state.append(4)
143 try:
144 yield 5
145 finally:
146 state.append(6)
147 try:
148 with nested(a(), b()) as (x, y):
149 state.append(x)
150 state.append(y)
151 1/0
152 except ZeroDivisionError:
153 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
154 else:
155 self.fail("Didn't raise ZeroDivisionError")
156
Nick Coghlanda2268f2006-04-24 04:37:15 +0000157 def test_nested_right_exception(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000158 @contextmanager
Nick Coghlanda2268f2006-04-24 04:37:15 +0000159 def a():
160 yield 1
161 class b(object):
162 def __enter__(self):
163 return 2
164 def __exit__(self, *exc_info):
165 try:
166 raise Exception()
167 except:
168 pass
169 try:
170 with nested(a(), b()) as (x, y):
171 1/0
172 except ZeroDivisionError:
173 self.assertEqual((x, y), (1, 2))
174 except Exception:
175 self.fail("Reraised wrong exception")
176 else:
177 self.fail("Didn't raise ZeroDivisionError")
178
Guido van Rossuma9f06872006-03-01 17:10:01 +0000179 def test_nested_b_swallows(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000180 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000181 def a():
182 yield
Nick Coghlanafd5e632006-05-03 13:02:47 +0000183 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000184 def b():
185 try:
186 yield
187 except:
188 # Swallow the exception
189 pass
190 try:
191 with nested(a(), b()):
192 1/0
193 except ZeroDivisionError:
194 self.fail("Didn't swallow ZeroDivisionError")
195
196 def test_nested_break(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000197 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000198 def a():
199 yield
200 state = 0
201 while True:
202 state += 1
203 with nested(a(), a()):
204 break
205 state += 10
206 self.assertEqual(state, 1)
207
208 def test_nested_continue(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000209 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000210 def a():
211 yield
212 state = 0
213 while state < 3:
214 state += 1
215 with nested(a(), a()):
216 continue
217 state += 10
218 self.assertEqual(state, 3)
219
220 def test_nested_return(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000221 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000222 def a():
223 try:
224 yield
225 except:
226 pass
227 def foo():
228 with nested(a(), a()):
229 return 1
230 return 10
231 self.assertEqual(foo(), 1)
232
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000233class ClosingTestCase(unittest.TestCase):
234
235 # XXX This needs more work
236
237 def test_closing(self):
238 state = []
239 class C:
240 def close(self):
241 state.append(1)
242 x = C()
243 self.assertEqual(state, [])
244 with closing(x) as y:
245 self.assertEqual(x, y)
246 self.assertEqual(state, [1])
247
248 def test_closing_error(self):
249 state = []
250 class C:
251 def close(self):
252 state.append(1)
253 x = C()
254 self.assertEqual(state, [])
255 try:
256 with closing(x) as y:
257 self.assertEqual(x, y)
258 1/0
259 except ZeroDivisionError:
260 self.assertEqual(state, [1])
261 else:
262 self.fail("Didn't raise ZeroDivisionError")
263
264class FileContextTestCase(unittest.TestCase):
265
266 def testWithOpen(self):
267 tfn = tempfile.mktemp()
268 try:
269 f = None
270 with open(tfn, "w") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000271 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000272 f.write("Booh\n")
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000273 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000274 f = None
275 try:
276 with open(tfn, "r") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000277 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000278 self.assertEqual(f.read(), "Booh\n")
279 1/0
280 except ZeroDivisionError:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000281 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000282 else:
283 self.fail("Didn't raise ZeroDivisionError")
284 finally:
285 try:
286 os.remove(tfn)
287 except os.error:
288 pass
289
290class LockContextTestCase(unittest.TestCase):
291
292 def boilerPlate(self, lock, locked):
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000293 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000294 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000295 self.assertTrue(locked())
296 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000297 try:
298 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000299 self.assertTrue(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000300 1/0
301 except ZeroDivisionError:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000302 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000303 else:
304 self.fail("Didn't raise ZeroDivisionError")
305
306 def testWithLock(self):
307 lock = threading.Lock()
308 self.boilerPlate(lock, lock.locked)
309
310 def testWithRLock(self):
311 lock = threading.RLock()
312 self.boilerPlate(lock, lock._is_owned)
313
314 def testWithCondition(self):
315 lock = threading.Condition()
316 def locked():
317 return lock._is_owned()
318 self.boilerPlate(lock, locked)
319
320 def testWithSemaphore(self):
321 lock = threading.Semaphore()
322 def locked():
323 if lock.acquire(False):
324 lock.release()
325 return False
326 else:
327 return True
328 self.boilerPlate(lock, locked)
329
330 def testWithBoundedSemaphore(self):
331 lock = threading.BoundedSemaphore()
332 def locked():
333 if lock.acquire(False):
334 lock.release()
335 return False
336 else:
337 return True
338 self.boilerPlate(lock, locked)
339
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000340# This is needed to make the test actually run under regrtest.py!
341def test_main():
Raymond Hettinger822b87f2009-05-29 01:46:48 +0000342 with warnings.catch_warnings():
343 warnings.simplefilter('ignore')
344 test_support.run_unittest(__name__)
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000345
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000346if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000347 test_main()