fixed a few bugs from the unittests
[jinja2.git] / jinja2 / utils.py
1 # -*- coding: utf-8 -*-
2 """
3     jinja2.utils
4     ~~~~~~~~~~~~
5
6     Utility functions.
7
8     :copyright: 2008 by Armin Ronacher.
9     :license: BSD, see LICENSE for more details.
10 """
11 import re
12 import string
13 from collections import deque
14 from copy import deepcopy
15 from functools import update_wrapper
16 from itertools import imap
17
18
19 _word_split_re = re.compile(r'(\s+)')
20 _punctuation_re = re.compile(
21     '^(?P<lead>(?:%s)*)(?P<middle>.*?)(?P<trail>(?:%s)*)$' % (
22         '|'.join(imap(re.escape, ('(', '<', '&lt;'))),
23         '|'.join(imap(re.escape, ('.', ',', ')', '>', '\n', '&gt;')))
24     )
25 )
26 _simple_email_re = re.compile(r'^\S+@[a-zA-Z0-9._-]+\.[a-zA-Z0-9._-]+$')
27
28
29 def soft_unicode(s):
30     """Make a string unicode if it isn't already.  That way a markup
31     string is not converted back to unicode.
32     """
33     if not isinstance(s, unicode):
34         s = unicode(s)
35     return s
36
37
38 def pformat(obj, verbose=False):
39     """Prettyprint an object.  Either use the `pretty` library or the
40     builtin `pprint`.
41     """
42     try:
43         from pretty import pretty
44         return pretty(obj, verbose=verbose)
45     except ImportError:
46         from pprint import pformat
47         return pformat(obj)
48
49
50 def urlize(text, trim_url_limit=None, nofollow=False):
51     """Converts any URLs in text into clickable links. Works on http://,
52     https:// and www. links. Links can have trailing punctuation (periods,
53     commas, close-parens) and leading punctuation (opening parens) and
54     it'll still do the right thing.
55
56     If trim_url_limit is not None, the URLs in link text will be limited
57     to trim_url_limit characters.
58
59     If nofollow is True, the URLs in link text will get a rel="nofollow"
60     attribute.
61     """
62     trim_url = lambda x, limit=trim_url_limit: limit is not None \
63                          and (x[:limit] + (len(x) >=limit and '...'
64                          or '')) or x
65     words = _word_split_re.split(text)
66     nofollow_attr = nofollow and ' rel="nofollow"' or ''
67     for i, word in enumerate(words):
68         match = _punctuation_re.match(word)
69         if match:
70             lead, middle, trail = match.groups()
71             if middle.startswith('www.') or (
72                 '@' not in middle and
73                 not middle.startswith('http://') and
74                 len(middle) > 0 and
75                 middle[0] in string.letters + string.digits and (
76                     middle.endswith('.org') or
77                     middle.endswith('.net') or
78                     middle.endswith('.com')
79                 )):
80                 middle = '<a href="http://%s"%s>%s</a>' % (middle,
81                     nofollow_attr, trim_url(middle))
82             if middle.startswith('http://') or \
83                middle.startswith('https://'):
84                 middle = '<a href="%s"%s>%s</a>' % (middle,
85                     nofollow_attr, trim_url(middle))
86             if '@' in middle and not middle.startswith('www.') and \
87                not ':' in middle and _simple_email_re.match(middle):
88                 middle = '<a href="mailto:%s">%s</a>' % (middle, middle)
89             if lead + middle + trail != word:
90                 words[i] = lead + middle + trail
91     return u''.join(words)
92
93
94 class Markup(unicode):
95     """Marks a string as being safe for inclusion in HTML/XML output without
96     needing to be escaped.  This implements the `__html__` interface a couple
97     of frameworks and web applications use.
98
99     The `escape` function returns markup objects so that double escaping can't
100     happen.  If you want to use autoescaping in Jinja just set the finalizer
101     of the environment to `escape`.
102     """
103
104     __slots__ = ()
105
106     def __html__(self):
107         return self
108
109     def __add__(self, other):
110         if hasattr(other, '__html__') or isinstance(other, basestring):
111             return self.__class__(unicode(self) + unicode(escape(other)))
112         return NotImplemented
113
114     def __radd__(self, other):
115         if hasattr(other, '__html__') or isinstance(other, basestring):
116             return self.__class__(unicode(escape(other)) + unicode(self))
117         return NotImplemented
118
119     def __mul__(self, num):
120         if not isinstance(num, (int, long)):
121             return NotImplemented
122         return self.__class__(unicode.__mul__(self, num))
123     __rmul__ = __mul__
124
125     def __mod__(self, arg):
126         if isinstance(arg, tuple):
127             arg = tuple(imap(_MarkupEscapeHelper, arg))
128         else:
129             arg = _MarkupEscapeHelper(arg)
130         return self.__class__(unicode.__mod__(self, arg))
131
132     def __repr__(self):
133         return '%s(%s)' % (
134             self.__class__.__name__,
135             unicode.__repr__(self)
136         )
137
138     def join(self, seq):
139         return self.__class__(unicode.join(self, imap(escape, seq)))
140
141     def split(self, *args, **kwargs):
142         return map(self.__class__, unicode.split(self, *args, **kwargs))
143
144     def rsplit(self, *args, **kwargs):
145         return map(self.__class__, unicode.rsplit(self, *args, **kwargs))
146
147     def splitlines(self, *args, **kwargs):
148         return map(self.__class__, unicode.splitlines(self, *args, **kwargs))
149
150     def make_wrapper(name):
151         orig = getattr(unicode, name)
152         def func(self, *args, **kwargs):
153             args = list(args)
154             for idx, arg in enumerate(args):
155                 if hasattr(arg, '__html__') or isinstance(arg, basestring):
156                     args[idx] = escape(arg)
157             for name, arg in kwargs.iteritems():
158                 if hasattr(arg, '__html__') or isinstance(arg, basestring):
159                     kwargs[name] = escape(arg)
160             return self.__class__(orig(self, *args, **kwargs))
161         return update_wrapper(func, orig, ('__name__', '__doc__'))
162     for method in '__getitem__', '__getslice__', 'capitalize', \
163                   'title', 'lower', 'upper', 'replace', 'ljust', \
164                   'rjust', 'lstrip', 'rstrip', 'partition', 'center', \
165                   'strip', 'translate', 'expandtabs', 'rpartition', \
166                   'swapcase', 'zfill':
167         locals()[method] = make_wrapper(method)
168     del method, make_wrapper
169
170
171 class _MarkupEscapeHelper(object):
172     """Helper for Markup.__mod__"""
173
174     def __init__(self, obj):
175         self.obj = obj
176
177     __getitem__ = lambda s, x: _MarkupEscapeHelper(s.obj[x])
178     __unicode__ = lambda s: unicode(escape(s.obj))
179     __str__ = lambda s: str(escape(s.obj))
180     __repr__ = lambda s: str(repr(escape(s.obj)))
181     __int__ = lambda s: int(s.obj)
182     __float__ = lambda s: float(s.obj)
183
184
185 class LRUCache(object):
186     """A simple LRU Cache implementation."""
187     # this is fast for small capacities (something around 200) but doesn't
188     # scale.  But as long as it's only used for the database connections in
189     # a non request fallback it's fine.
190
191     def __init__(self, capacity):
192         self.capacity = capacity
193         self._mapping = {}
194         self._queue = deque()
195
196         # alias all queue methods for faster lookup
197         self._popleft = self._queue.popleft
198         self._pop = self._queue.pop
199         if hasattr(self._queue, 'remove'):
200             self._remove = self._queue.remove
201         self._append = self._queue.append
202
203     def _remove(self, obj):
204         """Python 2.4 compatibility."""
205         for idx, item in enumerate(self._queue):
206             if item == obj:
207                 del self._queue[idx]
208                 break
209
210     def copy(self):
211         """Return an shallow copy of the instance."""
212         rv = self.__class__(self.capacity)
213         rv._mapping.update(self._mapping)
214         rv._queue = deque(self._queue)
215         return rv
216
217     def get(self, key, default=None):
218         """Return an item from the cache dict or `default`"""
219         if key in self:
220             return self[key]
221         return default
222
223     def setdefault(self, key, default=None):
224         """Set `default` if the key is not in the cache otherwise
225         leave unchanged. Return the value of this key.
226         """
227         if key in self:
228             return self[key]
229         self[key] = default
230         return default
231
232     def clear(self):
233         """Clear the cache."""
234         self._mapping.clear()
235         self._queue.clear()
236
237     def __contains__(self, key):
238         """Check if a key exists in this cache."""
239         return key in self._mapping
240
241     def __len__(self):
242         """Return the current size of the cache."""
243         return len(self._mapping)
244
245     def __repr__(self):
246         return '<%s %r>' % (
247             self.__class__.__name__,
248             self._mapping
249         )
250
251     def __getitem__(self, key):
252         """Get an item from the cache. Moves the item up so that it has the
253         highest priority then.
254
255         Raise an `KeyError` if it does not exist.
256         """
257         rv = self._mapping[key]
258         if self._queue[-1] != key:
259             self._remove(key)
260             self._append(key)
261         return rv
262
263     def __setitem__(self, key, value):
264         """Sets the value for an item. Moves the item up so that it
265         has the highest priority then.
266         """
267         if key in self._mapping:
268             self._remove(key)
269         elif len(self._mapping) == self.capacity:
270             del self._mapping[self._popleft()]
271         self._append(key)
272         self._mapping[key] = value
273
274     def __delitem__(self, key):
275         """Remove an item from the cache dict.
276         Raise an `KeyError` if it does not exist.
277         """
278         del self._mapping[key]
279         self._remove(key)
280
281     def __iter__(self):
282         """Iterate over all values in the cache dict, ordered by
283         the most recent usage.
284         """
285         return reversed(self._queue)
286
287     def __reversed__(self):
288         """Iterate over the values in the cache dict, oldest items
289         coming first.
290         """
291         return iter(self._queue)
292
293     __copy__ = copy
294
295
296 # we have to import it down here as the speedups module imports the
297 # markup type which is define above.
298 try:
299     from jinja2._speedups import escape
300 except ImportError:
301     def escape(obj):
302         """Convert the characters &, <, >, and " in string s to HTML-safe
303         sequences. Use this if you need to display text that might contain
304         such characters in HTML.
305         """
306         if hasattr(obj, '__html__'):
307             return obj.__html__()
308         return Markup(unicode(obj)
309             .replace('&', '&amp;')
310             .replace('>', '&gt;')
311             .replace('<', '&lt;')
312             .replace('"', '&quot;')
313         )