from jinja2.runtime import subscribe
+class ContextStack(object):
+ """Simple compile time context implementation."""
+
+ def __init__(self, initial=None):
+ self.stack = [{}]
+ if initial is not None:
+ self.stack.insert(0, initial)
+
+ def push(self):
+ self.stack.append({})
+
+ def pop(self):
+ self.stack.pop()
+
+ def __getitem__(self, key):
+ for level in reversed(self.stack):
+ if key in level:
+ return level[key]
+ raise KeyError(key)
+
+ def __setitem__(self, key, value):
+ self.stack[-1][key] = value
+
+ def blank(self):
+ """Return a new context with nothing but the root scope."""
+ return ContextStack(self.stack[0])
+
+
class Optimizer(NodeTransformer):
- def __init__(self, environment, context={}):
+ def __init__(self, environment):
self.environment = environment
- self.context = context
- def visit_Filter(self, node):
+ def visit_Filter(self, node, context):
"""Try to evaluate filters if possible."""
+ # XXX: nonconstant arguments? not-called visitors? generic visit!
try:
- x = self.visit(node.node).as_const()
+ x = self.visit(node.node, context).as_const()
except nodes.Impossible:
- return node
+ return self.generic_visit(node, context)
for filter in reversed(node.filters):
# XXX: call filters with arguments
x = self.environment.filters[filter.name](self.environment, x)
# XXX: don't optimize context dependent filters
return nodes.Const(x)
- def visit_For(self, node):
- """Loop unrolling for constant values."""
+ def visit_For(self, node, context):
+ """Loop unrolling for iterable constant values."""
try:
- iter = self.visit(node.iter).as_const()
- except nodes.Impossible:
- return node
+ iterable = iter(self.visit(node.iter, context).as_const())
+ except (nodes.Impossible, TypeError):
+ return self.generic_visit(node, context)
+ context.push()
result = []
+ # XXX: tuple unpacking (for key, value in foo)
target = node.target.name
- for item in iter:
- # XXX: take care of variable scopes
- self.context[target] = item
- result.extend(self.visit(n) for n in deepcopy(node.body))
+ iterated = False
+ for item in iterable:
+ context[target] = item
+ result.extend(self.visit(n, context) for n in deepcopy(node.body))
+ iterated = True
+ if not iterated and node.else_:
+ result.extend(self.visit(n, context) for n in deepcopy(node.else_))
+ context.pop()
return result
- def visit_Name(self, node):
- # XXX: take care of variable scopes!
- if node.name not in self.context:
+ def visit_If(self, node, context):
+ try:
+ val = self.visit(node.test, context).as_const()
+ except nodes.Impossible:
+ return self.generic_visit(node, context)
+ if val:
+ return node.body
+ return node.else_
+
+ def visit_Name(self, node, context):
+ if node.ctx == 'load':
+ try:
+ return nodes.Const(context[node.name], lineno=node.lineno)
+ except KeyError:
+ pass
+ return node
+
+ def visit_Assign(self, node, context):
+ try:
+ target = node.target = self.generic_visit(node.target, context)
+ value = self.generic_visit(node.node, context).as_const()
+ except nodes.Impossible:
return node
- return nodes.Const(self.context[node.name])
- def visit_Subscript(self, node):
+ result = []
+ lineno = node.lineno
+ def walk(target, value):
+ if isinstance(target, nodes.Name):
+ const_value = nodes.Const(value, lineno=lineno)
+ result.append(nodes.Assign(target, const_value, lineno=lineno))
+ context[target.name] = value
+ elif isinstance(target, nodes.Tuple):
+ try:
+ value = tuple(value)
+ except TypeError:
+ raise nodes.Impossible()
+ if len(target) != len(value):
+ raise nodes.Impossible()
+ for name, val in zip(target, value):
+ walk(name, val)
+ else:
+ raise AssertionError('unexpected assignable node')
+
try:
- item = self.visit(node.node).as_const()
- arg = self.visit(node.arg).as_const()
+ walk(target, value)
except nodes.Impossible:
return node
- # XXX: what does the 3rd parameter mean?
- return nodes.Const(subscribe(item, arg, None))
+ return result
+
+ def visit_Subscript(self, node, context):
+ if node.ctx == 'load':
+ try:
+ item = self.visit(node.node, context).as_const()
+ arg = self.visit(node.arg, context).as_const()
+ except nodes.Impossible:
+ return self.generic_visit(node, context)
+ return nodes.Const(subscribe(item, arg, 'load'))
+ return self.generic_visit(node, context)
-def optimize(node, environment, context={}):
- optimizer = Optimizer(environment, context=context)
- return optimizer.visit(node)
+def optimize(node, environment, context_hint=None):
+ """The context hint can be used to perform an static optimization
+ based on the context given."""
+ optimizer = Optimizer(environment)
+ return optimizer.visit(node, ContextStack(context_hint))