From 81b881709326684091fecbf2d25999cbf48e805e Mon Sep 17 00:00:00 2001 From: Armin Ronacher Date: Wed, 9 Apr 2008 00:40:05 +0200 Subject: [PATCH] improved static optimizer --HG-- branch : trunk --- jinja2/optimizer.py | 132 +++++++++++++++++++++++++++++++++++--------- jinja2/runtime.py | 4 +- test_optimizer.py | 9 ++- 3 files changed, 116 insertions(+), 29 deletions(-) diff --git a/jinja2/optimizer.py b/jinja2/optimizer.py index e7f6ac7..c508727 100644 --- a/jinja2/optimizer.py +++ b/jinja2/optimizer.py @@ -25,54 +25,134 @@ from jinja2.visitor import NodeVisitor, NodeTransformer from jinja2.runtime import subscribe +class ContextStack(object): + """Simple compile time context implementation.""" + + def __init__(self, initial=None): + self.stack = [{}] + if initial is not None: + self.stack.insert(0, initial) + + def push(self): + self.stack.append({}) + + def pop(self): + self.stack.pop() + + def __getitem__(self, key): + for level in reversed(self.stack): + if key in level: + return level[key] + raise KeyError(key) + + def __setitem__(self, key, value): + self.stack[-1][key] = value + + def blank(self): + """Return a new context with nothing but the root scope.""" + return ContextStack(self.stack[0]) + + class Optimizer(NodeTransformer): - def __init__(self, environment, context={}): + def __init__(self, environment): self.environment = environment - self.context = context - def visit_Filter(self, node): + def visit_Filter(self, node, context): """Try to evaluate filters if possible.""" + # XXX: nonconstant arguments? not-called visitors? generic visit! try: - x = self.visit(node.node).as_const() + x = self.visit(node.node, context).as_const() except nodes.Impossible: - return node + return self.generic_visit(node, context) for filter in reversed(node.filters): # XXX: call filters with arguments x = self.environment.filters[filter.name](self.environment, x) # XXX: don't optimize context dependent filters return nodes.Const(x) - def visit_For(self, node): - """Loop unrolling for constant values.""" + def visit_For(self, node, context): + """Loop unrolling for iterable constant values.""" try: - iter = self.visit(node.iter).as_const() - except nodes.Impossible: - return node + iterable = iter(self.visit(node.iter, context).as_const()) + except (nodes.Impossible, TypeError): + return self.generic_visit(node, context) + context.push() result = [] + # XXX: tuple unpacking (for key, value in foo) target = node.target.name - for item in iter: - # XXX: take care of variable scopes - self.context[target] = item - result.extend(self.visit(n) for n in deepcopy(node.body)) + iterated = False + for item in iterable: + context[target] = item + result.extend(self.visit(n, context) for n in deepcopy(node.body)) + iterated = True + if not iterated and node.else_: + result.extend(self.visit(n, context) for n in deepcopy(node.else_)) + context.pop() return result - def visit_Name(self, node): - # XXX: take care of variable scopes! - if node.name not in self.context: + def visit_If(self, node, context): + try: + val = self.visit(node.test, context).as_const() + except nodes.Impossible: + return self.generic_visit(node, context) + if val: + return node.body + return node.else_ + + def visit_Name(self, node, context): + if node.ctx == 'load': + try: + return nodes.Const(context[node.name], lineno=node.lineno) + except KeyError: + pass + return node + + def visit_Assign(self, node, context): + try: + target = node.target = self.generic_visit(node.target, context) + value = self.generic_visit(node.node, context).as_const() + except nodes.Impossible: return node - return nodes.Const(self.context[node.name]) - def visit_Subscript(self, node): + result = [] + lineno = node.lineno + def walk(target, value): + if isinstance(target, nodes.Name): + const_value = nodes.Const(value, lineno=lineno) + result.append(nodes.Assign(target, const_value, lineno=lineno)) + context[target.name] = value + elif isinstance(target, nodes.Tuple): + try: + value = tuple(value) + except TypeError: + raise nodes.Impossible() + if len(target) != len(value): + raise nodes.Impossible() + for name, val in zip(target, value): + walk(name, val) + else: + raise AssertionError('unexpected assignable node') + try: - item = self.visit(node.node).as_const() - arg = self.visit(node.arg).as_const() + walk(target, value) except nodes.Impossible: return node - # XXX: what does the 3rd parameter mean? - return nodes.Const(subscribe(item, arg, None)) + return result + + def visit_Subscript(self, node, context): + if node.ctx == 'load': + try: + item = self.visit(node.node, context).as_const() + arg = self.visit(node.arg, context).as_const() + except nodes.Impossible: + return self.generic_visit(node, context) + return nodes.Const(subscribe(item, arg, 'load')) + return self.generic_visit(node, context) -def optimize(node, environment, context={}): - optimizer = Optimizer(environment, context=context) - return optimizer.visit(node) +def optimize(node, environment, context_hint=None): + """The context hint can be used to perform an static optimization + based on the context given.""" + optimizer = Optimizer(environment) + return optimizer.visit(node, ContextStack(context_hint)) diff --git a/jinja2/runtime.py b/jinja2/runtime.py index c1c34e1..5a9764e 100644 --- a/jinja2/runtime.py +++ b/jinja2/runtime.py @@ -24,8 +24,8 @@ def extends(template, namespace): def subscribe(obj, argument, undefined_factory): """Get an item or attribute of an object.""" try: - return getattr(obj, argument) - except AttributeError: + return getattr(obj, str(argument)) + except (AttributeError, UnicodeError): try: return obj[argument] except LookupError: diff --git a/test_optimizer.py b/test_optimizer.py index 06ee440..ddb0fa0 100644 --- a/test_optimizer.py +++ b/test_optimizer.py @@ -16,12 +16,19 @@ ast = env.parse(""" {% for forum in forums %} {{ readstatus(forum.id) }} {{ forum.id|e }} {{ forum.name|e }} {% endfor %} + + {% navigation = [('#foo', 'Foo'), ('#bar', 'Bar')] %} + """) print ast print print generate(ast, env, "foo.html") print -ast = optimize(ast, env, context={'forums': forums}) +ast = optimize(ast, env, context_hint={'forums': forums}) print ast print print generate(ast, env, "foo.html") -- 2.26.2