blob: 7d7f8d281c34908955ca2f01e1a044274f053db1 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
3from __future__ import with_statement
4
5import os
6import decimal
7import tempfile
8import unittest
9import threading
10from contextlib import * # Tests __all__
11
12class ContextManagerTestCase(unittest.TestCase):
13
14 def test_contextmanager_plain(self):
15 state = []
16 @contextmanager
17 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
27 def test_contextmanager_finally(self):
28 state = []
29 @contextmanager
30 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
Phillip J. Eby6edd2582006-03-25 00:28:24 +000048 def test_contextmanager_no_reraise(self):
49 @contextmanager
50 def whee():
51 yield
52 ctx = whee().__context__()
53 ctx.__enter__()
54 # Calling __exit__ should not result in an exception
55 self.failIf(ctx.__exit__(TypeError, TypeError("foo"), None))
56
57 def test_contextmanager_trap_yield_after_throw(self):
58 @contextmanager
59 def whoo():
60 try:
61 yield
62 except:
63 yield
64 ctx = whoo().__context__()
65 ctx.__enter__()
66 self.assertRaises(
67 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
68 )
69
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000070 def test_contextmanager_except(self):
71 state = []
72 @contextmanager
73 def woohoo():
74 state.append(1)
75 try:
76 yield 42
77 except ZeroDivisionError, e:
78 state.append(e.args[0])
79 self.assertEqual(state, [1, 42, 999])
80 with woohoo() as x:
81 self.assertEqual(state, [1])
82 self.assertEqual(x, 42)
83 state.append(x)
84 raise ZeroDivisionError(999)
85 self.assertEqual(state, [1, 42, 999])
86
87class NestedTestCase(unittest.TestCase):
88
89 # XXX This needs more work
90
91 def test_nested(self):
92 @contextmanager
93 def a():
94 yield 1
95 @contextmanager
96 def b():
97 yield 2
98 @contextmanager
99 def c():
100 yield 3
101 with nested(a(), b(), c()) as (x, y, z):
102 self.assertEqual(x, 1)
103 self.assertEqual(y, 2)
104 self.assertEqual(z, 3)
105
106 def test_nested_cleanup(self):
107 state = []
108 @contextmanager
109 def a():
110 state.append(1)
111 try:
112 yield 2
113 finally:
114 state.append(3)
115 @contextmanager
116 def b():
117 state.append(4)
118 try:
119 yield 5
120 finally:
121 state.append(6)
122 try:
123 with nested(a(), b()) as (x, y):
124 state.append(x)
125 state.append(y)
126 1/0
127 except ZeroDivisionError:
128 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
129 else:
130 self.fail("Didn't raise ZeroDivisionError")
131
Guido van Rossuma9f06872006-03-01 17:10:01 +0000132 def test_nested_b_swallows(self):
133 @contextmanager
134 def a():
135 yield
136 @contextmanager
137 def b():
138 try:
139 yield
140 except:
141 # Swallow the exception
142 pass
143 try:
144 with nested(a(), b()):
145 1/0
146 except ZeroDivisionError:
147 self.fail("Didn't swallow ZeroDivisionError")
148
149 def test_nested_break(self):
150 @contextmanager
151 def a():
152 yield
153 state = 0
154 while True:
155 state += 1
156 with nested(a(), a()):
157 break
158 state += 10
159 self.assertEqual(state, 1)
160
161 def test_nested_continue(self):
162 @contextmanager
163 def a():
164 yield
165 state = 0
166 while state < 3:
167 state += 1
168 with nested(a(), a()):
169 continue
170 state += 10
171 self.assertEqual(state, 3)
172
173 def test_nested_return(self):
174 @contextmanager
175 def a():
176 try:
177 yield
178 except:
179 pass
180 def foo():
181 with nested(a(), a()):
182 return 1
183 return 10
184 self.assertEqual(foo(), 1)
185
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000186class ClosingTestCase(unittest.TestCase):
187
188 # XXX This needs more work
189
190 def test_closing(self):
191 state = []
192 class C:
193 def close(self):
194 state.append(1)
195 x = C()
196 self.assertEqual(state, [])
197 with closing(x) as y:
198 self.assertEqual(x, y)
199 self.assertEqual(state, [1])
200
201 def test_closing_error(self):
202 state = []
203 class C:
204 def close(self):
205 state.append(1)
206 x = C()
207 self.assertEqual(state, [])
208 try:
209 with closing(x) as y:
210 self.assertEqual(x, y)
211 1/0
212 except ZeroDivisionError:
213 self.assertEqual(state, [1])
214 else:
215 self.fail("Didn't raise ZeroDivisionError")
216
217class FileContextTestCase(unittest.TestCase):
218
219 def testWithOpen(self):
220 tfn = tempfile.mktemp()
221 try:
222 f = None
223 with open(tfn, "w") as f:
224 self.failIf(f.closed)
225 f.write("Booh\n")
226 self.failUnless(f.closed)
227 f = None
228 try:
229 with open(tfn, "r") as f:
230 self.failIf(f.closed)
231 self.assertEqual(f.read(), "Booh\n")
232 1/0
233 except ZeroDivisionError:
234 self.failUnless(f.closed)
235 else:
236 self.fail("Didn't raise ZeroDivisionError")
237 finally:
238 try:
239 os.remove(tfn)
240 except os.error:
241 pass
242
243class LockContextTestCase(unittest.TestCase):
244
245 def boilerPlate(self, lock, locked):
246 self.failIf(locked())
247 with lock:
248 self.failUnless(locked())
249 self.failIf(locked())
250 try:
251 with lock:
252 self.failUnless(locked())
253 1/0
254 except ZeroDivisionError:
255 self.failIf(locked())
256 else:
257 self.fail("Didn't raise ZeroDivisionError")
258
259 def testWithLock(self):
260 lock = threading.Lock()
261 self.boilerPlate(lock, lock.locked)
262
263 def testWithRLock(self):
264 lock = threading.RLock()
265 self.boilerPlate(lock, lock._is_owned)
266
267 def testWithCondition(self):
268 lock = threading.Condition()
269 def locked():
270 return lock._is_owned()
271 self.boilerPlate(lock, locked)
272
273 def testWithSemaphore(self):
274 lock = threading.Semaphore()
275 def locked():
276 if lock.acquire(False):
277 lock.release()
278 return False
279 else:
280 return True
281 self.boilerPlate(lock, locked)
282
283 def testWithBoundedSemaphore(self):
284 lock = threading.BoundedSemaphore()
285 def locked():
286 if lock.acquire(False):
287 lock.release()
288 return False
289 else:
290 return True
291 self.boilerPlate(lock, locked)
292
293class DecimalContextTestCase(unittest.TestCase):
294
295 # XXX Somebody should write more thorough tests for this
296
297 def testBasic(self):
298 ctx = decimal.getcontext()
299 ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
300 with decimal.ExtendedContext:
301 self.assertEqual(decimal.getcontext().prec,
302 decimal.ExtendedContext.prec)
303 self.assertEqual(decimal.getcontext().prec, save_prec)
304 try:
305 with decimal.ExtendedContext:
306 self.assertEqual(decimal.getcontext().prec,
307 decimal.ExtendedContext.prec)
308 1/0
309 except ZeroDivisionError:
310 self.assertEqual(decimal.getcontext().prec, save_prec)
311 else:
312 self.fail("Didn't raise ZeroDivisionError")
313
314
315if __name__ == "__main__":
316 unittest.main()