blob: 8c8d887a55366755bf7d29d4b8e6b3d8dae96bf1 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
3from __future__ import with_statement
4
5import os
6import decimal
7import tempfile
8import unittest
9import threading
10from contextlib import * # Tests __all__
11
12class ContextManagerTestCase(unittest.TestCase):
13
14 def test_contextmanager_plain(self):
15 state = []
16 @contextmanager
17 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
27 def test_contextmanager_finally(self):
28 state = []
29 @contextmanager
30 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
48 def test_contextmanager_except(self):
49 state = []
50 @contextmanager
51 def woohoo():
52 state.append(1)
53 try:
54 yield 42
55 except ZeroDivisionError, e:
56 state.append(e.args[0])
57 self.assertEqual(state, [1, 42, 999])
58 with woohoo() as x:
59 self.assertEqual(state, [1])
60 self.assertEqual(x, 42)
61 state.append(x)
62 raise ZeroDivisionError(999)
63 self.assertEqual(state, [1, 42, 999])
64
65class NestedTestCase(unittest.TestCase):
66
67 # XXX This needs more work
68
69 def test_nested(self):
70 @contextmanager
71 def a():
72 yield 1
73 @contextmanager
74 def b():
75 yield 2
76 @contextmanager
77 def c():
78 yield 3
79 with nested(a(), b(), c()) as (x, y, z):
80 self.assertEqual(x, 1)
81 self.assertEqual(y, 2)
82 self.assertEqual(z, 3)
83
84 def test_nested_cleanup(self):
85 state = []
86 @contextmanager
87 def a():
88 state.append(1)
89 try:
90 yield 2
91 finally:
92 state.append(3)
93 @contextmanager
94 def b():
95 state.append(4)
96 try:
97 yield 5
98 finally:
99 state.append(6)
100 try:
101 with nested(a(), b()) as (x, y):
102 state.append(x)
103 state.append(y)
104 1/0
105 except ZeroDivisionError:
106 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
107 else:
108 self.fail("Didn't raise ZeroDivisionError")
109
110class ClosingTestCase(unittest.TestCase):
111
112 # XXX This needs more work
113
114 def test_closing(self):
115 state = []
116 class C:
117 def close(self):
118 state.append(1)
119 x = C()
120 self.assertEqual(state, [])
121 with closing(x) as y:
122 self.assertEqual(x, y)
123 self.assertEqual(state, [1])
124
125 def test_closing_error(self):
126 state = []
127 class C:
128 def close(self):
129 state.append(1)
130 x = C()
131 self.assertEqual(state, [])
132 try:
133 with closing(x) as y:
134 self.assertEqual(x, y)
135 1/0
136 except ZeroDivisionError:
137 self.assertEqual(state, [1])
138 else:
139 self.fail("Didn't raise ZeroDivisionError")
140
141class FileContextTestCase(unittest.TestCase):
142
143 def testWithOpen(self):
144 tfn = tempfile.mktemp()
145 try:
146 f = None
147 with open(tfn, "w") as f:
148 self.failIf(f.closed)
149 f.write("Booh\n")
150 self.failUnless(f.closed)
151 f = None
152 try:
153 with open(tfn, "r") as f:
154 self.failIf(f.closed)
155 self.assertEqual(f.read(), "Booh\n")
156 1/0
157 except ZeroDivisionError:
158 self.failUnless(f.closed)
159 else:
160 self.fail("Didn't raise ZeroDivisionError")
161 finally:
162 try:
163 os.remove(tfn)
164 except os.error:
165 pass
166
167class LockContextTestCase(unittest.TestCase):
168
169 def boilerPlate(self, lock, locked):
170 self.failIf(locked())
171 with lock:
172 self.failUnless(locked())
173 self.failIf(locked())
174 try:
175 with lock:
176 self.failUnless(locked())
177 1/0
178 except ZeroDivisionError:
179 self.failIf(locked())
180 else:
181 self.fail("Didn't raise ZeroDivisionError")
182
183 def testWithLock(self):
184 lock = threading.Lock()
185 self.boilerPlate(lock, lock.locked)
186
187 def testWithRLock(self):
188 lock = threading.RLock()
189 self.boilerPlate(lock, lock._is_owned)
190
191 def testWithCondition(self):
192 lock = threading.Condition()
193 def locked():
194 return lock._is_owned()
195 self.boilerPlate(lock, locked)
196
197 def testWithSemaphore(self):
198 lock = threading.Semaphore()
199 def locked():
200 if lock.acquire(False):
201 lock.release()
202 return False
203 else:
204 return True
205 self.boilerPlate(lock, locked)
206
207 def testWithBoundedSemaphore(self):
208 lock = threading.BoundedSemaphore()
209 def locked():
210 if lock.acquire(False):
211 lock.release()
212 return False
213 else:
214 return True
215 self.boilerPlate(lock, locked)
216
217class DecimalContextTestCase(unittest.TestCase):
218
219 # XXX Somebody should write more thorough tests for this
220
221 def testBasic(self):
222 ctx = decimal.getcontext()
223 ctx.prec = save_prec = decimal.ExtendedContext.prec + 5
224 with decimal.ExtendedContext:
225 self.assertEqual(decimal.getcontext().prec,
226 decimal.ExtendedContext.prec)
227 self.assertEqual(decimal.getcontext().prec, save_prec)
228 try:
229 with decimal.ExtendedContext:
230 self.assertEqual(decimal.getcontext().prec,
231 decimal.ExtendedContext.prec)
232 1/0
233 except ZeroDivisionError:
234 self.assertEqual(decimal.getcontext().prec, save_prec)
235 else:
236 self.fail("Didn't raise ZeroDivisionError")
237
238
239if __name__ == "__main__":
240 unittest.main()