more compiler stuff

--HG--
branch : trunk
diff --git a/jinja2/compiler.py b/jinja2/compiler.py
index 330db56..41577c8 100644
--- a/jinja2/compiler.py
+++ b/jinja2/compiler.py
@@ -8,6 +8,7 @@
     :copyright: Copyright 2008 by Armin Ronacher.
     :license: GNU GPL.
 """
+from copy import copy
 from random import randrange
 from operator import xor
 from cStringIO import StringIO
@@ -74,7 +75,9 @@
 
     def __init__(self, parent=None):
         self.identifiers = Identifiers()
+        self.toplevel = False
         self.parent = parent
+        self.block = parent and parent.block or None
         if parent is not None:
             self.identifiers.declared.update(
                 parent.identifiers.declared |
@@ -83,6 +86,12 @@
                 parent.identifiers.declared_parameter
             )
 
+    def copy(self):
+        """Create a copy of the current one."""
+        rv = copy(self)
+        rv.identifiers = copy(self)
+        return rv
+
     def inspect(self, nodes):
         """Walk the node and check for identifiers."""
         visitor = FrameIdentifierVisitor(self.identifiers)
@@ -113,7 +122,7 @@
         self.identifiers.declared_locally.add(node.name)
 
     # stop traversing at instructions that have their own scope.
-    visit_Block = visit_Call = visit_FilterBlock = \
+    visit_Block = visit_CallBlock = visit_FilterBlock = \
         visit_For = lambda s, n: None
 
 
@@ -139,8 +148,8 @@
     def indent(self):
         self.indentation += 1
 
-    def outdent(self):
-        self.indentation -= 1
+    def outdent(self, step=1):
+        self.indentation -= step
 
     def blockvisit(self, nodes, frame, force_generator=False):
         self.indent()
@@ -170,6 +179,27 @@
             self.new_lines = 1
             self._last_line = node.lineno
 
+    def signature(self, node, frame, have_comma=True):
+        have_comma = have_comma and [True] or []
+        def touch_comma():
+            if have_comma:
+                self.write(', ')
+            else:
+                have_comma.append(True)
+
+        for arg in node.args:
+            touch_comma()
+            self.visit(arg, frame)
+        for kwarg in node.kwargs:
+            touch_comma()
+            self.visit(kwarg, frame)
+        if node.dyn_args:
+            touch_comma()
+            self.visit(node.dyn_args, frame)
+        if node.dyn_kwargs:
+            touch_comma()
+            self.visit(node.dyn_kwargs, frame)
+
     def pull_locals(self, frame, no_indent=False):
         if not no_indent:
             self.indent()
@@ -184,27 +214,35 @@
         assert frame is None, 'no root frame allowed'
         self.writeline('from jinja2.runtime import *')
         self.writeline('filename = %r' % self.filename)
-        self.writeline('context = TemplateContext(global_context, '
+        self.writeline('template_context = TemplateContext(global_context, '
                        'make_undefined, filename)')
 
-        # generate the body render function.
-        self.writeline('def body(context=context):', extra=1)
+        # generate the root render function.
+        self.writeline('def root(context=template_context):', extra=1)
+        self.indent()
+        self.writeline('parent_root = None')
+        self.outdent()
         frame = Frame()
         frame.inspect(node.body)
+        frame.toplevel = True
         self.pull_locals(frame)
         self.blockvisit(node.body, frame, True)
 
-        # top level changes to locals are pushed back to the
-        # context of *this* template for include.
+        # make sure that the parent root is called.
         self.indent()
-        self.writeline('context.from_locals(locals())')
-        self.outdent()
+        self.writeline('if parent_root is not None:')
+        self.indent()
+        self.writeline('for event in parent_root(context):')
+        self.indent()
+        self.writeline('yield event')
+        self.outdent(3)
 
         # at this point we now have the blocks collected and can visit them too.
         for name, block in self.blocks.iteritems():
             block_frame = Frame()
             block_frame.inspect(block.body)
-            self.writeline('def block_%s(context=context):' % name, block, 1)
+            block_frame.block = name
+            self.writeline('def block_%s(context):' % name, block, 1)
             self.pull_locals(block_frame)
             self.blockvisit(block.body, block_frame, True)
 
@@ -215,23 +253,32 @@
                                          node.name, node.lineno,
                                          self.filename)
         self.blocks[node.name] = node
-        self.writeline('for event in block_%s():' % node.name)
+        self.writeline('for event in block_%s(context):' % node.name)
         self.indent()
         self.writeline('yield event')
         self.outdent()
 
     def visit_Extends(self, node, frame):
         """Calls the extender."""
-        self.writeline('extends(', node, 1)
-        self.visit(node.template)
+        if not frame.toplevel:
+            raise TemplateAssertionError('cannot use extend from a non '
+                                         'top-level scope', node.lineno,
+                                         self.filename)
+        self.writeline('if parent_root is not None:')
+        self.indent()
+        self.writeline('raise TemplateRuntimeError(%r)' %
+                       'extended multiple times')
+        self.outdent()
+        self.writeline('parent_root = extends(', node, 1)
+        self.visit(node.template, frame)
         self.write(', globals())')
 
     def visit_For(self, node, frame):
         loop_frame = frame.inner()
         loop_frame.inspect(node.iter_child_nodes())
-        loop_frame.identifiers.add_special('loop')
         extended_loop = bool(node.else_) or \
                         'loop' in loop_frame.identifiers.undeclared
+        loop_frame.identifiers.add_special('loop')
 
         # make sure we "backup" overridden, local identifiers
         # TODO: we should probably optimize this and check if the
@@ -257,9 +304,16 @@
             self.writeline('if l_loop is None:')
             self.blockvisit(node.else_, loop_frame)
 
-        # reset the aliases and clean them up
+        # reset the aliases and clean up
+        delete = set('l_' + x for x in loop_frame.identifiers.declared_locally
+                     | loop_frame.identifiers.declared_parameter)
+        if extended_loop:
+            delete.add('l_loop')
         for name, alias in aliases.iteritems():
-            self.writeline('l_%s = %s; del %s' % (name, alias, alias))
+            self.writeline('l_%s = %s' % (name, alias))
+            delete.add(alias)
+            delete.discard('l_' + name)
+        self.writeline('del %s' % ', '.join(delete))
 
     def visit_If(self, node, frame):
         self.writeline('if ', node)
@@ -270,6 +324,31 @@
             self.writeline('else:')
             self.blockvisit(node.else_, frame)
 
+    def visit_Macro(self, node, frame):
+        macro_frame = frame.inner()
+        macro_frame.inspect(node.body)
+        args = ['l_' + x.name for x in node.args]
+        if 'arguments' in macro_frame.identifiers.undeclared:
+            accesses_arguments = True
+            args.append('l_arguments')
+        else:
+            accesses_arguments = False
+        self.writeline('def macro(%s):' % ', '.join(args), node)
+        self.indent()
+        self.writeline('if 0: yield None')
+        self.outdent()
+        self.blockvisit(node.body, frame)
+        self.newline()
+        if frame.toplevel:
+            self.write('context[%r] = ' % node.name)
+        arg_tuple = ', '.join(repr(x.name) for x in node.args)
+        if len(node.args) == 1:
+            arg_tuple += ','
+        self.write('l_%s = Macro(macro, %r, (%s), %s)' % (
+            node.name, node.name,
+            arg_tuple, accesses_arguments
+        ))
+
     def visit_ExprStmt(self, node, frame):
         self.newline(node)
         self.visit(node, frame)
@@ -320,9 +399,27 @@
                 self.visit(argument, frame)
             self.write(idx == 0 and ',)' or ')')
 
+    def visit_Assign(self, node, frame):
+        self.newline(node)
+        # toplevel assignments however go into the local namespace and
+        # the current template's context.  We create a copy of the frame
+        # here and add a set so that the Name visitor can add the assigned
+        # names here.
+        if frame.toplevel:
+            assignment_frame = frame.copy()
+            assignment_frame.assigned_names = set()
+        else:
+            assignment_frame = frame
+        self.visit(node.target, assignment_frame)
+        self.write(' = ')
+        self.visit(node.node, frame)
+        if frame.toplevel:
+            for name in assignment_frame.assigned_names:
+                self.writeline('context[%r] = l_%s' % (name, name))
+
     def visit_Name(self, node, frame):
-        # at this point we should only have locals left as the
-        # blocks, macros and template body ensure that they are set.
+        if frame.toplevel and node.ctx == 'store':
+            frame.assigned_names.add(node.name)
         self.write('l_' + node.name)
 
     def visit_Const(self, node, frame):
@@ -333,6 +430,15 @@
         else:
             self.write(repr(val))
 
+    def visit_Tuple(self, node, frame):
+        self.write('(')
+        idx = -1
+        for idx, item in enumerate(node.items):
+            if idx:
+                self.write(', ')
+            self.visit(item, frame)
+        self.write(idx == 0 and ',)' or ')')
+
     def binop(operator):
         def visitor(self, node, frame):
             self.write('(')
@@ -373,8 +479,62 @@
         self.visit(node.expr, frame)
 
     def visit_Subscript(self, node, frame):
-        self.write('subscript(')
+        if isinstance(node.arg, nodes.Slice):
+            self.visit(node.node, frame)
+            self.write('[')
+            self.visit(node.arg, frame)
+            self.write(']')
+            return
+        try:
+            const = node.arg.as_const()
+            have_const = True
+        except nodes.Impossible:
+            have_const = False
+        if have_const:
+            if isinstance(const, (int, long, float)):
+                self.visit(node.node, frame)
+                self.write('[%s]' % const)
+                return
+        self.write('subscribe(')
         self.visit(node.node, frame)
         self.write(', ')
-        self.visit(node.arg, frame)
+        if have_const:
+            self.write(repr(const))
+        else:
+            self.visit(node.arg, frame)
+        self.write(', make_undefined)')
+
+    def visit_Slice(self, node, frame):
+        if node.start is not None:
+            self.visit(node.start, frame)
+        self.write(':')
+        if node.stop is not None:
+            self.visit(node.stop, frame)
+        if node.step is not None:
+            self.write(':')
+            self.visit(node.step, frame)
+
+    def visit_Filter(self, node, frame):
+        for filter in node.filters:
+            self.write('context.filters[%r](' % filter.name)
+        self.visit(node.node, frame)
+        for filter in reversed(node.filters):
+            self.signature(filter, frame)
+            self.write(')')
+
+    def visit_Test(self, node, frame):
+        self.write('context.tests[%r](')
+        self.visit(node.node, frame)
+        self.signature(node, frame)
         self.write(')')
+
+    def visit_Call(self, node, frame):
+        self.visit(node.node, frame)
+        self.write('(')
+        self.signature(node, frame, False)
+        self.write(')')
+
+    def visit_Keyword(self, node, frame):
+        self.visit(node.key, frame)
+        self.write('=')
+        self.visit(node.value, frame)