simplified undefined behavior for better compile time processing
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 9 Apr 2008 13:03:29 +0000 (15:03 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 9 Apr 2008 13:03:29 +0000 (15:03 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/environment.py
jinja2/nodes.py
jinja2/optimizer.py
jinja2/runtime.py
test_optimizer.py

index 4f5ff0b5e1ceee05fccf0d9a1c9f1e83e72d3889..c963b8f4b256f5174094d8a2b01980502a4b390a 100644 (file)
@@ -14,6 +14,7 @@ from cStringIO import StringIO
 from jinja2 import nodes
 from jinja2.visitor import NodeVisitor, NodeTransformer
 from jinja2.exceptions import TemplateAssertionError
+from jinja2.runtime import StaticLoopContext
 
 
 operators = {
@@ -36,6 +37,27 @@ def generate(node, environment, filename, stream=None):
         return generator.stream.getvalue()
 
 
+def has_safe_repr(value):
+    """Does the node have a safe representation?"""
+    if value is None:
+        return True
+    if isinstance(value, (int, long, float, basestring, StaticLoopContext)):
+        return True
+    if isinstance(value, (tuple, list)):
+        for item in value:
+            if not has_safe_repr(item):
+                return False
+        return True
+    elif isinstance(value, dict):
+        for key, value in value.iteritems():
+            if not has_safe_repr(key):
+                return False
+            if not has_safe_repr(value):
+                return False
+        return True
+    return False
+
+
 class Identifiers(object):
     """Tracks the status of identifiers in frames."""
 
@@ -235,7 +257,7 @@ class CodeGenerator(NodeVisitor):
         self.writeline('from jinja2.runtime import *')
         self.writeline('filename = %r' % self.filename)
         self.writeline('template_context = TemplateContext(global_context, '
-                       'make_undefined, filename)')
+                       'filename)')
 
         # generate the root render function.
         self.writeline('def root(context=template_context):', extra=1)
@@ -397,7 +419,7 @@ class CodeGenerator(NodeVisitor):
         for arg in node.defaults:
             self.visit(arg)
             self.write(', ')
-        self.write('), %r, make_undefined)' % accesses_arguments)
+        self.write('), %r)' % accesses_arguments)
 
     def visit_ExprStmt(self, node, frame):
         self.newline(node)
@@ -554,7 +576,7 @@ class CodeGenerator(NodeVisitor):
             self.write(repr(const))
         else:
             self.visit(node.arg, frame)
-        self.write(', make_undefined)')
+        self.write(')')
 
     def visit_Slice(self, node, frame):
         if node.start is not None:
index 77f6047e8292f268a9975bae6b7540404c8049a5..77e0d03453699cd66b218458015a1314a1e4d1c8 100644 (file)
@@ -10,6 +10,7 @@
 """
 from jinja2.lexer import Lexer
 from jinja2.parser import Parser
+from jinja2.runtime import Undefined
 from jinja2.defaults import DEFAULT_FILTERS, DEFAULT_TESTS, DEFAULT_NAMESPACE
 
 
@@ -67,6 +68,9 @@ class Environment(object):
         self.tests = DEFAULT_TESTS.copy()
         self.globals = DEFAULT_NAMESPACE.copy()
 
+        # the factory that creates the undefined object
+        self.undefined_factory = Undefined
+
         # create lexer
         self.lexer = Lexer(self)
 
index dc8cc0b8819806a3ba251ba3a6bd1fcb0f7edc33..6c16c8beac4440917e07ee7200f4a8b97d94d6ca 100644 (file)
@@ -16,6 +16,7 @@ import operator
 from itertools import chain, izip
 from collections import deque
 from copy import copy
+from jinja2.runtime import Undefined, subscribe
 
 
 _binop_to_func = {
@@ -286,6 +287,18 @@ class Const(Literal):
     def as_const(self):
         return self.value
 
+    @classmethod
+    def from_untrusted(cls, value, lineno=None, silent=False):
+        """Return a const object if the value is representable as
+        constant value in the generated code, otherwise it will raise
+        an `Impossible` exception."""
+        from compiler import has_safe_repr
+        if not has_safe_repr(value):
+            if silent:
+                return
+            raise Impossible()
+        return cls(value, lineno=lineno)
+
 
 class Tuple(Literal):
     """For loop unpacking and some other things like multiple arguments
@@ -361,14 +374,35 @@ class Call(Expr):
     """{{ foo(bar) }}"""
     fields = ('node', 'args', 'kwargs', 'dyn_args', 'dyn_kwargs')
 
+    def as_const(self):
+        obj = self.node.as_const()
+        args = [x.as_const() for x in self.args]
+        kwargs = dict(x.as_const() for x in self.kwargs)
+        if self.dyn_args is not None:
+            try:
+                args.extend(self.dyn_args.as_const())
+            except:
+                raise Impossible()
+        if self.dyn_kwargs is not None:
+            try:
+                dyn_kwargs.update(self.dyn_kwargs.as_const())
+            except:
+                raise Impossible()
+        try:
+            return obj(*args, **kwargs)
+        except:
+            raise nodes.Impossible()
+
 
 class Subscript(Expr):
     """{{ foo.bar }} and {{ foo['bar'] }} etc."""
     fields = ('node', 'arg', 'ctx')
 
     def as_const(self):
+        if self.ctx != 'load':
+            raise Impossible()
         try:
-            return self.node.as_const()[self.node.as_const()]
+            return subscribe(self.node.as_const(), self.arg.as_const())
         except:
             raise Impossible()
 
@@ -380,6 +414,13 @@ class Slice(Expr):
     """1:2:3 etc."""
     fields = ('start', 'stop', 'step')
 
+    def as_const(self):
+        def const(obj):
+            if obj is None:
+                return obj
+            return obj.as_const()
+        return slice(const(self.start), const(self.stop), const(self.step))
+
 
 class Concat(Expr):
     """For {{ foo ~ bar }}.  Concatenates strings."""
index ee2969cc9682fa67f63449da91514d47fb59ebaf..592a2948408aa7e8ad92ffd490215092de847ece 100644 (file)
@@ -78,7 +78,10 @@ class Optimizer(NodeTransformer):
             # XXX: call filters with arguments
             x = self.environment.filters[filter.name](self.environment, x)
             # XXX: don't optimize context dependent filters
-        return nodes.Const(x)
+        try:
+            return nodes.Const.from_untrusted(x, lineno=node.lineno)
+        except nodes.Impossible:
+            return self.generic_visit(node)
 
     def visit_For(self, node, context):
         """Loop unrolling for iterable constant values."""
@@ -139,8 +142,9 @@ class Optimizer(NodeTransformer):
     def visit_Name(self, node, context):
         if node.ctx == 'load':
             try:
-                return nodes.Const(context[node.name], lineno=node.lineno)
-            except KeyError:
+                return nodes.Const.from_untrusted(context[node.name],
+                                                  lineno=node.lineno)
+            except (KeyError, nodes.Impossible):
                 pass
         return node
 
@@ -155,8 +159,8 @@ class Optimizer(NodeTransformer):
         lineno = node.lineno
         def walk(target, value):
             if isinstance(target, nodes.Name):
-                const_value = nodes.Const(value, lineno=lineno)
-                result.append(nodes.Assign(target, const_value, lineno=lineno))
+                const = nodes.Const.from_untrusted(value, lineno=lineno)
+                result.append(nodes.Assign(target, const, lineno=lineno))
                 context[target.name] = value
             elif isinstance(target, nodes.Tuple):
                 try:
@@ -180,22 +184,14 @@ class Optimizer(NodeTransformer):
         """Do constant folding."""
         node = self.generic_visit(node, context)
         try:
-            return nodes.Const(node.as_const(), lineno=node.lineno)
+            return nodes.Const.from_untrusted(node.as_const(),
+                                              lineno=node.lineno)
         except nodes.Impossible:
             return node
     visit_Add = visit_Sub = visit_Mul = visit_Div = visit_FloorDiv = \
     visit_Pow = visit_Mod = visit_And = visit_Or = visit_Pos = visit_Neg = \
-    visit_Not = visit_Compare = fold
-
-    def visit_Subscript(self, node, context):
-        if node.ctx == 'load':
-            try:
-                item = self.visit(node.node, context).as_const()
-                arg = self.visit(node.arg, context).as_const()
-            except nodes.Impossible:
-                return self.generic_visit(node, context)
-            return nodes.Const(subscribe(item, arg, 'load'))
-        return self.generic_visit(node, context)
+    visit_Not = visit_Compare = visit_Subscribt = visit_Call = fold
+    del fold
 
 
 def optimize(node, environment, context_hint=None):
index fd22395a7c6b33145b45106b6c2fdd2c182aaaff..0e4200500d1f6879a435025c9bd858dc5c1baf72 100644 (file)
@@ -14,17 +14,15 @@ except ImportError:
     defaultdict = None
 
 
-# contains only the variables the template will import automatically, not the
-# objects injected by the evaluation loop (such as undefined objects)
 __all__ = ['extends', 'subscribe', 'LoopContext', 'StaticLoopContext',
-           'TemplateContext', 'Macro']
+           'TemplateContext', 'Macro', 'Undefined']
 
 
 def extends(template, namespace):
     """This loads a template (and evaluates it) and replaces the blocks."""
 
 
-def subscribe(obj, argument, undefined_factory):
+def subscribe(obj, argument):
     """Get an item or attribute of an object."""
     try:
         return getattr(obj, str(argument))
@@ -32,7 +30,7 @@ def subscribe(obj, argument, undefined_factory):
         try:
             return obj[argument]
         except LookupError:
-            return undefined_factory(attr=argument)
+            return Undefined(obj, argument)
 
 
 class TemplateContext(dict):
@@ -43,10 +41,9 @@ class TemplateContext(dict):
     the exported variables for example).
     """
 
-    def __init__(self, globals, undefined_factory, filename):
+    def __init__(self, globals, filename):
         dict.__init__(self, globals)
         self.exported = set()
-        self.undefined_factory = undefined_factory
         self.filename = filename
         self.filters = {}
         self.tests = {}
@@ -71,10 +68,10 @@ class TemplateContext(dict):
         def __getitem__(self, name):
             if name in self:
                 return self[name]
-            return self.undefined_factory(name)
+            return Undefined(name)
     else:
         def __missing__(self, key):
-            return self.undefined_factory(key)
+            return Undefined(key)
 
 
 class LoopContextBase(object):
@@ -128,6 +125,10 @@ class LoopContext(LoopContextBase):
 
 
 class StaticLoopContext(LoopContextBase):
+    """The static loop context is used in the optimizer to "freeze" the
+    status of an iteration.  The only reason for this object is if the
+    loop object is accessed in a non static way (eg: becomes part of a
+    function call)."""
 
     def __init__(self, index0, length, parent):
         self.index0 = index0
@@ -135,6 +136,7 @@ class StaticLoopContext(LoopContextBase):
         self._length = length
 
     def __repr__(self):
+        """The repr is used by the optimizer to dump the object."""
         return 'StaticLoopContext(%r, %r, %r)' % (
             self.index0,
             self._length,
@@ -150,14 +152,12 @@ class Macro(object):
     Wraps a macor
     """
 
