readded support for recursive for-loops
authorArmin Ronacher <armin.ronacher@active-4.com>
Sun, 11 May 2008 21:21:16 +0000 (23:21 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Sun, 11 May 2008 21:21:16 +0000 (23:21 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/runtime.py

index df1a85eb094a99925ceb6ab6796f059a6b3bdbf3..4716de194d8d23dcb1d72b44cba2b575e2f3a5e5 100644 (file)
@@ -499,7 +499,8 @@ class CodeGenerator(NodeVisitor):
             self.writeline('%s = l_%s' % (ident, name))
         return aliases
 
-    def function_scoping(self, node, frame, children=None):
+    def function_scoping(self, node, frame, children=None,
+                         find_special=True):
         """In Jinja a few statements require the help of anonymous
         functions.  Those are currently macros and call blocks and in
         the future also recursive loops.  As there is currently
@@ -542,6 +543,10 @@ class CodeGenerator(NodeVisitor):
             func_frame.identifiers.declared
         )
 
+        # no special variables for this scope, abort early
+        if not find_special:
+            return func_frame
+
         func_frame.accesses_kwargs = False
         func_frame.accesses_varargs = False
         func_frame.accesses_caller = False
@@ -791,14 +796,35 @@ class CodeGenerator(NodeVisitor):
                     self.writeline('context.exported_vars.discard(%r)' % alias)
 
     def visit_For(self, node, frame):
-        loop_frame = frame.inner()
-        loop_frame.inspect(node.iter_child_nodes(exclude=('iter',)))
-        extended_loop = bool(node.else_) or \
+        # when calculating the nodes for the inner frame we have to exclude
+        # the iterator contents from it
+        children = node.iter_child_nodes(exclude=('iter',))
+
+        if node.recursive:
+            loop_frame = self.function_scoping(node, frame, children,
+                                               find_special=False)
+        else:
+            loop_frame = frame.inner()
+            loop_frame.inspect(children)
+
+        extended_loop = node.recursive or node.else_ or \
                         'loop' in loop_frame.identifiers.undeclared
         if extended_loop:
             loop_frame.identifiers.add_special('loop')
 
-        aliases = self.collect_shadowed(loop_frame)
+        # if we don't have an recursive loop we have to find the shadowed
+        # variables at that point
+        if not node.recursive:
+            aliases = self.collect_shadowed(loop_frame)
+
+        # otherwise we set up a buffer and add a function def
+        else:
+            loop_frame.buffer = buf = self.temporary_identifier()
+            self.writeline('def loop(reciter, loop_render_func):', node)
+            self.indent()
+            self.writeline('%s = []' % buf, node)
+            aliases = {}
+
         self.pull_locals(loop_frame)
         if node.else_:
             self.writeline('l_loop = None')
@@ -825,17 +851,25 @@ class CodeGenerator(NodeVisitor):
             self.write(' for ')
             self.visit(node.target, loop_frame)
             self.write(' in ')
-            self.visit(node.iter, loop_frame)
+            if node.recursive:
+                self.write('reciter')
+            else:
+                self.visit(node.iter, loop_frame)
             self.write(' if (')
             test_frame = loop_frame.copy()
             self.writeline('l_loop = ' + parent_loop)
             self.visit(node.test, test_frame)
             self.write('))')
 
+        elif node.recursive:
+            self.write('reciter')
         else:
             self.visit(node.iter, loop_frame)
 
-        self.write(extended_loop and '):' or ':')
+        if node.recursive:
+            self.write(', recurse=loop_render_func):')
+        else:
+            self.write(extended_loop and '):' or ':')
 
         # tests in not extended loops become a continue
         if not extended_loop and node.test is not None:
@@ -862,6 +896,23 @@ class CodeGenerator(NodeVisitor):
         for name, alias in aliases.iteritems():
             self.writeline('l_%s = %s' % (name, alias))
 
+        # if the node was recursive we have to return the buffer contents
+        # and start the iteration code
+        if node.recursive:
+            if self.environment.autoescape:
+                self.writeline('return Markup(concat(%s))' % buf)
+            else:
+                self.writeline('return concat(%s)' % buf)
+            self.outdent()
+            if frame.buffer is None:
+                self.writeline('yield loop(', node)
+            else:
+                self.writeline('%s.append(loop(' % frame.buffer, node)
+            self.visit(node.iter, frame)
+            self.write(', loop)')
+            if frame.buffer is not None:
+                self.write(')')
+
     def visit_If(self, node, frame):
         if_frame = frame.soft()
         self.writeline('if ', node)
@@ -1000,10 +1051,14 @@ class CodeGenerator(NodeVisitor):
                 body.append([const])
 
         # if we have less than 3 nodes or less than 6 and a buffer we
-        # yield or extend
+        # yield or extend/append
         if len(body) < 3 or (frame.buffer is not None and len(body) < 6):
             if frame.buffer is not None:
-                self.writeline('%s.extend((' % frame.buffer)
+                # for one item we append, for more we extend
+                if len(body) == 1:
+                    self.writeline('%s.append(' % frame.buffer)
+                else:
+                    self.writeline('%s.extend((' % frame.buffer)
             for item in body:
                 if isinstance(item, list):
                     val = repr(concat(item))
@@ -1027,7 +1082,8 @@ class CodeGenerator(NodeVisitor):
                     if frame.buffer is not None:
                         self.write(', ')
             if frame.buffer is not None:
-                self.write('))')
+                # close the open parentheses
+                self.write(len(body) == 1 and ')' or '))')
 
         # otherwise we create a format string as this is faster in that case
         else:
index 8fb3f0b74ebd485a7dc10c09d0265fb78c048552..d734182c3450abafbbcd4681cd123fa9eed4f10e 100644 (file)
@@ -178,10 +178,11 @@ class TemplateReference(object):
 class LoopContext(object):
     """A loop context for dynamic iteration."""
 
-    def __init__(self, iterable, enforce_length=False):
+    def __init__(self, iterable, enforce_length=False, recurse=None):
         self._iterable = iterable
         self._next = iter(iterable).next
         self._length = None
+        self._recurse = recurse
         self.index0 = -1
         if enforce_length:
             len(self)
@@ -204,6 +205,12 @@ class LoopContext(object):
     def __iter__(self):
         return self
 
+    def __call__(self, iterable):
+        if self._recurse is None:
+            raise TypeError('Tried to call non recursive loop.  Maybe you '
+                            'forgot the "recursive" keyword.')
+        return self._recurse(iterable, self._recurse)
+
     def next(self):
         self.index0 += 1
         return self._next(), self