live up to @mitsuhiko's ridiculous expectations
[jinja2.git] / jinja2 / filters.py
index e8ab29681303ec772b21a722ea941e500d6ec9c5..69e67e2913974f6fe7a3f5d088d2c266bcc8de21 100644 (file)
@@ -5,52 +5,91 @@
 
     Bundled jinja filters.
 
-    :copyright: 2008 by Armin Ronacher, Christoph Hack.
+    :copyright: (c) 2010 by the Jinja Team.
     :license: BSD, see LICENSE for more details.
 """
 import re
 import math
-import textwrap
+import urllib
 from random import choice
 from operator import itemgetter
 from itertools import imap, groupby
 from jinja2.utils import Markup, escape, pformat, urlize, soft_unicode
 from jinja2.runtime import Undefined
-from jinja2.exceptions import FilterArgumentError, SecurityError
+from jinja2.exceptions import FilterArgumentError
 
 
-_word_re = re.compile(r'\w+')
+_word_re = re.compile(r'\w+(?u)')
 
 
 def contextfilter(f):
     """Decorator for marking context dependent filters. The current
     :class:`Context` will be passed as first argument.
     """
-    if getattr(f, 'environmentfilter', False):
-        raise TypeError('filter already marked as environment filter')
     f.contextfilter = True
     return f
 
 
+def evalcontextfilter(f):
+    """Decorator for marking eval-context dependent filters.  An eval
+    context object is passed as first argument.  For more information
+    about the eval context, see :ref:`eval-context`.
+
+    .. versionadded:: 2.4
+    """
+    f.evalcontextfilter = True
+    return f
+
+
 def environmentfilter(f):
     """Decorator for marking evironment dependent filters.  The current
     :class:`Environment` is passed to the filter as first argument.
     """
-    if getattr(f, 'contextfilter', False):
-        raise TypeError('filter already marked as context filter')
     f.environmentfilter = True
     return f
 
 
+def make_attrgetter(environment, attribute):
+    """Returns a callable that looks up the given attribute from a
+    passed object with the rules of the environment.  Dots are allowed
+    to access attributes of attributes.
+    """
+    if not isinstance(attribute, basestring) or '.' not in attribute:
+        return lambda x: environment.getitem(x, attribute)
+    attribute = attribute.split('.')
+    def attrgetter(item):
+        for part in attribute:
+            item = environment.getitem(item, part)
+        return item
+    return attrgetter
+
+
 def do_forceescape(value):
     """Enforce HTML escaping.  This will probably double escape variables."""
     if hasattr(value, '__html__'):
         value = value.__html__()
     return escape(unicode(value))
 
+def do_urlescape(value):
+    """Escape strings for use in URLs (uses UTF-8 encoding)."""
+    def utf8(o):
+        return unicode(o).encode('utf8')
+    
+    if isinstance(value, basestring):
+        return urllib.quote(utf8(value))
+    
+    if hasattr(value, 'items'):
+        # convert dictionaries to list of 2-tuples
+        value = value.items()
+    
+    if hasattr(value, 'next'):
+        # convert generators to list
+        value = list(value)
+    
+    return urllib.urlencode([(utf8(k), utf8(v)) for (k, v) in value])
 
-@environmentfilter
-def do_replace(environment, s, old, new, count=None):
+@evalcontextfilter
+def do_replace(eval_ctx, s, old, new, count=None):
     """Return a copy of the value with all occurrences of a substring
     replaced with a new one. The first argument is the substring
     that should be replaced, the second is the replacement string.
@@ -67,7 +106,7 @@ def do_replace(environment, s, old, new, count=None):
     """
     if count is None:
         count = -1
-    if not environment.autoescape:
+    if not eval_ctx.autoescape:
         return unicode(s).replace(unicode(old), unicode(new), count)
     if hasattr(old, '__html__') or hasattr(new, '__html__') and \
        not hasattr(s, '__html__'):
@@ -87,8 +126,8 @@ def do_lower(s):
     return soft_unicode(s).lower()
 
 
-@environmentfilter
-def do_xmlattr(_environment, d, autospace=True):
+@evalcontextfilter
+def do_xmlattr(_eval_ctx, d, autospace=True):
     """Create an SGML/XML attribute string based on the items in a dict.
     All values that are neither `none` nor `undefined` are automatically
     escaped:
