The solution would be a second syntax tree that has the scoping rules stored.
- :copyright: Copyright 2008 by Christoph Hack, Armin Ronacher.
- :license: GNU GPL.
+ :copyright: (c) 2010 by the Jinja Team.
+ :license: BSD.
"""
from jinja2 import nodes
-from jinja2.visitor import NodeVisitor, NodeTransformer
-from jinja2.runtime import LoopContext
+from jinja2.visitor import NodeTransformer
-def optimize(node, environment, context_hint=None):
+def optimize(node, environment):
"""The context hint can be used to perform an static optimization
based on the context given."""
optimizer = Optimizer(environment)
- return optimizer.visit(node, ContextStack(context_hint))
-
-
-class ContextStack(object):
- """Simple compile time context implementation."""
- undefined = object()
-
- def __init__(self, initial=None):
- self.stack = [{}]
- if initial is not None:
- self.stack.insert(0, initial)
-
- def push(self):
- self.stack.append({})
-
- def pop(self):
- self.stack.pop()
-
- def get(self, key, default=None):
- try:
- return self[key]
- except KeyError:
- return default
-
- def undef(self, name):
- if name in self:
- self[name] = self.undefined
-
- def __contains__(self, key):
- try:
- self[key]
- except KeyError:
- return False
- return True
-
- def __getitem__(self, key):
- for level in reversed(self.stack):
- if key in level:
- rv = level[key]
- if rv is self.undefined:
- raise KeyError(key)
- return rv
- raise KeyError(key)
-
- def __setitem__(self, key, value):
- self.stack[-1][key] = value
-
- def blank(self):
- """Return a new context with nothing but the root scope."""
- return ContextStack(self.stack[0])
+ return optimizer.visit(node)
class Optimizer(NodeTransformer):
def __init__(self, environment):
self.environment = environment
- def visit_Block(self, node, context):
- return self.generic_visit(node, context.blank())
-
- def scoped_section(self, node, context):
- context.push()
- try:
- return self.generic_visit(node, context)
- finally:
- context.pop()
- visit_For = visit_Macro = scoped_section
-
- def visit_FilterBlock(self, node, context):
- """Try to filter a block at compile time."""
- node = self.generic_visit(node, context)
- context.push()
-
- # check if we can evaluate the wrapper body into a string
- # at compile time
- buffer = []
- for child in node.body:
- if not isinstance(child, nodes.Output):
- return node
- for item in child.optimized_nodes():
- if isinstance(item, nodes.Node):
- return node
- buffer.append(item)
-
- # now check if we can evaluate the filter at compile time.
- try:
- data = node.filter.as_const(u''.join(buffer))
- except nodes.Impossible:
- return node
-
- context.pop()
- const = nodes.Const(data, lineno=node.lineno)
- return nodes.Output([const], lineno=node.lineno)
-
- def visit_If(self, node, context):
+ def visit_If(self, node):
+ """Eliminate dead code."""
+ # do not optimize ifs that have a block inside so that it doesn't
+ # break super().
+ if node.find(nodes.Block) is not None:
+ return self.generic_visit(node)
try:
- val = self.visit(node.test, context).as_const()
+ val = self.visit(node.test).as_const()
except nodes.Impossible:
- return self.generic_visit(node, context)
+ return self.generic_visit(node)
if val:
- return node.body
- return node.else_
-
- def visit_Name(self, node, context):
- if node.ctx != 'load':
- # something overwrote the variable, we can no longer use
- # the constant from the context
- context.undef(node.name)
- return node
- try:
- return nodes.Const.from_untrusted(context[node.name],
- lineno=node.lineno,
- environment=self.environment)
- except (KeyError, nodes.Impossible):
- return node
-
- def visit_Assign(self, node, context):
- try:
- target = node.target = self.generic_visit(node.target, context)
- value = self.generic_visit(node.node, context).as_const()
- except nodes.Impossible:
- return node
-
+ body = node.body
+ else:
+ body = node.else_
result = []
- lineno = node.lineno
- def walk(target, value):
- if isinstance(target, nodes.Name):
- const = nodes.Const.from_untrusted(value, lineno=lineno)
- result.append(nodes.Assign(target, const, lineno=lineno))
- context[target.name] = value
- elif isinstance(target, nodes.Tuple):
- try:
- value = tuple(value)
- except TypeError:
- raise nodes.Impossible()
- if len(target.items) != len(value):
- raise nodes.Impossible()
- for name, val in zip(target.items, value):
- walk(name, val)
- else:
- raise AssertionError('unexpected assignable node')
-
- try:
- walk(target, value)
- except nodes.Impossible:
- return node
+ for node in body:
+ result.extend(self.visit_list(node))
return result
- def visit_Import(self, node, context):
- rv = self.generic_visit(node, context)
- context.undef(node.target)
- return rv
-
- def visit_FromImport(self, node, context):
- rv = self.generic_visit(node, context)
- for name in node.names:
- context.undef(name)
- return rv
-
- def fold(self, node, context):
+ def fold(self, node):
"""Do constant folding."""
- node = self.generic_visit(node, context)
+ node = self.generic_visit(node)
try:
return nodes.Const.from_untrusted(node.as_const(),
lineno=node.lineno,
visit_Add = visit_Sub = visit_Mul = visit_Div = visit_FloorDiv = \
visit_Pow = visit_Mod = visit_And = visit_Or = visit_Pos = visit_Neg = \
- visit_Not = visit_Compare = visit_Subscript = visit_Call = \
+ visit_Not = visit_Compare = visit_Getitem = visit_Getattr = visit_Call = \
visit_Filter = visit_Test = visit_CondExpr = fold
del fold