refactored compiler and improved identifier handling for for-loops
authorArmin Ronacher <armin.ronacher@active-4.com>
Fri, 23 May 2008 14:12:47 +0000 (16:12 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Fri, 23 May 2008 14:12:47 +0000 (16:12 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/nodes.py
tests/test_forloop.py

index 8fcb30f2d336af5f4f1c99aa4b700c4df4a164f0..9bf00445b94290a13097d5f58eb897ab22af9849 100644 (file)
@@ -533,6 +533,11 @@ class CodeGenerator(NodeVisitor):
             self.writeline('%s = l_%s' % (ident, name))
         return aliases
 
+    def restore_shadowed(self, aliases):
+        """Restore all aliases."""
+        for name, alias in aliases.iteritems():
+            self.writeline('l_%s = %s' % (name, alias))
+
     def function_scoping(self, node, frame, children=None,
                          find_special=True):
         """In Jinja a few statements require the help of anonymous
@@ -870,8 +875,7 @@ class CodeGenerator(NodeVisitor):
     def visit_For(self, node, frame):
         # when calculating the nodes for the inner frame we have to exclude
         # the iterator contents from it
-        children = list(node.iter_child_nodes(exclude=('iter',)))
-
+        children = node.iter_child_nodes(exclude=('iter',))
         if node.recursive:
             loop_frame = self.function_scoping(node, frame, children,
                                                find_special=False)
@@ -879,10 +883,20 @@ class CodeGenerator(NodeVisitor):
             loop_frame = frame.inner()
             loop_frame.inspect(children)
 
-        undeclared = find_undeclared(children, ('loop',))
-        extended_loop = node.recursive or node.else_ or 'loop' in undeclared
-        if extended_loop:
-            loop_frame.identifiers.add_special('loop')
+        # try to figure out if we have an extended loop.  An extended loop
+        # is necessary if the loop is in recursive mode if the special loop
+        # variable is accessed in the body.
+        extended_loop = node.recursive or 'loop' in \
+                        find_undeclared(node.iter_child_nodes(
+                            only=('body',)), ('loop',))
+
+        # make sure the loop variable is a special one and raise a template
+        # assertion error if a loop tries to write to loop
+        loop_frame.identifiers.add_special('loop')
+        for name in node.find_all(nodes.Name):
+            if name.ctx == 'store' and name.name == 'loop':
+                self.fail('Can\'t assign to special loop variable '
+                          'in for-loop target', name.lineno)
 
         # if we don't have an recursive loop we have to find the shadowed
         # variables at that point
@@ -898,22 +912,24 @@ class CodeGenerator(NodeVisitor):
 
         self.pull_locals(loop_frame)
         if node.else_:
-            self.writeline('l_loop = None')
-
-        self.newline(node)
-        self.writeline('for ')
+            iteration_indicator = self.temporary_identifier()
+            self.writeline('%s = 1' % iteration_indicator)
+
+        # Create a fake parent loop if the else or test section of a
+        # loop is accessing the special loop variable and no parent loop
+        # exists.
+        if 'loop' not in aliases and 'loop' in find_undeclared(
+           node.iter_child_nodes(only=('else_', 'test')), ('loop',)):
+            self.writeline("l_loop = environment.undefined(%r, name='loop')" %
+                "'loop' is undefined. the filter section of a loop as well " \
+                "as the else block doesn't have access to the special 'loop' "
+                "variable of the current loop.  Because there is no parent "
+                "loop it's undefined.")
+
+        self.writeline('for ', node)
         self.visit(node.target, loop_frame)
         self.write(extended_loop and ', l_loop in LoopContext(' or ' in ')
 
-        # the expression pointing to the parent loop.  We make the
-        # undefined a bit more debug friendly at the same time.
-        parent_loop = 'loop' in aliases and aliases['loop'] \
-                      or "environment.undefined(%r, name='loop')" % "'loop' " \
-                         'is undefined. "the filter section of a loop as well ' \
-                         'as the else block doesn\'t have access to the ' \
-                         "special 'loop' variable of the current loop.  " \
-                         "Because there is no parent loop it's undefined."
-
         # if we have an extened loop and a node test, we filter in the
         # "outer frame".
         if extended_loop and node.test is not None:
@@ -928,7 +944,6 @@ class CodeGenerator(NodeVisitor):
                 self.visit(node.iter, loop_frame)
             self.write(' if (')
             test_frame = loop_frame.copy()
-            self.writeline('l_loop = ' + parent_loop)
             self.visit(node.test, test_frame)
             self.write('))')
 
@@ -954,18 +969,18 @@ class CodeGenerator(NodeVisitor):
 
         self.indent()
         self.blockvisit(node.body, loop_frame, force_generator=True)
+        if node.else_:
+            self.writeline('%s = 0' % iteration_indicator)
         self.outdent()
 
         if node.else_:
-            self.writeline('if l_loop is None:')
+            self.writeline('if %s:' % iteration_indicator)
             self.indent()
-            self.writeline('l_loop = ' + parent_loop)
             self.blockvisit(node.else_, loop_frame, force_generator=False)
             self.outdent()
 
         # reset the aliases if there are any.
-        for name, alias in aliases.iteritems():
-            self.writeline('l_%s = %s' % (name, alias))
+        self.restore_shadowed(aliases)
 
         # if the node was recursive we have to return the buffer contents
         # and start the iteration code
@@ -1008,25 +1023,20 @@ class CodeGenerator(NodeVisitor):
         self.writeline('caller = ')
         self.macro_def(node, call_frame)
         self.start_write(frame, node)
-        self.visit_Call(node.call, call_frame,
-                        extra_kwargs={'caller': 'caller'})
+        self.visit_Call(node.call, call_frame, forward_caller=True)
         self.end_write(frame)
 
     def visit_FilterBlock(self, node, frame):
         filter_frame = frame.inner()
         filter_frame.inspect(node.iter_child_nodes())
-
         aliases = self.collect_shadowed(filter_frame)
         self.pull_locals(filter_frame)
         self.buffer(filter_frame)
-
-        for child in node.body:
-            self.visit(child, filter_frame)
-
+        self.blockvisit(node.body, filter_frame, force_generator=False)
         self.start_write(frame, node)
-        self.visit_Filter(node.filter, filter_frame, 'concat(%s)'
-                          % filter_frame.buffer)
+        self.visit_Filter(node.filter, filter_frame)
         self.end_write(frame)
+        self.restore_shadowed(aliases)
 
     def visit_ExprStmt(self, node, frame):
         self.newline(node)
@@ -1283,7 +1293,7 @@ class CodeGenerator(NodeVisitor):
             self.write(':')
             self.visit(node.step, frame)
 
-    def visit_Filter(self, node, frame, initial=None):
+    def visit_Filter(self, node, frame):
         self.write(self.filters[node.name] + '(')
         func = self.environment.filters.get(node.name)
         if func is None:
@@ -1292,10 +1302,15 @@ class CodeGenerator(NodeVisitor):
             self.write('context, ')
         elif getattr(func, 'environmentfilter', False):
             self.write('environment, ')
-        if isinstance(node.node, nodes.Filter):
-            self.visit_Filter(node.node, frame, initial)
-        elif node.node is None:
-            self.write(initial)
+
+        # if the filter node is None we are inside a filter block
+        # and want to write to the current buffer
+        if node.node is None:
+            if self.environment.autoescape:
+                tmpl = 'Markup(concat(%s))'
+            else:
+                tmpl = 'concat(%s)'
+            self.write(tmpl % frame.buffer)
         else:
             self.visit(node.node, frame)
         self.signature(node, frame)
@@ -1327,11 +1342,12 @@ class CodeGenerator(NodeVisitor):
             self.visit(node.expr2, frame)
             self.write(')')
 
-    def visit_Call(self, node, frame, extra_kwargs=None):
+    def visit_Call(self, node, frame, forward_caller=False):
         if self.environment.sandboxed:
             self.write('environment.call(')
         self.visit(node.node, frame)
         self.write(self.environment.sandboxed and ', ' or '(')
+        extra_kwargs = forward_caller and {'caller': 'caller'} or None
         self.signature(node, frame, False, extra_kwargs)
         self.write(')')
 
index 969d785910db723e9b127d16903a3ff0c88c7173..27c9ddbcd2e6a6862bc9fe3812b58a75d1be18a6 100644 (file)
@@ -116,24 +116,26 @@ class Node(object):
             raise TypeError('unknown attribute %r' %
                             iter(attributes).next())
 
-    def iter_fields(self, exclude=()):
+    def iter_fields(self, exclude=None, only=None):
         """This method iterates over all fields that are defined and yields
         ``(key, value)`` tuples.  Optionally a parameter of ignored fields
         can be provided.
         """
         for name in self.fields:
-            if name not in exclude:
+            if (exclude is only is None) or \
+               (exclude is not None and name not in exclude) or \
+               (only is not None and name in only):
                 try:
                     yield name, getattr(self, name)
                 except AttributeError:
                     pass
 
-    def iter_child_nodes(self, exclude=()):
+    def iter_child_nodes(self, exclude=None, only=None):
         """Iterates over all direct child nodes of the node.  This iterates
         over all fields and yields the values of they are nodes.  If the value
         of a field is a list all the nodes in that list are returned.
         """
-        for field, item in self.iter_fields(exclude):
+        for field, item in self.iter_fields(exclude, only):
             if isinstance(item, list):
                 for n in item:
                     if isinstance(n, Node):
@@ -529,6 +531,9 @@ class CondExpr(Expr):
 class Filter(Expr):
     """This node applies a filter on an expression.  `name` is the name of
     the filter, the rest of the fields are the same as for :class:`Call`.
+
+    If the `node` of a filter is `None` the contents of the last buffer are
+    filtered.  Buffers are created by macros and filter blocks.
     """
     fields = ('node', 'name', 'args', 'kwargs', 'dyn_args', 'dyn_kwargs')
 
index 7469f93efebf45636563d31ca4d6ddbdb2015f16..5c0288d0ac4b06a13328919846f721b7016f9b37 100644 (file)
@@ -7,6 +7,7 @@
     :license: BSD, see LICENSE for more details.
 """
 from py.test import raises
+from jinja2.exceptions import UndefinedError
 
 
 SIMPLE = '''{% for item in seq %}{{ item }}{% endfor %}'''
@@ -30,6 +31,10 @@ LOOPLOOP = '''{% for row in table %}
         [{{ rowloop.index }}|{{ loop.index }}]
     {%- endfor %}
 {%- endfor %}'''
+LOOPERROR1 = '''\
+{% for item in [1] if loop.index == 0 %}...{% endfor %}'''
+LOOPERROR2 = '''\
+{% for item in [] %}...{% else %}{{ loop }}{% endfor %}'''
 
 
 def test_simple(env):
@@ -102,3 +107,10 @@ def test_recursive(env):
 def test_looploop(env):
     tmpl = env.from_string(LOOPLOOP)
     assert tmpl.render(table=['ab', 'cd']) == '[1|1][1|2][2|1][2|2]'
+
+
+def test_loop_errors(env):
+    tmpl = env.from_string(LOOPERROR1)
+    raises(UndefinedError, tmpl.render)
+    tmpl = env.from_string(LOOPERROR2)
+    assert tmpl.render() == ''