updated filters: wordwraps uses the wordwrap module and urlize marks the result as...
[jinja2.git] / jinja2 / optimizer.py
index c508727f9cafd89cd9d5824f3bed6fe1262dc30b..283d1fae5637493aec61f722a995bec87b4edfd9 100644 (file)
@@ -3,30 +3,35 @@
     jinja2.optimizer
     ~~~~~~~~~~~~~~~~
 
-    This module tries to optimize template trees by:
+    The jinja optimizer is currently trying to constant fold a few expressions
+    and modify the AST in place so that it should be easier to evaluate it.
 
-        * eliminating constant nodes
-        * evaluating filters and macros on constant nodes
-        * unroll loops on constant values
-        * replace variables which are already known (because they doesn't
-          change often and you want to prerender a template) with constants
+    Because the AST does not contain all the scoping information and the
+    compiler has to find that out, we cannot do all the optimizations we
+    want.  For example loop unrolling doesn't work because unrolled loops would
+    have a different scoping.
 
-    After the optimation you will get a new, simplier template which can
-    be saved again for later rendering. But even if you don't want to
-    prerender a template, this module might speed up your templates a bit
-    if you are using a lot of constants.
+    The solution would be a second syntax tree that has the scoping rules stored.
 
-    :copyright: Copyright 2008 by Christoph Hack.
+    :copyright: Copyright 2008 by Christoph Hack, Armin Ronacher.
     :license: GNU GPL.
 """
-from copy import deepcopy
 from jinja2 import nodes
 from jinja2.visitor import NodeVisitor, NodeTransformer
-from jinja2.runtime import subscribe
+from jinja2.runtime import LoopContext
+from jinja2.utils import concat
+
+
+def optimize(node, environment, context_hint=None):
+    """The context hint can be used to perform an static optimization
+    based on the context given."""
+    optimizer = Optimizer(environment)
+    return optimizer.visit(node, ContextStack(context_hint))
 
 
 class ContextStack(object):
     """Simple compile time context implementation."""
+    undefined = object()
 
     def __init__(self, initial=None):
         self.stack = [{}]
@@ -39,10 +44,30 @@ class ContextStack(object):
     def pop(self):
         self.stack.pop()
 
+    def get(self, key, default=None):
+        try:
+            return self[key]
+        except KeyError:
+            return default
+
+    def undef(self, name):
+        if name in self:
+            self[name] = self.undefined
+
+    def __contains__(self, key):
+        try:
+            self[key]
+        except KeyError:
+            return False
+        return True
+
     def __getitem__(self, key):
         for level in reversed(self.stack):
             if key in level:
-                return level[key]
+                rv = level[key]
+                if rv is self.undefined:
+                    raise KeyError(key)
+                return rv
         raise KeyError(key)
 
     def __setitem__(self, key, value):
@@ -58,38 +83,45 @@ class Optimizer(NodeTransformer):
     def __init__(self, environment):
         self.environment = environment
 
-    def visit_Filter(self, node, context):
-        """Try to evaluate filters if possible."""
-        # XXX: nonconstant arguments?  not-called visitors?  generic visit!
+    def visit_Block(self, node, context):
+        block_context = context.blank()
+        for name in 'super', 'self':
+            block_context.undef(name)
+        return self.generic_visit(node, block_context)
+
+    def visit_For(self, node, context):
+        context.push()
+        context.undef('loop')
         try:
-            x = self.visit(node.node, context).as_const()
-        except nodes.Impossible:
             return self.generic_visit(node, context)
-        for filter in reversed(node.filters):
-            # XXX: call filters with arguments
-            x = self.environment.filters[filter.name](self.environment, x)
-            # XXX: don't optimize context dependent filters
-        return nodes.Const(x)
+        finally:
+            context.pop()
 
-    def visit_For(self, node, context):
-        """Loop unrolling for iterable constant values."""
+    def visit_Macro(self, node, context):
+        context.push()
+        for name in 'varargs', 'kwargs', 'caller':
+            context.undef(name)
         try:
-            iterable = iter(self.visit(node.iter, context).as_const())
-        except (nodes.Impossible, TypeError):
             return self.generic_visit(node, context)
+        finally:
+            context.pop()
+
+    def visit_CallBlock(self, node, context):
         context.push()
-        result = []
-        # XXX: tuple unpacking (for key, value in foo)
-        target = node.target.name
-        iterated = False
-        for item in iterable:
-            context[target] = item
-            result.extend(self.visit(n, context) for n in deepcopy(node.body))
-            iterated = True
-        if not iterated and node.else_:
-            result.extend(self.visit(n, context) for n in deepcopy(node.else_))
-        context.pop()
-        return result
+        for name in 'varargs', 'kwargs':
+            context.undef(name)
+        try:
+            return self.generic_visit(node, context)
+        finally:
+            context.pop()
+
+    def visit_FilterBlock(self, node, context):
+        """Try to filter a block at compile time."""
+        context.push()
+        try:
+            return self.generic_visit(node, context)
+        finally:
+            context.pop()
 
     def visit_If(self, node, context):
         try:
@@ -97,62 +129,53 @@ class Optimizer(NodeTransformer):
         except nodes.Impossible:
             return self.generic_visit(node, context)
         if val:
-            return node.body
-        return node.else_
+            body = node.body
+        else:
+            body = node.else_
+        result = []
+        for node in body:
+            result.extend(self.visit_list(node, context))
+        return result
 
     def visit_Name(self, node, context):
-        if node.ctx == 'load':
-            try:
-                return nodes.Const(context[node.name], lineno=node.lineno)
-            except KeyError:
-                pass
-        return node
-
-    def visit_Assign(self, node, context):
+        if node.ctx != 'load':
+            # something overwrote the variable, we can no longer use
+            # the constant from the context
+            context.undef(node.name)
+            return node
         try:
-            target = node.target = self.generic_visit(node.target, context)
-            value = self.generic_visit(node.node, context).as_const()
-        except nodes.Impossible:
+            return nodes.Const.from_untrusted(context[node.name],
+                                              lineno=node.lineno,
+                                              environment=self.environment)
+        except (KeyError, nodes.Impossible):
             return node
 
-        result = []
-        lineno = node.lineno
-        def walk(target, value):
-            if isinstance(target, nodes.Name):
-                const_value = nodes.Const(value, lineno=lineno)
-                result.append(nodes.Assign(target, const_value, lineno=lineno))
-                context[target.name] = value
-            elif isinstance(target, nodes.Tuple):
-                try:
-                    value = tuple(value)
-                except TypeError:
-                    raise nodes.Impossible()
-                if len(target) != len(value):
-                    raise nodes.Impossible()
-                for name, val in zip(target, value):
-                    walk(name, val)
+    def visit_Import(self, node, context):
+        rv = self.generic_visit(node, context)
+        context.undef(node.target)
+        return rv
+
+    def visit_FromImport(self, node, context):
+        rv = self.generic_visit(node, context)
+        for name in node.names:
+            if isinstance(name, tuple):
+                context.undef(name[1])
             else:
-                raise AssertionError('unexpected assignable node')
+                context.undef(name)
+        return rv
 
+    def fold(self, node, context):
+        """Do constant folding."""
+        node = self.generic_visit(node, context)
         try:
-            walk(target, value)
+            return nodes.Const.from_untrusted(node.as_const(),
+                                              lineno=node.lineno,
+                                              environment=self.environment)
         except nodes.Impossible:
             return node
-        return result
-
-    def visit_Subscript(self, node, context):
-        if node.ctx == 'load':
-            try:
-                item = self.visit(node.node, context).as_const()
-                arg = self.visit(node.arg, context).as_const()
-            except nodes.Impossible:
-                return self.generic_visit(node, context)
-            return nodes.Const(subscribe(item, arg, 'load'))
-        return self.generic_visit(node, context)
-
 
-def optimize(node, environment, context_hint=None):
-    """The context hint can be used to perform an static optimization
-    based on the context given."""
-    optimizer = Optimizer(environment)
-    return optimizer.visit(node, ContextStack(context_hint))
+    visit_Add = visit_Sub = visit_Mul = visit_Div = visit_FloorDiv = \
+    visit_Pow = visit_Mod = visit_And = visit_Or = visit_Pos = visit_Neg = \
+    visit_Not = visit_Compare = visit_Subscript = visit_Call = \
+    visit_Filter = visit_Test = visit_CondExpr = fold
+    del fold