@@ -118,7 +157,7 @@ def do_xmlattr(_environment, d, autospace=True):
     )
     if autospace and rv:
         rv = u' ' + rv
-    if _environment.autoescape:
+    if _eval_ctx.autoescape:
         rv = Markup(rv)
     return rv
 
@@ -138,7 +177,7 @@ def do_title(s):
 
 
 def do_dictsort(value, case_sensitive=False, by='key'):
-    """ Sort a dict and yield (key, value) pairs. Because python dicts are
+    """Sort a dict and yield (key, value) pairs. Because python dicts are
     unsorted you may want to use this function to order them by either
     key or value:
 
@@ -163,15 +202,55 @@ def do_dictsort(value, case_sensitive=False, by='key'):
                                   '"key" or "value"')
     def sort_func(item):
         value = item[pos]
-        if isinstance(value, basestring):
-            value = unicode(value)
-            if not case_sensitive:
-                value = value.lower()
+        if isinstance(value, basestring) and not case_sensitive:
+            value = value.lower()
         return value
 
     return sorted(value.items(), key=sort_func)
 
 
+@environmentfilter
+def do_sort(environment, value, reverse=False, case_sensitive=False,
+            attribute=None):
+    """Sort an iterable.  Per default it sorts ascending, if you pass it
+    true as first argument it will reverse the sorting.
+
+    If the iterable is made of strings the third parameter can be used to
+    control the case sensitiveness of the comparison which is disabled by
+    default.
+
+    .. sourcecode:: jinja
+
+        {% for item in iterable|sort %}
+            ...
+        {% endfor %}
+
+    It is also possible to sort by an attribute (for example to sort
+    by the date of an object) by specifying the `attribute` parameter:
+
+    .. sourcecode:: jinja
+
+        {% for item in iterable|sort(attribute='date') %}
+            ...
+        {% endfor %}
+
+    .. versionchanged:: 2.6
+       The `attribute` parameter was added.
+    """
+    if not case_sensitive:
+        def sort_func(item):
+            if isinstance(item, basestring):
+                item = item.lower()
+            return item
+    else:
+        sort_func = None
+    if attribute is not None:
+        getter = make_attrgetter(environment, attribute)
+        def sort_func(item, processor=sort_func or (lambda x: x)):
+            return processor(getter(item))
+    return sorted(value, key=sort_func, reverse=reverse)
+
+
 def do_default(value, default_value=u'', boolean=False):
     """If the value is undefined it will return the passed default value,
     otherwise the value of the variable:
@@ -194,8 +273,8 @@ def do_default(value, default_value=u'', boolean=False):
     return value
 
 
-@environmentfilter
-def do_join(environment, value, d=u''):
+@evalcontextfilter
+def do_join(eval_ctx, value, d=u'', attribute=None):
     """Return a string which is the concatenation of the strings in the
     sequence. The separator between elements is an empty string per
     default, you can define it with the optional parameter:
@@ -207,9 +286,21 @@ def do_join(environment, value, d=u''):
 
         {{ [1, 2, 3]|join }}
             -> 123
+
+    It is also possible to join certain attributes of an object:
+
+    .. sourcecode:: jinja
+
+        {{ users|join(', ', attribute='username') }}
+
+    .. versionadded:: 2.6
+       The `attribute` parameter was added.
     """
+    if attribute is not None:
+        value = imap(make_attrgetter(eval_ctx.environment, attribute), value)
+
     # no automatic escaping?  joining is a lot eaiser then
-    if not environment.autoescape:
+    if not eval_ctx.autoescape:
         return unicode(d).join(imap(unicode, value))
 
     # if the delimiter doesn't have an html representation we check
@@ -264,23 +355,34 @@ def do_random(environment, seq):
         return environment.undefined('No random item, sequence was empty.')
 
 
-def do_filesizeformat(value):
-    """Format the value like a 'human-readable' file size (i.e. 13 KB,
-    4.1 MB, 102 bytes, etc).
+def do_filesizeformat(value, binary=False):
+    """Format the value like a 'human-readable' file size (i.e. 13 kB,
+    4.1 MB, 102 Bytes, etc).  Per default decimal prefixes are used (Mega,
+    Giga, etc.), if the second parameter is set to `True` the binary
+    prefixes are used (Mebi, Gibi).
     """
