nodes have access to environment now
authorArmin Ronacher <armin.ronacher@active-4.com>
Wed, 9 Apr 2008 14:13:39 +0000 (16:13 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Wed, 9 Apr 2008 14:13:39 +0000 (16:13 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/nodes.py
jinja2/optimizer.py
jinja2/parser.py
jinja2/runtime.py
test_optimizer.py

index c963b8f4b256f5174094d8a2b01980502a4b390a..44087f1609d5b5769f5c9a275f8cf824dbc4610a 100644 (file)
@@ -39,11 +39,12 @@ def generate(node, environment, filename, stream=None):
 
 def has_safe_repr(value):
     """Does the node have a safe representation?"""
-    if value is None:
+    if value is None or value is NotImplemented or value is Ellipsis:
         return True
-    if isinstance(value, (int, long, float, basestring, StaticLoopContext)):
+    if isinstance(value, (bool, int, long, float, complex, basestring,
+                          StaticLoopContext)):
         return True
-    if isinstance(value, (tuple, list)):
+    if isinstance(value, (tuple, list, set, frozenset)):
         for item in value:
             if not has_safe_repr(item):
                 return False
@@ -148,7 +149,7 @@ class FrameIdentifierVisitor(NodeVisitor):
             if not self.identifiers.is_declared(node.name, self.hard_scope):
                 self.identifiers.undeclared.add(node.name)
 
-    def visit_FilterCall(self, node):
+    def visit_Filter(self, node):
         if not node.name in self.identifiers.declared_filter:
             uf = self.identifiers.undeclared_filter.get(node.name, 0) + 1
             if uf > 1:
@@ -589,15 +590,13 @@ class CodeGenerator(NodeVisitor):
             self.visit(node.step, frame)
 
     def visit_Filter(self, node, frame):
-        for filter in node.filters:
-            if filter.name in frame.identifiers.declared_filter:
-                self.write('f_%s(' % filter.name)
-            else:
-                self.write('context.filter[%r](' % filter.name)
+        if node.name in frame.identifiers.declared_filter:
+            self.write('f_%s(' % node.name)
+        else:
+            self.write('context.filter[%r](' % node.name)
         self.visit(node.node, frame)
-        for filter in reversed(node.filters):
-            self.signature(filter, frame)
-            self.write(')')
+        self.signature(node, frame)
+        self.write(')')
 
     def visit_Test(self, node, frame):
         self.write('context.tests[%r](')
index 6c16c8beac4440917e07ee7200f4a8b97d94d6ca..72edeb62a73c5a9a46cb5c44db5eee472fbfe23d 100644 (file)
@@ -60,7 +60,7 @@ class Node(object):
     """Baseclass for all Jinja nodes."""
     __metaclass__ = NodeType
     fields = ()
-    attributes = ('lineno',)
+    attributes = ('lineno', 'environment')
 
     def __init__(self, *args, **kw):
         if args:
@@ -125,6 +125,14 @@ class Node(object):
                 node.ctx = ctx
             todo.extend(node.iter_child_nodes())
 
+    def set_environment(self, environment):
+        """Set the environment for all nodes."""
+        todo = deque([self])
+        while todo:
+            node = todo.popleft()
+            node.environment = environment
+            todo.extend(node.iter_child_nodes())
+
     def __repr__(self):
         return '%s(%s)' % (
             self.__class__.__name__,
@@ -288,7 +296,7 @@ class Const(Literal):
         return self.value
 
     @classmethod
-    def from_untrusted(cls, value, lineno=None, silent=False):
+    def from_untrusted(cls, value, lineno=None, environment=None):
         """Return a const object if the value is representable as
         constant value in the generated code, otherwise it will raise
         an `Impossible` exception."""
@@ -297,7 +305,7 @@ class Const(Literal):
             if silent:
                 return
             raise Impossible()
-        return cls(value, lineno=lineno)
+        return cls(value, lineno=lineno, environment=environment)
 
 
 class Tuple(Literal):
@@ -357,12 +365,29 @@ class CondExpr(Expr):
 
 class Filter(Expr):
     """{{ foo|bar|baz }}"""
-    fields = ('node', 'filters')
-
+    fields = ('node', 'name', 'args', 'kwargs', 'dyn_args', 'dyn_kwargs')
 
-class FilterCall(Expr):
-    """{{ |bar() }}"""
-    fields = ('name', 'args', 'kwargs', 'dyn_args', 'dyn_kwargs')
+    def as_const(self):
+        filter = self.environment.filters.get(self.name)
+        if filter is None or getattr(filter, 'contextfilter', False):
+            raise nodes.Impossible()
+        obj = self.node.as_const()
+        args = [x.as_const() for x in self.args]
+        kwargs = dict(x.as_const() for x in self.kwargs)
+        if self.dyn_args is not None:
+            try:
+                args.extend(self.dyn_args.as_const())
+            except:
+                raise Impossible()
+        if self.dyn_kwargs is not None:
+            try:
+                kwargs.update(self.dyn_kwargs.as_const())
+            except:
+                raise Impossible()
+        try:
+            return filter(obj, *args, **kwargs)
+        except:
+            raise nodes.Impossible()
 
 
 class Test(Expr):
@@ -385,7 +410,7 @@ class Call(Expr):
                 raise Impossible()
         if self.dyn_kwargs is not None:
             try:
-                dyn_kwargs.update(self.dyn_kwargs.as_const())
+                kwargs.update(self.dyn_kwargs.as_const())
             except:
                 raise Impossible()
         try:
index ee2f08266ac3dc5f7ebce18cfc84ab4e3f637b5b..13b0fbc1e7d752864171e3b01eaf73b81bd418d6 100644 (file)
@@ -67,22 +67,6 @@ class Optimizer(NodeTransformer):
     def visit_Block(self, node, context):
         return self.generic_visit(node, context.blank())
 
-    def visit_Filter(self, node, context):
-        """Try to evaluate filters if possible."""
-        # XXX: nonconstant arguments?  not-called visitors?  generic visit!
-        try:
-            x = self.visit(node.node, context).as_const()
-        except nodes.Impossible:
-            return self.generic_visit(node, context)
-        for filter in reversed(node.filters):
-            # XXX: call filters with arguments
-            x = self.environment.filters[filter.name](x)
-            # XXX: don't optimize context dependent filters
-        try:
-            return nodes.Const.from_untrusted(x, lineno=node.lineno)
-        except nodes.Impossible:
-            return self.generic_visit(node)
-
     def visit_For(self, node, context):
         """Loop unrolling for iterable constant values."""
         try:
@@ -140,13 +124,14 @@ class Optimizer(NodeTransformer):
         return node.else_
 
     def visit_Name(self, node, context):
-        if node.ctx == 'load':
-            try:
-                return nodes.Const.from_untrusted(context[node.name],
-                                                  lineno=node.lineno)
-            except (KeyError, nodes.Impossible):
-                pass
-        return node
+        if node.ctx != 'load':
+            return node
+        try:
+            return nodes.Const.from_untrusted(context[node.name],
+                                              lineno=node.lineno,
+                                              environment=self.environment)
+        except (KeyError, nodes.Impossible):
+            return node
 
     def visit_Assign(self, node, context):
         try:
@@ -185,12 +170,14 @@ class Optimizer(NodeTransformer):
         node = self.generic_visit(node, context)
         try:
             return nodes.Const.from_untrusted(node.as_const(),
-                                              lineno=node.lineno)
+                                              lineno=node.lineno,
+                                              environment=self.environment)
         except nodes.Impossible:
             return node
     visit_Add = visit_Sub = visit_Mul = visit_Div = visit_FloorDiv = \
     visit_Pow = visit_Mod = visit_And = visit_Or = visit_Pos = visit_Neg = \
-    visit_Not = visit_Compare = visit_Subscribt = visit_Call = fold
+    visit_Not = visit_Compare = visit_Subscript = visit_Call = \
+    visit_Filter = visit_Test = fold
     del fold
 
 
index 426e5541aaac99c45de85c705f6dd0f9ff2740d6..a7f0e980afda4b1129ef068b86d1891a27448609 100644 (file)
@@ -23,8 +23,7 @@ _statement_end_tokens = set(['elif', 'else', 'endblock', 'endfilter',
 
 
 class Parser(object):
-    """
-    The template parser class.
+    """The template parser class.
 
     Transforms sourcecode into an abstract syntax tree.
     """
@@ -572,7 +571,6 @@ class Parser(object):
 
     def parse_filter(self, node):
         lineno = self.stream.current.type
-        filters = []
         while self.stream.current.type == 'pipe':
             self.stream.next()
             token = self.stream.expect('name')
@@ -582,10 +580,9 @@ class Parser(object):
                 args = []
                 kwargs = []
                 dyn_args = dyn_kwargs = None
-            filters.append(nodes.FilterCall(token.value, args, kwargs,
-                                            dyn_args, dyn_kwargs,
-                                            lineno=token.lineno))
-        return nodes.Filter(node, filters)
+            node = nodes.Filter(node, token.value, args, kwargs, dyn_args,
+                                dyn_kwargs, lineno=token.lineno)
+        return node
 
     def parse_test(self, node):
         token = self.stream.expect('is')
@@ -653,4 +650,6 @@ class Parser(object):
 
     def parse(self):
         """Parse the whole template into a `Template` node."""
-        return nodes.Template(self.subparse(), lineno=1)
+        result = nodes.Template(self.subparse(), lineno=1)
+        result.set_environment(self.environment)
+        return result
index 0e4200500d1f6879a435025c9bd858dc5c1baf72..40ecdf0283efabae7243d82b02c2595e726b69cb 100644 (file)
@@ -148,9 +148,7 @@ class StaticLoopContext(LoopContextBase):
 
 
 class Macro(object):
-    """
-    Wraps a macor
-    """
+    """Wraps a macro."""
 
     def __init__(self, func, name, arguments, defaults, catch_all):
         self.func = func
index acf8d73ef193c998d55f93656939391ecf37e279..d0d855110a6eedcf51464681bebc061aa6f598d4 100644 (file)
@@ -20,9 +20,8 @@ ast = env.parse("""
     {% navigation = [('#foo', 'Foo'), ('#bar', 'Bar'), ('#baz', 42 * 2 + 23)] %}
     <ul>
     {% for key, value in navigation %}
-        <li>{{ test(loop) }}: <a href="{{ key|e }}">{{ value|e }}</a></li>
+        <li>{{ loop.index }}: <a href="{{ key[1:].upper()|e }}">{{ value|e }}</a></li>
     {% endfor %}
-    {{ "Hello World".upper() }}
     </ul>
 """)
 print ast