blob: ed072c9573d7171cda3fae09733f803fc4f38576 [file] [log] [blame]
Tim Peters400cbc32006-02-28 18:44:41 +00001#!/usr/bin/env python
2
3"""Unit tests for the with statement specified in PEP 343."""
4
Thomas Wouters34aa7ba2006-02-28 19:02:24 +00005from __future__ import with_statement
6
Tim Peters400cbc32006-02-28 18:44:41 +00007__author__ = "Mike Bland"
8__email__ = "mbland at acm dot org"
9
10import unittest
11from test.contextmanager import GeneratorContextManager
12from test.nested import nested
13from test.test_support import run_unittest
14
15
16class MockContextManager(GeneratorContextManager):
17 def __init__(self, gen):
18 GeneratorContextManager.__init__(self, gen)
19 self.context_called = False
20 self.enter_called = False
21 self.exit_called = False
22 self.exit_args = None
23
24 def __context__(self):
25 self.context_called = True
26 return GeneratorContextManager.__context__(self)
27
28 def __enter__(self):
29 self.enter_called = True
30 return GeneratorContextManager.__enter__(self)
31
32 def __exit__(self, type, value, traceback):
33 self.exit_called = True
34 self.exit_args = (type, value, traceback)
35 return GeneratorContextManager.__exit__(self, type, value, traceback)
36
37
38def mock_contextmanager(func):
39 def helper(*args, **kwds):
40 return MockContextManager(func(*args, **kwds))
41 return helper
42
43
44class MockResource(object):
45 def __init__(self):
46 self.yielded = False
47 self.stopped = False
48
49
50@mock_contextmanager
51def mock_contextmanager_generator():
52 mock = MockResource()
53 try:
54 mock.yielded = True
55 yield mock
56 finally:
57 mock.stopped = True
58
59
60class MockNested(nested):
61 def __init__(self, *contexts):
62 nested.__init__(self, *contexts)
63 self.context_called = False
64 self.enter_called = False
65 self.exit_called = False
66 self.exit_args = None
67
68 def __context__(self):
69 self.context_called = True
70 return nested.__context__(self)
71
72 def __enter__(self):
73 self.enter_called = True
74 return nested.__enter__(self)
75
76 def __exit__(self, *exc_info):
77 self.exit_called = True
78 self.exit_args = exc_info
79 return nested.__exit__(self, *exc_info)
80
81
82class FailureTestCase(unittest.TestCase):
83 def testNameError(self):
84 def fooNotDeclared():
85 with foo: pass
86 self.assertRaises(NameError, fooNotDeclared)
87
88 def testContextAttributeError(self):
89 class LacksContext(object):
90 def __enter__(self):
91 pass
92
93 def __exit__(self, type, value, traceback):
94 pass
95
96 def fooLacksContext():
97 foo = LacksContext()
98 with foo: pass
99 self.assertRaises(AttributeError, fooLacksContext)
100
101 def testEnterAttributeError(self):
102 class LacksEnter(object):
103 def __context__(self):
104 pass
105
106 def __exit__(self, type, value, traceback):
107 pass
108
109 def fooLacksEnter():
110 foo = LacksEnter()
111 with foo: pass
112 self.assertRaises(AttributeError, fooLacksEnter)
113
114 def testExitAttributeError(self):
115 class LacksExit(object):
116 def __context__(self):
117 pass
118
119 def __enter__(self):
120 pass
121
122 def fooLacksExit():
123 foo = LacksExit()
124 with foo: pass
125 self.assertRaises(AttributeError, fooLacksExit)
126
127 def assertRaisesSyntaxError(self, codestr):
128 def shouldRaiseSyntaxError(s):
129 compile(s, '', 'single')
130 self.assertRaises(SyntaxError, shouldRaiseSyntaxError, codestr)
131
132 def testAssignmentToNoneError(self):
133 self.assertRaisesSyntaxError('with mock as None:\n pass')
134 self.assertRaisesSyntaxError(
135 'with mock as (None):\n'
136 ' pass')
137
138 def testAssignmentToEmptyTupleError(self):
139 self.assertRaisesSyntaxError(
140 'with mock as ():\n'
141 ' pass')
142
143 def testAssignmentToTupleOnlyContainingNoneError(self):
144 self.assertRaisesSyntaxError('with mock as None,:\n pass')
145 self.assertRaisesSyntaxError(
146 'with mock as (None,):\n'
147 ' pass')
148
149 def testAssignmentToTupleContainingNoneError(self):
150 self.assertRaisesSyntaxError(
151 'with mock as (foo, None, bar):\n'
152 ' pass')
153
154 def testContextThrows(self):
155 class ContextThrows(object):
156 def __context__(self):
157 raise RuntimeError("Context threw")
158
159 def shouldThrow():
160 ct = ContextThrows()
161 self.foo = None
162 with ct as self.foo:
163 pass
164 self.assertRaises(RuntimeError, shouldThrow)
165 self.assertEqual(self.foo, None)
166
167 def testEnterThrows(self):
168 class EnterThrows(object):
169 def __context__(self):
170 return self
171
172 def __enter__(self):
173 raise RuntimeError("Context threw")
174
175 def __exit__(self, *args):
176 pass
177
178 def shouldThrow():
179 ct = EnterThrows()
180 self.foo = None
181 with ct as self.foo:
182 pass
183 self.assertRaises(RuntimeError, shouldThrow)
184 self.assertEqual(self.foo, None)
185
186 def testExitThrows(self):
187 class ExitThrows(object):
188 def __context__(self):
189 return self
190 def __enter__(self):
191 return
192 def __exit__(self, *args):
193 raise RuntimeError(42)
194 def shouldThrow():
195 with ExitThrows():
196 pass
197 self.assertRaises(RuntimeError, shouldThrow)
198
199class ContextmanagerAssertionMixin(object):
200 TEST_EXCEPTION = RuntimeError("test exception")
201
202 def assertInWithManagerInvariants(self, mock_manager):
203 self.assertTrue(mock_manager.context_called)
204 self.assertTrue(mock_manager.enter_called)
205 self.assertFalse(mock_manager.exit_called)
206 self.assertEqual(mock_manager.exit_args, None)
207
208 def assertAfterWithManagerInvariants(self, mock_manager, exit_args):
209 self.assertTrue(mock_manager.context_called)
210 self.assertTrue(mock_manager.enter_called)
211 self.assertTrue(mock_manager.exit_called)
212 self.assertEqual(mock_manager.exit_args, exit_args)
213
214 def assertAfterWithManagerInvariantsNoError(self, mock_manager):
215 self.assertAfterWithManagerInvariants(mock_manager,
216 (None, None, None))
217
218 def assertInWithGeneratorInvariants(self, mock_generator):
219 self.assertTrue(mock_generator.yielded)
220 self.assertFalse(mock_generator.stopped)
221
222 def assertAfterWithGeneratorInvariantsNoError(self, mock_generator):
223 self.assertTrue(mock_generator.yielded)
224 self.assertTrue(mock_generator.stopped)
225
226 def raiseTestException(self):
227 raise self.TEST_EXCEPTION
228
229 def assertAfterWithManagerInvariantsWithError(self, mock_manager):
230 self.assertTrue(mock_manager.context_called)
231 self.assertTrue(mock_manager.enter_called)
232 self.assertTrue(mock_manager.exit_called)
233 self.assertEqual(mock_manager.exit_args[0], RuntimeError)
234 self.assertEqual(mock_manager.exit_args[1], self.TEST_EXCEPTION)
235
236 def assertAfterWithGeneratorInvariantsWithError(self, mock_generator):
237 self.assertTrue(mock_generator.yielded)
238 self.assertTrue(mock_generator.stopped)
239
240
241class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
242 def testInlineGeneratorSyntax(self):
243 with mock_contextmanager_generator():
244 pass
245
246 def testUnboundGenerator(self):
247 mock = mock_contextmanager_generator()
248 with mock:
249 pass
250 self.assertAfterWithManagerInvariantsNoError(mock)
251
252 def testInlineGeneratorBoundSyntax(self):
253 with mock_contextmanager_generator() as foo:
254 self.assertInWithGeneratorInvariants(foo)
255 # FIXME: In the future, we'll try to keep the bound names from leaking
256 self.assertAfterWithGeneratorInvariantsNoError(foo)
257
258 def testInlineGeneratorBoundToExistingVariable(self):
259 foo = None
260 with mock_contextmanager_generator() as foo:
261 self.assertInWithGeneratorInvariants(foo)
262 self.assertAfterWithGeneratorInvariantsNoError(foo)
263
264 def testInlineGeneratorBoundToDottedVariable(self):
265 with mock_contextmanager_generator() as self.foo:
266 self.assertInWithGeneratorInvariants(self.foo)
267 self.assertAfterWithGeneratorInvariantsNoError(self.foo)
268
269 def testBoundGenerator(self):
270 mock = mock_contextmanager_generator()
271 with mock as foo:
272 self.assertInWithGeneratorInvariants(foo)
273 self.assertInWithManagerInvariants(mock)
274 self.assertAfterWithGeneratorInvariantsNoError(foo)
275 self.assertAfterWithManagerInvariantsNoError(mock)
276
277 def testNestedSingleStatements(self):
278 mock_a = mock_contextmanager_generator()
279 with mock_a as foo:
280 mock_b = mock_contextmanager_generator()
281 with mock_b as bar:
282 self.assertInWithManagerInvariants(mock_a)
283 self.assertInWithManagerInvariants(mock_b)
284 self.assertInWithGeneratorInvariants(foo)
285 self.assertInWithGeneratorInvariants(bar)
286 self.assertAfterWithManagerInvariantsNoError(mock_b)
287 self.assertAfterWithGeneratorInvariantsNoError(bar)
288 self.assertInWithManagerInvariants(mock_a)
289 self.assertInWithGeneratorInvariants(foo)
290 self.assertAfterWithManagerInvariantsNoError(mock_a)
291 self.assertAfterWithGeneratorInvariantsNoError(foo)
292
293
294class NestedNonexceptionalTestCase(unittest.TestCase,
295 ContextmanagerAssertionMixin):
296 def testSingleArgInlineGeneratorSyntax(self):
297 with nested(mock_contextmanager_generator()):
298 pass
299
300 def testSingleArgUnbound(self):
301 mock_contextmanager = mock_contextmanager_generator()
302 mock_nested = MockNested(mock_contextmanager)
303 with mock_nested:
304 self.assertInWithManagerInvariants(mock_contextmanager)
305 self.assertInWithManagerInvariants(mock_nested)
306 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
307 self.assertAfterWithManagerInvariantsNoError(mock_nested)
308
309 def testSingleArgBoundToNonTuple(self):
310 m = mock_contextmanager_generator()
311 # This will bind all the arguments to nested() into a single list
312 # assigned to foo.
313 with nested(m) as foo:
314 self.assertInWithManagerInvariants(m)
315 self.assertAfterWithManagerInvariantsNoError(m)
316
317 def testSingleArgBoundToSingleElementParenthesizedList(self):
318 m = mock_contextmanager_generator()
319 # This will bind all the arguments to nested() into a single list
320 # assigned to foo.
321 # FIXME: what should this do: with nested(m) as (foo,):
322 with nested(m) as (foo):
323 self.assertInWithManagerInvariants(m)
324 self.assertAfterWithManagerInvariantsNoError(m)
325
326 def testSingleArgBoundToMultipleElementTupleError(self):
327 def shouldThrowValueError():
328 with nested(mock_contextmanager_generator()) as (foo, bar):
329 pass
330 self.assertRaises(ValueError, shouldThrowValueError)
331
332 def testSingleArgUnbound(self):
333 mock_contextmanager = mock_contextmanager_generator()
334 mock_nested = MockNested(mock_contextmanager)
335 with mock_nested:
336 self.assertInWithManagerInvariants(mock_contextmanager)
337 self.assertInWithManagerInvariants(mock_nested)
338 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
339 self.assertAfterWithManagerInvariantsNoError(mock_nested)
340
341 def testMultipleArgUnbound(self):
342 m = mock_contextmanager_generator()
343 n = mock_contextmanager_generator()
344 o = mock_contextmanager_generator()
345 mock_nested = MockNested(m, n, o)
346 with mock_nested:
347 self.assertInWithManagerInvariants(m)
348 self.assertInWithManagerInvariants(n)
349 self.assertInWithManagerInvariants(o)
350 self.assertInWithManagerInvariants(mock_nested)
351 self.assertAfterWithManagerInvariantsNoError(m)
352 self.assertAfterWithManagerInvariantsNoError(n)
353 self.assertAfterWithManagerInvariantsNoError(o)
354 self.assertAfterWithManagerInvariantsNoError(mock_nested)
355
356 def testMultipleArgBound(self):
357 mock_nested = MockNested(mock_contextmanager_generator(),
358 mock_contextmanager_generator(), mock_contextmanager_generator())
359 with mock_nested as (m, n, o):
360 self.assertInWithGeneratorInvariants(m)
361 self.assertInWithGeneratorInvariants(n)
362 self.assertInWithGeneratorInvariants(o)
363 self.assertInWithManagerInvariants(mock_nested)
364 self.assertAfterWithGeneratorInvariantsNoError(m)
365 self.assertAfterWithGeneratorInvariantsNoError(n)
366 self.assertAfterWithGeneratorInvariantsNoError(o)
367 self.assertAfterWithManagerInvariantsNoError(mock_nested)
368
369
370class ExceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
371 def testSingleResource(self):
372 cm = mock_contextmanager_generator()
373 def shouldThrow():
374 with cm as self.resource:
375 self.assertInWithManagerInvariants(cm)
376 self.assertInWithGeneratorInvariants(self.resource)
377 self.raiseTestException()
378 self.assertRaises(RuntimeError, shouldThrow)
379 self.assertAfterWithManagerInvariantsWithError(cm)
380 self.assertAfterWithGeneratorInvariantsWithError(self.resource)
381
382 def testNestedSingleStatements(self):
383 mock_a = mock_contextmanager_generator()
384 mock_b = mock_contextmanager_generator()
385 def shouldThrow():
386 with mock_a as self.foo:
387 with mock_b as self.bar:
388 self.assertInWithManagerInvariants(mock_a)
389 self.assertInWithManagerInvariants(mock_b)
390 self.assertInWithGeneratorInvariants(self.foo)
391 self.assertInWithGeneratorInvariants(self.bar)
392 self.raiseTestException()
393 self.assertRaises(RuntimeError, shouldThrow)
394 self.assertAfterWithManagerInvariantsWithError(mock_a)
395 self.assertAfterWithManagerInvariantsWithError(mock_b)
396 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
397 self.assertAfterWithGeneratorInvariantsWithError(self.bar)
398
399 def testMultipleResourcesInSingleStatement(self):
400 cm_a = mock_contextmanager_generator()
401 cm_b = mock_contextmanager_generator()
402 mock_nested = MockNested(cm_a, cm_b)
403 def shouldThrow():
404 with mock_nested as (self.resource_a, self.resource_b):
405 self.assertInWithManagerInvariants(cm_a)
406 self.assertInWithManagerInvariants(cm_b)
407 self.assertInWithManagerInvariants(mock_nested)
408 self.assertInWithGeneratorInvariants(self.resource_a)
409 self.assertInWithGeneratorInvariants(self.resource_b)
410 self.raiseTestException()
411 self.assertRaises(RuntimeError, shouldThrow)
412 self.assertAfterWithManagerInvariantsWithError(cm_a)
413 self.assertAfterWithManagerInvariantsWithError(cm_b)
414 self.assertAfterWithManagerInvariantsWithError(mock_nested)
415 self.assertAfterWithGeneratorInvariantsWithError(self.resource_a)
416 self.assertAfterWithGeneratorInvariantsWithError(self.resource_b)
417
418 def testNestedExceptionBeforeInnerStatement(self):
419 mock_a = mock_contextmanager_generator()
420 mock_b = mock_contextmanager_generator()
421 self.bar = None
422 def shouldThrow():
423 with mock_a as self.foo:
424 self.assertInWithManagerInvariants(mock_a)
425 self.assertInWithGeneratorInvariants(self.foo)
426 self.raiseTestException()
427 with mock_b as self.bar:
428 pass
429 self.assertRaises(RuntimeError, shouldThrow)
430 self.assertAfterWithManagerInvariantsWithError(mock_a)
431 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
432
433 # The inner statement stuff should never have been touched
434 self.assertEqual(self.bar, None)
435 self.assertFalse(mock_b.context_called)
436 self.assertFalse(mock_b.enter_called)
437 self.assertFalse(mock_b.exit_called)
438 self.assertEqual(mock_b.exit_args, None)
439
440 def testNestedExceptionAfterInnerStatement(self):
441 mock_a = mock_contextmanager_generator()
442 mock_b = mock_contextmanager_generator()
443 def shouldThrow():
444 with mock_a as self.foo:
445 with mock_b as self.bar:
446 self.assertInWithManagerInvariants(mock_a)
447 self.assertInWithManagerInvariants(mock_b)
448 self.assertInWithGeneratorInvariants(self.foo)
449 self.assertInWithGeneratorInvariants(self.bar)
450 self.raiseTestException()
451 self.assertRaises(RuntimeError, shouldThrow)
452 self.assertAfterWithManagerInvariantsWithError(mock_a)
453 self.assertAfterWithManagerInvariantsNoError(mock_b)
454 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
455 self.assertAfterWithGeneratorInvariantsNoError(self.bar)
456
457
458class NonLocalFlowControlTestCase(unittest.TestCase):
459
460 def testWithBreak(self):
461 counter = 0
462 while True:
463 counter += 1
464 with mock_contextmanager_generator():
465 counter += 10
466 break
467 counter += 100 # Not reached
468 self.assertEqual(counter, 11)
469
470 def testWithContinue(self):
471 counter = 0
472 while True:
473 counter += 1
474 if counter > 2:
475 break
476 with mock_contextmanager_generator():
477 counter += 10
478 continue
479 counter += 100 # Not reached
480 self.assertEqual(counter, 12)
481
482 def testWithReturn(self):
483 def foo():
484 counter = 0
485 while True:
486 counter += 1
487 with mock_contextmanager_generator():
488 counter += 10
489 return counter
490 counter += 100 # Not reached
491 self.assertEqual(foo(), 11)
492
493 def testWithYield(self):
494 def gen():
495 with mock_contextmanager_generator():
496 yield 12
497 yield 13
498 x = list(gen())
499 self.assertEqual(x, [12, 13])
500
501 def testWithRaise(self):
502 counter = 0
503 try:
504 counter += 1
505 with mock_contextmanager_generator():
506 counter += 10
507 raise RuntimeError
508 counter += 100 # Not reached
509 except RuntimeError:
510 self.assertEqual(counter, 11)
511 else:
512 self.fail("Didn't raise RuntimeError")
513
514
515class AssignmentTargetTestCase(unittest.TestCase):
516
517 def testSingleComplexTarget(self):
518 targets = {1: [0, 1, 2]}
519 with mock_contextmanager_generator() as targets[1][0]:
520 self.assertEqual(targets.keys(), [1])
521 self.assertEqual(targets[1][0].__class__, MockResource)
522 with mock_contextmanager_generator() as targets.values()[0][1]:
523 self.assertEqual(targets.keys(), [1])
524 self.assertEqual(targets[1][1].__class__, MockResource)
525 with mock_contextmanager_generator() as targets[2]:
526 keys = targets.keys()
527 keys.sort()
528 self.assertEqual(keys, [1, 2])
529 class C: pass
530 blah = C()
531 with mock_contextmanager_generator() as blah.foo:
532 self.assertEqual(hasattr(blah, "foo"), True)
533
534 def testMultipleComplexTargets(self):
535 class C:
536 def __context__(self): return self
537 def __enter__(self): return 1, 2, 3
538 def __exit__(self, *a): pass
539 targets = {1: [0, 1, 2]}
540 with C() as (targets[1][0], targets[1][1], targets[1][2]):
541 self.assertEqual(targets, {1: [1, 2, 3]})
542 with C() as (targets.values()[0][2], targets.values()[0][1], targets.values()[0][0]):
543 self.assertEqual(targets, {1: [3, 2, 1]})
544 with C() as (targets[1], targets[2], targets[3]):
545 self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
546 class B: pass
547 blah = B()
548 with C() as (blah.one, blah.two, blah.three):
549 self.assertEqual(blah.one, 1)
550 self.assertEqual(blah.two, 2)
551 self.assertEqual(blah.three, 3)
552
553
554def test_main():
555 run_unittest(FailureTestCase, NonexceptionalTestCase,
556 NestedNonexceptionalTestCase, ExceptionalTestCase,
557 NonLocalFlowControlTestCase,
558 AssignmentTargetTestCase)
559
560
561if __name__ == '__main__':
562 test_main()