blob: 36035e3c00550938a1e4b3b61da1d8c03ad0016a [file] [log] [blame]
Tim Peters400cbc32006-02-28 18:44:41 +00001#!/usr/bin/env python
2
3"""Unit tests for the with statement specified in PEP 343."""
4
Thomas Wouters34aa7ba2006-02-28 19:02:24 +00005from __future__ import with_statement
6
Tim Peters400cbc32006-02-28 18:44:41 +00007__author__ = "Mike Bland"
8__email__ = "mbland at acm dot org"
9
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000010import sys
Tim Peters400cbc32006-02-28 18:44:41 +000011import unittest
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000012from collections import deque
13from contextlib import GeneratorContextManager, contextmanager
Tim Peters400cbc32006-02-28 18:44:41 +000014from test.test_support import run_unittest
15
16
17class MockContextManager(GeneratorContextManager):
18 def __init__(self, gen):
19 GeneratorContextManager.__init__(self, gen)
20 self.context_called = False
21 self.enter_called = False
22 self.exit_called = False
23 self.exit_args = None
24
25 def __context__(self):
26 self.context_called = True
27 return GeneratorContextManager.__context__(self)
28
29 def __enter__(self):
30 self.enter_called = True
31 return GeneratorContextManager.__enter__(self)
32
33 def __exit__(self, type, value, traceback):
34 self.exit_called = True
35 self.exit_args = (type, value, traceback)
36 return GeneratorContextManager.__exit__(self, type, value, traceback)
37
38
39def mock_contextmanager(func):
40 def helper(*args, **kwds):
41 return MockContextManager(func(*args, **kwds))
42 return helper
43
44
45class MockResource(object):
46 def __init__(self):
47 self.yielded = False
48 self.stopped = False
49
50
51@mock_contextmanager
52def mock_contextmanager_generator():
53 mock = MockResource()
54 try:
55 mock.yielded = True
56 yield mock
57 finally:
58 mock.stopped = True
59
60
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000061class Nested(object):
62
Tim Peters400cbc32006-02-28 18:44:41 +000063 def __init__(self, *contexts):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +000064 self.contexts = contexts
65 self.entered = None
66
67 def __context__(self):
68 return self
69
70 def __enter__(self):
71 if self.entered is not None:
72 raise RuntimeError("Context is not reentrant")
73 self.entered = deque()
74 vars = []
75 try:
76 for context in self.contexts:
77 mgr = context.__context__()
78 vars.append(mgr.__enter__())
79 self.entered.appendleft(mgr)
80 except:
81 self.__exit__(*sys.exc_info())
82 raise
83 return vars
84
85 def __exit__(self, *exc_info):
86 # Behave like nested with statements
87 # first in, last out
88 # New exceptions override old ones
89 ex = exc_info
90 for mgr in self.entered:
91 try:
92 mgr.__exit__(*ex)
93 except:
94 ex = sys.exc_info()
95 self.entered = None
96 if ex is not exc_info:
97 raise ex[0], ex[1], ex[2]
98
99
100class MockNested(Nested):
101 def __init__(self, *contexts):
102 Nested.__init__(self, *contexts)
Tim Peters400cbc32006-02-28 18:44:41 +0000103 self.context_called = False
104 self.enter_called = False
105 self.exit_called = False
106 self.exit_args = None
107
108 def __context__(self):
109 self.context_called = True
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000110 return Nested.__context__(self)
Tim Peters400cbc32006-02-28 18:44:41 +0000111
112 def __enter__(self):
113 self.enter_called = True
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000114 return Nested.__enter__(self)
Tim Peters400cbc32006-02-28 18:44:41 +0000115
116 def __exit__(self, *exc_info):
117 self.exit_called = True
118 self.exit_args = exc_info
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000119 return Nested.__exit__(self, *exc_info)
Tim Peters400cbc32006-02-28 18:44:41 +0000120
121
122class FailureTestCase(unittest.TestCase):
123 def testNameError(self):
124 def fooNotDeclared():
125 with foo: pass
126 self.assertRaises(NameError, fooNotDeclared)
127
128 def testContextAttributeError(self):
129 class LacksContext(object):
130 def __enter__(self):
131 pass
132
133 def __exit__(self, type, value, traceback):
134 pass
135
136 def fooLacksContext():
137 foo = LacksContext()
138 with foo: pass
139 self.assertRaises(AttributeError, fooLacksContext)
140
141 def testEnterAttributeError(self):
142 class LacksEnter(object):
143 def __context__(self):
144 pass
145
146 def __exit__(self, type, value, traceback):
147 pass
148
149 def fooLacksEnter():
150 foo = LacksEnter()
151 with foo: pass
152 self.assertRaises(AttributeError, fooLacksEnter)
153
154 def testExitAttributeError(self):
155 class LacksExit(object):
156 def __context__(self):
157 pass
158
159 def __enter__(self):
160 pass
161
162 def fooLacksExit():
163 foo = LacksExit()
164 with foo: pass
165 self.assertRaises(AttributeError, fooLacksExit)
166
167 def assertRaisesSyntaxError(self, codestr):
168 def shouldRaiseSyntaxError(s):
169 compile(s, '', 'single')
170 self.assertRaises(SyntaxError, shouldRaiseSyntaxError, codestr)
171
172 def testAssignmentToNoneError(self):
173 self.assertRaisesSyntaxError('with mock as None:\n pass')
174 self.assertRaisesSyntaxError(
175 'with mock as (None):\n'
176 ' pass')
177
178 def testAssignmentToEmptyTupleError(self):
179 self.assertRaisesSyntaxError(
180 'with mock as ():\n'
181 ' pass')
182
183 def testAssignmentToTupleOnlyContainingNoneError(self):
184 self.assertRaisesSyntaxError('with mock as None,:\n pass')
185 self.assertRaisesSyntaxError(
186 'with mock as (None,):\n'
187 ' pass')
188
189 def testAssignmentToTupleContainingNoneError(self):
190 self.assertRaisesSyntaxError(
191 'with mock as (foo, None, bar):\n'
192 ' pass')
193
194 def testContextThrows(self):
195 class ContextThrows(object):
196 def __context__(self):
197 raise RuntimeError("Context threw")
198
199 def shouldThrow():
200 ct = ContextThrows()
201 self.foo = None
202 with ct as self.foo:
203 pass
204 self.assertRaises(RuntimeError, shouldThrow)
205 self.assertEqual(self.foo, None)
206
207 def testEnterThrows(self):
208 class EnterThrows(object):
209 def __context__(self):
210 return self
211
212 def __enter__(self):
213 raise RuntimeError("Context threw")
214
215 def __exit__(self, *args):
216 pass
217
218 def shouldThrow():
219 ct = EnterThrows()
220 self.foo = None
221 with ct as self.foo:
222 pass
223 self.assertRaises(RuntimeError, shouldThrow)
224 self.assertEqual(self.foo, None)
225
226 def testExitThrows(self):
227 class ExitThrows(object):
228 def __context__(self):
229 return self
230 def __enter__(self):
231 return
232 def __exit__(self, *args):
233 raise RuntimeError(42)
234 def shouldThrow():
235 with ExitThrows():
236 pass
237 self.assertRaises(RuntimeError, shouldThrow)
238
239class ContextmanagerAssertionMixin(object):
240 TEST_EXCEPTION = RuntimeError("test exception")
241
242 def assertInWithManagerInvariants(self, mock_manager):
243 self.assertTrue(mock_manager.context_called)
244 self.assertTrue(mock_manager.enter_called)
245 self.assertFalse(mock_manager.exit_called)
246 self.assertEqual(mock_manager.exit_args, None)
247
248 def assertAfterWithManagerInvariants(self, mock_manager, exit_args):
249 self.assertTrue(mock_manager.context_called)
250 self.assertTrue(mock_manager.enter_called)
251 self.assertTrue(mock_manager.exit_called)
252 self.assertEqual(mock_manager.exit_args, exit_args)
253
254 def assertAfterWithManagerInvariantsNoError(self, mock_manager):
255 self.assertAfterWithManagerInvariants(mock_manager,
256 (None, None, None))
257
258 def assertInWithGeneratorInvariants(self, mock_generator):
259 self.assertTrue(mock_generator.yielded)
260 self.assertFalse(mock_generator.stopped)
261
262 def assertAfterWithGeneratorInvariantsNoError(self, mock_generator):
263 self.assertTrue(mock_generator.yielded)
264 self.assertTrue(mock_generator.stopped)
265
266 def raiseTestException(self):
267 raise self.TEST_EXCEPTION
268
269 def assertAfterWithManagerInvariantsWithError(self, mock_manager):
270 self.assertTrue(mock_manager.context_called)
271 self.assertTrue(mock_manager.enter_called)
272 self.assertTrue(mock_manager.exit_called)
273 self.assertEqual(mock_manager.exit_args[0], RuntimeError)
274 self.assertEqual(mock_manager.exit_args[1], self.TEST_EXCEPTION)
275
276 def assertAfterWithGeneratorInvariantsWithError(self, mock_generator):
277 self.assertTrue(mock_generator.yielded)
278 self.assertTrue(mock_generator.stopped)
279
280
281class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
282 def testInlineGeneratorSyntax(self):
283 with mock_contextmanager_generator():
284 pass
285
286 def testUnboundGenerator(self):
287 mock = mock_contextmanager_generator()
288 with mock:
289 pass
290 self.assertAfterWithManagerInvariantsNoError(mock)
291
292 def testInlineGeneratorBoundSyntax(self):
293 with mock_contextmanager_generator() as foo:
294 self.assertInWithGeneratorInvariants(foo)
295 # FIXME: In the future, we'll try to keep the bound names from leaking
296 self.assertAfterWithGeneratorInvariantsNoError(foo)
297
298 def testInlineGeneratorBoundToExistingVariable(self):
299 foo = None
300 with mock_contextmanager_generator() as foo:
301 self.assertInWithGeneratorInvariants(foo)
302 self.assertAfterWithGeneratorInvariantsNoError(foo)
303
304 def testInlineGeneratorBoundToDottedVariable(self):
305 with mock_contextmanager_generator() as self.foo:
306 self.assertInWithGeneratorInvariants(self.foo)
307 self.assertAfterWithGeneratorInvariantsNoError(self.foo)
308
309 def testBoundGenerator(self):
310 mock = mock_contextmanager_generator()
311 with mock as foo:
312 self.assertInWithGeneratorInvariants(foo)
313 self.assertInWithManagerInvariants(mock)
314 self.assertAfterWithGeneratorInvariantsNoError(foo)
315 self.assertAfterWithManagerInvariantsNoError(mock)
316
317 def testNestedSingleStatements(self):
318 mock_a = mock_contextmanager_generator()
319 with mock_a as foo:
320 mock_b = mock_contextmanager_generator()
321 with mock_b as bar:
322 self.assertInWithManagerInvariants(mock_a)
323 self.assertInWithManagerInvariants(mock_b)
324 self.assertInWithGeneratorInvariants(foo)
325 self.assertInWithGeneratorInvariants(bar)
326 self.assertAfterWithManagerInvariantsNoError(mock_b)
327 self.assertAfterWithGeneratorInvariantsNoError(bar)
328 self.assertInWithManagerInvariants(mock_a)
329 self.assertInWithGeneratorInvariants(foo)
330 self.assertAfterWithManagerInvariantsNoError(mock_a)
331 self.assertAfterWithGeneratorInvariantsNoError(foo)
332
333
334class NestedNonexceptionalTestCase(unittest.TestCase,
335 ContextmanagerAssertionMixin):
336 def testSingleArgInlineGeneratorSyntax(self):
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000337 with Nested(mock_contextmanager_generator()):
Tim Peters400cbc32006-02-28 18:44:41 +0000338 pass
339
340 def testSingleArgUnbound(self):
341 mock_contextmanager = mock_contextmanager_generator()
342 mock_nested = MockNested(mock_contextmanager)
343 with mock_nested:
344 self.assertInWithManagerInvariants(mock_contextmanager)
345 self.assertInWithManagerInvariants(mock_nested)
346 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
347 self.assertAfterWithManagerInvariantsNoError(mock_nested)
348
349 def testSingleArgBoundToNonTuple(self):
350 m = mock_contextmanager_generator()
351 # This will bind all the arguments to nested() into a single list
352 # assigned to foo.
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000353 with Nested(m) as foo:
Tim Peters400cbc32006-02-28 18:44:41 +0000354 self.assertInWithManagerInvariants(m)
355 self.assertAfterWithManagerInvariantsNoError(m)
356
357 def testSingleArgBoundToSingleElementParenthesizedList(self):
358 m = mock_contextmanager_generator()
359 # This will bind all the arguments to nested() into a single list
360 # assigned to foo.
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000361 with Nested(m) as (foo):
Tim Peters400cbc32006-02-28 18:44:41 +0000362 self.assertInWithManagerInvariants(m)
363 self.assertAfterWithManagerInvariantsNoError(m)
364
365 def testSingleArgBoundToMultipleElementTupleError(self):
366 def shouldThrowValueError():
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000367 with Nested(mock_contextmanager_generator()) as (foo, bar):
Tim Peters400cbc32006-02-28 18:44:41 +0000368 pass
369 self.assertRaises(ValueError, shouldThrowValueError)
370
371 def testSingleArgUnbound(self):
372 mock_contextmanager = mock_contextmanager_generator()
373 mock_nested = MockNested(mock_contextmanager)
374 with mock_nested:
375 self.assertInWithManagerInvariants(mock_contextmanager)
376 self.assertInWithManagerInvariants(mock_nested)
377 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
378 self.assertAfterWithManagerInvariantsNoError(mock_nested)
379
380 def testMultipleArgUnbound(self):
381 m = mock_contextmanager_generator()
382 n = mock_contextmanager_generator()
383 o = mock_contextmanager_generator()
384 mock_nested = MockNested(m, n, o)
385 with mock_nested:
386 self.assertInWithManagerInvariants(m)
387 self.assertInWithManagerInvariants(n)
388 self.assertInWithManagerInvariants(o)
389 self.assertInWithManagerInvariants(mock_nested)
390 self.assertAfterWithManagerInvariantsNoError(m)
391 self.assertAfterWithManagerInvariantsNoError(n)
392 self.assertAfterWithManagerInvariantsNoError(o)
393 self.assertAfterWithManagerInvariantsNoError(mock_nested)
394
395 def testMultipleArgBound(self):
396 mock_nested = MockNested(mock_contextmanager_generator(),
397 mock_contextmanager_generator(), mock_contextmanager_generator())
398 with mock_nested as (m, n, o):
399 self.assertInWithGeneratorInvariants(m)
400 self.assertInWithGeneratorInvariants(n)
401 self.assertInWithGeneratorInvariants(o)
402 self.assertInWithManagerInvariants(mock_nested)
403 self.assertAfterWithGeneratorInvariantsNoError(m)
404 self.assertAfterWithGeneratorInvariantsNoError(n)
405 self.assertAfterWithGeneratorInvariantsNoError(o)
406 self.assertAfterWithManagerInvariantsNoError(mock_nested)
407
408
409class ExceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
410 def testSingleResource(self):
411 cm = mock_contextmanager_generator()
412 def shouldThrow():
413 with cm as self.resource:
414 self.assertInWithManagerInvariants(cm)
415 self.assertInWithGeneratorInvariants(self.resource)
416 self.raiseTestException()
417 self.assertRaises(RuntimeError, shouldThrow)
418 self.assertAfterWithManagerInvariantsWithError(cm)
419 self.assertAfterWithGeneratorInvariantsWithError(self.resource)
420
421 def testNestedSingleStatements(self):
422 mock_a = mock_contextmanager_generator()
423 mock_b = mock_contextmanager_generator()
424 def shouldThrow():
425 with mock_a as self.foo:
426 with mock_b as self.bar:
427 self.assertInWithManagerInvariants(mock_a)
428 self.assertInWithManagerInvariants(mock_b)
429 self.assertInWithGeneratorInvariants(self.foo)
430 self.assertInWithGeneratorInvariants(self.bar)
431 self.raiseTestException()
432 self.assertRaises(RuntimeError, shouldThrow)
433 self.assertAfterWithManagerInvariantsWithError(mock_a)
434 self.assertAfterWithManagerInvariantsWithError(mock_b)
435 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
436 self.assertAfterWithGeneratorInvariantsWithError(self.bar)
437
438 def testMultipleResourcesInSingleStatement(self):
439 cm_a = mock_contextmanager_generator()
440 cm_b = mock_contextmanager_generator()
441 mock_nested = MockNested(cm_a, cm_b)
442 def shouldThrow():
443 with mock_nested as (self.resource_a, self.resource_b):
444 self.assertInWithManagerInvariants(cm_a)
445 self.assertInWithManagerInvariants(cm_b)
446 self.assertInWithManagerInvariants(mock_nested)
447 self.assertInWithGeneratorInvariants(self.resource_a)
448 self.assertInWithGeneratorInvariants(self.resource_b)
449 self.raiseTestException()
450 self.assertRaises(RuntimeError, shouldThrow)
451 self.assertAfterWithManagerInvariantsWithError(cm_a)
452 self.assertAfterWithManagerInvariantsWithError(cm_b)
453 self.assertAfterWithManagerInvariantsWithError(mock_nested)
454 self.assertAfterWithGeneratorInvariantsWithError(self.resource_a)
455 self.assertAfterWithGeneratorInvariantsWithError(self.resource_b)
456
457 def testNestedExceptionBeforeInnerStatement(self):
458 mock_a = mock_contextmanager_generator()
459 mock_b = mock_contextmanager_generator()
460 self.bar = None
461 def shouldThrow():
462 with mock_a as self.foo:
463 self.assertInWithManagerInvariants(mock_a)
464 self.assertInWithGeneratorInvariants(self.foo)
465 self.raiseTestException()
466 with mock_b as self.bar:
467 pass
468 self.assertRaises(RuntimeError, shouldThrow)
469 self.assertAfterWithManagerInvariantsWithError(mock_a)
470 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
471
472 # The inner statement stuff should never have been touched
473 self.assertEqual(self.bar, None)
474 self.assertFalse(mock_b.context_called)
475 self.assertFalse(mock_b.enter_called)
476 self.assertFalse(mock_b.exit_called)
477 self.assertEqual(mock_b.exit_args, None)
478
479 def testNestedExceptionAfterInnerStatement(self):
480 mock_a = mock_contextmanager_generator()
481 mock_b = mock_contextmanager_generator()
482 def shouldThrow():
483 with mock_a as self.foo:
484 with mock_b as self.bar:
485 self.assertInWithManagerInvariants(mock_a)
486 self.assertInWithManagerInvariants(mock_b)
487 self.assertInWithGeneratorInvariants(self.foo)
488 self.assertInWithGeneratorInvariants(self.bar)
489 self.raiseTestException()
490 self.assertRaises(RuntimeError, shouldThrow)
491 self.assertAfterWithManagerInvariantsWithError(mock_a)
492 self.assertAfterWithManagerInvariantsNoError(mock_b)
493 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
494 self.assertAfterWithGeneratorInvariantsNoError(self.bar)
495
496
497class NonLocalFlowControlTestCase(unittest.TestCase):
498
499 def testWithBreak(self):
500 counter = 0
501 while True:
502 counter += 1
503 with mock_contextmanager_generator():
504 counter += 10
505 break
506 counter += 100 # Not reached
507 self.assertEqual(counter, 11)
508
509 def testWithContinue(self):
510 counter = 0
511 while True:
512 counter += 1
513 if counter > 2:
514 break
515 with mock_contextmanager_generator():
516 counter += 10
517 continue
518 counter += 100 # Not reached
519 self.assertEqual(counter, 12)
520
521 def testWithReturn(self):
522 def foo():
523 counter = 0
524 while True:
525 counter += 1
526 with mock_contextmanager_generator():
527 counter += 10
528 return counter
529 counter += 100 # Not reached
530 self.assertEqual(foo(), 11)
531
532 def testWithYield(self):
533 def gen():
534 with mock_contextmanager_generator():
535 yield 12
536 yield 13
537 x = list(gen())
538 self.assertEqual(x, [12, 13])
539
540 def testWithRaise(self):
541 counter = 0
542 try:
543 counter += 1
544 with mock_contextmanager_generator():
545 counter += 10
546 raise RuntimeError
547 counter += 100 # Not reached
548 except RuntimeError:
549 self.assertEqual(counter, 11)
550 else:
551 self.fail("Didn't raise RuntimeError")
552
553
554class AssignmentTargetTestCase(unittest.TestCase):
555
556 def testSingleComplexTarget(self):
557 targets = {1: [0, 1, 2]}
558 with mock_contextmanager_generator() as targets[1][0]:
559 self.assertEqual(targets.keys(), [1])
560 self.assertEqual(targets[1][0].__class__, MockResource)
561 with mock_contextmanager_generator() as targets.values()[0][1]:
562 self.assertEqual(targets.keys(), [1])
563 self.assertEqual(targets[1][1].__class__, MockResource)
564 with mock_contextmanager_generator() as targets[2]:
565 keys = targets.keys()
566 keys.sort()
567 self.assertEqual(keys, [1, 2])
568 class C: pass
569 blah = C()
570 with mock_contextmanager_generator() as blah.foo:
571 self.assertEqual(hasattr(blah, "foo"), True)
572
573 def testMultipleComplexTargets(self):
574 class C:
575 def __context__(self): return self
576 def __enter__(self): return 1, 2, 3
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000577 def __exit__(self, t, v, tb):
578 if t is not None:
579 raise t, v, tb
Tim Peters400cbc32006-02-28 18:44:41 +0000580 targets = {1: [0, 1, 2]}
581 with C() as (targets[1][0], targets[1][1], targets[1][2]):
582 self.assertEqual(targets, {1: [1, 2, 3]})
583 with C() as (targets.values()[0][2], targets.values()[0][1], targets.values()[0][0]):
584 self.assertEqual(targets, {1: [3, 2, 1]})
585 with C() as (targets[1], targets[2], targets[3]):
586 self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
587 class B: pass
588 blah = B()
589 with C() as (blah.one, blah.two, blah.three):
590 self.assertEqual(blah.one, 1)
591 self.assertEqual(blah.two, 2)
592 self.assertEqual(blah.three, 3)
593
594
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000595class ExitSwallowsExceptionTestCase(unittest.TestCase):
596
597 def testExitSwallowsException(self):
598 class AfricanOrEuropean:
599 def __context__(self): return self
600 def __enter__(self): pass
601 def __exit__(self, t, v, tb): pass
602 try:
603 with AfricanOrEuropean():
604 1/0
605 except ZeroDivisionError:
606 self.fail("ZeroDivisionError should have been swallowed")
607
608
Tim Peters400cbc32006-02-28 18:44:41 +0000609def test_main():
610 run_unittest(FailureTestCase, NonexceptionalTestCase,
611 NestedNonexceptionalTestCase, ExceptionalTestCase,
612 NonLocalFlowControlTestCase,
Guido van Rossum1a5e21e2006-02-28 21:57:43 +0000613 AssignmentTargetTestCase,
614 ExitSwallowsExceptionTestCase)
Tim Peters400cbc32006-02-28 18:44:41 +0000615
616
617if __name__ == '__main__':
618 test_main()