-    # fail silently
-    try:
-        bytes = float(value)
-    except TypeError:
-        bytes = 0
-
-    if bytes < 1024:
-        return "%d Byte%s" % (bytes, bytes != 1 and 's' or '')
-    elif bytes < 1024 * 1024:
-        return "%.1f KB" % (bytes / 1024)
-    elif bytes < 1024 * 1024 * 1024:
-        return "%.1f MB" % (bytes / (1024 * 1024))
-    return "%.1f GB" % (bytes / (1024 * 1024 * 1024))
+    bytes = float(value)
+    base = binary and 1024 or 1000
+    prefixes = [
+        (binary and 'KiB' or 'kB'),
+        (binary and 'MiB' or 'MB'),
+        (binary and 'GiB' or 'GB'),
+        (binary and 'TiB' or 'TB'),
+        (binary and 'PiB' or 'PB'),
+        (binary and 'EiB' or 'EB'),
+        (binary and 'ZiB' or 'ZB'),
+        (binary and 'YiB' or 'YB')
+    ]
+    if bytes == 1:
+        return '1 Byte'
+    elif bytes < base:
+        return '%d Bytes' % bytes
+    else:
+        for i, prefix in enumerate(prefixes):
+            unit = base ** (i + 2)
+            if bytes < unit:
+                return '%.1f %s' % ((base * bytes / unit), prefix)
+        return '%.1f %s' % ((base * bytes / unit), prefix)
 
 
 def do_pprint(value, verbose=False):
@@ -292,8 +394,8 @@ def do_pprint(value, verbose=False):
     return pformat(value, verbose=verbose)
 
 
-@environmentfilter
-def do_urlize(environment, value, trim_url_limit=None, nofollow=False):
+@evalcontextfilter
+def do_urlize(eval_ctx, value, trim_url_limit=None, nofollow=False):
     """Converts URLs in plain text into clickable links.
 
     If you pass the filter an additional integer it will shorten the urls
@@ -305,8 +407,8 @@ def do_urlize(environment, value, trim_url_limit=None, nofollow=False):
         {{ mytext|urlize(40, true) }}
             links are shortened to 40 chars and defined with rel="nofollow"
     """
-    rv = urlize(soft_unicode(value), trim_url_limit, nofollow)
-    if environment.autoescape:
+    rv = urlize(value, trim_url_limit, nofollow)
+    if eval_ctx.autoescape:
         rv = Markup(rv)
     return rv
 
@@ -322,10 +424,11 @@ def do_indent(s, width=4, indentfirst=False):
         {{ mytext|indent(2, true) }}
             indent by two spaces and indent the first line too.
     """
-    indention = ' ' * width
+    indention = u' ' * width
+    rv = (u'\n' + indention).join(s.splitlines())
     if indentfirst:
-        return u'\n'.join(indention + line for line in s.splitlines())
-    return s.replace('\n', '\n' + indention)
+        rv = indention + rv
+    return rv
 
 
 def do_truncate(s, length=255, killwords=False, end='...'):
@@ -358,15 +461,16 @@ def do_truncate(s, length=255, killwords=False, end='...'):
     result.append(end)
     return u' '.join(result)
 
-
-def do_wordwrap(s, width=79, break_long_words=True):
+@environmentfilter
+def do_wordwrap(environment, s, width=79, break_long_words=True):
     """
     Return a copy of the string passed to the filter wrapped after
     ``79`` characters.  You can override this default using the first
     parameter.  If you set the second parameter to `false` Jinja will not
     split words apart if they are longer than `width`.
     """
-    return u'\n'.join(textwrap.wrap(s, width=width, expand_tabs=False,
+    import textwrap
+    return environment.newline_sequence.join(textwrap.wrap(s, width=width, expand_tabs=False,
                                    replace_whitespace=False,
                                    break_long_words=break_long_words))
 
@@ -433,7 +537,7 @@ def do_striptags(value):
 def do_slice(value, slices, fill_with=None):
     """Slice an iterator and return a list of lists containing
     those items. Useful if you want to create a div containing
-    three div tags that represent columns:
+    three ul tags that represent columns:
 
     .. sourcecode:: html+jinja
 
@@ -479,7 +583,7 @@ def do_batch(value, linecount, fill_with=None):
         {%- for row in items|batch(3, '&nbsp;') %}
           <tr>
           {%- for column in row %}
-            <tr>{{ column }}</td>
+            <td>{{ column }}</td>
           {%- endfor %}
           </tr>
         {%- endfor %}
