blob: ae18085ca3881cc9a4ffc6d260954b10639b56ed [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00003
Thomas Wouters49fd7fa2006-04-21 10:40:58 +00004import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00005import os
6import decimal
R. David Murray378c0cf2010-02-24 01:46:21 +00007import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00008import tempfile
9import unittest
10import threading
11from contextlib import * # Tests __all__
Benjamin Petersonee8712c2008-05-20 21:35:26 +000012from test import support
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000013
14class ContextManagerTestCase(unittest.TestCase):
15
16 def test_contextmanager_plain(self):
17 state = []
18 @contextmanager
19 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
29 def test_contextmanager_finally(self):
30 state = []
31 @contextmanager
32 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
38 try:
39 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
44 except ZeroDivisionError:
45 pass
46 else:
47 self.fail("Expected ZeroDivisionError")
48 self.assertEqual(state, [1, 42, 999])
49
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000050 def test_contextmanager_no_reraise(self):
51 @contextmanager
52 def whee():
53 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000054 ctx = whee()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000055 ctx.__enter__()
56 # Calling __exit__ should not result in an exception
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000057 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000058
59 def test_contextmanager_trap_yield_after_throw(self):
60 @contextmanager
61 def whoo():
62 try:
63 yield
64 except:
65 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000066 ctx = whoo()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000067 ctx.__enter__()
68 self.assertRaises(
69 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
70 )
71
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000072 def test_contextmanager_except(self):
73 state = []
74 @contextmanager
75 def woohoo():
76 state.append(1)
77 try:
78 yield 42
Guido van Rossumb940e112007-01-10 16:19:56 +000079 except ZeroDivisionError as e:
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000080 state.append(e.args[0])
81 self.assertEqual(state, [1, 42, 999])
82 with woohoo() as x:
83 self.assertEqual(state, [1])
84 self.assertEqual(x, 42)
85 state.append(x)
86 raise ZeroDivisionError(999)
87 self.assertEqual(state, [1, 42, 999])
88
R. David Murray378c0cf2010-02-24 01:46:21 +000089 def _create_contextmanager_attribs(self):
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000090 def attribs(**kw):
91 def decorate(func):
92 for k,v in kw.items():
93 setattr(func,k,v)
94 return func
95 return decorate
96 @contextmanager
97 @attribs(foo='bar')
98 def baz(spam):
99 """Whee!"""
R. David Murray378c0cf2010-02-24 01:46:21 +0000100 return baz
101
102 def test_contextmanager_attribs(self):
103 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000104 self.assertEqual(baz.__name__,'baz')
105 self.assertEqual(baz.foo, 'bar')
R. David Murray378c0cf2010-02-24 01:46:21 +0000106
107 @unittest.skipIf(sys.flags.optimize >= 2,
108 "Docstrings are omitted with -O2 and above")
109 def test_contextmanager_doc_attrib(self):
110 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000111 self.assertEqual(baz.__doc__, "Whee!")
112
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000113class ClosingTestCase(unittest.TestCase):
114
115 # XXX This needs more work
116
117 def test_closing(self):
118 state = []
119 class C:
120 def close(self):
121 state.append(1)
122 x = C()
123 self.assertEqual(state, [])
124 with closing(x) as y:
125 self.assertEqual(x, y)
126 self.assertEqual(state, [1])
127
128 def test_closing_error(self):
129 state = []
130 class C:
131 def close(self):
132 state.append(1)
133 x = C()
134 self.assertEqual(state, [])
135 try:
136 with closing(x) as y:
137 self.assertEqual(x, y)
138 1/0
139 except ZeroDivisionError:
140 self.assertEqual(state, [1])
141 else:
142 self.fail("Didn't raise ZeroDivisionError")
143
144class FileContextTestCase(unittest.TestCase):
145
146 def testWithOpen(self):
147 tfn = tempfile.mktemp()
148 try:
149 f = None
150 with open(tfn, "w") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000151 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000152 f.write("Booh\n")
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000153 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000154 f = None
155 try:
156 with open(tfn, "r") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000157 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000158 self.assertEqual(f.read(), "Booh\n")
159 1/0
160 except ZeroDivisionError:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000161 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000162 else:
163 self.fail("Didn't raise ZeroDivisionError")
164 finally:
165 try:
166 os.remove(tfn)
167 except os.error:
168 pass
169
170class LockContextTestCase(unittest.TestCase):
171
172 def boilerPlate(self, lock, locked):
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000173 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000174 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000175 self.assertTrue(locked())
176 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000177 try:
178 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000179 self.assertTrue(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000180 1/0
181 except ZeroDivisionError:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000182 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000183 else:
184 self.fail("Didn't raise ZeroDivisionError")
185
186 def testWithLock(self):
187 lock = threading.Lock()
188 self.boilerPlate(lock, lock.locked)
189
190 def testWithRLock(self):
191 lock = threading.RLock()
192 self.boilerPlate(lock, lock._is_owned)
193
194 def testWithCondition(self):
195 lock = threading.Condition()
196 def locked():
197 return lock._is_owned()
198 self.boilerPlate(lock, locked)
199
200 def testWithSemaphore(self):
201 lock = threading.Semaphore()
202 def locked():
203 if lock.acquire(False):
204 lock.release()
205 return False
206 else:
207 return True
208 self.boilerPlate(lock, locked)
209
210 def testWithBoundedSemaphore(self):
211 lock = threading.BoundedSemaphore()
212 def locked():
213 if lock.acquire(False):
214 lock.release()
215 return False
216 else:
217 return True
218 self.boilerPlate(lock, locked)
219
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000220# This is needed to make the test actually run under regrtest.py!
221def test_main():
Benjamin Petersonc8c0d782009-07-01 01:39:51 +0000222 support.run_unittest(__name__)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000223
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000224if __name__ == "__main__":
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000225 test_main()