merged
[jinja2.git] / jinja2 / optimizer.py
1 # -*- coding: utf-8 -*-
2 """
3     jinja2.optimizer
4     ~~~~~~~~~~~~~~~~
5
6     This module tries to optimize template trees by:
7
8         * eliminating constant nodes
9         * evaluating filters and macros on constant nodes
10         * unroll loops on constant values
11         * replace variables which are already known (because they doesn't
12           change often and you want to prerender a template) with constants
13
14     After the optimation you will get a new, simplier template which can
15     be saved again for later rendering. But even if you don't want to
16     prerender a template, this module might speed up your templates a bit
17     if you are using a lot of constants.
18
19     :copyright: Copyright 2008 by Christoph Hack.
20     :license: GNU GPL.
21 """
22 from copy import deepcopy
23 from jinja2 import nodes
24 from jinja2.visitor import NodeVisitor, NodeTransformer
25 from jinja2.runtime import subscribe, LoopContext
26
27
28 class ContextStack(object):
29     """Simple compile time context implementation."""
30
31     def __init__(self, initial=None):
32         self.stack = [{}]
33         if initial is not None:
34             self.stack.insert(0, initial)
35
36     def push(self):
37         self.stack.append({})
38
39     def pop(self):
40         self.stack.pop()
41
42     def get(self, key, default=None):
43         try:
44             return self[key]
45         except KeyError:
46             return default
47
48     def __getitem__(self, key):
49         for level in reversed(self.stack):
50             if key in level:
51                 return level[key]
52         raise KeyError(key)
53
54     def __setitem__(self, key, value):
55         self.stack[-1][key] = value
56
57     def blank(self):
58         """Return a new context with nothing but the root scope."""
59         return ContextStack(self.stack[0])
60
61
62 class Optimizer(NodeTransformer):
63
64     def __init__(self, environment):
65         self.environment = environment
66
67     def visit_Block(self, node, context):
68         return self.generic_visit(node, context.blank())
69
70     def visit_Filter(self, node, context):
71         """Try to evaluate filters if possible."""
72         # XXX: nonconstant arguments?  not-called visitors?  generic visit!
73         try:
74             x = self.visit(node.node, context).as_const()
75         except nodes.Impossible:
76             return self.generic_visit(node, context)
77         for filter in reversed(node.filters):
78             # XXX: call filters with arguments
79             x = self.environment.filters[filter.name](x)
80             # XXX: don't optimize context dependent filters
81         return nodes.Const(x)
82
83     def visit_For(self, node, context):
84         """Loop unrolling for iterable constant values."""
85         try:
86             iterable = iter(self.visit(node.iter, context).as_const())
87         except (nodes.Impossible, TypeError):
88             return self.generic_visit(node, context)
89
90         parent = context.get('loop')
91         context.push()
92         result = []
93         iterated = False
94
95         def assign(target, value):
96             if isinstance(target, nodes.Name):
97                 context[target.name] = value
98             elif isinstance(target, nodes.Tuple):
99                 try:
100                     value = tuple(value)
101                 except TypeError:
102                     raise nodes.Impossible()
103                 if len(target.items) != len(value):
104                     raise nodes.Impossible()
105                 for name, val in zip(target.items, value):
106                     assign(name, val)
107             else:
108                 raise AssertionError('unexpected assignable node')
109
110         # XXX: not covered cases:
111         #       - item is accessed by dynamic part in the iteration
112         try:
113             try:
114                 for loop, item in LoopContext(iterable, parent):
115                     context['loop'] = loop
116                     assign(node.target, item)
117                     result.extend(self.visit(n, context)
118                                   for n in deepcopy(node.body))
119                     iterated = True
120                 if not iterated and node.else_:
121                     result.extend(self.visit(n, context)
122                                   for n in deepcopy(node.else_))
123             except nodes.Impossible:
124                 return node
125         finally:
126             context.pop()
127         return result
128
129     def visit_If(self, node, context):
130         try:
131             val = self.visit(node.test, context).as_const()
132         except nodes.Impossible:
133             return self.generic_visit(node, context)
134         if val:
135             return node.body
136         return node.else_
137
138     def visit_Name(self, node, context):
139         if node.ctx == 'load':
140             try:
141                 return nodes.Const(context[node.name], lineno=node.lineno)
142             except KeyError:
143                 pass
144         return node
145
146     def visit_Assign(self, node, context):
147         try:
148             target = node.target = self.generic_visit(node.target, context)
149             value = self.generic_visit(node.node, context).as_const()
150         except nodes.Impossible:
151             return node
152
153         result = []
154         lineno = node.lineno
155         def walk(target, value):
156             if isinstance(target, nodes.Name):
157                 const_value = nodes.Const(value, lineno=lineno)
158                 result.append(nodes.Assign(target, const_value, lineno=lineno))
159                 context[target.name] = value
160             elif isinstance(target, nodes.Tuple):
161                 try:
162                     value = tuple(value)
163                 except TypeError:
164                     raise nodes.Impossible()
165                 if len(target.items) != len(value):
166                     raise nodes.Impossible()
167                 for name, val in zip(target.items, value):
168                     walk(name, val)
169             else:
170                 raise AssertionError('unexpected assignable node')
171
172         try:
173             walk(target, value)
174         except nodes.Impossible:
175             return node
176         return result
177
178     def fold(self, node, context):
179         """Do constant folding."""
180         node = self.generic_visit(node, context)
181         try:
182             return nodes.Const(node.as_const(), lineno=node.lineno)
183         except nodes.Impossible:
184             return node
185     visit_Add = visit_Sub = visit_Mul = visit_Div = visit_FloorDiv = \
186     visit_Pow = visit_Mod = visit_And = visit_Or = visit_Pos = visit_Neg = \
187     visit_Not = visit_Compare = fold
188
189     def visit_Subscript(self, node, context):
190         if node.ctx == 'load':
191             try:
192                 item = self.visit(node.node, context).as_const()
193                 arg = self.visit(node.arg, context).as_const()
194             except nodes.Impossible:
195                 return self.generic_visit(node, context)
196             return nodes.Const(subscribe(item, arg, 'load'))
197         return self.generic_visit(node, context)
198
199
200 def optimize(node, environment, context_hint=None):
201     """The context hint can be used to perform an static optimization
202     based on the context given."""
203     optimizer = Optimizer(environment)
204     return optimizer.visit(node, ContextStack(context_hint))