blob: f28c95eadbc1968b50c2fa1ca79d8547b583b2c5 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
R. David Murrayf28fd242010-02-23 00:24:49 +00003import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00004import tempfile
5import unittest
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006from contextlib import * # Tests __all__
Collin Winterc2898c52007-04-25 17:29:52 +00007from test import test_support
Victor Stinner6a102812010-04-27 23:55:59 +00008try:
9 import threading
10except ImportError:
11 threading = None
Florent Xicluna6257a7b2010-03-31 22:01:03 +000012
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000013
14class ContextManagerTestCase(unittest.TestCase):
15
Nick Coghlanafd5e632006-05-03 13:02:47 +000016 def test_contextmanager_plain(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000017 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000018 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000019 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
Nick Coghlanafd5e632006-05-03 13:02:47 +000029 def test_contextmanager_finally(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000030 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000031 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000032 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
Florent Xicluna6257a7b2010-03-31 22:01:03 +000038 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000039 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000044 self.assertEqual(state, [1, 42, 999])
45
Nick Coghlanafd5e632006-05-03 13:02:47 +000046 def test_contextmanager_no_reraise(self):
47 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000048 def whee():
49 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000050 ctx = whee()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000051 ctx.__enter__()
52 # Calling __exit__ should not result in an exception
Benjamin Peterson5c8da862009-06-30 22:57:08 +000053 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Phillip J. Eby6edd2582006-03-25 00:28:24 +000054
Nick Coghlanafd5e632006-05-03 13:02:47 +000055 def test_contextmanager_trap_yield_after_throw(self):
56 @contextmanager
Phillip J. Eby6edd2582006-03-25 00:28:24 +000057 def whoo():
58 try:
59 yield
60 except:
61 yield
Guido van Rossumda5b7012006-05-02 19:47:52 +000062 ctx = whoo()
Phillip J. Eby6edd2582006-03-25 00:28:24 +000063 ctx.__enter__()
64 self.assertRaises(
65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
66 )
67
Nick Coghlanafd5e632006-05-03 13:02:47 +000068 def test_contextmanager_except(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000069 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +000070 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000071 def woohoo():
72 state.append(1)
73 try:
74 yield 42
75 except ZeroDivisionError, e:
76 state.append(e.args[0])
77 self.assertEqual(state, [1, 42, 999])
78 with woohoo() as x:
79 self.assertEqual(state, [1])
80 self.assertEqual(x, 42)
81 state.append(x)
82 raise ZeroDivisionError(999)
83 self.assertEqual(state, [1, 42, 999])
84
R. David Murrayf28fd242010-02-23 00:24:49 +000085 def _create_contextmanager_attribs(self):
Phillip J. Eby35fd1422006-03-28 00:07:24 +000086 def attribs(**kw):
87 def decorate(func):
88 for k,v in kw.items():
89 setattr(func,k,v)
90 return func
91 return decorate
Nick Coghlanafd5e632006-05-03 13:02:47 +000092 @contextmanager
Phillip J. Eby35fd1422006-03-28 00:07:24 +000093 @attribs(foo='bar')
94 def baz(spam):
95 """Whee!"""
R. David Murrayf28fd242010-02-23 00:24:49 +000096 return baz
97
98 def test_contextmanager_attribs(self):
99 baz = self._create_contextmanager_attribs()
Phillip J. Eby35fd1422006-03-28 00:07:24 +0000100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
R. David Murrayf28fd242010-02-23 00:24:49 +0000102
103 @unittest.skipIf(sys.flags.optimize >= 2,
104 "Docstrings are omitted with -O2 and above")
105 def test_contextmanager_doc_attrib(self):
106 baz = self._create_contextmanager_attribs()
Phillip J. Eby35fd1422006-03-28 00:07:24 +0000107 self.assertEqual(baz.__doc__, "Whee!")
108
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000109class NestedTestCase(unittest.TestCase):
110
111 # XXX This needs more work
112
113 def test_nested(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000114 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000115 def a():
116 yield 1
Nick Coghlanafd5e632006-05-03 13:02:47 +0000117 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000118 def b():
119 yield 2
Nick Coghlanafd5e632006-05-03 13:02:47 +0000120 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000121 def c():
122 yield 3
123 with nested(a(), b(), c()) as (x, y, z):
124 self.assertEqual(x, 1)
125 self.assertEqual(y, 2)
126 self.assertEqual(z, 3)
127
128 def test_nested_cleanup(self):
129 state = []
Nick Coghlanafd5e632006-05-03 13:02:47 +0000130 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000131 def a():
132 state.append(1)
133 try:
134 yield 2
135 finally:
136 state.append(3)
Nick Coghlanafd5e632006-05-03 13:02:47 +0000137 @contextmanager
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000138 def b():
139 state.append(4)
140 try:
141 yield 5
142 finally:
143 state.append(6)
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000144 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000145 with nested(a(), b()) as (x, y):
146 state.append(x)
147 state.append(y)
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000148 1 // 0
149 self.assertEqual(state, [1, 4, 2, 5, 6, 3])
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000150
Nick Coghlanda2268f2006-04-24 04:37:15 +0000151 def test_nested_right_exception(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000152 @contextmanager
Nick Coghlanda2268f2006-04-24 04:37:15 +0000153 def a():
154 yield 1
155 class b(object):
156 def __enter__(self):
157 return 2
158 def __exit__(self, *exc_info):
159 try:
160 raise Exception()
161 except:
162 pass
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000163 with self.assertRaises(ZeroDivisionError):
Nick Coghlanda2268f2006-04-24 04:37:15 +0000164 with nested(a(), b()) as (x, y):
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000165 1 // 0
166 self.assertEqual((x, y), (1, 2))
Nick Coghlanda2268f2006-04-24 04:37:15 +0000167
Guido van Rossuma9f06872006-03-01 17:10:01 +0000168 def test_nested_b_swallows(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000169 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000170 def a():
171 yield
Nick Coghlanafd5e632006-05-03 13:02:47 +0000172 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000173 def b():
174 try:
175 yield
176 except:
177 # Swallow the exception
178 pass
179 try:
180 with nested(a(), b()):
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000181 1 // 0
Guido van Rossuma9f06872006-03-01 17:10:01 +0000182 except ZeroDivisionError:
183 self.fail("Didn't swallow ZeroDivisionError")
184
185 def test_nested_break(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000186 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000187 def a():
188 yield
189 state = 0
190 while True:
191 state += 1
192 with nested(a(), a()):
193 break
194 state += 10
195 self.assertEqual(state, 1)
196
197 def test_nested_continue(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000198 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000199 def a():
200 yield
201 state = 0
202 while state < 3:
203 state += 1
204 with nested(a(), a()):
205 continue
206 state += 10
207 self.assertEqual(state, 3)
208
209 def test_nested_return(self):
Nick Coghlanafd5e632006-05-03 13:02:47 +0000210 @contextmanager
Guido van Rossuma9f06872006-03-01 17:10:01 +0000211 def a():
212 try:
213 yield
214 except:
215 pass
216 def foo():
217 with nested(a(), a()):
218 return 1
219 return 10
220 self.assertEqual(foo(), 1)
221
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000222class ClosingTestCase(unittest.TestCase):
223
224 # XXX This needs more work
225
226 def test_closing(self):
227 state = []
228 class C:
229 def close(self):
230 state.append(1)
231 x = C()
232 self.assertEqual(state, [])
233 with closing(x) as y:
234 self.assertEqual(x, y)
235 self.assertEqual(state, [1])
236
237 def test_closing_error(self):
238 state = []
239 class C:
240 def close(self):
241 state.append(1)
242 x = C()
243 self.assertEqual(state, [])
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000244 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000245 with closing(x) as y:
246 self.assertEqual(x, y)
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000247 1 // 0
248 self.assertEqual(state, [1])
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000249
250class FileContextTestCase(unittest.TestCase):
251
252 def testWithOpen(self):
253 tfn = tempfile.mktemp()
254 try:
255 f = None
256 with open(tfn, "w") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000257 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000258 f.write("Booh\n")
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000259 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000260 f = None
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000261 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000262 with open(tfn, "r") as f:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000263 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000264 self.assertEqual(f.read(), "Booh\n")
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000265 1 // 0
266 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000267 finally:
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000268 test_support.unlink(tfn)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000269
Victor Stinner6a102812010-04-27 23:55:59 +0000270@unittest.skipUnless(threading, 'Threading required for this test.')
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000271class LockContextTestCase(unittest.TestCase):
272
273 def boilerPlate(self, lock, locked):
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000274 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000275 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000276 self.assertTrue(locked())
277 self.assertFalse(locked())
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000278 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000279 with lock:
Benjamin Peterson5c8da862009-06-30 22:57:08 +0000280 self.assertTrue(locked())
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000281 1 // 0
282 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000283
284 def testWithLock(self):
285 lock = threading.Lock()
286 self.boilerPlate(lock, lock.locked)
287
288 def testWithRLock(self):
289 lock = threading.RLock()
290 self.boilerPlate(lock, lock._is_owned)
291
292 def testWithCondition(self):
293 lock = threading.Condition()
294 def locked():
295 return lock._is_owned()
296 self.boilerPlate(lock, locked)
297
298 def testWithSemaphore(self):
299 lock = threading.Semaphore()
300 def locked():
301 if lock.acquire(False):
302 lock.release()
303 return False
304 else:
305 return True
306 self.boilerPlate(lock, locked)
307
308 def testWithBoundedSemaphore(self):
309 lock = threading.BoundedSemaphore()
310 def locked():
311 if lock.acquire(False):
312 lock.release()
313 return False
314 else:
315 return True
316 self.boilerPlate(lock, locked)
317
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000318# This is needed to make the test actually run under regrtest.py!
319def test_main():
Florent Xicluna6257a7b2010-03-31 22:01:03 +0000320 with test_support.check_warnings(("With-statements now directly support "
321 "multiple context managers",
322 DeprecationWarning)):
Raymond Hettinger822b87f2009-05-29 01:46:48 +0000323 test_support.run_unittest(__name__)
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000324
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000325if __name__ == "__main__":
Phillip J. Ebybd0c10f2006-04-10 18:33:17 +0000326 test_main()