Allow multiple context managers in one with statement, as proposed
in http://codereview.appspot.com/53094 and accepted by Guido.
The construct is transformed into multiple With AST nodes so that
there should be no problems with the semantics.
diff --git a/Lib/compiler/transformer.py b/Lib/compiler/transformer.py
index f5fe582..2a156d3 100644
--- a/Lib/compiler/transformer.py
+++ b/Lib/compiler/transformer.py
@@ -965,18 +965,22 @@
return try_except
def com_with(self, nodelist):
- # with_stmt: 'with' expr [with_var] ':' suite
- expr = self.com_node(nodelist[1])
+ # with_stmt: 'with' with_item (',' with_item)* ':' suite
body = self.com_node(nodelist[-1])
- if nodelist[2][0] == token.COLON:
- var = None
- else:
- var = self.com_assign(nodelist[2][2], OP_ASSIGN)
- return With(expr, var, body, lineno=nodelist[0][2])
+ for i in range(len(nodelist) - 3, 0, -2):
+ ret = self.com_with_item(nodelist[i], body, nodelist[0][2])
+ if i == 1:
+ return ret
+ body = ret
- def com_with_var(self, nodelist):
- # with_var: 'as' expr
- return self.com_node(nodelist[1])
+ def com_with_item(self, nodelist, body, lineno):
+ # with_item: test ['as' expr]
+ if len(nodelist) == 4:
+ var = self.com_assign(nodelist[3], OP_ASSIGN)
+ else:
+ var = None
+ expr = self.com_node(nodelist[1])
+ return With(expr, var, body, lineno=lineno)
def com_augassign_op(self, node):
assert node[0] == symbol.augassign
diff --git a/Lib/test/test_compiler.py b/Lib/test/test_compiler.py
index 052e07e..f1fef74 100644
--- a/Lib/test/test_compiler.py
+++ b/Lib/test/test_compiler.py
@@ -165,6 +165,27 @@
exec c in dct
self.assertEquals(dct.get('result'), 1)
+ def testWithMult(self):
+ events = []
+ class Ctx:
+ def __init__(self, n):
+ self.n = n
+ def __enter__(self):
+ events.append(self.n)
+ def __exit__(self, *args):
+ pass
+ c = compiler.compile('from __future__ import with_statement\n'
+ 'def f():\n'
+ ' with Ctx(1) as tc, Ctx(2) as tc2:\n'
+ ' return 1\n'
+ 'result = f()',
+ '<string>',
+ 'exec' )
+ dct = {'Ctx': Ctx}
+ exec c in dct
+ self.assertEquals(dct.get('result'), 1)
+ self.assertEquals(events, [1, 2])
+
def testGlobal(self):
code = compiler.compile('global x\nx=1', '<string>', 'exec')
d1 = {'__builtins__': {}}
diff --git a/Lib/test/test_parser.py b/Lib/test/test_parser.py
index 23f418e..7d059c2 100644
--- a/Lib/test/test_parser.py
+++ b/Lib/test/test_parser.py
@@ -199,6 +199,7 @@
def test_with(self):
self.check_suite("with open('x'): pass\n")
self.check_suite("with open('x') as f: pass\n")
+ self.check_suite("with open('x') as f, open('y') as g: pass\n")
def test_try_stmt(self):
self.check_suite("try: pass\nexcept: pass\n")
diff --git a/Lib/test/test_with.py b/Lib/test/test_with.py
index bfeb06b..68ae890 100644
--- a/Lib/test/test_with.py
+++ b/Lib/test/test_with.py
@@ -654,12 +654,88 @@
self.fail("ZeroDivisionError should have been raised")
+class NestedWith(unittest.TestCase):
+
+ class Dummy(object):
+ def __init__(self, value=None, gobble=False):
+ if value is None:
+ value = self
+ self.value = value
+ self.gobble = gobble
+ self.enter_called = False
+ self.exit_called = False
+
+ def __enter__(self):
+ self.enter_called = True
+ return self.value
+
+ def __exit__(self, *exc_info):
+ self.exit_called = True
+ self.exc_info = exc_info
+ if self.gobble:
+ return True
+
+ class CtorRaises(object):
+ def __init__(self): raise RuntimeError()
+
+ class EnterRaises(object):
+ def __enter__(self): raise RuntimeError()
+ def __exit__(self, *exc_info): pass
+
+ class ExitRaises(object):
+ def __enter__(self): pass
+ def __exit__(self, *exc_info): raise RuntimeError()
+
+ def testNoExceptions(self):
+ with self.Dummy() as a, self.Dummy() as b:
+ self.assertTrue(a.enter_called)
+ self.assertTrue(b.enter_called)
+ self.assertTrue(a.exit_called)
+ self.assertTrue(b.exit_called)
+
+ def testExceptionInExprList(self):
+ try:
+ with self.Dummy() as a, self.CtorRaises():
+ pass
+ except:
+ pass
+ self.assertTrue(a.enter_called)
+ self.assertTrue(a.exit_called)
+
+ def testExceptionInEnter(self):
+ try:
+ with self.Dummy() as a, self.EnterRaises():
+ self.fail('body of bad with executed')
+ except RuntimeError:
+ pass
+ else:
+ self.fail('RuntimeError not reraised')
+ self.assertTrue(a.enter_called)
+ self.assertTrue(a.exit_called)
+
+ def testExceptionInExit(self):
+ body_executed = False
+ with self.Dummy(gobble=True) as a, self.ExitRaises():
+ body_executed = True
+ self.assertTrue(a.enter_called)
+ self.assertTrue(a.exit_called)
+ self.assertNotEqual(a.exc_info[0], None)
+
+ def testEnterReturnsTuple(self):
+ with self.Dummy(value=(1,2)) as (a1, a2), \
+ self.Dummy(value=(10, 20)) as (b1, b2):
+ self.assertEquals(1, a1)
+ self.assertEquals(2, a2)
+ self.assertEquals(10, b1)
+ self.assertEquals(20, b2)
+
def test_main():
run_unittest(FailureTestCase, NonexceptionalTestCase,
NestedNonexceptionalTestCase, ExceptionalTestCase,
NonLocalFlowControlTestCase,
AssignmentTargetTestCase,
- ExitSwallowsExceptionTestCase)
+ ExitSwallowsExceptionTestCase,
+ NestedWith)
if __name__ == '__main__':