From 3d8b784a7b443659574bbd35af385ab6ae4b5189 Mon Sep 17 00:00:00 2001 From: Armin Ronacher Date: Sun, 13 Apr 2008 13:16:50 +0200 Subject: [PATCH] added loop filtering --HG-- branch : trunk --- jinja2/compiler.py | 90 +++++++++++++++++++++++++++++++++++++++++---- jinja2/nodes.py | 2 +- jinja2/optimizer.py | 5 ++- jinja2/parser.py | 24 +++++++----- jinja2/runtime.py | 27 ++++++++------ test_loop_filter.py | 12 ++++++ 6 files changed, 131 insertions(+), 29 deletions(-) create mode 100644 test_loop_filter.py diff --git a/jinja2/compiler.py b/jinja2/compiler.py index e692184..e16617e 100644 --- a/jinja2/compiler.py +++ b/jinja2/compiler.py @@ -29,6 +29,14 @@ operators = { } +try: + exec '(0 if 0 else 0)' +except SyntaxError: + have_condexpr = False +else: + have_condexpr = True + + def generate(node, environment, filename, stream=None): """Generate the python source for a node tree.""" is_child = node.find(nodes.Extends) is not None @@ -114,6 +122,7 @@ class Frame(object): self.rootlevel = False self.parent = parent self.buffer = None + self.name_overrides = {} self.block = parent and parent.block or None if parent is not None: self.identifiers.declared.update( @@ -122,11 +131,13 @@ class Frame(object): parent.identifiers.declared_parameter ) self.buffer = parent.buffer + self.name_overrides = parent.name_overrides.copy() def copy(self): """Create a copy of the current one.""" rv = copy(self) rv.identifiers = copy(self.identifiers) + rv.name_overrides = self.name_overrides.copy() return rv def inspect(self, nodes, hard_scope=False): @@ -516,21 +527,62 @@ class CodeGenerator(NodeVisitor): aliases = self.collect_shadowed(loop_frame) self.pull_locals(loop_frame, True) - - self.newline(node) if node.else_: self.writeline('l_loop = None') - self.write('for ') + + self.newline(node) + self.writeline('for ') self.visit(node.target, loop_frame) self.write(extended_loop and ', l_loop in LoopContext(' or ' in ') - self.visit(node.iter, loop_frame) + + # the expression pointing to the parent loop. We make the + # undefined a bit more debug friendly at the same time. + parent_loop = 'loop' in aliases and aliases['loop'] \ + or "Undefined('loop', extra=%r)" % \ + 'the filter section of a loop as well as the ' \ + 'else block doesn\'t have access to the special ' \ + "'loop' variable of the current loop. Because " \ + 'there is no parent loop it\'s undefined.' + + # if we have an extened loop and a node test, we filter in the + # "outer frame". + if extended_loop and node.test is not None: + self.write('(') + self.visit(node.target, loop_frame) + self.write(' for ') + self.visit(node.target, loop_frame) + self.write(' in ') + self.visit(node.iter, loop_frame) + self.write(' if (') + test_frame = loop_frame.copy() + test_frame.name_overrides['loop'] = parent_loop + self.visit(node.test, test_frame) + self.write('))') + + else: + self.visit(node.iter, loop_frame) + if 'loop' in aliases: self.write(', ' + aliases['loop']) self.write(extended_loop and '):' or ':') + + # tests in not extended loops become a continue + if not extended_loop and node.test is not None: + self.indent() + self.writeline('if ') + self.visit(node.test) + self.write(':') + self.indent() + self.writeline('continue') + self.outdent(2) + self.blockvisit(node.body, loop_frame) if node.else_: self.writeline('if l_loop is None:') + self.indent() + self.writeline('l_loop = ' + parent_loop) + self.outdent() self.blockvisit(node.else_, loop_frame) # reset the aliases if there are any. @@ -667,7 +719,7 @@ class CodeGenerator(NodeVisitor): self.write('%s.append(' % frame.buffer) self.write(finalizer + '(') self.visit(item, frame) - self.write(')' * (1 + frame.buffer is not None)) + self.write(')' * (1 + (frame.buffer is not None))) # otherwise we create a format string as this is faster in that case else: @@ -721,8 +773,14 @@ class CodeGenerator(NodeVisitor): self.writeline('context[%r] = l_%s' % (name, name)) def visit_Name(self, node, frame): - if frame.toplevel and node.ctx == 'store': - frame.assigned_names.add(node.name) + if node.ctx == 'store': + if frame.toplevel: + frame.assigned_names.add(node.name) + frame.name_overrides.pop(node.name, None) + elif node.ctx == 'load': + if node.name in frame.name_overrides: + self.write(frame.name_overrides[node.name]) + return self.write('l_' + node.name) def visit_Const(self, node, frame): @@ -856,6 +914,24 @@ class CodeGenerator(NodeVisitor): self.signature(node, frame) self.write(')') + def visit_CondExpr(self, node, frame): + if not have_condexpr: + self.write('((') + self.visit(node.test, frame) + self.write(') and (') + self.visit(node.expr1, frame) + self.write(',) or (') + self.visit(node.expr2, frame) + self.write(',))[0]') + else: + self.write('(') + self.visit(node.expr1, frame) + self.write(' if ') + self.visit(node.test, frame) + self.write(' else ') + self.visit(node.expr2, frame) + self.write(')') + def visit_Call(self, node, frame, extra_kwargs=None): self.visit(node.node, frame) self.write('(') diff --git a/jinja2/nodes.py b/jinja2/nodes.py index e1857d2..89c6fa2 100644 --- a/jinja2/nodes.py +++ b/jinja2/nodes.py @@ -207,7 +207,7 @@ class Extends(Stmt): class For(Stmt): """A node that represents a for loop""" - fields = ('target', 'iter', 'body', 'else_', 'recursive') + fields = ('target', 'iter', 'body', 'else_', 'test') class If(Stmt): diff --git a/jinja2/optimizer.py b/jinja2/optimizer.py index d2550f4..5877c6f 100644 --- a/jinja2/optimizer.py +++ b/jinja2/optimizer.py @@ -131,6 +131,9 @@ class Optimizer(NodeTransformer): # we also don't want unrolling if macros are defined in it if node.find(nodes.Macro) is not None: raise TypeError() + # XXX: add support for loop test clauses in the optimizer + if node.test is not None: + raise TypeError() except (nodes.Impossible, TypeError): return self.generic_visit(node, context) @@ -156,7 +159,7 @@ class Optimizer(NodeTransformer): try: try: - for loop, item in LoopContext(iterable, parent, True): + for item, loop in LoopContext(iterable, parent, True): context['loop'] = loop.make_static() assign(node.target, item) result.extend(self.visit(n.copy(), context) diff --git a/jinja2/parser.py b/jinja2/parser.py index 2d04211..5185df3 100644 --- a/jinja2/parser.py +++ b/jinja2/parser.py @@ -110,18 +110,17 @@ class Parser(object): self.filename) target.set_ctx('store') self.stream.expect('in') - iter = self.parse_tuple() - if self.stream.current.type is 'recursive': + iter = self.parse_tuple(no_condexpr=True) + test = None + if self.stream.current.type is 'if': self.stream.next() - recursive = True - else: - recursive = False + test = self.parse_expression() body = self.parse_statements(('endfor', 'else')) if self.stream.next().type is 'endfor': else_ = [] else: else_ = self.parse_statements(('endfor',), drop_needle=True) - return nodes.For(target, iter, body, else_, False, lineno=lineno) + return nodes.For(target, iter, body, else_, test, lineno=lineno) def parse_if(self): """Parse an if construct.""" @@ -236,8 +235,10 @@ class Parser(object): self.end_statement() return node - def parse_expression(self): + def parse_expression(self, no_condexpr=False): """Parse an expression.""" + if no_condexpr: + return self.parse_or() return self.parse_condexpr() def parse_condexpr(self): @@ -417,14 +418,19 @@ class Parser(object): node = self.parse_postfix(node) return node - def parse_tuple(self, enforce=False, simplified=False): + def parse_tuple(self, enforce=False, simplified=False, no_condexpr=False): """ Parse multiple expressions into a tuple. This can also return just one expression which is not a tuple. If you want to enforce a tuple, pass it enforce=True (currently unused). """ lineno = self.stream.current.lineno - parse = simplified and self.parse_primary or self.parse_expression + if simplified: + parse = self.parse_primary + elif no_condexpr: + parse = lambda: self.parse_expression(no_condexpr=True) + else: + parse = self.parse_expression args = [] is_tuple = False while 1: diff --git a/jinja2/runtime.py b/jinja2/runtime.py index 238d4cf..676c1f5 100644 --- a/jinja2/runtime.py +++ b/jinja2/runtime.py @@ -25,7 +25,7 @@ def subscribe(obj, argument): except (AttributeError, UnicodeError): try: return obj[argument] - except LookupError: + except (TypeError, LookupError): return Undefined(obj, argument) @@ -138,8 +138,9 @@ class LoopContext(LoopContextBase): def __init__(self, iterable, parent=None, enforce_length=False): self._iterable = iterable + self._next = iter(iterable).next self._length = None - self.index0 = 0 + self.index0 = -1 self.parent = parent if enforce_length: len(self) @@ -152,9 +153,11 @@ class LoopContext(LoopContextBase): return StaticLoopContext(self.index0, self.length, parent) def __iter__(self): - for item in self._iterable: - yield self, item - self.index0 += 1 + return self + + def next(self): + self.index0 += 1 + return self._next(), self def __len__(self): if self._length is None: @@ -162,7 +165,8 @@ class LoopContext(LoopContextBase): length = len(self._iterable) except TypeError: self._iterable = tuple(self._iterable) - length = self.index0 + len(tuple(self._iterable)) + self._next = iter(self._iterable).next + length = len(tuple(self._iterable)) + self.index0 + 1 self._length = length return self._length @@ -217,12 +221,13 @@ class Macro(object): try: value = self.defaults[idx - arg_count] except IndexError: - value = Undefined(name) + value = Undefined(name, extra='parameter not provided') arguments['l_' + name] = value if self.caller: caller = kwargs.pop('caller', None) if caller is None: - caller = Undefined('caller') + caller = Undefined('caller', extra='The macro was called ' + 'from an expression and not a call block.') arguments['l_caller'] = caller if self.catch_all: arguments['l_arguments'] = kwargs @@ -238,14 +243,14 @@ class Macro(object): class Undefined(object): """The object for undefined values.""" - def __init__(self, name=None, attr=None): + def __init__(self, name=None, attr=None, extra=None): if attr is None: self._undefined_hint = '%r is undefined' % name - elif name is None: - self._undefined_hint = 'attribute %r is undefined' % name else: self._undefined_hint = 'attribute %r of %r is undefined' \ % (attr, name) + if extra is not None: + self._undefined_hint += ' (' + extra + ')' def fail(self, *args, **kwargs): raise TypeError(self._undefined_hint) diff --git a/test_loop_filter.py b/test_loop_filter.py new file mode 100644 index 0000000..64f32d3 --- /dev/null +++ b/test_loop_filter.py @@ -0,0 +1,12 @@ +from jinja2 import Environment + +tmpl = Environment().from_string("""\ + +{{ 1 if foo else 0 }} +""") + +print tmpl.render(foo=True) -- 2.26.2