blob: 97470c78fbe49b1ad45ed63ae7c3e82d8e3b2a96 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
3from __future__ import with_statement
4
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +00005import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006import os
7import decimal
8import tempfile
9import unittest
10import threading
11from contextlib import * # Tests __all__
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +000012from test.test_support import run_suite
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000013
14class ContextManagerTestCase(unittest.TestCase):
15
16 def test_contextmanager_plain(self):
17 state = []
18 @contextmanager
19 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
29 def test_contextmanager_finally(self):
30 state = []
31 @contextmanager
32 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
38 try:
39 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
44 except ZeroDivisionError:
45 pass
46 else:
47 self.fail("Expected ZeroDivisionError")
48 self.assertEqual(state, [1, 42, 999])
49
Phillip J. Eby6edd2582006-03-25 00:28:24 +000050 def test_contextmanager_no_reraise(self):
51 @contextmanager
52 def whee():
53 yield
54 ctx = whee().__context__()
55 ctx.__enter__()
56 # Calling __exit__ should not result in an exception
57 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
58
59 def test_contextmanager_trap_yield_after_throw(self):
60 @contextmanager
61 def whoo():
62 try:
63 yield
64 except:
65 yield
66 ctx = whoo().__context__()
67 ctx.__enter__()
68 self.assertRaises(
69 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
70 )
71
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000072 def test_contextmanager_except(self):
73 state = []
74 @contextmanager
75 def woohoo():
76 state.append(1)
77 try:
78 yield 42
79 except ZeroDivisionError, e:
80 state.append(e.args[0])
81 self.assertEqual(state, [1, 42, 999])
82 with woohoo() as x:
83 self.assertEqual(state, [1])
84 self.assertEqual(x, 42)
85 state.append(x)
86 raise ZeroDivisionError(999)
87 self.assertEqual(state, [1, 42, 999])
88
Phillip J. Eby35fd1422006-03-28 00:07:24 +000089 def test_contextmanager_attribs(self):
90 def attribs(**kw):
91 def decorate(func):
92 for k,v in kw.items():
93 setattr(func,k,v)
94 return func
95 return decorate
96 @contextmanager
97 @attribs(foo='bar')
98 def baz(spam):
99 """Whee!"""
100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
102 self.assertEqual(baz.__doc__, "Whee!")
103
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000104class NestedTestCase(unittest.TestCase):
105
106 # XXX This needs more work
107
108 def test_nested(self):
109 @contextmanager
110 def a():
111 yield 1
112 @contextmanager
113 def b():
114 yield 2
115 @contextmanager
116 def c():
117 yield 3
118 with nested(a(), b(), c()) as (x, y, z):
119 self.assertEqual(x, 1)
120 self.assertEqual(y, 2)
121 self.assertEqual(z, 3)
122
123 def test_nested_cleanup(self):
124 state = []
125 @contextmanager
126 def a():
127 state.append(1)
128 try:
129 yield 2
130 finally:
131 state.append(3)
132 @contextmanager
133 def b():
134 state.append(4)
135 try:
136 yield 5
137 finally:
138 state.append(6)
139 try:
140 with nested(a(), b()) as (x, y):
141 state.append(x)
142 state.append(y)
143 1/0
144 except ZeroDivisionError:
145 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
146 else:
147 self.fail("Didn't raise ZeroDivisionError")
148
Guido van Rossuma9f06872006-03-01 17:10:01 +0000149 def test_nested_b_swallows(self):
150 @contextmanager
151 def a():
152 yield
153 @contextmanager
154 def b():
155 try:
156 yield
157 except:
158 # Swallow the exception
159 pass
160 try:
161 with nested(a(), b()):
162 1/0
163 except ZeroDivisionError:
164 self.fail("Didn't swallow ZeroDivisionError")
165
166 def test_nested_break(self):
167 @contextmanager
168 def a():
169 yield
170 state = 0
171 while True:
172 state += 1
173 with nested(a(), a()):
174 break
175 state += 10
176 self.assertEqual(state, 1)
177
178 def test_nested_continue(self):
179 @contextmanager
180 def a():
181 yield
182 state = 0
183 while state < 3:
184 state += 1
185 with nested(a(), a()):
186 continue
187 state += 10
188 self.assertEqual(state, 3)
189
190 def test_nested_return(self):
191 @contextmanager
192 def a():
193 try:
194 yield
195 except:
196 pass
197 def foo():
198 with nested(a(), a()):
199 return 1
200 return 10
201 self.assertEqual(foo(), 1)
202
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000203class ClosingTestCase(unittest.TestCase):
204
205 # XXX This needs more work
206
207 def test_closing(self):
208 state = []
209 class C:
210 def close(self):
211 state.append(1)
212 x = C()
213 self.assertEqual(state, [])
214 with closing(x) as y:
215 self.assertEqual(x, y)
216 self.assertEqual(state, [1])
217
218 def test_closing_error(self):
219 state = []
220 class C:
221 def close(self):
222 state.append(1)
223 x = C()
224 self.assertEqual(state, [])
225 try:
226 with closing(x) as y:
227 self.assertEqual(x, y)
228 1/0
229 except ZeroDivisionError:
230 self.assertEqual(state, [1])
231 else:
232 self.fail("Didn't raise ZeroDivisionError")
233
234class FileContextTestCase(unittest.TestCase):
235
236 def testWithOpen(self):
237 tfn = tempfile.mktemp()
238 try:
239 f = None
240 with open(tfn, "w") as f:
241 self.failIf(f.closed)
242 f.write("Booh\n")
243 self.failUnless(f.closed)
244 f = None
245 try:
246 with open(tfn, "r") as f:
247 self.failIf(f.closed)
248 self.assertEqual(f.read(), "Booh\n")
249 1/0
250 except ZeroDivisionError:
251 self.failUnless(f.closed)
252 else:
253 self.fail("Didn't raise ZeroDivisionError")
254 finally:
255 try:
256 os.remove(tfn)
257 except os.error:
258 pass
259
260class LockContextTestCase(unittest.TestCase):
261
262 def boilerPlate(self, lock, locked):
263 self.failIf(locked())
264 with lock:
265 self.failUnless(locked())
266 self.failIf(locked())
267 try:
268 with lock:
269 self.failUnless(locked())
270 1/0
271 except ZeroDivisionError:
272 self.failIf(locked())
273 else:
274 self.fail("Didn't raise ZeroDivisionError")
275
276 def testWithLock(self):
277 lock = threading.Lock()
278 self.boilerPlate(lock, lock.locked)
279
280 def testWithRLock(self):
281 lock = threading.RLock()
282 self.boilerPlate(lock, lock._is_owned)
283
284 def testWithCondition(self):
285 lock = threading.Condition()
286 def locked():
287 return lock._is_owned()
288 self.boilerPlate(lock, locked)
289
290 def testWithSemaphore(self):
291 lock = threading.Semaphore()
292 def locked():
293 if lock.acquire(False):
294 lock.release()
295 return False
296 else:
297 return True
298 self.boilerPlate(lock, locked)
299
300 def testWithBoundedSemaphore(self):
301 lock = threading.BoundedSemaphore()
302 def locked():
303 if lock.acquire(False):
304 lock.release()
305 return False
306 else:
307 return True
308 self.boilerPlate(lock, locked)
309
310class DecimalContextTestCase(unittest.TestCase):
311
312 # XXX Somebody should write more thorough tests for this
313
314 def testBasic(self):
315 ctx = decimal.getcontext()
Tim Petersa19dc0b2006-04-10 20:25:47 +0000316 orig_context = ctx.copy()
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000317 try:
Tim Petersa19dc0b2006-04-10 20:25:47 +0000318 ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000319 with decimal.ExtendedContext:
320 self.assertEqual(decimal.getcontext().prec,
321 decimal.ExtendedContext.prec)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000322 self.assertEqual(decimal.getcontext().prec, save_prec)
Tim Petersa19dc0b2006-04-10 20:25:47 +0000323 try:
324 with decimal.ExtendedContext:
325 self.assertEqual(decimal.getcontext().prec,
326 decimal.ExtendedContext.prec)
327 1/0
328 except ZeroDivisionError:
329 self.assertEqual(decimal.getcontext().prec, save_prec)
330 else:
331 self.fail("Didn't raise ZeroDivisionError")
332 finally:
333 decimal.setcontext(orig_context)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000334
335
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000336# This is needed to make the test actually run under regrtest.py!
337def test_main():
338 run_suite(
339 unittest.defaultTestLoader.loadTestsFromModule(sys.modules[__name__])
340 )
341
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000342if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000343 test_main()