blob: cd8895e873a60c3b0b2cf45704b5ee3a42efa480 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
3from __future__ import with_statement
4
5import os
6import decimal
7import tempfile
8import unittest
9import threading
10from contextlib import * # Tests __all__
11
12class ContextManagerTestCase(unittest.TestCase):
13
14 def test_contextmanager_plain(self):
15 state = []
16 @contextmanager
17 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
27 def test_contextmanager_finally(self):
28 state = []
29 @contextmanager
30 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
Phillip J. Eby6edd2582006-03-25 00:28:24 +000048 def test_contextmanager_no_reraise(self):
49 @contextmanager
50 def whee():
51 yield
52 ctx = whee().__context__()
53 ctx.__enter__()
54 # Calling __exit__ should not result in an exception
55 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
56
57 def test_contextmanager_trap_yield_after_throw(self):
58 @contextmanager
59 def whoo():
60 try:
61 yield
62 except:
63 yield
64 ctx = whoo().__context__()
65 ctx.__enter__()
66 self.assertRaises(
67 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
68 )
69
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000070 def test_contextmanager_except(self):
71 state = []
72 @contextmanager
73 def woohoo():
74 state.append(1)
75 try:
76 yield 42
77 except ZeroDivisionError, e:
78 state.append(e.args[0])
79 self.assertEqual(state, [1, 42, 999])
80 with woohoo() as x:
81 self.assertEqual(state, [1])
82 self.assertEqual(x, 42)
83 state.append(x)
84 raise ZeroDivisionError(999)
85 self.assertEqual(state, [1, 42, 999])
86
Phillip J. Eby35fd1422006-03-28 00:07:24 +000087 def test_contextmanager_attribs(self):
88 def attribs(**kw):
89 def decorate(func):
90 for k,v in kw.items():
91 setattr(func,k,v)
92 return func
93 return decorate
94 @contextmanager
95 @attribs(foo='bar')
96 def baz(spam):
97 """Whee!"""
98 self.assertEqual(baz.__name__,'baz')
99 self.assertEqual(baz.foo, 'bar')
100 self.assertEqual(baz.__doc__, "Whee!")
101
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000102class NestedTestCase(unittest.TestCase):
103
104 # XXX This needs more work
105
106 def test_nested(self):
107 @contextmanager
108 def a():
109 yield 1
110 @contextmanager
111 def b():
112 yield 2
113 @contextmanager
114 def c():
115 yield 3
116 with nested(a(), b(), c()) as (x, y, z):
117 self.assertEqual(x, 1)
118 self.assertEqual(y, 2)
119 self.assertEqual(z, 3)
120
121 def test_nested_cleanup(self):
122 state = []
123 @contextmanager
124 def a():
125 state.append(1)
126 try:
127 yield 2
128 finally:
129 state.append(3)
130 @contextmanager
131 def b():
132 state.append(4)
133 try:
134 yield 5
135 finally:
136 state.append(6)
137 try:
138 with nested(a(), b()) as (x, y):
139 state.append(x)
140 state.append(y)
141 1/0
142 except ZeroDivisionError:
143 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
144 else:
145 self.fail("Didn't raise ZeroDivisionError")
146
Guido van Rossuma9f06872006-03-01 17:10:01 +0000147 def test_nested_b_swallows(self):
148 @contextmanager
149 def a():
150 yield
151 @contextmanager
152 def b():
153 try:
154 yield
155 except:
156 # Swallow the exception
157 pass
158 try:
159 with nested(a(), b()):
160 1/0
161 except ZeroDivisionError:
162 self.fail("Didn't swallow ZeroDivisionError")
163
164 def test_nested_break(self):
165 @contextmanager
166 def a():
167 yield
168 state = 0
169 while True:
170 state += 1
171 with nested(a(), a()):
172 break
173 state += 10
174 self.assertEqual(state, 1)
175
176 def test_nested_continue(self):
177 @contextmanager
178 def a():
179 yield
180 state = 0
181 while state < 3:
182 state += 1
183 with nested(a(), a()):
184 continue
185 state += 10
186 self.assertEqual(state, 3)
187
188 def test_nested_return(self):
189 @contextmanager
190 def a():
191 try:
192 yield
193 except:
194 pass
195 def foo():
196 with nested(a(), a()):
197 return 1
198 return 10
199 self.assertEqual(foo(), 1)
200
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000201class ClosingTestCase(unittest.TestCase):
202
203 # XXX This needs more work
204
205 def test_closing(self):
206 state = []
207 class C:
208 def close(self):
209 state.append(1)
210 x = C()
211 self.assertEqual(state, [])
212 with closing(x) as y:
213 self.assertEqual(x, y)
214 self.assertEqual(state, [1])
215
216 def test_closing_error(self):
217 state = []
218 class C:
219 def close(self):
220 state.append(1)
221 x = C()
222 self.assertEqual(state, [])
223 try:
224 with closing(x) as y:
225 self.assertEqual(x, y)
226 1/0
227 except ZeroDivisionError:
228 self.assertEqual(state, [1])
229 else:
230 self.fail("Didn't raise ZeroDivisionError")
231
232class FileContextTestCase(unittest.TestCase):
233
234 def testWithOpen(self):
235 tfn = tempfile.mktemp()
236 try:
237 f = None
238 with open(tfn, "w") as f:
239 self.failIf(f.closed)
240 f.write("Booh\n")
241 self.failUnless(f.closed)
242 f = None
243 try:
244 with open(tfn, "r") as f:
245 self.failIf(f.closed)
246 self.assertEqual(f.read(), "Booh\n")
247 1/0
248 except ZeroDivisionError:
249 self.failUnless(f.closed)
250 else:
251 self.fail("Didn't raise ZeroDivisionError")
252 finally:
253 try:
254 os.remove(tfn)
255 except os.error:
256 pass
257
258class LockContextTestCase(unittest.TestCase):
259
260 def boilerPlate(self, lock, locked):
261 self.failIf(locked())
262 with lock:
263 self.failUnless(locked())
264 self.failIf(locked())
265 try:
266 with lock:
267 self.failUnless(locked())
268 1/0
269 except ZeroDivisionError:
270 self.failIf(locked())
271 else:
272 self.fail("Didn't raise ZeroDivisionError")
273
274 def testWithLock(self):
275 lock = threading.Lock()
276 self.boilerPlate(lock, lock.locked)
277
278 def testWithRLock(self):
279 lock = threading.RLock()
280 self.boilerPlate(lock, lock._is_owned)
281
282 def testWithCondition(self):
283 lock = threading.Condition()
284 def locked():
285 return lock._is_owned()
286 self.boilerPlate(lock, locked)
287
288 def testWithSemaphore(self):
289 lock = threading.Semaphore()
290 def locked():
291 if lock.acquire(False):
292 lock.release()
293 return False
294 else:
295 return True
296 self.boilerPlate(lock, locked)
297
298 def testWithBoundedSemaphore(self):
299 lock = threading.BoundedSemaphore()
300 def locked():
301 if lock.acquire(False):
302 lock.release()
303 return False
304 else:
305 return True
306 self.boilerPlate(lock, locked)
307
308class DecimalContextTestCase(unittest.TestCase):
309
310 # XXX Somebody should write more thorough tests for this
311
312 def testBasic(self):
313 ctx = decimal.getcontext()
314 ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
315 with decimal.ExtendedContext:
316 self.assertEqual(decimal.getcontext().prec,
317 decimal.ExtendedContext.prec)
318 self.assertEqual(decimal.getcontext().prec, save_prec)
319 try:
320 with decimal.ExtendedContext:
321 self.assertEqual(decimal.getcontext().prec,
322 decimal.ExtendedContext.prec)
323 1/0
324 except ZeroDivisionError:
325 self.assertEqual(decimal.getcontext().prec, save_prec)
326 else:
327 self.fail("Didn't raise ZeroDivisionError")
328
329
330if __name__ == "__main__":
331 unittest.main()