blob: 53f23b27118801e7c57aba7c9da0d0ba1a044699 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
3from __future__ import with_statement
4
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +00005import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006import os
7import decimal
8import tempfile
9import unittest
10import threading
11from contextlib import * # Tests __all__
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +000012from test.test_support import run_suite
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000013
14class ContextManagerTestCase(unittest.TestCase):
15
Nick Coghlana7e820a2006-04-25 10:56:51 +000016 def test_contextfactory_plain(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000017 state = []
Nick Coghlana7e820a2006-04-25 10:56:51 +000018 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000019 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
Nick Coghlana7e820a2006-04-25 10:56:51 +000029 def test_contextfactory_finally(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000030 state = []
Nick Coghlana7e820a2006-04-25 10:56:51 +000031 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000032 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
38 try:
39 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
44 except ZeroDivisionError:
45 pass
46 else:
47 self.fail("Expected ZeroDivisionError")
48 self.assertEqual(state, [1, 42, 999])
49
Nick Coghlana7e820a2006-04-25 10:56:51 +000050 def test_contextfactory_no_reraise(self):
51 @contextfactory
Phillip J. Eby6edd2582006-03-25 00:28:24 +000052 def whee():
53 yield
54 ctx = whee().__context__()
55 ctx.__enter__()
56 # Calling __exit__ should not result in an exception
57 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
58
Nick Coghlana7e820a2006-04-25 10:56:51 +000059 def test_contextfactory_trap_yield_after_throw(self):
60 @contextfactory
Phillip J. Eby6edd2582006-03-25 00:28:24 +000061 def whoo():
62 try:
63 yield
64 except:
65 yield
66 ctx = whoo().__context__()
67 ctx.__enter__()
68 self.assertRaises(
69 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
70 )
71
Nick Coghlana7e820a2006-04-25 10:56:51 +000072 def test_contextfactory_except(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000073 state = []
Nick Coghlana7e820a2006-04-25 10:56:51 +000074 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000075 def woohoo():
76 state.append(1)
77 try:
78 yield 42
79 except ZeroDivisionError, e:
80 state.append(e.args[0])
81 self.assertEqual(state, [1, 42, 999])
82 with woohoo() as x:
83 self.assertEqual(state, [1])
84 self.assertEqual(x, 42)
85 state.append(x)
86 raise ZeroDivisionError(999)
87 self.assertEqual(state, [1, 42, 999])
88
Nick Coghlana7e820a2006-04-25 10:56:51 +000089 def test_contextfactory_attribs(self):
Phillip J. Eby35fd1422006-03-28 00:07:24 +000090 def attribs(**kw):
91 def decorate(func):
92 for k,v in kw.items():
93 setattr(func,k,v)
94 return func
95 return decorate
Nick Coghlana7e820a2006-04-25 10:56:51 +000096 @contextfactory
Phillip J. Eby35fd1422006-03-28 00:07:24 +000097 @attribs(foo='bar')
98 def baz(spam):
99 """Whee!"""
100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
102 self.assertEqual(baz.__doc__, "Whee!")
103
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000104class NestedTestCase(unittest.TestCase):
105
106 # XXX This needs more work
107
108 def test_nested(self):
Nick Coghlana7e820a2006-04-25 10:56:51 +0000109 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000110 def a():
111 yield 1
Nick Coghlana7e820a2006-04-25 10:56:51 +0000112 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000113 def b():
114 yield 2
Nick Coghlana7e820a2006-04-25 10:56:51 +0000115 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000116 def c():
117 yield 3
118 with nested(a(), b(), c()) as (x, y, z):
119 self.assertEqual(x, 1)
120 self.assertEqual(y, 2)
121 self.assertEqual(z, 3)
122
123 def test_nested_cleanup(self):
124 state = []
Nick Coghlana7e820a2006-04-25 10:56:51 +0000125 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000126 def a():
127 state.append(1)
128 try:
129 yield 2
130 finally:
131 state.append(3)
Nick Coghlana7e820a2006-04-25 10:56:51 +0000132 @contextfactory
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000133 def b():
134 state.append(4)
135 try:
136 yield 5
137 finally:
138 state.append(6)
139 try:
140 with nested(a(), b()) as (x, y):
141 state.append(x)
142 state.append(y)
143 1/0
144 except ZeroDivisionError:
145 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
146 else:
147 self.fail("Didn't raise ZeroDivisionError")
148
Nick Coghlanda2268f2006-04-24 04:37:15 +0000149 def test_nested_right_exception(self):
150 state = []
Nick Coghlana7e820a2006-04-25 10:56:51 +0000151 @contextfactory
Nick Coghlanda2268f2006-04-24 04:37:15 +0000152 def a():
153 yield 1
154 class b(object):
Nick Coghlan0e019622006-04-24 04:59:28 +0000155 def __context__(self):
156 return self
Nick Coghlanda2268f2006-04-24 04:37:15 +0000157 def __enter__(self):
158 return 2
159 def __exit__(self, *exc_info):
160 try:
161 raise Exception()
162 except:
163 pass
164 try:
165 with nested(a(), b()) as (x, y):
166 1/0
167 except ZeroDivisionError:
168 self.assertEqual((x, y), (1, 2))
169 except Exception:
170 self.fail("Reraised wrong exception")
171 else:
172 self.fail("Didn't raise ZeroDivisionError")
173
Guido van Rossuma9f06872006-03-01 17:10:01 +0000174 def test_nested_b_swallows(self):
Nick Coghlana7e820a2006-04-25 10:56:51 +0000175 @contextfactory
Guido van Rossuma9f06872006-03-01 17:10:01 +0000176 def a():
177 yield
Nick Coghlana7e820a2006-04-25 10:56:51 +0000178 @contextfactory
Guido van Rossuma9f06872006-03-01 17:10:01 +0000179 def b():
180 try:
181 yield
182 except:
183 # Swallow the exception
184 pass
185 try:
186 with nested(a(), b()):
187 1/0
188 except ZeroDivisionError:
189 self.fail("Didn't swallow ZeroDivisionError")
190
191 def test_nested_break(self):
Nick Coghlana7e820a2006-04-25 10:56:51 +0000192 @contextfactory
Guido van Rossuma9f06872006-03-01 17:10:01 +0000193 def a():
194 yield
195 state = 0
196 while True:
197 state += 1
198 with nested(a(), a()):
199 break
200 state += 10
201 self.assertEqual(state, 1)
202
203 def test_nested_continue(self):
Nick Coghlana7e820a2006-04-25 10:56:51 +0000204 @contextfactory
Guido van Rossuma9f06872006-03-01 17:10:01 +0000205 def a():
206 yield
207 state = 0
208 while state < 3:
209 state += 1
210 with nested(a(), a()):
211 continue
212 state += 10
213 self.assertEqual(state, 3)
214
215 def test_nested_return(self):
Nick Coghlana7e820a2006-04-25 10:56:51 +0000216 @contextfactory
Guido van Rossuma9f06872006-03-01 17:10:01 +0000217 def a():
218 try:
219 yield
220 except:
221 pass
222 def foo():
223 with nested(a(), a()):
224 return 1
225 return 10
226 self.assertEqual(foo(), 1)
227
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000228class ClosingTestCase(unittest.TestCase):
229
230 # XXX This needs more work
231
232 def test_closing(self):
233 state = []
234 class C:
235 def close(self):
236 state.append(1)
237 x = C()
238 self.assertEqual(state, [])
239 with closing(x) as y:
240 self.assertEqual(x, y)
241 self.assertEqual(state, [1])
242
243 def test_closing_error(self):
244 state = []
245 class C:
246 def close(self):
247 state.append(1)
248 x = C()
249 self.assertEqual(state, [])
250 try:
251 with closing(x) as y:
252 self.assertEqual(x, y)
253 1/0
254 except ZeroDivisionError:
255 self.assertEqual(state, [1])
256 else:
257 self.fail("Didn't raise ZeroDivisionError")
258
259class FileContextTestCase(unittest.TestCase):
260
261 def testWithOpen(self):
262 tfn = tempfile.mktemp()
263 try:
264 f = None
265 with open(tfn, "w") as f:
266 self.failIf(f.closed)
267 f.write("Booh\n")
268 self.failUnless(f.closed)
269 f = None
270 try:
271 with open(tfn, "r") as f:
272 self.failIf(f.closed)
273 self.assertEqual(f.read(), "Booh\n")
274 1/0
275 except ZeroDivisionError:
276 self.failUnless(f.closed)
277 else:
278 self.fail("Didn't raise ZeroDivisionError")
279 finally:
280 try:
281 os.remove(tfn)
282 except os.error:
283 pass
284
285class LockContextTestCase(unittest.TestCase):
286
287 def boilerPlate(self, lock, locked):
288 self.failIf(locked())
289 with lock:
290 self.failUnless(locked())
291 self.failIf(locked())
292 try:
293 with lock:
294 self.failUnless(locked())
295 1/0
296 except ZeroDivisionError:
297 self.failIf(locked())
298 else:
299 self.fail("Didn't raise ZeroDivisionError")
300
301 def testWithLock(self):
302 lock = threading.Lock()
303 self.boilerPlate(lock, lock.locked)
304
305 def testWithRLock(self):
306 lock = threading.RLock()
307 self.boilerPlate(lock, lock._is_owned)
308
309 def testWithCondition(self):
310 lock = threading.Condition()
311 def locked():
312 return lock._is_owned()
313 self.boilerPlate(lock, locked)
314
315 def testWithSemaphore(self):
316 lock = threading.Semaphore()
317 def locked():
318 if lock.acquire(False):
319 lock.release()
320 return False
321 else:
322 return True
323 self.boilerPlate(lock, locked)
324
325 def testWithBoundedSemaphore(self):
326 lock = threading.BoundedSemaphore()
327 def locked():
328 if lock.acquire(False):
329 lock.release()
330 return False
331 else:
332 return True
333 self.boilerPlate(lock, locked)
334
335class DecimalContextTestCase(unittest.TestCase):
336
337 # XXX Somebody should write more thorough tests for this
338
339 def testBasic(self):
340 ctx = decimal.getcontext()
Tim Petersa19dc0b2006-04-10 20:25:47 +0000341 orig_context = ctx.copy()
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000342 try:
Tim Petersa19dc0b2006-04-10 20:25:47 +0000343 ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000344 with decimal.ExtendedContext:
345 self.assertEqual(decimal.getcontext().prec,
346 decimal.ExtendedContext.prec)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000347 self.assertEqual(decimal.getcontext().prec, save_prec)
Tim Petersa19dc0b2006-04-10 20:25:47 +0000348 try:
349 with decimal.ExtendedContext:
350 self.assertEqual(decimal.getcontext().prec,
351 decimal.ExtendedContext.prec)
352 1/0
353 except ZeroDivisionError:
354 self.assertEqual(decimal.getcontext().prec, save_prec)
355 else:
356 self.fail("Didn't raise ZeroDivisionError")
357 finally:
358 decimal.setcontext(orig_context)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000359
360
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000361# This is needed to make the test actually run under regrtest.py!
362def test_main():
363 run_suite(
364 unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
365 )
366
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000367if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000368 test_main()