optimizer can optimize filtered for loops now
authorArmin Ronacher <armin.ronacher@active-4.com>
Sun, 13 Apr 2008 14:31:08 +0000 (16:31 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Sun, 13 Apr 2008 14:31:08 +0000 (16:31 +0200)
--HG--
branch : trunk

examples/test_loop_filter.py
jinja2/compiler.py
jinja2/nodes.py
jinja2/optimizer.py
jinja2/runtime.py

index 64f32d357083174e76ae7b2dff224889782bbf72..49c2efcfe2b36365cb9e1ba44a57418de5af6c67 100644 (file)
@@ -2,11 +2,11 @@ from jinja2 import Environment
 
 tmpl = Environment().from_string("""\
 <ul>
-{% for item in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] if item % 2 == 0 %}
+{%- for item in [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] if item % 2 == 0 %}
     <li>{{ loop.index }} / {{ loop.length }}: {{ item }}</li>
-{% endfor %}
+{%- endfor %}
 </ul>
-{{ 1 if foo else 0 }}
+if condition: {{ 1 if foo else 0 }}
 """)
 
 print tmpl.render(foo=True)
index e16617e93ca4aa115ae38dfeee94d325690061e5..3c2347d65c5dda2ba8dccb1630d5359e4a47778a 100644 (file)
@@ -240,16 +240,18 @@ class CodeGenerator(NodeVisitor):
     def outdent(self, step=1):
         self.indentation -= step
 
-    def blockvisit(self, nodes, frame, force_generator=False):
-        self.indent()
-        if force_generator and frame.buffer is None:
+    def blockvisit(self, nodes, frame, indent=True, force_generator=True):
+        if indent:
+            self.indent()
+        if frame.buffer is None and force_generator:
             self.writeline('if 0: yield None')
         try:
             for node in nodes:
                 self.visit(node, frame)
         except CompilerExit:
             pass
-        self.outdent()
+        if indent:
+            self.outdent()
 
     def write(self, x):
         if self.new_lines:
@@ -297,8 +299,8 @@ class CodeGenerator(NodeVisitor):
             self.write('**')
             self.visit(node.dyn_kwargs, frame)
 
-    def pull_locals(self, frame, no_indent=False):
-        if not no_indent:
+    def pull_locals(self, frame, indent=True):
+        if indent:
             self.indent()
         for name in frame.identifiers.undeclared:
             self.writeline('l_%s = context[%r]' % (name, name))
@@ -306,7 +308,7 @@ class CodeGenerator(NodeVisitor):
             self.writeline('f_%s = environment.filters[%r]' % (name, name))
         for name in frame.identifiers.tests:
             self.writeline('t_%s = environment.tests[%r]' % (name, name))
-        if not no_indent:
+        if indent:
             self.outdent()
 
     def collect_shadowed(self, frame):
@@ -394,11 +396,11 @@ class CodeGenerator(NodeVisitor):
         frame = Frame()
         frame.inspect(node.body)
         frame.toplevel = frame.rootlevel = True
-        self.pull_locals(frame)
         self.indent()
+        self.pull_locals(frame, indent=False)
         self.writeline('yield context')
+        self.blockvisit(node.body, frame, indent=False)
         self.outdent()
-        self.blockvisit(node.body, frame)
 
         # make sure that the parent root is called.
         if have_extends:
@@ -421,7 +423,7 @@ class CodeGenerator(NodeVisitor):
             self.writeline('def block_%s(context, environment=environment):'
                            % name, block, 1)
             self.pull_locals(block_frame)
-            self.blockvisit(block.body, block_frame, True)
+            self.blockvisit(block.body, block_frame)
 
         self.writeline('blocks = {%s}' % ', '.join('%r: block_%s' % (x, x)
                                                    for x in self.blocks), extra=1)
@@ -526,7 +528,7 @@ class CodeGenerator(NodeVisitor):
             loop_frame.identifiers.add_special('loop')
 
         aliases = self.collect_shadowed(loop_frame)
-        self.pull_locals(loop_frame, True)
+        self.pull_locals(loop_frame, indent=False)
         if node.else_:
             self.writeline('l_loop = None')
 
@@ -576,14 +578,14 @@ class CodeGenerator(NodeVisitor):
             self.writeline('continue')
             self.outdent(2)
 
-        self.blockvisit(node.body, loop_frame)
+        self.blockvisit(node.body, loop_frame, force_generator=False)
 
         if node.else_:
             self.writeline('if l_loop is None:')
             self.indent()
             self.writeline('l_loop = ' + parent_loop)
             self.outdent()
-            self.blockvisit(node.else_, loop_frame)
+            self.blockvisit(node.else_, loop_frame, force_generator=False)
 
         # reset the aliases if there are any.
         for name, alias in aliases.iteritems():
@@ -603,8 +605,13 @@ class CodeGenerator(NodeVisitor):
         macro_frame = self.function_scoping(node, frame)
         args = macro_frame.arguments
         self.writeline('def macro(%s):' % ', '.join(args), node)
-        self.pull_locals(macro_frame)
-        self.blockvisit(node.body, macro_frame, True)
+        macro_frame.buffer = buf = self.temporary_identifier()
+        self.indent()
+        self.pull_locals(macro_frame, indent=False)
+        self.writeline('%s = []' % buf)
+        self.blockvisit(node.body, macro_frame, indent=False)
+        self.writeline("return TemplateData(u''.join(%s))" % buf)
+        self.outdent()
         self.newline()
         if frame.toplevel:
             self.write('context[%r] = ' % node.name)
@@ -614,7 +621,7 @@ class CodeGenerator(NodeVisitor):
         self.write('l_%s = Macro(macro, %r, (%s), (' % (node.name, node.name,
                                                        arg_tuple))
         for arg in node.defaults:
-            self.visit(arg)
+            self.visit(arg, macro_frame)
             self.write(', ')
         self.write('), %s, %s)' % (
             macro_frame.accesses_arguments and '1' or '0',
@@ -625,7 +632,13 @@ class CodeGenerator(NodeVisitor):
         call_frame = self.function_scoping(node, frame)
         args = call_frame.arguments
         self.writeline('def call(%s):' % ', '.join(args), node)
-        self.blockvisit(node.body, call_frame, node)
+        call_frame.buffer = buf = self.temporary_identifier()
+        self.indent()
+        self.pull_locals(call_frame, indent=False)
+        self.writeline('%s = []' % buf)
+        self.blockvisit(node.body, call_frame, indent=False)
+        self.writeline("return TemplateData(u''.join(%s))" % buf)
+        self.outdent()
         arg_tuple = ', '.join(repr(x.name) for x in node.args)
         if len(node.args) == 1:
             arg_tuple += ','
@@ -647,7 +660,7 @@ class CodeGenerator(NodeVisitor):
         filter_frame.inspect(node.iter_child_nodes())
 
         aliases = self.collect_shadowed(filter_frame)
-        self.pull_locals(filter_frame, True)
+        self.pull_locals(filter_frame, indent=False)
         filter_frame.buffer = buf = self.temporary_identifier()
 
         self.writeline('%s = []' % buf, node)
index 89c6fa24415e2636888fcfd6a445fc3848a2220d..ecc3f1e9f51ea1fe5afa1763b3930e12b55cb503 100644 (file)
@@ -35,6 +35,17 @@ _uaop_to_func = {
     '-':        operator.neg
 }
 
+_cmpop_to_func = {
+    'eq':       operator.eq,
+    'ne':       operator.ne,
+    'gt':       operator.gt,
+    'gteq':     operator.ge,
+    'lt':       operator.lt,
+    'lteq':     operator.le,
+    'in':       operator.contains,
+    'notin':    lambda a, b: not operator.contains(a, b)
+}
+
 
 class Impossible(Exception):
     """Raised if the node could not perform a requested action."""
@@ -484,6 +495,14 @@ class Compare(Expr):
     """{{ foo == bar }}, {{ foo >= bar }} etc."""
     fields = ('expr', 'ops')
 
+    def as_const(self):
+        result = value = self.expr.as_const()
+        for op in self.ops:
+            new_value = op.expr.as_const()
+            result = _cmpop_to_func[op.op](value, new_value)
+            value = new_value
+        return result
+
 
 class Operand(Helper):
     """Operator + expression."""
index 5877c6f0eb3c54606d2749476b80f7e038759bdc..4dbd5d9d27f399d262d728bcc53c1dc04151a988 100644 (file)
@@ -123,6 +123,7 @@ class Optimizer(NodeTransformer):
 
     def visit_For(self, node, context):
         """Loop unrolling for iterable constant values."""
+        fallback = self.generic_visit(node.copy(), context)
         try:
             iterable = self.visit(node.iter, context).as_const()
             # we only unroll them if they have a length and are iterable
@@ -131,11 +132,8 @@ class Optimizer(NodeTransformer):
             # we also don't want unrolling if macros are defined in it
             if node.find(nodes.Macro) is not None:
                 raise TypeError()
-            # XXX: add support for loop test clauses in the optimizer
-            if node.test is not None:
-                raise TypeError()
         except (nodes.Impossible, TypeError):
-            return self.generic_visit(node, context)
+            return fallback
 
         parent = context.get('loop')
         context.push()
@@ -157,6 +155,20 @@ class Optimizer(NodeTransformer):
             else:
                 raise AssertionError('unexpected assignable node')
 
+        if node.test is not None:
+            filtered_sequence = []
+            for item in iterable:
+                context.push()
+                assign(node.target, item)
+                try:
+                    rv = self.visit(node.test.copy(), context).as_const()
+                except:
+                    return fallback
+                context.pop()
+                if rv:
+                    filtered_sequence.append(item)
+            iterable = filtered_sequence
+
         try:
             try:
                 for item, loop in LoopContext(iterable, parent, True):
index 676c1f5842aef94a1c84d1028042f8af10f9e98d..fb2802ee94538f7eb2ea1a0fcce448a302530722 100644 (file)
@@ -15,7 +15,7 @@ except ImportError:
 
 
 __all__ = ['subscribe', 'LoopContext', 'StaticLoopContext', 'TemplateContext',
-           'Macro', 'IncludedTemplate', 'Undefined']
+           'Macro', 'IncludedTemplate', 'Undefined', 'TemplateData']
 
 
 def subscribe(obj, argument):
@@ -190,6 +190,9 @@ class StaticLoopContext(LoopContextBase):
             self.parent
         )
 
+    def __len__(self):
+        return self._length
+
     def make_static(self):
         return self
 
@@ -231,7 +234,7 @@ class Macro(object):
             arguments['l_caller'] = caller
         if self.catch_all:
             arguments['l_arguments'] = kwargs
-        return TemplateData(u''.join(self._func(**arguments)))
+        return self._func(**arguments)
 
     def __repr__(self):
         return '<%s %s>' % (