all unittests pass, the special and dependency lookups have their own visitors now...
authorArmin Ronacher <armin.ronacher@active-4.com>
Sun, 27 Apr 2008 19:28:03 +0000 (21:28 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Sun, 27 Apr 2008 19:28:03 +0000 (21:28 +0200)
--HG--
branch : trunk

examples/bench.py
jinja2/compiler.py
jinja2/environment.py
jinja2/filters.py
jinja2/nodes.py
jinja2/optimizer.py
jinja2/parser.py
jinja2/runtime.py
tests/test_filters.py
tests/test_macros.py

index 15830a0b728d4411aeb7f09120030457f9f4f6f3..e27489c31443e159963b7ead0d3f4c2216f48e56 100644 (file)
@@ -295,7 +295,9 @@ sys.stdout.write('\r' + '\n'.join((
     __doc__,
     '-' * 80
 )) + '\n')
-for test in 'jinja', 'tenjin', 'mako', 'spitfire', 'django', 'genshi', 'cheetah':
+
+
+for test in 'jinja', 'mako', 'tenjin', 'spitfire', 'django', 'genshi', 'cheetah':
     if locals()['test_' + test] is None:
         sys.stdout.write('    %-20s*not installed*\n' % test)
         continue
index a1ffdecab65837cf9e8b3a0faed18edefae38f6c..8c699f5d5bafcd31642d0c62477f99fbdc9e3ed5 100644 (file)
@@ -32,7 +32,6 @@ operators = {
     'notin':    'not in'
 }
 
-
 try:
     exec '(0 if 0 else 0)'
 except SyntaxError:
@@ -71,6 +70,19 @@ def has_safe_repr(value):
     return False
 
 
+def find_undeclared(nodes, names):
+    """Check if the names passed are accessed undeclared.  The return value
+    is a set of all the undeclared names from the sequence of names found.
+    """
+    visitor = UndeclaredNameVisitor(names)
+    try:
+        for node in nodes:
+            visitor.visit(node)
+    except VisitorExit:
+        pass
+    return visitor.undeclared
+
+
 class Identifiers(object):
     """Tracks the status of identifiers in frames."""
 
@@ -93,10 +105,6 @@ class Identifiers(object):
         # names that are declared by parameters
         self.declared_parameter = set()
 
-        # filters/tests that are referenced
-        self.filters = set()
-        self.tests = set()
-
     def add_special(self, name):
         """Register a special name like `loop`."""
         self.undeclared.discard(name)
@@ -136,10 +144,6 @@ class Frame(object):
         # buffer.
         self.buffer = None
 
-        # if a frame has name_overrides, all read access to a name in this
-        # dict is redirected to a string expression.
-        self.name_overrides = {}
-
         # the name of the block we're in, otherwise None.
         self.block = parent and parent.block or None
 
@@ -157,28 +161,21 @@ class Frame(object):
                 self.identifiers.declared
             )
             self.buffer = parent.buffer
-            self.name_overrides = parent.name_overrides.copy()
 
     def copy(self):
         """Create a copy of the current one."""
         rv = copy(self)
         rv.identifiers = copy(self.identifiers)
-        rv.name_overrides = self.name_overrides.copy()
         return rv
 
-    def inspect(self, nodes, with_depenencies=False, hard_scope=False):
+    def inspect(self, nodes, hard_scope=False):
         """Walk the node and check for identifiers.  If the scope is hard (eg:
         enforce on a python level) overrides from outer scopes are tracked
         differently.
-
-        Per default filters and tests (dependencies) are not tracked.  That's
-        the case because filters and tests are absolutely immutable and so we
-        can savely use them in closures too.  The `Template` and `Block`
-        visitor visits the frame with dependencies to collect them.
         """
         visitor = FrameIdentifierVisitor(self.identifiers, hard_scope)
         for node in nodes:
-            visitor.visit(node, True, with_depenencies)
+            visitor.visit(node)
 
     def inner(self):
         """Return an inner frame."""
@@ -194,6 +191,51 @@ class Frame(object):
         return rv
 
 
+class VisitorExit(RuntimeError):
+    """Exception used by the `UndeclaredNameVisitor` to signal a stop."""
+
+
+class DependencyFinderVisitor(NodeVisitor):
+    """A visitor that collects filter and test calls."""
+
+    def __init__(self):
+        self.filters = set()
+        self.tests = set()
+
+    def visit_Filter(self, node):
+        self.generic_visit(node)
+        self.filters.add(node.name)
+
+    def visit_Test(self, node):
+        self.generic_visit(node)
+        self.tests.add(node.name)
+
+    def visit_Block(self, node):
+        """Stop visiting at blocks."""
+
+
+class UndeclaredNameVisitor(NodeVisitor):
+    """A visitor that checks if a name is accessed without being
+    declared.  This is different from the frame visitor as it will
+    not stop at closure frames.
+    """
+
+    def __init__(self, names):
+        self.names = set(names)
+        self.undeclared = set()
+
+    def visit_Name(self, node):
+        if node.ctx == 'load' and node.name in self.names:
+            self.undeclared.add(node.name)
+            if self.undeclared == self.names:
+                raise VisitorExit()
+        else:
+            self.names.discard(node.name)
+
+    def visit_Block(self, node):
+        """Stop visiting a blocks."""
+
+
 class FrameIdentifierVisitor(NodeVisitor):
     """A visitor for `Frame.inspect`."""
 
