blob: f8db88cc58a78f4e503f7139a763e87c7ad97fd8 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
3from __future__ import with_statement
4
5import os
6import decimal
7import tempfile
8import unittest
9import threading
10from contextlib import * # Tests __all__
11
12class ContextManagerTestCase(unittest.TestCase):
13
14 def test_contextmanager_plain(self):
15 state = []
16 @contextmanager
17 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
27 def test_contextmanager_finally(self):
28 state = []
29 @contextmanager
30 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
48 def test_contextmanager_except(self):
49 state = []
50 @contextmanager
51 def woohoo():
52 state.append(1)
53 try:
54 yield 42
55 except ZeroDivisionError, e:
56 state.append(e.args[0])
57 self.assertEqual(state, [1, 42, 999])
58 with woohoo() as x:
59 self.assertEqual(state, [1])
60 self.assertEqual(x, 42)
61 state.append(x)
62 raise ZeroDivisionError(999)
63 self.assertEqual(state, [1, 42, 999])
64
65class NestedTestCase(unittest.TestCase):
66
67 # XXX This needs more work
68
69 def test_nested(self):
70 @contextmanager
71 def a():
72 yield 1
73 @contextmanager
74 def b():
75 yield 2
76 @contextmanager
77 def c():
78 yield 3
79 with nested(a(), b(), c()) as (x, y, z):
80 self.assertEqual(x, 1)
81 self.assertEqual(y, 2)
82 self.assertEqual(z, 3)
83
84 def test_nested_cleanup(self):
85 state = []
86 @contextmanager
87 def a():
88 state.append(1)
89 try:
90 yield 2
91 finally:
92 state.append(3)
93 @contextmanager
94 def b():
95 state.append(4)
96 try:
97 yield 5
98 finally:
99 state.append(6)
100 try:
101 with nested(a(), b()) as (x, y):
102 state.append(x)
103 state.append(y)
104 1/0
105 except ZeroDivisionError:
106 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
107 else:
108 self.fail("Didn't raise ZeroDivisionError")
109
Guido van Rossuma9f06872006-03-01 17:10:01 +0000110 def test_nested_b_swallows(self):
111 @contextmanager
112 def a():
113 yield
114 @contextmanager
115 def b():
116 try:
117 yield
118 except:
119 # Swallow the exception
120 pass
121 try:
122 with nested(a(), b()):
123 1/0
124 except ZeroDivisionError:
125 self.fail("Didn't swallow ZeroDivisionError")
126
127 def test_nested_break(self):
128 @contextmanager
129 def a():
130 yield
131 state = 0
132 while True:
133 state += 1
134 with nested(a(), a()):
135 break
136 state += 10
137 self.assertEqual(state, 1)
138
139 def test_nested_continue(self):
140 @contextmanager
141 def a():
142 yield
143 state = 0
144 while state < 3:
145 state += 1
146 with nested(a(), a()):
147 continue
148 state += 10
149 self.assertEqual(state, 3)
150
151 def test_nested_return(self):
152 @contextmanager
153 def a():
154 try:
155 yield
156 except:
157 pass
158 def foo():
159 with nested(a(), a()):
160 return 1
161 return 10
162 self.assertEqual(foo(), 1)
163
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000164class ClosingTestCase(unittest.TestCase):
165
166 # XXX This needs more work
167
168 def test_closing(self):
169 state = []
170 class C:
171 def close(self):
172 state.append(1)
173 x = C()
174 self.assertEqual(state, [])
175 with closing(x) as y:
176 self.assertEqual(x, y)
177 self.assertEqual(state, [1])
178
179 def test_closing_error(self):
180 state = []
181 class C:
182 def close(self):
183 state.append(1)
184 x = C()
185 self.assertEqual(state, [])
186 try:
187 with closing(x) as y:
188 self.assertEqual(x, y)
189 1/0
190 except ZeroDivisionError:
191 self.assertEqual(state, [1])
192 else:
193 self.fail("Didn't raise ZeroDivisionError")
194
195class FileContextTestCase(unittest.TestCase):
196
197 def testWithOpen(self):
198 tfn = tempfile.mktemp()
199 try:
200 f = None
201 with open(tfn, "w") as f:
202 self.failIf(f.closed)
203 f.write("Booh\n")
204 self.failUnless(f.closed)
205 f = None
206 try:
207 with open(tfn, "r") as f:
208 self.failIf(f.closed)
209 self.assertEqual(f.read(), "Booh\n")
210 1/0
211 except ZeroDivisionError:
212 self.failUnless(f.closed)
213 else:
214 self.fail("Didn't raise ZeroDivisionError")
215 finally:
216 try:
217 os.remove(tfn)
218 except os.error:
219 pass
220
221class LockContextTestCase(unittest.TestCase):
222
223 def boilerPlate(self, lock, locked):
224 self.failIf(locked())
225 with lock:
226 self.failUnless(locked())
227 self.failIf(locked())
228 try:
229 with lock:
230 self.failUnless(locked())
231 1/0
232 except ZeroDivisionError:
233 self.failIf(locked())
234 else:
235 self.fail("Didn't raise ZeroDivisionError")
236
237 def testWithLock(self):
238 lock = threading.Lock()
239 self.boilerPlate(lock, lock.locked)
240
241 def testWithRLock(self):
242 lock = threading.RLock()
243 self.boilerPlate(lock, lock._is_owned)
244
245 def testWithCondition(self):
246 lock = threading.Condition()
247 def locked():
248 return lock._is_owned()
249 self.boilerPlate(lock, locked)
250
251 def testWithSemaphore(self):
252 lock = threading.Semaphore()
253 def locked():
254 if lock.acquire(False):
255 lock.release()
256 return False
257 else:
258 return True
259 self.boilerPlate(lock, locked)
260
261 def testWithBoundedSemaphore(self):
262 lock = threading.BoundedSemaphore()
263 def locked():
264 if lock.acquire(False):
265 lock.release()
266 return False
267 else:
268 return True
269 self.boilerPlate(lock, locked)
270
271class DecimalContextTestCase(unittest.TestCase):
272
273 # XXX Somebody should write more thorough tests for this
274
275 def testBasic(self):
276 ctx = decimal.getcontext()
277 ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
278 with decimal.ExtendedContext:
279 self.assertEqual(decimal.getcontext().prec,
280 decimal.ExtendedContext.prec)
281 self.assertEqual(decimal.getcontext().prec, save_prec)
282 try:
283 with decimal.ExtendedContext:
284 self.assertEqual(decimal.getcontext().prec,
285 decimal.ExtendedContext.prec)
286 1/0
287 except ZeroDivisionError:
288 self.assertEqual(decimal.getcontext().prec, save_prec)
289 else:
290 self.fail("Didn't raise ZeroDivisionError")
291
292
293if __name__ == "__main__":
294 unittest.main()