improved static optimizer
authorArmin Ronacher <armin.ronacher@active-4.com>
Tue, 8 Apr 2008 22:40:05 +0000 (00:40 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Tue, 8 Apr 2008 22:40:05 +0000 (00:40 +0200)
--HG--
branch : trunk

jinja2/optimizer.py
jinja2/runtime.py
test_optimizer.py

index e7f6ac700c90144a0ef9bd8d17a8caf7a6eb10f9..c508727f9cafd89cd9d5824f3bed6fe1262dc30b 100644 (file)
@@ -25,54 +25,134 @@ from jinja2.visitor import NodeVisitor, NodeTransformer
 from jinja2.runtime import subscribe
 
 
+class ContextStack(object):
+    """Simple compile time context implementation."""
+
+    def __init__(self, initial=None):
+        self.stack = [{}]
+        if initial is not None:
+            self.stack.insert(0, initial)
+
+    def push(self):
+        self.stack.append({})
+
+    def pop(self):
+        self.stack.pop()
+
+    def __getitem__(self, key):
+        for level in reversed(self.stack):
+            if key in level:
+                return level[key]
+        raise KeyError(key)
+
+    def __setitem__(self, key, value):
+        self.stack[-1][key] = value
+
+    def blank(self):
+        """Return a new context with nothing but the root scope."""
+        return ContextStack(self.stack[0])
+
+
 class Optimizer(NodeTransformer):
 
-    def __init__(self, environment, context={}):
+    def __init__(self, environment):
         self.environment = environment
-        self.context = context
 
-    def visit_Filter(self, node):
+    def visit_Filter(self, node, context):
         """Try to evaluate filters if possible."""
+        # XXX: nonconstant arguments?  not-called visitors?  generic visit!
         try:
-            x = self.visit(node.node).as_const()
+            x = self.visit(node.node, context).as_const()
         except nodes.Impossible:
-            return node
+            return self.generic_visit(node, context)
         for filter in reversed(node.filters):
             # XXX: call filters with arguments
             x = self.environment.filters[filter.name](self.environment, x)
             # XXX: don't optimize context dependent filters
         return nodes.Const(x)
 
-    def visit_For(self, node):
-        """Loop unrolling for constant values."""
+    def visit_For(self, node, context):
+        """Loop unrolling for iterable constant values."""
         try:
-            iter = self.visit(node.iter).as_const()
-        except nodes.Impossible:
-            return node
+            iterable = iter(self.visit(node.iter, context).as_const())
+        except (nodes.Impossible, TypeError):
+            return self.generic_visit(node, context)
+        context.push()
         result = []
+        # XXX: tuple unpacking (for key, value in foo)
         target = node.target.name
-        for item in iter:
-            # XXX: take care of variable scopes
-            self.context[target] = item
-            result.extend(self.visit(n) for n in deepcopy(node.body))
+        iterated = False
+        for item in iterable:
+            context[target] = item
+            result.extend(self.visit(n, context) for n in deepcopy(node.body))
+            iterated = True
+        if not iterated and node.else_:
+            result.extend(self.visit(n, context) for n in deepcopy(node.else_))
+        context.pop()
         return result
 
-    def visit_Name(self, node):
-        # XXX: take care of variable scopes!
-        if node.name not in self.context:
+    def visit_If(self, node, context):
+        try:
+            val = self.visit(node.test, context).as_const()
+        except nodes.Impossible:
+            return self.generic_visit(node, context)
+        if val:
+            return node.body
+        return node.else_
+
+    def visit_Name(self, node, context):
+        if node.ctx == 'load':
+            try:
+                return nodes.Const(context[node.name], lineno=node.lineno)
+            except KeyError:
+                pass
+        return node
+
+    def visit_Assign(self, node, context):
+        try:
+            target = node.target = self.generic_visit(node.target, context)
+            value = self.generic_visit(node.node, context).as_const()
+        except nodes.Impossible:
             return node
-        return nodes.Const(self.context[node.name])
 
-    def visit_Subscript(self, node):
+        result = []
+        lineno = node.lineno
+        def walk(target, value):
+            if isinstance(target, nodes.Name):
+                const_value = nodes.Const(value, lineno=lineno)
+                result.append(nodes.Assign(target, const_value, lineno=lineno))
+                context[target.name] = value
+            elif isinstance(target, nodes.Tuple):
+                try:
+                    value = tuple(value)
+                except TypeError:
+                    raise nodes.Impossible()
+                if len(target) != len(value):
+                    raise nodes.Impossible()
+                for name, val in zip(target, value):
+                    walk(name, val)
+            else:
+                raise AssertionError('unexpected assignable node')
+
         try:
-            item = self.visit(node.node).as_const()
-            arg = self.visit(node.arg).as_const()
+            walk(target, value)
         except nodes.Impossible:
             return node
-        # XXX: what does the 3rd parameter mean?
-        return nodes.Const(subscribe(item, arg, None))
+        return result
+
+    def visit_Subscript(self, node, context):
+        if node.ctx == 'load':
+            try:
+                item = self.visit(node.node, context).as_const()
+                arg = self.visit(node.arg, context).as_const()
+            except nodes.Impossible:
+                return self.generic_visit(node, context)
+            return nodes.Const(subscribe(item, arg, 'load'))
+        return self.generic_visit(node, context)
 
 
-def optimize(node, environment, context={}):
-    optimizer = Optimizer(environment, context=context)
-    return optimizer.visit(node)
+def optimize(node, environment, context_hint=None):
+    """The context hint can be used to perform an static optimization
+    based on the context given."""
+    optimizer = Optimizer(environment)
+    return optimizer.visit(node, ContextStack(context_hint))
index c1c34e18d6f2d65f14a5f1ac99bd537e882e1439..5a9764edc33504f2127dee5c0e4ceeb7b1af19a2 100644 (file)
@@ -24,8 +24,8 @@ def extends(template, namespace):
 def subscribe(obj, argument, undefined_factory):
     """Get an item or attribute of an object."""
     try:
-        return getattr(obj, argument)
-    except AttributeError:
+        return getattr(obj, str(argument))
+    except (AttributeError, UnicodeError):
         try:
             return obj[argument]
         except LookupError:
index 06ee4406404c87de09fc68379a3879a64a0dca61..ddb0fa0aba4b9f77f0c26186e98b48ad57c27487 100644 (file)
@@ -16,12 +16,19 @@ ast = env.parse("""
     {% for forum in forums %}
         {{ readstatus(forum.id) }} {{ forum.id|e }} {{ forum.name|e }}
     {% endfor %}
+
+    {% navigation = [('#foo', 'Foo'), ('#bar', 'Bar')] %}
+    <ul>
+    {% for item in navigation %}
+        <li><a href="{{ item[0] }}">{{ item[1] }}</a></li>
+    {% endfor %}
+    </ul>
 """)
 print ast
 print
 print generate(ast, env, "foo.html")
 print
-ast = optimize(ast, env, context={'forums': forums})
+ast = optimize(ast, env, context_hint={'forums': forums})
 print ast
 print
 print generate(ast, env, "foo.html")