@@ -201,63 +243,50 @@ class FrameIdentifierVisitor(NodeVisitor):
         self.identifiers = identifiers
         self.hard_scope = hard_scope
 
-    def visit_Name(self, node, visit_ident, visit_deps):
+    def visit_Name(self, node):
         """All assignments to names go through this function."""
-        if visit_ident:
-            if node.ctx in ('store', 'param'):
-                self.identifiers.declared_locally.add(node.name)
-            elif node.ctx == 'load' and not \
-                 self.identifiers.is_declared(node.name, self.hard_scope):
-                self.identifiers.undeclared.add(node.name)
-
-    def visit_Filter(self, node, visit_ident, visit_deps):
-        if visit_deps:
-            self.generic_visit(node, visit_ident, True)
-            self.identifiers.filters.add(node.name)
-
-    def visit_Test(self, node, visit_ident, visit_deps):
-        if visit_deps:
-            self.generic_visit(node, visit_ident, True)
-            self.identifiers.tests.add(node.name)
-
-    def visit_Macro(self, node, visit_ident, visit_deps):
-        if visit_ident:
+        if node.ctx in ('store', 'param'):
             self.identifiers.declared_locally.add(node.name)
+        elif node.ctx == 'load' and not \
+             self.identifiers.is_declared(node.name, self.hard_scope):
+            self.identifiers.undeclared.add(node.name)
 
-    def visit_Import(self, node, visit_ident, visit_deps):
-        if visit_ident:
-            self.generic_visit(node, True, visit_deps)
-            self.identifiers.declared_locally.add(node.target)
-
-    def visit_FromImport(self, node, visit_ident, visit_deps):
-        if visit_ident:
-            self.generic_visit(node, True, visit_deps)
-            for name in node.names:
-                if isinstance(name, tuple):
-                    self.identifiers.declared_locally.add(name[1])
-                else:
-                    self.identifiers.declared_locally.add(name)
+    def visit_Macro(self, node):
+        self.generic_visit(node)
+        self.identifiers.declared_locally.add(node.name)
+
+    def visit_Import(self, node):
+        self.generic_visit(node)
+        self.identifiers.declared_locally.add(node.target)
+
+    def visit_FromImport(self, node):
+        self.generic_visit(node)
+        for name in node.names:
+            if isinstance(name, tuple):
+                self.identifiers.declared_locally.add(name[1])
+            else:
+                self.identifiers.declared_locally.add(name)
 
-    def visit_Assign(self, node, visit_ident, visit_deps):
+    def visit_Assign(self, node):
         """Visit assignments in the correct order."""
-        self.visit(node.node, visit_ident, visit_deps)
-        self.visit(node.target, visit_ident, visit_deps)
+        self.visit(node.node)
+        self.visit(node.target)
 
-    def visit_For(self, node, visit_ident, visit_deps):
+    def visit_For(self, node):
         """Visiting stops at for blocks.  However the block sequence
         is visited as part of the outer scope.
         """
-        if visit_ident:
-            self.visit(node.iter, True, visit_deps)
-            if visit_deps:
-                for child in node.iter_child_nodes(exclude=('iter',)):
-                    self.visit(child, False, True)
+        self.visit(node.iter)
+
+    def visit_CallBlock(self, node):
+        for child in node.iter_child_nodes(exclude=('body',)):
+            self.visit(child)
 
-    def ident_stop(self, node, visit_ident, visit_deps):
-        if visit_deps:
-            self.generic_visit(node, False, True)
-    visit_CallBlock = visit_FilterBlock = ident_stop
-    visit_Block = lambda s, n, a, b: None
+    def visit_FilterBlock(self, node):
+        self.visit(node.filter)
+
+    def visit_Block(self, node):
+        """Stop visiting at blocks."""
 
 
 class CompilerExit(Exception):
@@ -325,15 +354,11 @@ class CodeGenerator(NodeVisitor):
         """Outdent by step."""
         self._indentation -= step
 