-    def __init__(self, func, name, arguments, defaults, catch_all, \
-                 undefined_factory):
+    def __init__(self, func, name, arguments, defaults, catch_all):
         self.func = func
         self.name = name
         self.arguments = arguments
         self.defaults = defaults
         self.catch_all = catch_all
-        self.undefined_factory = undefined_factory
 
     def __call__(self, *args, **kwargs):
         arg_count = len(self.arguments)
@@ -175,7 +175,7 @@ class Macro(object):
                     try:
                         value = self.defaults[idx - arg_count]
                     except IndexError:
-                        value = self.undefined_factory(name)
+                        value = Undefined(name)
             arguments['l_' + name] = arg
         if self.catch_all:
             arguments['l_arguments'] = kwargs
@@ -183,7 +183,7 @@ class Macro(object):
 
 
 class Undefined(object):
-    """The default undefined behavior."""
+    """The object for undefined values."""
 
     def __init__(self, name=None, attr=None):
         if attr is None:
index 36f2011fe1b42d38aa897e9d2135a5f90f636265..acf8d73ef193c998d55f93656939391ecf37e279 100644 (file)
@@ -22,6 +22,7 @@ ast = env.parse("""
     {% for key, value in navigation %}
         <li>{{ test(loop) }}: <a href="{{ key|e }}">{{ value|e }}</a></li>
     {% endfor %}
+    {{ "Hello World".upper() }}
     </ul>
 """)
 print ast