blob: 8423ee1d290eff858559e51b3599a4dccfc7b81c [file] [log] [blame]
Tim Peters400cbc32006-02-28 18:44:41 +00001#!/usr/bin/env python
2
3"""Unit tests for the with statement specified in PEP 343."""
4
5__author__ = "Mike Bland"
6__email__ = "mbland at acm dot org"
7
8import unittest
9from test.contextmanager import GeneratorContextManager
10from test.nested import nested
11from test.test_support import run_unittest
12
13
14class MockContextManager(GeneratorContextManager):
15 def __init__(self, gen):
16 GeneratorContextManager.__init__(self, gen)
17 self.context_called = False
18 self.enter_called = False
19 self.exit_called = False
20 self.exit_args = None
21
22 def __context__(self):
23 self.context_called = True
24 return GeneratorContextManager.__context__(self)
25
26 def __enter__(self):
27 self.enter_called = True
28 return GeneratorContextManager.__enter__(self)
29
30 def __exit__(self, type, value, traceback):
31 self.exit_called = True
32 self.exit_args = (type, value, traceback)
33 return GeneratorContextManager.__exit__(self, type, value, traceback)
34
35
36def mock_contextmanager(func):
37 def helper(*args, **kwds):
38 return MockContextManager(func(*args, **kwds))
39 return helper
40
41
42class MockResource(object):
43 def __init__(self):
44 self.yielded = False
45 self.stopped = False
46
47
48@mock_contextmanager
49def mock_contextmanager_generator():
50 mock = MockResource()
51 try:
52 mock.yielded = True
53 yield mock
54 finally:
55 mock.stopped = True
56
57
58class MockNested(nested):
59 def __init__(self, *contexts):
60 nested.__init__(self, *contexts)
61 self.context_called = False
62 self.enter_called = False
63 self.exit_called = False
64 self.exit_args = None
65
66 def __context__(self):
67 self.context_called = True
68 return nested.__context__(self)
69
70 def __enter__(self):
71 self.enter_called = True
72 return nested.__enter__(self)
73
74 def __exit__(self, *exc_info):
75 self.exit_called = True
76 self.exit_args = exc_info
77 return nested.__exit__(self, *exc_info)
78
79
80class FailureTestCase(unittest.TestCase):
81 def testNameError(self):
82 def fooNotDeclared():
83 with foo: pass
84 self.assertRaises(NameError, fooNotDeclared)
85
86 def testContextAttributeError(self):
87 class LacksContext(object):
88 def __enter__(self):
89 pass
90
91 def __exit__(self, type, value, traceback):
92 pass
93
94 def fooLacksContext():
95 foo = LacksContext()
96 with foo: pass
97 self.assertRaises(AttributeError, fooLacksContext)
98
99 def testEnterAttributeError(self):
100 class LacksEnter(object):
101 def __context__(self):
102 pass
103
104 def __exit__(self, type, value, traceback):
105 pass
106
107 def fooLacksEnter():
108 foo = LacksEnter()
109 with foo: pass
110 self.assertRaises(AttributeError, fooLacksEnter)
111
112 def testExitAttributeError(self):
113 class LacksExit(object):
114 def __context__(self):
115 pass
116
117 def __enter__(self):
118 pass
119
120 def fooLacksExit():
121 foo = LacksExit()
122 with foo: pass
123 self.assertRaises(AttributeError, fooLacksExit)
124
125 def assertRaisesSyntaxError(self, codestr):
126 def shouldRaiseSyntaxError(s):
127 compile(s, '', 'single')
128 self.assertRaises(SyntaxError, shouldRaiseSyntaxError, codestr)
129
130 def testAssignmentToNoneError(self):
131 self.assertRaisesSyntaxError('with mock as None:\n pass')
132 self.assertRaisesSyntaxError(
133 'with mock as (None):\n'
134 ' pass')
135
136 def testAssignmentToEmptyTupleError(self):
137 self.assertRaisesSyntaxError(
138 'with mock as ():\n'
139 ' pass')
140
141 def testAssignmentToTupleOnlyContainingNoneError(self):
142 self.assertRaisesSyntaxError('with mock as None,:\n pass')
143 self.assertRaisesSyntaxError(
144 'with mock as (None,):\n'
145 ' pass')
146
147 def testAssignmentToTupleContainingNoneError(self):
148 self.assertRaisesSyntaxError(
149 'with mock as (foo, None, bar):\n'
150 ' pass')
151
152 def testContextThrows(self):
153 class ContextThrows(object):
154 def __context__(self):
155 raise RuntimeError("Context threw")
156
157 def shouldThrow():
158 ct = ContextThrows()
159 self.foo = None
160 with ct as self.foo:
161 pass
162 self.assertRaises(RuntimeError, shouldThrow)
163 self.assertEqual(self.foo, None)
164
165 def testEnterThrows(self):
166 class EnterThrows(object):
167 def __context__(self):
168 return self
169
170 def __enter__(self):
171 raise RuntimeError("Context threw")
172
173 def __exit__(self, *args):
174 pass
175
176 def shouldThrow():
177 ct = EnterThrows()
178 self.foo = None
179 with ct as self.foo:
180 pass
181 self.assertRaises(RuntimeError, shouldThrow)
182 self.assertEqual(self.foo, None)
183
184 def testExitThrows(self):
185 class ExitThrows(object):
186 def __context__(self):
187 return self
188 def __enter__(self):
189 return
190 def __exit__(self, *args):
191 raise RuntimeError(42)
192 def shouldThrow():
193 with ExitThrows():
194 pass
195 self.assertRaises(RuntimeError, shouldThrow)
196
197class ContextmanagerAssertionMixin(object):
198 TEST_EXCEPTION = RuntimeError("test exception")
199
200 def assertInWithManagerInvariants(self, mock_manager):
201 self.assertTrue(mock_manager.context_called)
202 self.assertTrue(mock_manager.enter_called)
203 self.assertFalse(mock_manager.exit_called)
204 self.assertEqual(mock_manager.exit_args, None)
205
206 def assertAfterWithManagerInvariants(self, mock_manager, exit_args):
207 self.assertTrue(mock_manager.context_called)
208 self.assertTrue(mock_manager.enter_called)
209 self.assertTrue(mock_manager.exit_called)
210 self.assertEqual(mock_manager.exit_args, exit_args)
211
212 def assertAfterWithManagerInvariantsNoError(self, mock_manager):
213 self.assertAfterWithManagerInvariants(mock_manager,
214 (None, None, None))
215
216 def assertInWithGeneratorInvariants(self, mock_generator):
217 self.assertTrue(mock_generator.yielded)
218 self.assertFalse(mock_generator.stopped)
219
220 def assertAfterWithGeneratorInvariantsNoError(self, mock_generator):
221 self.assertTrue(mock_generator.yielded)
222 self.assertTrue(mock_generator.stopped)
223
224 def raiseTestException(self):
225 raise self.TEST_EXCEPTION
226
227 def assertAfterWithManagerInvariantsWithError(self, mock_manager):
228 self.assertTrue(mock_manager.context_called)
229 self.assertTrue(mock_manager.enter_called)
230 self.assertTrue(mock_manager.exit_called)
231 self.assertEqual(mock_manager.exit_args[0], RuntimeError)
232 self.assertEqual(mock_manager.exit_args[1], self.TEST_EXCEPTION)
233
234 def assertAfterWithGeneratorInvariantsWithError(self, mock_generator):
235 self.assertTrue(mock_generator.yielded)
236 self.assertTrue(mock_generator.stopped)
237
238
239class NonexceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
240 def testInlineGeneratorSyntax(self):
241 with mock_contextmanager_generator():
242 pass
243
244 def testUnboundGenerator(self):
245 mock = mock_contextmanager_generator()
246 with mock:
247 pass
248 self.assertAfterWithManagerInvariantsNoError(mock)
249
250 def testInlineGeneratorBoundSyntax(self):
251 with mock_contextmanager_generator() as foo:
252 self.assertInWithGeneratorInvariants(foo)
253 # FIXME: In the future, we'll try to keep the bound names from leaking
254 self.assertAfterWithGeneratorInvariantsNoError(foo)
255
256 def testInlineGeneratorBoundToExistingVariable(self):
257 foo = None
258 with mock_contextmanager_generator() as foo:
259 self.assertInWithGeneratorInvariants(foo)
260 self.assertAfterWithGeneratorInvariantsNoError(foo)
261
262 def testInlineGeneratorBoundToDottedVariable(self):
263 with mock_contextmanager_generator() as self.foo:
264 self.assertInWithGeneratorInvariants(self.foo)
265 self.assertAfterWithGeneratorInvariantsNoError(self.foo)
266
267 def testBoundGenerator(self):
268 mock = mock_contextmanager_generator()
269 with mock as foo:
270 self.assertInWithGeneratorInvariants(foo)
271 self.assertInWithManagerInvariants(mock)
272 self.assertAfterWithGeneratorInvariantsNoError(foo)
273 self.assertAfterWithManagerInvariantsNoError(mock)
274
275 def testNestedSingleStatements(self):
276 mock_a = mock_contextmanager_generator()
277 with mock_a as foo:
278 mock_b = mock_contextmanager_generator()
279 with mock_b as bar:
280 self.assertInWithManagerInvariants(mock_a)
281 self.assertInWithManagerInvariants(mock_b)
282 self.assertInWithGeneratorInvariants(foo)
283 self.assertInWithGeneratorInvariants(bar)
284 self.assertAfterWithManagerInvariantsNoError(mock_b)
285 self.assertAfterWithGeneratorInvariantsNoError(bar)
286 self.assertInWithManagerInvariants(mock_a)
287 self.assertInWithGeneratorInvariants(foo)
288 self.assertAfterWithManagerInvariantsNoError(mock_a)
289 self.assertAfterWithGeneratorInvariantsNoError(foo)
290
291
292class NestedNonexceptionalTestCase(unittest.TestCase,
293 ContextmanagerAssertionMixin):
294 def testSingleArgInlineGeneratorSyntax(self):
295 with nested(mock_contextmanager_generator()):
296 pass
297
298 def testSingleArgUnbound(self):
299 mock_contextmanager = mock_contextmanager_generator()
300 mock_nested = MockNested(mock_contextmanager)
301 with mock_nested:
302 self.assertInWithManagerInvariants(mock_contextmanager)
303 self.assertInWithManagerInvariants(mock_nested)
304 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
305 self.assertAfterWithManagerInvariantsNoError(mock_nested)
306
307 def testSingleArgBoundToNonTuple(self):
308 m = mock_contextmanager_generator()
309 # This will bind all the arguments to nested() into a single list
310 # assigned to foo.
311 with nested(m) as foo:
312 self.assertInWithManagerInvariants(m)
313 self.assertAfterWithManagerInvariantsNoError(m)
314
315 def testSingleArgBoundToSingleElementParenthesizedList(self):
316 m = mock_contextmanager_generator()
317 # This will bind all the arguments to nested() into a single list
318 # assigned to foo.
319 # FIXME: what should this do: with nested(m) as (foo,):
320 with nested(m) as (foo):
321 self.assertInWithManagerInvariants(m)
322 self.assertAfterWithManagerInvariantsNoError(m)
323
324 def testSingleArgBoundToMultipleElementTupleError(self):
325 def shouldThrowValueError():
326 with nested(mock_contextmanager_generator()) as (foo, bar):
327 pass
328 self.assertRaises(ValueError, shouldThrowValueError)
329
330 def testSingleArgUnbound(self):
331 mock_contextmanager = mock_contextmanager_generator()
332 mock_nested = MockNested(mock_contextmanager)
333 with mock_nested:
334 self.assertInWithManagerInvariants(mock_contextmanager)
335 self.assertInWithManagerInvariants(mock_nested)
336 self.assertAfterWithManagerInvariantsNoError(mock_contextmanager)
337 self.assertAfterWithManagerInvariantsNoError(mock_nested)
338
339 def testMultipleArgUnbound(self):
340 m = mock_contextmanager_generator()
341 n = mock_contextmanager_generator()
342 o = mock_contextmanager_generator()
343 mock_nested = MockNested(m, n, o)
344 with mock_nested:
345 self.assertInWithManagerInvariants(m)
346 self.assertInWithManagerInvariants(n)
347 self.assertInWithManagerInvariants(o)
348 self.assertInWithManagerInvariants(mock_nested)
349 self.assertAfterWithManagerInvariantsNoError(m)
350 self.assertAfterWithManagerInvariantsNoError(n)
351 self.assertAfterWithManagerInvariantsNoError(o)
352 self.assertAfterWithManagerInvariantsNoError(mock_nested)
353
354 def testMultipleArgBound(self):
355 mock_nested = MockNested(mock_contextmanager_generator(),
356 mock_contextmanager_generator(), mock_contextmanager_generator())
357 with mock_nested as (m, n, o):
358 self.assertInWithGeneratorInvariants(m)
359 self.assertInWithGeneratorInvariants(n)
360 self.assertInWithGeneratorInvariants(o)
361 self.assertInWithManagerInvariants(mock_nested)
362 self.assertAfterWithGeneratorInvariantsNoError(m)
363 self.assertAfterWithGeneratorInvariantsNoError(n)
364 self.assertAfterWithGeneratorInvariantsNoError(o)
365 self.assertAfterWithManagerInvariantsNoError(mock_nested)
366
367
368class ExceptionalTestCase(unittest.TestCase, ContextmanagerAssertionMixin):
369 def testSingleResource(self):
370 cm = mock_contextmanager_generator()
371 def shouldThrow():
372 with cm as self.resource:
373 self.assertInWithManagerInvariants(cm)
374 self.assertInWithGeneratorInvariants(self.resource)
375 self.raiseTestException()
376 self.assertRaises(RuntimeError, shouldThrow)
377 self.assertAfterWithManagerInvariantsWithError(cm)
378 self.assertAfterWithGeneratorInvariantsWithError(self.resource)
379
380 def testNestedSingleStatements(self):
381 mock_a = mock_contextmanager_generator()
382 mock_b = mock_contextmanager_generator()
383 def shouldThrow():
384 with mock_a as self.foo:
385 with mock_b as self.bar:
386 self.assertInWithManagerInvariants(mock_a)
387 self.assertInWithManagerInvariants(mock_b)
388 self.assertInWithGeneratorInvariants(self.foo)
389 self.assertInWithGeneratorInvariants(self.bar)
390 self.raiseTestException()
391 self.assertRaises(RuntimeError, shouldThrow)
392 self.assertAfterWithManagerInvariantsWithError(mock_a)
393 self.assertAfterWithManagerInvariantsWithError(mock_b)
394 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
395 self.assertAfterWithGeneratorInvariantsWithError(self.bar)
396
397 def testMultipleResourcesInSingleStatement(self):
398 cm_a = mock_contextmanager_generator()
399 cm_b = mock_contextmanager_generator()
400 mock_nested = MockNested(cm_a, cm_b)
401 def shouldThrow():
402 with mock_nested as (self.resource_a, self.resource_b):
403 self.assertInWithManagerInvariants(cm_a)
404 self.assertInWithManagerInvariants(cm_b)
405 self.assertInWithManagerInvariants(mock_nested)
406 self.assertInWithGeneratorInvariants(self.resource_a)
407 self.assertInWithGeneratorInvariants(self.resource_b)
408 self.raiseTestException()
409 self.assertRaises(RuntimeError, shouldThrow)
410 self.assertAfterWithManagerInvariantsWithError(cm_a)
411 self.assertAfterWithManagerInvariantsWithError(cm_b)
412 self.assertAfterWithManagerInvariantsWithError(mock_nested)
413 self.assertAfterWithGeneratorInvariantsWithError(self.resource_a)
414 self.assertAfterWithGeneratorInvariantsWithError(self.resource_b)
415
416 def testNestedExceptionBeforeInnerStatement(self):
417 mock_a = mock_contextmanager_generator()
418 mock_b = mock_contextmanager_generator()
419 self.bar = None
420 def shouldThrow():
421 with mock_a as self.foo:
422 self.assertInWithManagerInvariants(mock_a)
423 self.assertInWithGeneratorInvariants(self.foo)
424 self.raiseTestException()
425 with mock_b as self.bar:
426 pass
427 self.assertRaises(RuntimeError, shouldThrow)
428 self.assertAfterWithManagerInvariantsWithError(mock_a)
429 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
430
431 # The inner statement stuff should never have been touched
432 self.assertEqual(self.bar, None)
433 self.assertFalse(mock_b.context_called)
434 self.assertFalse(mock_b.enter_called)
435 self.assertFalse(mock_b.exit_called)
436 self.assertEqual(mock_b.exit_args, None)
437
438 def testNestedExceptionAfterInnerStatement(self):
439 mock_a = mock_contextmanager_generator()
440 mock_b = mock_contextmanager_generator()
441 def shouldThrow():
442 with mock_a as self.foo:
443 with mock_b as self.bar:
444 self.assertInWithManagerInvariants(mock_a)
445 self.assertInWithManagerInvariants(mock_b)
446 self.assertInWithGeneratorInvariants(self.foo)
447 self.assertInWithGeneratorInvariants(self.bar)
448 self.raiseTestException()
449 self.assertRaises(RuntimeError, shouldThrow)
450 self.assertAfterWithManagerInvariantsWithError(mock_a)
451 self.assertAfterWithManagerInvariantsNoError(mock_b)
452 self.assertAfterWithGeneratorInvariantsWithError(self.foo)
453 self.assertAfterWithGeneratorInvariantsNoError(self.bar)
454
455
456class NonLocalFlowControlTestCase(unittest.TestCase):
457
458 def testWithBreak(self):
459 counter = 0
460 while True:
461 counter += 1
462 with mock_contextmanager_generator():
463 counter += 10
464 break
465 counter += 100 # Not reached
466 self.assertEqual(counter, 11)
467
468 def testWithContinue(self):
469 counter = 0
470 while True:
471 counter += 1
472 if counter > 2:
473 break
474 with mock_contextmanager_generator():
475 counter += 10
476 continue
477 counter += 100 # Not reached
478 self.assertEqual(counter, 12)
479
480 def testWithReturn(self):
481 def foo():
482 counter = 0
483 while True:
484 counter += 1
485 with mock_contextmanager_generator():
486 counter += 10
487 return counter
488 counter += 100 # Not reached
489 self.assertEqual(foo(), 11)
490
491 def testWithYield(self):
492 def gen():
493 with mock_contextmanager_generator():
494 yield 12
495 yield 13
496 x = list(gen())
497 self.assertEqual(x, [12, 13])
498
499 def testWithRaise(self):
500 counter = 0
501 try:
502 counter += 1
503 with mock_contextmanager_generator():
504 counter += 10
505 raise RuntimeError
506 counter += 100 # Not reached
507 except RuntimeError:
508 self.assertEqual(counter, 11)
509 else:
510 self.fail("Didn't raise RuntimeError")
511
512
513class AssignmentTargetTestCase(unittest.TestCase):
514
515 def testSingleComplexTarget(self):
516 targets = {1: [0, 1, 2]}
517 with mock_contextmanager_generator() as targets[1][0]:
518 self.assertEqual(targets.keys(), [1])
519 self.assertEqual(targets[1][0].__class__, MockResource)
520 with mock_contextmanager_generator() as targets.values()[0][1]:
521 self.assertEqual(targets.keys(), [1])
522 self.assertEqual(targets[1][1].__class__, MockResource)
523 with mock_contextmanager_generator() as targets[2]:
524 keys = targets.keys()
525 keys.sort()
526 self.assertEqual(keys, [1, 2])
527 class C: pass
528 blah = C()
529 with mock_contextmanager_generator() as blah.foo:
530 self.assertEqual(hasattr(blah, "foo"), True)
531
532 def testMultipleComplexTargets(self):
533 class C:
534 def __context__(self): return self
535 def __enter__(self): return 1, 2, 3
536 def __exit__(self, *a): pass
537 targets = {1: [0, 1, 2]}
538 with C() as (targets[1][0], targets[1][1], targets[1][2]):
539 self.assertEqual(targets, {1: [1, 2, 3]})
540 with C() as (targets.values()[0][2], targets.values()[0][1], targets.values()[0][0]):
541 self.assertEqual(targets, {1: [3, 2, 1]})
542 with C() as (targets[1], targets[2], targets[3]):
543 self.assertEqual(targets, {1: 1, 2: 2, 3: 3})
544 class B: pass
545 blah = B()
546 with C() as (blah.one, blah.two, blah.three):
547 self.assertEqual(blah.one, 1)
548 self.assertEqual(blah.two, 2)
549 self.assertEqual(blah.three, 3)
550
551
552def test_main():
553 run_unittest(FailureTestCase, NonexceptionalTestCase,
554 NestedNonexceptionalTestCase, ExceptionalTestCase,
555 NonLocalFlowControlTestCase,
556 AssignmentTargetTestCase)
557
558
559if __name__ == '__main__':
560 test_main()