-    def blockvisit(self, nodes, frame, indent=True, force_generator=True):
-        """Visit a list of nodes as block in a frame.  Per default the
-        code is indented, but this can be disabled by setting the indent
-        parameter to False.  If the current frame is no buffer a dummy
-        ``if 0: yield None`` is written automatically unless the
-        force_generator parameter is set to False.
+    def blockvisit(self, nodes, frame, force_generator=True):
+        """Visit a list of nodes as block in a frame.  If the current frame
+        is no buffer a dummy ``if 0: yield None`` is written automatically
+        unless the force_generator parameter is set to False.
         """
-        if indent:
-            self.indent()
         if frame.buffer is None and force_generator:
             self.writeline('if 0: yield None')
         try:
@@ -341,8 +366,6 @@ class CodeGenerator(NodeVisitor):
                 self.visit(node, frame)
         except CompilerExit:
             pass
-        if indent:
-            self.outdent()
 
     def write(self, x):
         """Write a string into the output stream."""
@@ -423,7 +446,6 @@ class CodeGenerator(NodeVisitor):
                 self.write(', ')
             if extra_kwargs is not None:
                 for key, value in extra_kwargs.iteritems():
-                    touch_comma()
                     self.write('%r: %s, ' % (key, value))
             if node.dyn_kwargs is not None:
                 self.write('}, **')
@@ -437,21 +459,20 @@ class CodeGenerator(NodeVisitor):
             self.write('**')
             self.visit(node.dyn_kwargs, frame)
 
-    def pull_locals(self, frame, indent=True):
-        """Pull all the references identifiers into the local scope.
-        This affects regular names, filters and tests.  If indent is
-        set to False, no automatic indentation will take place.
-        """
-        if indent:
-            self.indent()
+    def pull_locals(self, frame):
+        """Pull all the references identifiers into the local scope."""
         for name in frame.identifiers.undeclared:
             self.writeline('l_%s = context[%r]' % (name, name))
-        for name in frame.identifiers.filters:
+
+    def pull_dependencies(self, nodes):
+        """Pull all the dependencies."""
+        visitor = DependencyFinderVisitor()
+        for node in nodes:
+            visitor.visit(node)
+        for name in visitor.filters:
             self.writeline('f_%s = environment.filters[%r]' % (name, name))
-        for name in frame.identifiers.tests:
+        for name in visitor.tests:
             self.writeline('t_%s = environment.tests[%r]' % (name, name))
-        if indent:
-            self.outdent()
 
     def collect_shadowed(self, frame):
         """This function returns all the shadowed variables in a dict
@@ -467,7 +488,7 @@ class CodeGenerator(NodeVisitor):
             self.writeline('%s = l_%s' % (ident, name))
         return aliases
 
-    def function_scoping(self, node, frame):
+    def function_scoping(self, node, frame, children=None):
         """In Jinja a few statements require the help of anonymous
         functions.  Those are currently macros and call blocks and in
         the future also recursive loops.  As there is currently
@@ -479,8 +500,12 @@ class CodeGenerator(NodeVisitor):
 
         This will return the modified frame.
         """
+        # we have to iterate twice over it, make sure that works
+        if children is None:
+            children = node.iter_child_nodes()
+        children = list(children)
         func_frame = frame.inner()
-        func_frame.inspect(node.iter_child_nodes(), hard_scope=True)
+        func_frame.inspect(children, hard_scope=True)
 
         # variables that are undeclared (accessed before declaration) and
         # declared locally *and* part of an outside scope raise a template
@@ -511,15 +536,17 @@ class CodeGenerator(NodeVisitor):
         func_frame.accesses_caller = False
         func_frame.arguments = args = ['l_' + x.name for x in node.args]
 
-        if 'caller' in func_frame.identifiers.undeclared:
+        undeclared = find_undeclared(children, ('caller', 'kwargs', 'varargs'))
+
+        if 'caller' in undeclared:
             func_frame.accesses_caller = True
             func_frame.identifiers.add_special('caller')
             args.append('l_caller')
-        if 'kwargs' in func_frame.identifiers.undeclared:
+        if 'kwargs' in undeclared:
             func_frame.accesses_kwargs = True
             func_frame.identifiers.add_special('kwargs')
             args.append('l_kwargs')
-        if 'varargs' in func_frame.identifiers.undeclared:
+        if 'varargs' in undeclared:
             func_frame.accesses_varargs = True
             func_frame.identifiers.add_special('varargs')
             args.append('l_varargs')
@@ -547,18 +574,20 @@ class CodeGenerator(NodeVisitor):
 
         # generate the root render function.
         self.writeline('def root(context, environment=environment):', extra=1)
-        if have_extends:
-            self.indent()
-            self.writeline('parent_template = None')
-            self.outdent()
 
         # process the root
         frame = Frame()
-        frame.inspect(node.body, with_depenencies=True)
+        frame.inspect(node.body)
         frame.toplevel = frame.rootlevel = True
         self.indent()
