blob: 1d3bd90c58cb4307e0674309c1c38eecea22171a [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00003
4import os
R. David Murray378c0cf2010-02-24 01:46:21 +00005import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006import tempfile
7import unittest
8import threading
9from contextlib import * # Tests __all__
Benjamin Petersonee8712c2008-05-20 21:35:26 +000010from test import support
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000011
12class ContextManagerTestCase(unittest.TestCase):
13
14 def test_contextmanager_plain(self):
15 state = []
16 @contextmanager
17 def woohoo():
18 state.append(1)
19 yield 42
20 state.append(999)
21 with woohoo() as x:
22 self.assertEqual(state, [1])
23 self.assertEqual(x, 42)
24 state.append(x)
25 self.assertEqual(state, [1, 42, 999])
26
27 def test_contextmanager_finally(self):
28 state = []
29 @contextmanager
30 def woohoo():
31 state.append(1)
32 try:
33 yield 42
34 finally:
35 state.append(999)
36 try:
37 with woohoo() as x:
38 self.assertEqual(state, [1])
39 self.assertEqual(x, 42)
40 state.append(x)
41 raise ZeroDivisionError()
42 except ZeroDivisionError:
43 pass
44 else:
45 self.fail("Expected ZeroDivisionError")
46 self.assertEqual(state, [1, 42, 999])
47
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000048 def test_contextmanager_no_reraise(self):
49 @contextmanager
50 def whee():
51 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000052 ctx = whee()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000053 ctx.__enter__()
54 # Calling __exit__ should not result in an exception
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000055 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000056
57 def test_contextmanager_trap_yield_after_throw(self):
58 @contextmanager
59 def whoo():
60 try:
61 yield
62 except:
63 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000064 ctx = whoo()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000065 ctx.__enter__()
66 self.assertRaises(
67 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
68 )
69
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000070 def test_contextmanager_except(self):
71 state = []
72 @contextmanager
73 def woohoo():
74 state.append(1)
75 try:
76 yield 42
Guido van Rossumb940e112007-01-10 16:19:56 +000077 except ZeroDivisionError as e:
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000078 state.append(e.args[0])
79 self.assertEqual(state, [1, 42, 999])
80 with woohoo() as x:
81 self.assertEqual(state, [1])
82 self.assertEqual(x, 42)
83 state.append(x)
84 raise ZeroDivisionError(999)
85 self.assertEqual(state, [1, 42, 999])
86
R. David Murray378c0cf2010-02-24 01:46:21 +000087 def _create_contextmanager_attribs(self):
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000088 def attribs(**kw):
89 def decorate(func):
90 for k,v in kw.items():
91 setattr(func,k,v)
92 return func
93 return decorate
94 @contextmanager
95 @attribs(foo='bar')
96 def baz(spam):
97 """Whee!"""
R. David Murray378c0cf2010-02-24 01:46:21 +000098 return baz
99
100 def test_contextmanager_attribs(self):
101 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000102 self.assertEqual(baz.__name__,'baz')
103 self.assertEqual(baz.foo, 'bar')
R. David Murray378c0cf2010-02-24 01:46:21 +0000104
105 @unittest.skipIf(sys.flags.optimize >= 2,
106 "Docstrings are omitted with -O2 and above")
107 def test_contextmanager_doc_attrib(self):
108 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000109 self.assertEqual(baz.__doc__, "Whee!")
110
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000111class ClosingTestCase(unittest.TestCase):
112
113 # XXX This needs more work
114
115 def test_closing(self):
116 state = []
117 class C:
118 def close(self):
119 state.append(1)
120 x = C()
121 self.assertEqual(state, [])
122 with closing(x) as y:
123 self.assertEqual(x, y)
124 self.assertEqual(state, [1])
125
126 def test_closing_error(self):
127 state = []
128 class C:
129 def close(self):
130 state.append(1)
131 x = C()
132 self.assertEqual(state, [])
133 try:
134 with closing(x) as y:
135 self.assertEqual(x, y)
136 1/0
137 except ZeroDivisionError:
138 self.assertEqual(state, [1])
139 else:
140 self.fail("Didn't raise ZeroDivisionError")
141
142class FileContextTestCase(unittest.TestCase):
143
144 def testWithOpen(self):
145 tfn = tempfile.mktemp()
146 try:
147 f = None
148 with open(tfn, "w") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000149 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000150 f.write("Booh\n")
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000151 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000152 f = None
153 try:
154 with open(tfn, "r") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000155 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000156 self.assertEqual(f.read(), "Booh\n")
157 1/0
158 except ZeroDivisionError:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000159 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000160 else:
161 self.fail("Didn't raise ZeroDivisionError")
162 finally:
163 try:
164 os.remove(tfn)
165 except os.error:
166 pass
167
168class LockContextTestCase(unittest.TestCase):
169
170 def boilerPlate(self, lock, locked):
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000171 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000172 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000173 self.assertTrue(locked())
174 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000175 try:
176 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000177 self.assertTrue(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000178 1/0
179 except ZeroDivisionError:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000180 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000181 else:
182 self.fail("Didn't raise ZeroDivisionError")
183
184 def testWithLock(self):
185 lock = threading.Lock()
186 self.boilerPlate(lock, lock.locked)
187
188 def testWithRLock(self):
189 lock = threading.RLock()
190 self.boilerPlate(lock, lock._is_owned)
191
192 def testWithCondition(self):
193 lock = threading.Condition()
194 def locked():
195 return lock._is_owned()
196 self.boilerPlate(lock, locked)
197
198 def testWithSemaphore(self):
199 lock = threading.Semaphore()
200 def locked():
201 if lock.acquire(False):
202 lock.release()
203 return False
204 else:
205 return True
206 self.boilerPlate(lock, locked)
207
208 def testWithBoundedSemaphore(self):
209 lock = threading.BoundedSemaphore()
210 def locked():
211 if lock.acquire(False):
212 lock.release()
213 return False
214 else:
215 return True
216 self.boilerPlate(lock, locked)
217
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000218# This is needed to make the test actually run under regrtest.py!
219def test_main():
Benjamin Petersonc8c0d782009-07-01 01:39:51 +0000220 support.run_unittest(__name__)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000221
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000222if __name__ == "__main__":
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000223 test_main()