Automated merge with ssh://team@pocoo.org/jinja2-main
[jinja2.git] / jinja2 / optimizer.py
1 # -*- coding: utf-8 -*-
2 """
3     jinja2.optimizer
4     ~~~~~~~~~~~~~~~~
5
6     This module tries to optimize template trees by:
7
8         * eliminating constant nodes
9         * evaluating filters and macros on constant nodes
10         * unroll loops on constant values
11         * replace variables which are already known (because they doesn't
12           change often and you want to prerender a template) with constants
13
14     After the optimation you will get a new, simplier template which can
15     be saved again for later rendering. But even if you don't want to
16     prerender a template, this module might speed up your templates a bit
17     if you are using a lot of constants.
18
19     :copyright: Copyright 2008 by Christoph Hack.
20     :license: GNU GPL.
21 """
22 from copy import deepcopy
23 from jinja2 import nodes
24 from jinja2.visitor import NodeVisitor, NodeTransformer
25 from jinja2.runtime import subscribe, LoopContext
26
27
28 class ContextStack(object):
29     """Simple compile time context implementation."""
30
31     def __init__(self, initial=None):
32         self.stack = [{}]
33         if initial is not None:
34             self.stack.insert(0, initial)
35
36     def push(self):
37         self.stack.append({})
38
39     def pop(self):
40         self.stack.pop()
41
42     def get(self, key, default=None):
43         try:
44             return self[key]
45         except KeyError:
46             return default
47
48     def __getitem__(self, key):
49         for level in reversed(self.stack):
50             if key in level:
51                 return level[key]
52         raise KeyError(key)
53
54     def __setitem__(self, key, value):
55         self.stack[-1][key] = value
56
57     def blank(self):
58         """Return a new context with nothing but the root scope."""
59         return ContextStack(self.stack[0])
60
61
62 class Optimizer(NodeTransformer):
63
64     def __init__(self, environment):
65         self.environment = environment
66
67     def visit_Block(self, node, context):
68         return self.generic_visit(node, context.blank())
69
70     def visit_Filter(self, node, context):
71         """Try to evaluate filters if possible."""
72         # XXX: nonconstant arguments?  not-called visitors?  generic visit!
73         try:
74             x = self.visit(node.node, context).as_const()
75         except nodes.Impossible:
76             return self.generic_visit(node, context)
77         for filter in reversed(node.filters):
78             # XXX: call filters with arguments
79             x = self.environment.filters[filter.name](x)
80             # XXX: don't optimize context dependent filters
81         try:
82             return nodes.Const.from_untrusted(x, lineno=node.lineno)
83         except nodes.Impossible:
84             return self.generic_visit(node)
85
86     def visit_For(self, node, context):
87         """Loop unrolling for iterable constant values."""
88         try:
89             iterable = self.visit(node.iter, context).as_const()
90             # we only unroll them if they have a length and are iterable
91             iter(iterable)
92             len(iterable)
93         except (nodes.Impossible, TypeError):
94             return self.generic_visit(node, context)
95
96         parent = context.get('loop')
97         context.push()
98         result = []
99         iterated = False
100
101         def assign(target, value):
102             if isinstance(target, nodes.Name):
103                 context[target.name] = value
104             elif isinstance(target, nodes.Tuple):
105                 try:
106                     value = tuple(value)
107                 except TypeError:
108                     raise nodes.Impossible()
109                 if len(target.items) != len(value):
110                     raise nodes.Impossible()
111                 for name, val in zip(target.items, value):
112                     assign(name, val)
113             else:
114                 raise AssertionError('unexpected assignable node')
115
116         try:
117             try:
118                 for loop, item in LoopContext(iterable, parent, True):
119                     context['loop'] = loop.make_static()
120                     assign(node.target, item)
121                     result.extend(self.visit(n, context)
122                                   for n in deepcopy(node.body))
123                     iterated = True
124                 if not iterated and node.else_:
125                     result.extend(self.visit(n, context)
126                                   for n in deepcopy(node.else_))
127             except nodes.Impossible:
128                 return node
129         finally:
130             context.pop()
131         return result
132
133     def visit_If(self, node, context):
134         try:
135             val = self.visit(node.test, context).as_const()
136         except nodes.Impossible:
137             return self.generic_visit(node, context)
138         if val:
139             return node.body
140         return node.else_
141
142     def visit_Name(self, node, context):
143         if node.ctx == 'load':
144             try:
145                 return nodes.Const.from_untrusted(context[node.name],
146                                                   lineno=node.lineno)
147             except (KeyError, nodes.Impossible):
148                 pass
149         return node
150
151     def visit_Assign(self, node, context):
152         try:
153             target = node.target = self.generic_visit(node.target, context)
154             value = self.generic_visit(node.node, context).as_const()
155         except nodes.Impossible:
156             return node
157
158         result = []
159         lineno = node.lineno
160         def walk(target, value):
161             if isinstance(target, nodes.Name):
162                 const = nodes.Const.from_untrusted(value, lineno=lineno)
163                 result.append(nodes.Assign(target, const, lineno=lineno))
164                 context[target.name] = value
165             elif isinstance(target, nodes.Tuple):
166                 try:
167                     value = tuple(value)
168                 except TypeError:
169                     raise nodes.Impossible()
170                 if len(target.items) != len(value):
171                     raise nodes.Impossible()
172                 for name, val in zip(target.items, value):
173                     walk(name, val)
174             else:
175                 raise AssertionError('unexpected assignable node')
176
177         try:
178             walk(target, value)
179         except nodes.Impossible:
180             return node
181         return result
182
183     def fold(self, node, context):
184         """Do constant folding."""
185         node = self.generic_visit(node, context)
186         try:
187             return nodes.Const.from_untrusted(node.as_const(),
188                                               lineno=node.lineno)
189         except nodes.Impossible:
190             return node
191     visit_Add = visit_Sub = visit_Mul = visit_Div = visit_FloorDiv = \
192     visit_Pow = visit_Mod = visit_And = visit_Or = visit_Pos = visit_Neg = \
193     visit_Not = visit_Compare = visit_Subscribt = visit_Call = fold
194     del fold
195
196
197 def optimize(node, environment, context_hint=None):
198     """The context hint can be used to perform an static optimization
199     based on the context given."""
200     optimizer = Optimizer(environment)
201     return optimizer.visit(node, ContextStack(context_hint))