-        self.pull_locals(frame, indent=False)
-        self.blockvisit(node.body, frame, indent=False)
+        if have_extends:
+            self.writeline('parent_template = None')
+        self.pull_locals(frame)
+        self.pull_dependencies(node.body)
+        if 'self' in find_undeclared(node.body, ('self',)):
+            frame.identifiers.add_special('self')
+            self.writeline('l_self = TemplateReference(context)')
+        self.blockvisit(node.body, frame)
         self.outdent()
 
         # make sure that the parent root is called.
@@ -576,15 +605,23 @@ class CodeGenerator(NodeVisitor):
         # at this point we now have the blocks collected and can visit them too.
         for name, block in self.blocks.iteritems():
             block_frame = Frame()
-            block_frame.inspect(block.body, with_depenencies=True)
+            block_frame.inspect(block.body)
             block_frame.block = name
-            block_frame.identifiers.add_special('super')
-            block_frame.name_overrides['super'] = 'context.super(%r, ' \
-                'block_%s)' % (name, name)
             self.writeline('def block_%s(context, environment=environment):'
                            % name, block, 1)
+            self.indent()
+            undeclared = find_undeclared(block.body, ('self', 'super'))
+            if 'self' in undeclared:
+                block_frame.identifiers.add_special('self')
+                self.writeline('l_self = TemplateReference(context)')
+            if 'super' in undeclared:
+                block_frame.identifiers.add_special('super')
+                self.writeline('l_super = context.super(%r, '
+                               'block_%s)' % (name, name))
             self.pull_locals(block_frame)
+            self.pull_dependencies(block.body)
             self.blockvisit(block.body, block_frame)
+            self.outdent()
 
         self.writeline('blocks = {%s}' % ', '.join('%r: block_%s' % (x, x)
                                                    for x in self.blocks),
@@ -606,7 +643,8 @@ class CodeGenerator(NodeVisitor):
                 self.writeline('if parent_template is None:')
                 self.indent()
                 level += 1
-        self.writeline('for event in context.blocks[%r][-1](context):' % node.name)
+        self.writeline('for event in context.blocks[%r][-1](context):' %
+                       node.name, node)
         self.indent()
         if frame.buffer is None:
             self.writeline('yield event')
@@ -666,7 +704,7 @@ class CodeGenerator(NodeVisitor):
         self.visit(node.template, frame)
         self.write(', %r)' % self.name)
         self.writeline('for event in included_template.root_render_func('
-                       'included_template.new_context(context.get_root())):')
+                       'included_template.new_context(context.parent, True)):')
         self.indent()
         if frame.buffer is None:
             self.writeline('yield event')
@@ -682,7 +720,7 @@ class CodeGenerator(NodeVisitor):
         self.write('environment.get_template(')
         self.visit(node.template, frame)
         self.write(', %r).include(context)' % self.name)
-        if frame.toplevel:
+        if frame.toplevel and not node.target.startswith('__'):
             self.writeline('context.exported_vars.discard(%r)' % node.target)
 
     def visit_FromImport(self, node, frame):
@@ -707,7 +745,8 @@ class CodeGenerator(NodeVisitor):
             self.outdent()
             if frame.toplevel:
                 self.writeline('context.vars[%r] = l_%s' % (alias, alias))
-                self.writeline('context.exported_vars.discard(%r)' % alias)
+                if not alias.startswith('__'):
+                    self.writeline('context.exported_vars.discard(%r)' % alias)
 
     def visit_For(self, node, frame):
         loop_frame = frame.inner()
@@ -718,7 +757,7 @@ class CodeGenerator(NodeVisitor):
             loop_frame.identifiers.add_special('loop')
 
         aliases = self.collect_shadowed(loop_frame)
-        self.pull_locals(loop_frame, indent=False)
+        self.pull_locals(loop_frame)
         if node.else_:
             self.writeline('l_loop = None')
 
@@ -747,7 +786,7 @@ class CodeGenerator(NodeVisitor):
             self.visit(node.iter, loop_frame)
             self.write(' if (')
             test_frame = loop_frame.copy()
-            test_frame.name_overrides['loop'] = parent_loop
+            self.writeline('l_loop = ' + parent_loop)
             self.visit(node.test, test_frame)
             self.write('))')
 
@@ -766,14 +805,16 @@ class CodeGenerator(NodeVisitor):
             self.writeline('continue')
             self.outdent(2)
 
+        self.indent()
         self.blockvisit(node.body, loop_frame, force_generator=True)
+        self.outdent()
 
         if node.else_:
             self.writeline('if l_loop is None:')
             self.indent()
             self.writeline('l_loop = ' + parent_loop)
-            self.outdent()
             self.blockvisit(node.else_, loop_frame, force_generator=False)
+            self.outdent()
 
         # reset the aliases if there are any.
         for name, alias in aliases.iteritems():
@@ -784,10 +825,14 @@ class CodeGenerator(NodeVisitor):
         self.writeline('if ', node)
         self.visit(node.test, if_frame)
         self.write(':')