@@ -512,29 +616,24 @@ def do_round(value, precision=0, method='common'):
     .. sourcecode:: jinja
 
         {{ 42.55|round }}
-            -> 43
+            -> 43.0
         {{ 42.55|round(1, 'floor') }}
             -> 42.5
+
+    Note that even if rounded to 0 precision, a float is returned.  If
+    you need a real integer, pipe it through `int`:
+
+    .. sourcecode:: jinja
+
+        {{ 42.55|round|int }}
+            -> 43
     """
     if not method in ('common', 'ceil', 'floor'):
         raise FilterArgumentError('method must be common, ceil or floor')
-    if precision < 0:
-        raise FilterArgumentError('precision must be a postive integer '
-                                  'or zero.')
     if method == 'common':
         return round(value, precision)
     func = getattr(math, method)
-    if precision:
-        return func(value * 10 * precision) / (10 * precision)
-    else:
-        return func(value)
-
-
-def do_sort(value, reverse=False):
-    """Sort a sequence. Per default it sorts ascending, if you pass it
-    true as first argument it will reverse the sorting.
-    """
-    return sorted(value, reverse=reverse)
+    return func(value * (10 ** precision)) / (10 ** precision)
 
 
 @environmentfilter
@@ -571,8 +670,12 @@ def do_groupby(environment, value, attribute):
     As you can see the item we're grouping by is stored in the `grouper`
     attribute and the `list` contains all the objects that have this grouper
     in common.
+
+    .. versionchanged:: 2.6
+       It's now possible to use dotted notation to group by the child
+       attribute of another attribute.
     """
-    expr = lambda x: environment.getitem(x, attribute)
+    expr = make_attrgetter(environment, attribute)
     return sorted(map(_GroupTuple, groupby(sorted(value, key=expr), expr)))
 
 
@@ -585,6 +688,27 @@ class _GroupTuple(tuple):
         return tuple.__new__(cls, (key, list(value)))
 
 
+@environmentfilter
+def do_sum(environment, iterable, attribute=None, start=0):
+    """Returns the sum of a sequence of numbers plus the value of parameter
+    'start' (which defaults to 0).  When the sequence is empty it returns
+    start.
+
+    It is also possible to sum up only certain attributes:
+
+    .. sourcecode:: jinja
+
+        Total: {{ items|sum(attribute='price') }}
+
+    .. versionchanged:: 2.6
+       The `attribute` parameter was added to allow suming up over
+       attributes.  Also the `start` parameter was moved on to the right.
+    """
+    if attribute is not None:
+        iterable = imap(make_attrgetter(environment, attribute), iterable)
+    return sum(iterable, start)
+
+
 def do_list(value):
     """Convert the value into a list.  If it was a string the returned list
     will be a list of characters.
@@ -630,13 +754,20 @@ def do_attr(environment, obj, name):
     See :ref:`Notes on subscriptions <notes-on-subscriptions>` for more details.
     """
     try:
-        value = getattr(obj, name)
-    except AttributeError:
-        return environment.undefined(obj=obj, name=name)
-    if environment.sandboxed and not \
-       environment.is_safe_attribute(obj, name, value):
-        return environment.unsafe_undefined(obj, name)
-    return value
+        name = str(name)
+    except UnicodeError:
+        pass
+    else:
+        try:
+            value = getattr(obj, name)
+        except AttributeError:
+            pass
+        else:
+            if environment.sandboxed and not \
+               environment.is_safe_attribute(obj, name, value):
+                return environment.unsafe_undefined(obj, name)
+            return value
+    return environment.undefined(obj=obj, name=name)
 
 
 FILTERS = {
@@ -654,6 +785,7 @@ FILTERS = {
     'join':                 do_join,
     'count':                len,
     'dictsort':             do_dictsort,
+    'sort':                 do_sort,
     'length':               len,
     'reverse':              do_reverse,
     'center':               do_center,
@@ -678,11 +810,11 @@ FILTERS = {
     'striptags':            do_striptags,
     'slice':                do_slice,
     'batch':                do_batch,
-    'sum':                  sum,
+    'sum':                  do_sum,
     'abs':                  abs,
     'round':                do_round,
-    'sort':                 do_sort,
     'groupby':              do_groupby,
     'safe':                 do_mark_safe,
-    'xmlattr':              do_xmlattr
+    'xmlattr':              do_xmlattr,
+    'urlescape':            do_urlescape
 }