added loop filtering
authorArmin Ronacher <armin.ronacher@active-4.com>
Sun, 13 Apr 2008 11:16:50 +0000 (13:16 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Sun, 13 Apr 2008 11:16:50 +0000 (13:16 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/nodes.py
jinja2/optimizer.py
jinja2/parser.py
jinja2/runtime.py
test_loop_filter.py [new file with mode: 0644]

index e692184868b9abd69dd98d2ac4724e3a9d148dcf..e16617e93ca4aa115ae38dfeee94d325690061e5 100644 (file)
@@ -29,6 +29,14 @@ operators = {
 }
 
 
+try:
+    exec '(0 if 0 else 0)'
+except SyntaxError:
+    have_condexpr = False
+else:
+    have_condexpr = True
+
+
 def generate(node, environment, filename, stream=None):
     """Generate the python source for a node tree."""
     is_child = node.find(nodes.Extends) is not None
@@ -114,6 +122,7 @@ class Frame(object):
         self.rootlevel = False
         self.parent = parent
         self.buffer = None
+        self.name_overrides = {}
         self.block = parent and parent.block or None
         if parent is not None:
             self.identifiers.declared.update(
@@ -122,11 +131,13 @@ class Frame(object):
                 parent.identifiers.declared_parameter
             )
             self.buffer = parent.buffer
+            self.name_overrides = parent.name_overrides.copy()
 
     def copy(self):
         """Create a copy of the current one."""
         rv = copy(self)
         rv.identifiers = copy(self.identifiers)
+        rv.name_overrides = self.name_overrides.copy()
         return rv
 
     def inspect(self, nodes, hard_scope=False):
@@ -516,21 +527,62 @@ class CodeGenerator(NodeVisitor):
 
         aliases = self.collect_shadowed(loop_frame)
         self.pull_locals(loop_frame, True)
-
-        self.newline(node)
         if node.else_:
             self.writeline('l_loop = None')
-        self.write('for ')
+
+        self.newline(node)
+        self.writeline('for ')
         self.visit(node.target, loop_frame)
         self.write(extended_loop and ', l_loop in LoopContext(' or ' in ')
-        self.visit(node.iter, loop_frame)
+
+        # the expression pointing to the parent loop.  We make the
+        # undefined a bit more debug friendly at the same time.
+        parent_loop = 'loop' in aliases and aliases['loop'] \
+                      or "Undefined('loop', extra=%r)" % \
+                         'the filter section of a loop as well as the ' \
+                         'else block doesn\'t have access to the special ' \
+                         "'loop' variable of the current loop.  Because " \
+                         'there is no parent loop it\'s undefined.'
+
+        # if we have an extened loop and a node test, we filter in the
+        # "outer frame".
+        if extended_loop and node.test is not None:
+            self.write('(')
+            self.visit(node.target, loop_frame)
+            self.write(' for ')
+            self.visit(node.target, loop_frame)
+            self.write(' in ')
+            self.visit(node.iter, loop_frame)
+            self.write(' if (')
+            test_frame = loop_frame.copy()
+            test_frame.name_overrides['loop'] = parent_loop
+            self.visit(node.test, test_frame)
+            self.write('))')
+
+        else:
+            self.visit(node.iter, loop_frame)
+
         if 'loop' in aliases:
             self.write(', ' + aliases['loop'])
         self.write(extended_loop and '):' or ':')
+
+        # tests in not extended loops become a continue
+        if not extended_loop and node.test is not None:
+            self.indent()
+            self.writeline('if ')
+            self.visit(node.test)
+            self.write(':')
+            self.indent()
+            self.writeline('continue')
+            self.outdent(2)
+
         self.blockvisit(node.body, loop_frame)
 
         if node.else_:
             self.writeline('if l_loop is None:')
+            self.indent()
+            self.writeline('l_loop = ' + parent_loop)
+            self.outdent()
             self.blockvisit(node.else_, loop_frame)
 
         # reset the aliases if there are any.
@@ -667,7 +719,7 @@ class CodeGenerator(NodeVisitor):
                         self.write('%s.append(' % frame.buffer)
                     self.write(finalizer + '(')
                     self.visit(item, frame)
-                    self.write(')' * (1 + frame.buffer is not None))
+                    self.write(')' * (1 + (frame.buffer is not None)))
 
         # otherwise we create a format string as this is faster in that case
         else:
@@ -721,8 +773,14 @@ class CodeGenerator(NodeVisitor):
                 self.writeline('context[%r] = l_%s' % (name, name))
 
     def visit_Name(self, node, frame):
-        if frame.toplevel and node.ctx == 'store':
-            frame.assigned_names.add(node.name)
+        if node.ctx == 'store':
+            if frame.toplevel:
+                frame.assigned_names.add(node.name)
+            frame.name_overrides.pop(node.name, None)
+        elif node.ctx == 'load':
+            if node.name in frame.name_overrides:
+                self.write(frame.name_overrides[node.name])
+                return
         self.write('l_' + node.name)
 
     def visit_Const(self, node, frame):
@@ -856,6 +914,24 @@ class CodeGenerator(NodeVisitor):
         self.signature(node, frame)
         self.write(')')
 
+    def visit_CondExpr(self, node, frame):
+        if not have_condexpr:
+            self.write('((')
+            self.visit(node.test, frame)
+            self.write(') and (')
+            self.visit(node.expr1, frame)
+            self.write(',) or (')
+            self.visit(node.expr2, frame)
+            self.write(',))[0]')
+        else:
+            self.write('(')
+            self.visit(node.expr1, frame)
+            self.write(' if ')
+            self.visit(node.test, frame)
+            self.write(' else ')
+            self.visit(node.expr2, frame)
+            self.write(')')
+
     def visit_Call(self, node, frame, extra_kwargs=None):
         self.visit(node.node, frame)
         self.write('(')
index e1857d2290ab2b03f30f1e3f2c0f8641ba8ba723..89c6fa24415e2636888fcfd6a445fc3848a2220d 100644 (file)
@@ -207,7 +207,7 @@ class Extends(Stmt):
 
 class For(Stmt):
     """A node that represents a for loop"""
-    fields = ('target', 'iter', 'body', 'else_', 'recursive')
+    fields = ('target', 'iter', 'body', 'else_', 'test')
 
 
 class If(Stmt):
index d2550f49ef72493ef6505ab94f0fb6ac1a20449f..5877c6f0eb3c54606d2749476b80f7e038759bdc 100644 (file)
@@ -131,6 +131,9 @@ class Optimizer(NodeTransformer):
             # we also don't want unrolling if macros are defined in it
             if node.find(nodes.Macro) is not None:
                 raise TypeError()
+            # XXX: add support for loop test clauses in the optimizer
+            if node.test is not None:
+                raise TypeError()
         except (nodes.Impossible, TypeError):
             return self.generic_visit(node, context)
 
@@ -156,7 +159,7 @@ class Optimizer(NodeTransformer):
 
         try:
             try:
-                for loop, item in LoopContext(iterable, parent, True):
+                for item, loop in LoopContext(iterable, parent, True):
                     context['loop'] = loop.make_static()
                     assign(node.target, item)
                     result.extend(self.visit(n.copy(), context)
index 2d04211bd55d4d6ce9e923d9e97fdceb07a3de15..5185df37221b57f99272b901f41eaf81f97ed8de 100644 (file)
@@ -110,18 +110,17 @@ class Parser(object):
                                       self.filename)
         target.set_ctx('store')
         self.stream.expect('in')
-        iter = self.parse_tuple()
-        if self.stream.current.type is 'recursive':
+        iter = self.parse_tuple(no_condexpr=True)
+        test = None
+        if self.stream.current.type is 'if':
             self.stream.next()
-            recursive = True
-        else:
-            recursive = False
+            test = self.parse_expression()
         body = self.parse_statements(('endfor', 'else'))
         if self.stream.next().type is 'endfor':
             else_ = []
         else:
             else_ = self.parse_statements(('endfor',), drop_needle=True)
-        return nodes.For(target, iter, body, else_, False, lineno=lineno)
+        return nodes.For(target, iter, body, else_, test, lineno=lineno)
 
     def parse_if(self):
         """Parse an if construct."""
@@ -236,8 +235,10 @@ class Parser(object):
         self.end_statement()
         return node
 
-    def parse_expression(self):
+    def parse_expression(self, no_condexpr=False):
         """Parse an expression."""
+        if no_condexpr:
+            return self.parse_or()
         return self.parse_condexpr()
 
     def parse_condexpr(self):
@@ -417,14 +418,19 @@ class Parser(object):
             node = self.parse_postfix(node)
         return node
 
-    def parse_tuple(self, enforce=False, simplified=False):
+    def parse_tuple(self, enforce=False, simplified=False, no_condexpr=False):
         """
         Parse multiple expressions into a tuple. This can also return
         just one expression which is not a tuple. If you want to enforce
         a tuple, pass it enforce=True (currently unused).
         """
         lineno = self.stream.current.lineno
-        parse = simplified and self.parse_primary or self.parse_expression
+        if simplified:
+            parse = self.parse_primary
+        elif no_condexpr:
+            parse = lambda: self.parse_expression(no_condexpr=True)
+        else:
+            parse = self.parse_expression
         args = []
         is_tuple = False
         while 1:
index 238d4cfcf3c64ae676517b6daca4ac5712deabf0..676c1f5842aef94a1c84d1028042f8af10f9e98d 100644 (file)
@@ -25,7 +25,7 @@ def subscribe(obj, argument):
     except (AttributeError, UnicodeError):
         try:
             return obj[argument]
-        except LookupError:
+        except (TypeError, LookupError):
             return Undefined(obj, argument)
 
 
@@ -138,8 +138,9 @@ class LoopContext(LoopContextBase):
 
     def __init__(self, iterable, parent=None, enforce_length=False):
         self._iterable = iterable
+        self._next = iter(iterable).next
         self._length = None
-        self.index0 = 0
+        self.index0 = -1
         self.parent = parent
         if enforce_length:
             len(self)
@@ -152,9 +153,11 @@ class LoopContext(LoopContextBase):
         return StaticLoopContext(self.index0, self.length, parent)
 
     def __iter__(self):
-        for item in self._iterable:
-            yield self, item
-            self.index0 += 1
+        return self
+
+    def next(self):
+        self.index0 += 1
+        return self._next(), self
 
     def __len__(self):
         if self._length is None:
@@ -162,7 +165,8 @@ class LoopContext(LoopContextBase):
                 length = len(self._iterable)
             except TypeError:
                 self._iterable = tuple(self._iterable)
-                length = self.index0 + len(tuple(self._iterable))
+                self._next = iter(self._iterable).next
+                length = len(tuple(self._iterable)) + self.index0 + 1
             self._length = length
         return self._length
 
@@ -217,12 +221,13 @@ class Macro(object):
                     try:
                         value = self.defaults[idx - arg_count]
                     except IndexError:
-                        value = Undefined(name)
+                        value = Undefined(name, extra='parameter not provided')
             arguments['l_' + name] = value
         if self.caller:
             caller = kwargs.pop('caller', None)
             if caller is None:
-                caller = Undefined('caller')
+                caller = Undefined('caller', extra='The macro was called '
+                                   'from an expression and not a call block.')
             arguments['l_caller'] = caller
         if self.catch_all:
             arguments['l_arguments'] = kwargs
@@ -238,14 +243,14 @@ class Macro(object):
 class Undefined(object):
     """The object for undefined values."""
 
-    def __init__(self, name=None, attr=None):
+    def __init__(self, name=None, attr=None, extra=None):
         if attr is None:
             self._undefined_hint = '%r is undefined' % name
-        elif name is None:
-            self._undefined_hint = 'attribute %r is undefined' % name
         else:
             self._undefined_hint = 'attribute %r of %r is undefined' \
                                    % (attr, name)
+        if extra is not None:
+            self._undefined_hint += ' (' + extra + ')'
 
     def fail(self, *args, **kwargs):
         raise TypeError(self._undefined_hint)
diff --git a/test_loop_filter.py b/test_loop_filter.py
new file mode 100644 (file)
index 0000000..64f32d3
--- /dev/null
@@ -0,0 +1,12 @@
+from jinja2 import Environment
+
+tmpl = Environment().from_string("""\
+<ul>
+{% for item in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] if item % 2 == 0 %}
+    <li>{{ loop.index }} / {{ loop.length }}: {{ item }}</li>
+{% endfor %}
+</ul>
+{{ 1 if foo else 0 }}
+""")
+
+print tmpl.render(foo=True)