+        self.indent()
         self.blockvisit(node.body, if_frame)
+        self.outdent()
         if node.else_:
             self.writeline('else:')
+            self.indent()
             self.blockvisit(node.else_, if_frame)
+            self.outdent()
 
     def visit_Macro(self, node, frame):
         macro_frame = self.function_scoping(node, frame)
@@ -795,14 +840,15 @@ class CodeGenerator(NodeVisitor):
         self.writeline('def macro(%s):' % ', '.join(args), node)
         macro_frame.buffer = buf = self.temporary_identifier()
         self.indent()
-        self.pull_locals(macro_frame, indent=False)
+        self.pull_locals(macro_frame)
         self.writeline('%s = []' % buf)
-        self.blockvisit(node.body, macro_frame, indent=False)
+        self.blockvisit(node.body, macro_frame)
         self.writeline("return Markup(concat(%s))" % buf)
         self.outdent()
         self.newline()
         if frame.toplevel:
-            self.write('context.exported_vars.add(%r)' % node.name)
+            if not node.name.startswith('__'):
+                self.write('context.exported_vars.add(%r)' % node.name)
             self.writeline('context.vars[%r] = ' % node.name)
         arg_tuple = ', '.join(repr(x.name) for x in node.args)
         if len(node.args) == 1:
@@ -819,14 +865,15 @@ class CodeGenerator(NodeVisitor):
         ))
 
     def visit_CallBlock(self, node, frame):
-        call_frame = self.function_scoping(node, frame)
+        call_frame = self.function_scoping(node, frame, node.iter_child_nodes
+                                           (exclude=('call',)))
         args = call_frame.arguments
         self.writeline('def call(%s):' % ', '.join(args), node)
         call_frame.buffer = buf = self.temporary_identifier()
         self.indent()
-        self.pull_locals(call_frame, indent=False)
+        self.pull_locals(call_frame)
         self.writeline('%s = []' % buf)
-        self.blockvisit(node.body, call_frame, indent=False)
+        self.blockvisit(node.body, call_frame)
         self.writeline("return Markup(concat(%s))" % buf)
         self.outdent()
         arg_tuple = ', '.join(repr(x.name) for x in node.args)
@@ -835,7 +882,7 @@ class CodeGenerator(NodeVisitor):
         self.writeline('caller = Macro(environment, call, None, (%s), (' %
                        arg_tuple)
         for arg in node.defaults:
-            self.visit(arg)
+            self.visit(arg, call_frame)
             self.write(', ')
         self.write('), %s, %s, 0)' % (
             call_frame.accesses_kwargs and '1' or '0',
@@ -855,7 +902,7 @@ class CodeGenerator(NodeVisitor):
         filter_frame.inspect(node.iter_child_nodes())
 
         aliases = self.collect_shadowed(filter_frame)
-        self.pull_locals(filter_frame, indent=False)
+        self.pull_locals(filter_frame)
         filter_frame.buffer = buf = self.temporary_identifier()
 
         self.writeline('%s = []' % buf, node)
@@ -985,17 +1032,12 @@ class CodeGenerator(NodeVisitor):
         if frame.toplevel:
             for name in assignment_frame.assigned_names:
                 self.writeline('context.vars[%r] = l_%s' % (name, name))
-                self.writeline('context.exported_vars.add(%r)' % name)
+                if not name.startswith('__'):
+                    self.writeline('context.exported_vars.add(%r)' % name)
 
     def visit_Name(self, node, frame):
-        if node.ctx == 'store':
-            if frame.toplevel:
-                frame.assigned_names.add(node.name)
-            frame.name_overrides.pop(node.name, None)
-        elif node.ctx == 'load':
-            if node.name in frame.name_overrides:
-                self.write(frame.name_overrides[node.name])
-                return
+        if node.ctx == 'store' and frame.toplevel:
+            frame.assigned_names.add(node.name)
         self.write('l_' + node.name)
 
     def visit_Const(self, node, frame):
index ef66d1b35997b8de98074cd11daa2d5cc5ebace2..6a00fda2c27dcd96790200f7686cc4274871406b 100644 (file)
@@ -35,6 +35,7 @@ def get_spontaneous_environment(*args):
     if env is not None:
         return env
     _spontaneous_environments[args] = env = Environment(*args)
+    env.shared = True
     return env
 
 
@@ -76,6 +77,10 @@ class Environment(object):
     #: have a look at jinja2.sandbox
     sandboxed = False
 
+    #: shared environments have this set to `True`.  A shared environment
+    #: must not be modified
+    shared = False
+
     def __init__(self,
                  block_start_string='{%',
                  block_end_string='%}',
@@ -341,24 +346,35 @@ class Template(object):
             exc_type, exc_value, tb = translate_exception(sys.exc_info())
             raise exc_type, exc_value, tb
 
-    def new_context(self, vars):
-        """Create a new template context for this template."""
-        return TemplateContext(self.environment, dict(self.globals, **vars),
-                               self.name, self.blocks)
+    def new_context(self, vars=None, shared=False):
+        """Create a new template context for this template.  The vars
+        provided will be passed to the template.  Per default the globals
+        are added to the context, if shared is set to `True` the data
+        provided is used as parent namespace.  This is used to share the
+        same globals in multiple contexts without consuming more memory.
+        (This works because the context does not modify the parent dict)
+        """
+        if vars is None:
+            vars = {}
+        if shared:
+            parent = vars
+        else:
+            parent = dict(self.globals, **vars)
+        return TemplateContext(self.environment, parent, self.name,
+                               self.blocks)
 
-    def include(self, context=None):
+    def include(self, vars=None):
         """Include this template.  When passed a template context or dict
         the template is evaluated in that context and an `IncludedTemplate`
         object is returned.  This object then exposes all the exported
         variables as attributes and renders the contents of the template
         when converted to unicode.
         """
-        if context is None:
-            context = self.new_context({})
-        elif isinstance(context, TemplateContext):
-            context = self.new_context(context.get_root())
+        if isinstance(vars, TemplateContext):
+            context = TemplateContext(self.environment, vars.parent,
+                                      self.name, self.blocks)
         else:
-            context = self.new_context(context)
+            context = self.new_context(vars)
         return IncludedTemplate(self, context)
 
     def get_corresponding_lineno(self, lineno):
index 801b3502658b675597dc3093ca929a0274f0ca88..c4c108ee3a6d215661d3a8209b8734c8490681d2 100644 (file)
@@ -581,8 +581,7 @@ def do_groupby(environment, value, attribute):
     in common.
     """
     expr = lambda x: environment.subscribe(x, attribute)
-    return sorted(map(_GroupTuple, groupby(sorted(value, key=expr), expr)),
-                  key=itemgetter('grouper'))
+    return sorted(map(_GroupTuple, groupby(sorted(value, key=expr), expr)))
 
 
 class _GroupTuple(tuple):
@@ -590,6 +589,9 @@ class _GroupTuple(tuple):
     grouper = property(itemgetter(0))
     list = property(itemgetter(1))
 
+    def __new__(cls, (key, value)):
+        return tuple.__new__(cls, (key, list(value)))
+
 
 FILTERS = {
     'replace':              do_replace,
index 69a156f6d562a3e3576a2f69bdde8d623eee5068..f87bbb8fc906d822efcd98f245e7a7329dcec86d 100644 (file)
@@ -244,7 +244,7 @@ class Macro(Stmt):
 
 class CallBlock(Stmt):
     """A node that represents am extended macro call."""
-    fields = ('call', 'body')
+    fields = ('call', 'args', 'defaults', 'body')
 
 
 class Set(Stmt):
index f52b77f8e5b0c84b7218a73a06780cedb6732a4b..784c3a8c5d59f9ee75ecbc159e8de56ac0b21ec8 100644 (file)
@@ -83,41 +83,44 @@ class Optimizer(NodeTransformer):
         self.environment = environment
 
     def visit_Block(self, node, context):
-        return self.generic_visit(node, context.blank())
+        block_context = context.blank()
+        for name in 'super', 'self':
+            block_context.undef(name)
+        return self.generic_visit(node, block_context)
 
-    def scoped_section(self, node, context):
+    def visit_For(self, node, context):
         context.push()
+        context.undef('loop')
         try:
             return self.generic_visit(node, context)
         finally:
             context.pop()
-    visit_For = visit_Macro = scoped_section
 
-    def visit_FilterBlock(self, node, context):
-        """Try to filter a block at compile time."""
-        node = self.generic_visit(node, context)
+    def visit_Macro(self, node, context):
         context.push()
+        for name in 'varargs', 'kwargs', 'caller':
+            context.undef(name)
+        try:
+            return self.generic_visit(node, context)
+        finally:
+            context.pop()
 
-        # check if we can evaluate the wrapper body into a string
-        # at compile time
-        buffer = []
-        for child in node.body:
-            if not isinstance(child, nodes.Output):
-                return node
-            for item in child.optimized_nodes():
-                if isinstance(item, nodes.Node):
-                    return node
-                buffer.append(item)
-
-        # now check if we can evaluate the filter at compile time.
+    def visit_CallBlock(self, node, context):
+        context.push()
+        for name in 'varargs', 'kwargs':
+            context.undef(name)
         try:
-            data = node.filter.as_const(concat(buffer))
-        except nodes.Impossible:
-            return node
+            return self.generic_visit(node, context)
+        finally:
+            context.pop()
 
-        context.pop()
-        const = nodes.Const(data, lineno=node.lineno)
-        return nodes.Output([const], lineno=node.lineno)
+    def visit_FilterBlock(self, node, context):
+        """Try to filter a block at compile time."""
+        context.push()
+        try:
+            return self.generic_visit(node, context)
+        finally:
+            context.pop()
 
     def visit_If(self, node, context):
         try:
@@ -125,8 +128,13 @@ class Optimizer(NodeTransformer):
         except nodes.Impossible:
             return self.generic_visit(node, context)
         if val:
-            return node.body
-        return node.else_
+            body = node.body
+        else:
+            body = node.else_
+        result = []
+        for node in body:
+            result.extend(self.visit_list(node, context))
+        return result
 
     def visit_Name(self, node, context):
         if node.ctx != 'load':
@@ -141,38 +149,6 @@ class Optimizer(NodeTransformer):
         except (KeyError, nodes.Impossible):
             return node
 
-    def visit_Assign(self, node, context):
-        try:
-            target = node.target = self.generic_visit(node.target, context)
-            value = self.generic_visit(node.node, context).as_const()
-        except nodes.Impossible:
-            return node
-
-        result = []
-        lineno = node.lineno
-        def walk(target, value):
-            if isinstance(target, nodes.Name):
-                const = nodes.Const.from_untrusted(value, lineno=lineno)
-                result.append(nodes.Assign(target, const, lineno=lineno))
-                context[target.name] = value
-            elif isinstance(target, nodes.Tuple):
-                try:
-                    value = tuple(value)
-                except TypeError:
-                    raise nodes.Impossible()
-                if len(target.items) != len(value):
-                    raise nodes.Impossible()
-                for name, val in zip(target.items, value):
-                    walk(name, val)
-            else:
-                raise AssertionError('unexpected assignable node')
-
-        try:
-            walk(target, value)
-        except nodes.Impossible:
-            return node
-        return result
-
     def visit_Import(self, node, context):
         rv = self.generic_visit(node, context)
         context.undef(node.target)
@@ -181,7 +157,10 @@ class Optimizer(NodeTransformer):
     def visit_FromImport(self, node, context):
         rv = self.generic_visit(node, context)
         for name in node.names:
-            context.undef(name)
+            if isinstance(name, tuple):
+                context.undef(name[1])
+            else:
+                context.undef(name)
         return rv
 
     def fold(self, node, context):
index 2b26fc7739dcca782021f8dcd2b1b45ed6ff65b4..8a0eb6a53ea0f845b330954a8b0d868accd36576 100644 (file)
@@ -220,6 +220,9 @@ class Parser(object):
         node = nodes.CallBlock(lineno=self.stream.expect('call').lineno)
         if self.stream.current.type is 'lparen':
             self.parse_signature(node)
+        else:
+            node.args = []
+            node.defaults = []
 
         node.call = self.parse_expression()
         if not isinstance(node.call, nodes.Call):
index 20fa098487c849e0079f7eed9774c9a77e42a3ef..5f8de1f1226acf5153583789693ab698cdab7a8b 100644 (file)
@@ -14,8 +14,8 @@ from jinja2.exceptions import UndefinedError
 
 
 # these variables are exported to the template runtime
-__all__ = ['LoopContext', 'TemplateContext', 'Macro', 'Markup', 'missing',
-           'concat']
+__all__ = ['LoopContext', 'TemplateContext', 'TemplateReference', 'Macro',
+           'Markup', 'missing', 'concat']
 
 
 # special singleton representing missing values for the runtime
@@ -31,6 +31,11 @@ class TemplateContext(object):
     not save to use this class outside of the compiled code.  For example
     update and other methods will not work as they seem (they don't update
     the exported variables for example).
+
+    The context is immutable.  Modifications on `parent` must not happen and
+    modifications on `vars` are allowed from generated template code.  However
+    functions that are passed the template context may not modify the context
+    in any way.
     """
 
     def __init__(self, environment, parent, name, blocks):
@@ -55,15 +60,17 @@ class TemplateContext(object):
 
     def super(self, name, current):
         """Render a parent block."""
-        last = None
-        for block in self.blocks[name]:
-            if block is current:
-                break
-            last = block
-        if last is None:
+        try:
+            blocks = self.blocks[name]
+            pos = blocks.index(current) - 1
+            if pos < 0:
+                raise IndexError()
+        except LookupError:
             return self.environment.undefined('there is no parent block '
-                                              'called %r.' % block)
-        return SuperBlock(block, self, last)
+                                              'called %r.' % name)
+        render = lambda: Markup(concat(blocks[pos](self)))
+        render.__name__ = render.name = name
+        return render
 
     def get(self, key, default=None):
         """For dict compatibility"""
@@ -75,8 +82,7 @@ class TemplateContext(object):
 
     def get_exported(self):
         """Get a new dict with the exported variables."""
-        return dict((k, self.vars[k]) for k in self.exported_vars
-                    if not k.startswith('__'))
+        return dict((k, self.vars[k]) for k in self.exported_vars)
 
     def get_root(self):
         """Return a new dict with all the non local variables."""
@@ -86,6 +92,11 @@ class TemplateContext(object):
         """Return a copy of the complete context as dict."""
         return dict(self.parent, **self.vars)
 
+    def clone(self):
+        """Return a copy of the context without the locals."""
+        return self.__class__(self.environment, self.parent,
+                              self.name, self.blocks)
+
     def __contains__(self, name):
         return name in self.vars or name in self.parent
 
@@ -104,21 +115,22 @@ class TemplateContext(object):
         )
 
 
-class SuperBlock(object):
-    """When called this renders a parent block."""
+class TemplateReference(object):
+    """The `self` in templates."""
 
-    def __init__(self, name, context, render_func):
-        self.name = name
-        self._context = context
-        self._render_func = render_func
+    def __init__(self, context):
+        self.__context = context
 
-    def __call__(self):
-        return Markup(concat(self._render_func(self._context)))
+    def __getitem__(self, name):
+        func = self.__context.blocks[name][-1]
+        render = lambda: Markup(concat(func(self.__context)))
+        render.__name__ = render.name = name
+        return render
 
     def __repr__(self):
         return '<%s %r>' % (
             self.__class__.__name__,
-            self.name
+            self._context.name
         )
 
 
@@ -168,7 +180,11 @@ class LoopContext(object):
         return self._length
 
     def __repr__(self):
-        return 'LoopContext(%r)' % self.index0
+        return '<%s %r/%r>' % (
+            self.__class__.__name__,
+            self.index,
+            self.length
+        )
 
 
 class Macro(object):
index 9a1b89dcf8b8c368b94713c0d96a4d3d132b5c8d..da3d2000dc2c139fb80b213f1da550dac4316563 100644 (file)
@@ -60,10 +60,12 @@ ROUND = '''{{ 2.7|round }}|{{ 2.1|round }}|\
 XMLATTR = '''{{ {'foo': 42, 'bar': 23, 'fish': none,
 'spam': missing, 'blub:blub': '<?>'}|xmlattr }}'''
 SORT = '''{{ [2, 3, 1]|sort }}|{{ [2, 3, 1]|sort(true) }}'''
-GROUPBY = '''{{ [{'foo': 1, 'bar': 2},
+GROUPBY = '''{% for grouper, list in [{'foo': 1, 'bar': 2},
                  {'foo': 2, 'bar': 3},
                  {'foo': 1, 'bar': 1},
-                 {'foo': 3, 'bar': 4}]|groupby('foo') }}'''
+                 {'foo': 3, 'bar': 4}]|groupby('foo') -%}
+{{ grouper }}: {{ list|join(', ') }}
+{% endfor %}'''
 FILTERTAG = '''{% filter upper|replace('FOO', 'foo') %}foobar{% endfilter %}'''
 
 
@@ -279,11 +281,11 @@ def test_sort(env):
 
 def test_groupby(env):
     tmpl = env.from_string(GROUPBY)
-    assert tmpl.render() == (
-        "[{'list': [{'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}], "
-        "'grouper': 1}, {'list': [{'foo': 2, 'bar': 3}], 'grouper': 2}, "
-        "{'list': [{'foo': 3, 'bar': 4}], 'grouper': 3}]"
-    )
+    assert tmpl.render().splitlines() == [
+        "1: {'foo': 1, 'bar': 2}, {'foo': 1, 'bar': 1}",
+        "2: {'foo': 2, 'bar': 3}",
+        "3: {'foo': 3, 'bar': 4}"
+    ]
 
 
 def test_filtertag(env):
index aa3546c4533d873f7b87f3c98ba01dd61bac59e4..74594d3754299da2eaab67d12e6ce901d45a9b52 100644 (file)
@@ -35,12 +35,12 @@ SIMPLECALL = '''\
 '''
 
 COMPLEXCALL = '''\
-{% macro test() %}[[{{ caller(data='data') }}]]{% endmacro %}\
-{% call test() %}{{ data }}{% endcall %}\
+{% macro test() %}[[{{ caller('data') }}]]{% endmacro %}\
+{% call(data) test() %}{{ data }}{% endcall %}\
 '''
 
 CALLERUNDEFINED = '''\
-{% set caller = 42 %}\
+{% caller = 42 %}\
 {% macro test() %}{{ caller is not defined }}{% endmacro %}\
 {{ test() }}\
 '''
@@ -84,5 +84,5 @@ def test_caller_undefined(env):
 
 
 def test_include(env):
-    tmpl = env.from_string('{% include "include" %}{{ test("foo") }}')
+    tmpl = env.from_string('{% from "include" import test %}{{ test("foo") }}')
     assert tmpl.render() == '[foo]'