blob: 389e7d658c59d97ed3f7021d6d0668328238ac87 [file] [log] [blame]
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00001"""Unit tests for contextlib.py, and other context managers."""
2
R. David Murray378c0cf2010-02-24 01:46:21 +00003import sys
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00004import tempfile
5import unittest
Guido van Rossum1a5e21e2006-02-28 21:57:43 +00006from contextlib import * # Tests __all__
Benjamin Petersonee8712c2008-05-20 21:35:26 +00007from test import support
Victor Stinner45df8202010-04-28 22:31:17 +00008try:
9 import threading
10except ImportError:
11 threading = None
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000012
Florent Xicluna41fe6152010-04-02 18:52:12 +000013
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000014class ContextManagerTestCase(unittest.TestCase):
15
16 def test_contextmanager_plain(self):
17 state = []
18 @contextmanager
19 def woohoo():
20 state.append(1)
21 yield 42
22 state.append(999)
23 with woohoo() as x:
24 self.assertEqual(state, [1])
25 self.assertEqual(x, 42)
26 state.append(x)
27 self.assertEqual(state, [1, 42, 999])
28
29 def test_contextmanager_finally(self):
30 state = []
31 @contextmanager
32 def woohoo():
33 state.append(1)
34 try:
35 yield 42
36 finally:
37 state.append(999)
Florent Xicluna41fe6152010-04-02 18:52:12 +000038 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000039 with woohoo() as x:
40 self.assertEqual(state, [1])
41 self.assertEqual(x, 42)
42 state.append(x)
43 raise ZeroDivisionError()
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000044 self.assertEqual(state, [1, 42, 999])
45
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000046 def test_contextmanager_no_reraise(self):
47 @contextmanager
48 def whee():
49 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000050 ctx = whee()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000051 ctx.__enter__()
52 # Calling __exit__ should not result in an exception
Benjamin Petersonc9c0f202009-06-30 23:06:06 +000053 self.assertFalse(ctx.__exit__(TypeError, TypeError("foo"), None))
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000054
55 def test_contextmanager_trap_yield_after_throw(self):
56 @contextmanager
57 def whoo():
58 try:
59 yield
60 except:
61 yield
Thomas Wouters477c8d52006-05-27 19:21:47 +000062 ctx = whoo()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000063 ctx.__enter__()
64 self.assertRaises(
65 RuntimeError, ctx.__exit__, TypeError, TypeError("foo"), None
66 )
67
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000068 def test_contextmanager_except(self):
69 state = []
70 @contextmanager
71 def woohoo():
72 state.append(1)
73 try:
74 yield 42
Guido van Rossumb940e112007-01-10 16:19:56 +000075 except ZeroDivisionError as e:
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000076 state.append(e.args[0])
77 self.assertEqual(state, [1, 42, 999])
78 with woohoo() as x:
79 self.assertEqual(state, [1])
80 self.assertEqual(x, 42)
81 state.append(x)
82 raise ZeroDivisionError(999)
83 self.assertEqual(state, [1, 42, 999])
84
R. David Murray378c0cf2010-02-24 01:46:21 +000085 def _create_contextmanager_attribs(self):
Thomas Wouters49fd7fa2006-04-21 10:40:58 +000086 def attribs(**kw):
87 def decorate(func):
88 for k,v in kw.items():
89 setattr(func,k,v)
90 return func
91 return decorate
92 @contextmanager
93 @attribs(foo='bar')
94 def baz(spam):
95 """Whee!"""
R. David Murray378c0cf2010-02-24 01:46:21 +000096 return baz
97
98 def test_contextmanager_attribs(self):
99 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000100 self.assertEqual(baz.__name__,'baz')
101 self.assertEqual(baz.foo, 'bar')
R. David Murray378c0cf2010-02-24 01:46:21 +0000102
103 @unittest.skipIf(sys.flags.optimize >= 2,
104 "Docstrings are omitted with -O2 and above")
105 def test_contextmanager_doc_attrib(self):
106 baz = self._create_contextmanager_attribs()
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000107 self.assertEqual(baz.__doc__, "Whee!")
108
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000109class ClosingTestCase(unittest.TestCase):
110
111 # XXX This needs more work
112
113 def test_closing(self):
114 state = []
115 class C:
116 def close(self):
117 state.append(1)
118 x = C()
119 self.assertEqual(state, [])
120 with closing(x) as y:
121 self.assertEqual(x, y)
122 self.assertEqual(state, [1])
123
124 def test_closing_error(self):
125 state = []
126 class C:
127 def close(self):
128 state.append(1)
129 x = C()
130 self.assertEqual(state, [])
Florent Xicluna41fe6152010-04-02 18:52:12 +0000131 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000132 with closing(x) as y:
133 self.assertEqual(x, y)
Florent Xicluna41fe6152010-04-02 18:52:12 +0000134 1 / 0
135 self.assertEqual(state, [1])
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000136
137class FileContextTestCase(unittest.TestCase):
138
139 def testWithOpen(self):
140 tfn = tempfile.mktemp()
141 try:
142 f = None
143 with open(tfn, "w") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000144 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000145 f.write("Booh\n")
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000146 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000147 f = None
Florent Xicluna41fe6152010-04-02 18:52:12 +0000148 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000149 with open(tfn, "r") as f:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000150 self.assertFalse(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000151 self.assertEqual(f.read(), "Booh\n")
Florent Xicluna41fe6152010-04-02 18:52:12 +0000152 1 / 0
153 self.assertTrue(f.closed)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000154 finally:
Florent Xicluna41fe6152010-04-02 18:52:12 +0000155 support.unlink(tfn)
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000156
Victor Stinner45df8202010-04-28 22:31:17 +0000157@unittest.skipUnless(threading, 'Threading required for this test.')
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000158class LockContextTestCase(unittest.TestCase):
159
160 def boilerPlate(self, lock, locked):
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000161 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000162 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000163 self.assertTrue(locked())
164 self.assertFalse(locked())
Florent Xicluna41fe6152010-04-02 18:52:12 +0000165 with self.assertRaises(ZeroDivisionError):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000166 with lock:
Benjamin Petersonc9c0f202009-06-30 23:06:06 +0000167 self.assertTrue(locked())
Florent Xicluna41fe6152010-04-02 18:52:12 +0000168 1 / 0
169 self.assertFalse(locked())
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000170
171 def testWithLock(self):
172 lock = threading.Lock()
173 self.boilerPlate(lock, lock.locked)
174
175 def testWithRLock(self):
176 lock = threading.RLock()
177 self.boilerPlate(lock, lock._is_owned)
178
179 def testWithCondition(self):
180 lock = threading.Condition()
181 def locked():
182 return lock._is_owned()
183 self.boilerPlate(lock, locked)
184
185 def testWithSemaphore(self):
186 lock = threading.Semaphore()
187 def locked():
188 if lock.acquire(False):
189 lock.release()
190 return False
191 else:
192 return True
193 self.boilerPlate(lock, locked)
194
195 def testWithBoundedSemaphore(self):
196 lock = threading.BoundedSemaphore()
197 def locked():
198 if lock.acquire(False):
199 lock.release()
200 return False
201 else:
202 return True
203 self.boilerPlate(lock, locked)
204
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000205# This is needed to make the test actually run under regrtest.py!
206def test_main():
Benjamin Petersonc8c0d782009-07-01 01:39:51 +0000207 support.run_unittest(__name__)
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000208
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000209if __name__ == "__main__":
Thomas Wouters49fd7fa2006-04-21 10:40:58 +0000210 test_main()