inheritance uses a less awkward hack for contexts now and subclassing templates is...
[jinja2.git] / jinja2 / utils.py
1 # -*- coding: utf-8 -*-
2 """
3     jinja2.utils
4     ~~~~~~~~~~~~
5
6     Utility functions.
7
8     :copyright: 2008 by Armin Ronacher.
9     :license: BSD, see LICENSE for more details.
10 """
11 import re
12 import string
13 from collections import deque
14 from copy import deepcopy
15 from itertools import imap
16
17
18 _word_split_re = re.compile(r'(\s+)')
19 _punctuation_re = re.compile(
20     '^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % (
21         '|'.join(imap(re.escape, ('(', '<', '&lt;'))),
22         '|'.join(imap(re.escape, ('.', ',', ')', '>', '\n', '&gt;')))
23     )
24 )
25 _simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
26
27
28 def contextfunction(f):
29     """Mark a callable as context callable.  A context callable is passed
30     the active context as first argument.
31     """
32     f.contextfunction = True
33     return f
34
35
36 def environmentfunction(f):
37     """Mark a callable as environment callable.  An environment callable is
38     passed the current environment as first argument.
39     """
40     f.environmentfunction = True
41     return f
42
43
44 def import_string(import_name, silent=False):
45     """Imports an object based on a string.  This use useful if you want to
46     use import paths as endpoints or something similar.  An import path can
47     be specified either in dotted notation (``xml.sax.saxutils.escape``)
48     or with a colon as object delimiter (``xml.sax.saxutils:escape``).
49
50     If the `silent` is True the return value will be `None` if the import
51     fails.
52
53     :return: imported object
54     """
55     try:
56         if ':' in import_name:
57             module, obj = import_name.split(':', 1)
58         elif '.' in import_name:
59             items = import_name.split('.')
60             module = '.'.join(items[:-1])
61             obj = items[-1]
62         else:
63             return __import__(import_name)
64         return getattr(__import__(module, None, None, [obj]), obj)
65     except (ImportError, AttributeError):
66         if not silent:
67             raise
68
69
70 def pformat(obj, verbose=False):
71     """Prettyprint an object.  Either use the `pretty` library or the
72     builtin `pprint`.
73     """
74     try:
75         from pretty import pretty
76         return pretty(obj, verbose=verbose)
77     except ImportError:
78         from pprint import pformat
79         return pformat(obj)
80
81
82 def urlize(text, trim_url_limit=None, nofollow=False):
83     """Converts any URLs in text into clickable links. Works on http://,
84     https:// and www. links. Links can have trailing punctuation (periods,
85     commas, close-parens) and leading punctuation (opening parens) and
86     it'll still do the right thing.
87
88     If trim_url_limit is not None, the URLs in link text will be limited
89     to trim_url_limit characters.
90
91     If nofollow is True, the URLs in link text will get a rel="nofollow"
92     attribute.
93     """
94     trim_url = lambda x, limit=trim_url_limit: limit is not None \
95                          and (x[:limit] + (len(x) >=limit and '...'
96                          or '')) or x
97     words = _word_split_re.split(text)
98     nofollow_attr = nofollow and ' rel="nofollow"' or ''
99     for i, word in enumerate(words):
100         match = _punctuation_re.match(word)
101         if match:
102             lead, middle, trail = match.groups()
103             if middle.startswith('www.') or (
104                 '@' not in middle and
105                 not middle.startswith('http://') and
106                 len(middle) > 0 and
107                 middle[0] in string.letters + string.digits and (
108                     middle.endswith('.org') or
109                     middle.endswith('.net') or
110                     middle.endswith('.com')
111                 )):
112                 middle = '<a href="http://%s"%s>%s</a>' % (middle,
113                     nofollow_attr, trim_url(middle))
114             if middle.startswith('http://') or \
115                middle.startswith('https://'):
116                 middle = '<a href="%s"%s>%s</a>' % (middle,
117                     nofollow_attr, trim_url(middle))
118             if '@' in middle and not middle.startswith('www.') and \
119                not ':' in middle and _simple_email_re.match(middle):
120                 middle = '<a href="mailto:%s">%s</a>' % (middle, middle)
121             if lead + middle + trail != word:
122                 words[i] = lead + middle + trail
123     return u''.join(words)
124
125
126 def generate_lorem_ipsum(n=5, html=True, min=20, max=100):
127     """Generate some lorem impsum for the template."""
128     from jinja2.constants import LOREM_IPSUM_WORDS
129     from random import choice, random, randrange
130     words = LOREM_IPSUM_WORDS.split()
131     result = []
132
133     for _ in xrange(n):
134         next_capitalized = True
135         last_comma = last_fullstop = 0
136         word = None
137         last = None
138         p = []
139
140         # each paragraph contains out of 20 to 100 words.
141         for idx, _ in enumerate(xrange(randrange(min, max))):
142             while True:
143                 word = choice(words)
144                 if word != last:
145                     last = word
146                     break
147             if next_capitalized:
148                 word = word.capitalize()
149                 next_capitalized = False
150             # add commas
151             if idx - randrange(3, 8) > last_comma:
152                 last_comma = idx
153                 last_fullstop += 2
154                 word += ','
155             # add end of sentences
156             if idx - randrange(10, 20) > last_fullstop:
157                 last_comma = last_fullstop = idx
158                 word += '.'
159                 next_capitalized = True
160             p.append(word)
161
162         # ensure that the paragraph ends with a dot.
163         p = u' '.join(p)
164         if p.endswith(','):
165             p = p[:-1] + '.'
166         elif not p.endswith('.'):
167             p += '.'
168         result.append(p)
169
170     if not html:
171         return u'\n\n'.join(result)
172     return Markup(u'\n'.join(u'<p>%s</p>' % escape(x) for x in result))
173
174
175 class Markup(unicode):
176     """Marks a string as being safe for inclusion in HTML/XML output without
177     needing to be escaped.  This implements the `__html__` interface a couple
178     of frameworks and web applications use.
179
180     The `escape` function returns markup objects so that double escaping can't
181     happen.  If you want to use autoescaping in Jinja just set the finalizer
182     of the environment to `escape`.
183     """
184     __slots__ = ()
185
186     def __html__(self):
187         return self
188
189     def __add__(self, other):
190         if hasattr(other, '__html__') or isinstance(other, basestring):
191             return self.__class__(unicode(self) + unicode(escape(other)))
192         return NotImplemented
193
194     def __radd__(self, other):
195         if hasattr(other, '__html__') or isinstance(other, basestring):
196             return self.__class__(unicode(escape(other)) + unicode(self))
197         return NotImplemented
198
199     def __mul__(self, num):
200         if not isinstance(num, (int, long)):
201             return NotImplemented
202         return self.__class__(unicode.__mul__(self, num))
203     __rmul__ = __mul__
204
205     def __mod__(self, arg):
206         if isinstance(arg, tuple):
207             arg = tuple(imap(_MarkupEscapeHelper, arg))
208         else:
209             arg = _MarkupEscapeHelper(arg)
210         return self.__class__(unicode.__mod__(self, arg))
211
212     def __repr__(self):
213         return '%s(%s)' % (
214             self.__class__.__name__,
215             unicode.__repr__(self)
216         )
217
218     def join(self, seq):
219         return self.__class__(unicode.join(self, imap(escape, seq)))
220     join.__doc__ = unicode.join.__doc__
221
222     def split(self, *args, **kwargs):
223         return map(self.__class__, unicode.split(self, *args, **kwargs))
224     split.__doc__ = unicode.split.__doc__
225
226     def rsplit(self, *args, **kwargs):
227         return map(self.__class__, unicode.rsplit(self, *args, **kwargs))
228     rsplit.__doc__ = unicode.rsplit.__doc__
229
230     def splitlines(self, *args, **kwargs):
231         return map(self.__class__, unicode.splitlines(self, *args, **kwargs))
232     splitlines.__doc__ = unicode.splitlines.__doc__
233
234     def make_wrapper(name):
235         orig = getattr(unicode, name)
236         def func(self, *args, **kwargs):
237             args = list(args)
238             for idx, arg in enumerate(args):
239                 if hasattr(arg, '__html__') or isinstance(arg, basestring):
240                     args[idx] = escape(arg)
241             for name, arg in kwargs.iteritems():
242                 if hasattr(arg, '__html__') or isinstance(arg, basestring):
243                     kwargs[name] = escape(arg)
244             return self.__class__(orig(self, *args, **kwargs))
245         func.__name__ = orig.__name__
246         func.__doc__ = orig.__doc__
247         return func
248     for method in '__getitem__', '__getslice__', 'capitalize', \
249                   'title', 'lower', 'upper', 'replace', 'ljust', \
250                   'rjust', 'lstrip', 'rstrip', 'partition', 'center', \
251                   'strip', 'translate', 'expandtabs', 'rpartition', \
252                   'swapcase', 'zfill':
253         locals()[method] = make_wrapper(method)
254     del method, make_wrapper
255
256
257 class _MarkupEscapeHelper(object):
258     """Helper for Markup.__mod__"""
259
260     def __init__(self, obj):
261         self.obj = obj
262
263     __getitem__ = lambda s, x: _MarkupEscapeHelper(s.obj[x])
264     __unicode__ = lambda s: unicode(escape(s.obj))
265     __str__ = lambda s: str(escape(s.obj))
266     __repr__ = lambda s: str(repr(escape(s.obj)))
267     __int__ = lambda s: int(s.obj)
268     __float__ = lambda s: float(s.obj)
269
270
271 class LRUCache(object):
272     """A simple LRU Cache implementation."""
273     # this is fast for small capacities (something around 200) but doesn't
274     # scale.  But as long as it's only used for the database connections in
275     # a non request fallback it's fine.
276
277     def __init__(self, capacity):
278         self.capacity = capacity
279         self._mapping = {}
280         self._queue = deque()
281
282         # alias all queue methods for faster lookup
283         self._popleft = self._queue.popleft
284         self._pop = self._queue.pop
285         if hasattr(self._queue, 'remove'):
286             self._remove = self._queue.remove
287         self._append = self._queue.append
288
289     def _remove(self, obj):
290         """Python 2.4 compatibility."""
291         for idx, item in enumerate(self._queue):
292             if item == obj:
293                 del self._queue[idx]
294                 break
295
296     def copy(self):
297         """Return an shallow copy of the instance."""
298         rv = self.__class__(self.capacity)
299         rv._mapping.update(self._mapping)
300         rv._queue = deque(self._queue)
301         return rv
302
303     def get(self, key, default=None):
304         """Return an item from the cache dict or `default`"""
305         if key in self:
306             return self[key]
307         return default
308
309     def setdefault(self, key, default=None):
310         """Set `default` if the key is not in the cache otherwise
311         leave unchanged. Return the value of this key.
312         """
313         if key in self:
314             return self[key]
315         self[key] = default
316         return default
317
318     def clear(self):
319         """Clear the cache."""
320         self._mapping.clear()
321         self._queue.clear()
322
323     def __contains__(self, key):
324         """Check if a key exists in this cache."""
325         return key in self._mapping
326
327     def __len__(self):
328         """Return the current size of the cache."""
329         return len(self._mapping)
330
331     def __repr__(self):
332         return '<%s %r>' % (
333             self.__class__.__name__,
334             self._mapping
335         )
336
337     def __getitem__(self, key):
338         """Get an item from the cache. Moves the item up so that it has the
339         highest priority then.
340
341         Raise an `KeyError` if it does not exist.
342         """
343         rv = self._mapping[key]
344         if self._queue[-1] != key:
345             self._remove(key)
346             self._append(key)
347         return rv
348
349     def __setitem__(self, key, value):
350         """Sets the value for an item. Moves the item up so that it
351         has the highest priority then.
352         """
353         if key in self._mapping:
354             self._remove(key)
355         elif len(self._mapping) == self.capacity:
356             del self._mapping[self._popleft()]
357         self._append(key)
358         self._mapping[key] = value
359
360     def __delitem__(self, key):
361         """Remove an item from the cache dict.
362         Raise an `KeyError` if it does not exist.
363         """
364         del self._mapping[key]
365         self._remove(key)
366
367     def __iter__(self):
368         """Iterate over all values in the cache dict, ordered by
369         the most recent usage.
370         """
371         return reversed(self._queue)
372
373     def __reversed__(self):
374         """Iterate over the values in the cache dict, oldest items
375         coming first.
376         """
377         return iter(self._queue)
378
379     __copy__ = copy
380
381
382 # we have to import it down here as the speedups module imports the
383 # markup type which is define above.
384 try:
385     from jinja2._speedups import escape, soft_unicode
386 except ImportError:
387     def escape(obj):
388         """Convert the characters &, <, >, and " in string s to HTML-safe
389         sequences. Use this if you need to display text that might contain
390         such characters in HTML.
391         """
392         if hasattr(obj, '__html__'):
393             return obj.__html__()
394         return Markup(unicode(obj)
395             .replace('&', '&amp;')
396             .replace('>', '&gt;')
397             .replace('<', '&lt;')
398             .replace('"', '&quot;')
399         )
400
401     def soft_unicode(s):
402         """Make a string unicode if it isn't already.  That way a markup
403         string is not converted back to unicode.
404         """
405         if not isinstance(s, unicode):
406             s = unicode(s)
407         return s
408
409
410 # partials
411 try:
412     from functools import partial
413 except ImportError:
414     class partial(object):
415         def __init__(self, _func, *args, **kwargs):
416             self._func = func
417             self._args = args
418             self._kwargs = kwargs
419         def __call__(self, *args, **kwargs):
420             kwargs.update(self._kwargs)
421             return self._func(*(self._args + args), **kwargs)