improved filters
authorArmin Ronacher <armin.ronacher@active-4.com>
Thu, 17 Apr 2008 09:13:40 +0000 (11:13 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Thu, 17 Apr 2008 09:13:40 +0000 (11:13 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/filters.py
jinja2/nodes.py
jinja2/utils.py

index 64524eda38f229cad6fb97e90570d70998e971be..61f39179cbaa56b8b66d71d54cd741250fa3095a 100644 (file)
@@ -1005,6 +1005,8 @@ class CodeGenerator(NodeVisitor):
         func = self.environment.filters.get(node.name)
         if getattr(func, 'contextfilter', False):
             self.write('context, ')
+        elif getattr(func, 'environmentfilter', False):
+            self.write('environment, ')
         if isinstance(node.node, nodes.Filter):
             self.visit_Filter(node.node, frame, initial)
         elif node.node is None:
index 46a0e09dbf71babdd059d3d94d6207008ccfd533..c042a5bbea5d0cd81e79a650d90580c09bb3f6ac 100644 (file)
@@ -9,13 +9,15 @@
     :license: BSD, see LICENSE for more details.
 """
 import re
+import math
 from random import choice
 try:
     from operator import itemgetter
 except ImportError:
     itemgetter = lambda a: lambda b: b[a]
 from urllib import urlencode, quote
-from jinja2.utils import Markup, escape, pformat, urlize
+from itertools import imap
+from jinja2.utils import Markup, escape, pformat, urlize, soft_unicode
 from jinja2.runtime import Undefined
 
 
@@ -27,13 +29,24 @@ def contextfilter(f):
     """Decorator for marking context dependent filters. The current context
     argument will be passed as first argument.
     """
+    if getattr(f, 'environmentfilter', False):
+        raise TypeError('filter already marked as environment filter')
     f.contextfilter = True
     return f
 
 
-def do_replace(s, old, new, count=None):
+def environmentfilter(f):
+    """Decorator for marking evironment dependent filters.  The environment
+    used for the template is passed to the filter as first argument.
     """
-    Return a copy of the value with all occurrences of a substring
+    if getattr(f, 'contextfilter', False):
+        raise TypeError('filter already marked as context filter')
+    f.environmentfilter = True
+    return f
+
+
+def do_replace(s, old, new, count=None):
+    """Return a copy of the value with all occurrences of a substring
     replaced with a new one. The first argument is the substring
     that should be replaced, the second is the replacement string.
     If the optional third argument ``count`` is given, only the first
@@ -47,44 +60,36 @@ def do_replace(s, old, new, count=None):
         {{ "aaaaargh"|replace("a", "d'oh, ", 2) }}
             -> d'oh, d'oh, aaargh
     """
-    if not isinstance(old, basestring) or \
-       not isinstance(new, basestring):
-        raise FilterArgumentError('the replace filter requires '
-                                  'string replacement arguments')
     if count is None:
-        return s.replace(old, new)
-    if not isinstance(count, (int, long)):
-        raise FilterArgumentError('the count parameter of the '
-                                   'replace filter requires '
-                                   'an integer')
+        count = -1
+    if hasattr(old, '__html__') or hasattr(new, '__html__') and \
+       not hasattr(s, '__html__'):
+        s = escape(s)
+    else:
+        s = soft_unicode(s)
     return s.replace(old, new, count)
 
 
 def do_upper(s):
     """Convert a value to uppercase."""
-    return unicode(s).upper()
+    return soft_unicode(s).upper()
 
 
 def do_lower(s):
     """Convert a value to lowercase."""
-    return unicode(s).lower()
+    return soft_unicode(s).lower()
 
 
-def do_escape(s, attribute=False):
-    """
-    XML escape ``&``, ``<``, and ``>`` in a string of data. If the
-    optional parameter is `true` this filter will also convert
-    ``"`` to ``&quot;``. This filter is just used if the environment
-    was configured with disabled `auto_escape`.
+def do_escape(s):
+    """XML escape ``&``, ``<``, ``>``, and ``"`` in a string of data.
 
     This method will have no effect it the value is already escaped.
     """
-    return escape(unicode(s), attribute)
+    return escape(s)
 
 
 def do_xmlattr(d, autospace=False):
-    """
-    Create an SGML/XML attribute string based on the items in a dict.
+    """Create an SGML/XML attribute string based on the items in a dict.
     All values that are neither `none` nor `undefined` are automatically
     escaped:
 
@@ -106,8 +111,6 @@ def do_xmlattr(d, autospace=False):
     As you can see it automatically prepends a space in front of the item
     if the filter returned something. You can disable this by passing
     `false` as only argument to the filter.
-
-    *New in Jinja 1.1*
     """
     if not hasattr(d, 'iteritems'):
         raise TypeError('a dict is required')
@@ -118,31 +121,32 @@ def do_xmlattr(d, autospace=False):
                 escape(env.to_unicode(key)),
                 escape(env.to_unicode(value), True)
             ))
-    rv = u' '.join(result)
+    rv = u' '.join(
+        u'%s="%s"' % (escape(key), escape(value))
+        for key, value in d.iteritems()
+        if value is not None and not isinstance(value, Undefined)
+    )
     if autospace:
         rv = ' ' + rv
-    return rv
+    return Markup(rv)
 
 
 def do_capitalize(s):
-    """
-    Capitalize a value. The first character will be uppercase, all others
+    """Capitalize a value. The first character will be uppercase, all others
     lowercase.
     """
-    return unicode(s).capitalize()
+    return soft_unicode(s).capitalize()
 
 
 def do_title(s):
-    """
-    Return a titlecased version of the value. I.e. words will start with
+    """Return a titlecased version of the value. I.e. words will start with
     uppercase letters, all remaining characters are lowercase.
     """
-    return unicode(s).title()
+    return soft_unicode(s).title()
 
 
 def do_dictsort(value, case_sensitive=False, by='key'):
-    """
-    Sort a dict and yield (key, value) pairs. Because python dicts are
+    """ Sort a dict and yield (key, value) pairs. Because python dicts are
     unsorted you may want to use this function to order them by either
     key or value:
 
@@ -165,21 +169,19 @@ def do_dictsort(value, case_sensitive=False, by='key'):
     else:
         raise FilterArgumentError('You can only sort by either '
                                   '"key" or "value"')
-    def sort_func(value):
+    def sort_func(item):
+        value = item[pos]
         if isinstance(value, basestring):
             value = unicode(value)
             if not case_sensitive:
                 value = value.lower()
         return value
 
-    items = value.items()
-    items.sort(lambda a, b: cmp(sort_func(a[pos]), sort_func(b[pos])))
-    return items
+    return sorted(value.items(), key=sort_func)
 
 
 def do_default(value, default_value=u'', boolean=False):
-    """
-    If the value is undefined it will return the passed default value,
+    """If the value is undefined it will return the passed default value,
     otherwise the value of the variable:
 
     .. sourcecode:: jinja
@@ -201,8 +203,7 @@ def do_default(value, default_value=u'', boolean=False):
 
 
 def do_join(value, d=u''):
-    """
-    Return a string which is the concatenation of the strings in the
+    """Return a string which is the concatenation of the strings in the
     sequence. The separator between elements is an empty string per
     default, you can define ith with the optional parameter:
 
@@ -214,72 +215,63 @@ def do_join(value, d=u''):
         {{ [1, 2, 3]|join }}
             -> 123
     """
-    return unicode(d).join(unicode(x) for x in value)
+    # if the delimiter doesn't have an html representation we check
+    # if any of the items has.  If yes we do a coercion to Markup
+    if not hasttr(d, '__html__'):
+        value = list(value)
+        do_escape = False
+        for idx, item in enumerate(value):
+            if hasattr(item, '__html__'):
+                do_escape = True
+            else:
+                value[idx] = unicode(item)
+        if do_escape:
+            d = escape(d)
+        else:
+            d = unicode(d)
+        return d.join(value)
+
+    # no html involved, to normal joining
+    return soft_unicode(d).join(imap(soft_unicode, value))
 
 
 def do_center(value, width=80):
-    """
-    Centers the value in a field of a given width.
-    """
+    """Centers the value in a field of a given width."""
     return unicode(value).center(width)
 
 
-@contextfilter
-def do_first(context, seq):
-    """
-    Return the frist item of a sequence.
-    """
+@environmentfilter
+def do_first(environment, seq):
+    """Return the frist item of a sequence."""
     try:
         return iter(seq).next()
     except StopIteration:
-        return context.environment.undefined('seq|first',
+        return environment.undefined('seq|first',
             extra='the sequence was empty')
 
 
-@contextfilter
-def do_last(context, seq):
-    """
-    Return the last item of a sequence.
-    """
+@environmentfilter
+def do_last(environment, seq):
+    """Return the last item of a sequence."""
     try:
         return iter(reversed(seq)).next()
     except StopIteration:
-        return context.environment.undefined('seq|last',
+        return environment.undefined('seq|last',
             extra='the sequence was empty')
 
 
-@contextfilter
-def do_random(context, seq):
-    """
-    Return a random item from the sequence.
-    """
+@environmentfilter
+def do_random(environment, seq):
+    """Return a random item from the sequence."""
     try:
         return choice(seq)
     except IndexError:
-        return context.environment.undefined('seq|random',
+        return environment.undefined('seq|random',
             extra='the sequence was empty')
 
 
-def do_jsonencode(value):
-    """
-    JSON dump a variable. just works if simplejson is installed.
-
-    .. sourcecode:: jinja
-
-        {{ 'Hello World'|jsonencode }}
-            -> "Hello World"
-    """
-    global simplejson
-    try:
-        simplejson
-    except NameError:
-        import simplejson
-    return simplejson.dumps(value)
-
-
 def do_filesizeformat(value):
-    """
-    Format the value like a 'human-readable' file size (i.e. 13 KB,
+    """Format the value like a 'human-readable' file size (i.e. 13 KB,
     4.1 MB, 102 bytes, etc).
     """
     # fail silently
@@ -298,8 +290,7 @@ def do_filesizeformat(value):
 
 
 def do_pprint(value, verbose=False):
-    """
-    Pretty print a variable. Useful for debugging.
+    """Pretty print a variable. Useful for debugging.
 
     With Jinja 1.2 onwards you can pass it a parameter.  If this parameter
     is truthy the output will be more verbose (this requires `pretty`)
@@ -308,8 +299,7 @@ def do_pprint(value, verbose=False):
 
 
 def do_urlize(value, trim_url_limit=None, nofollow=False):
-    """
-    Converts URLs in plain text into clickable links.
+    """Converts URLs in plain text into clickable links.
 
     If you pass the filter an additional integer it will shorten the urls
     to that number. Also a third argument exists that makes the urls
@@ -320,7 +310,7 @@ def do_urlize(value, trim_url_limit=None, nofollow=False):
         {{ mytext|urlize(40, True) }}
             links are shortened to 40 chars and defined with rel="nofollow"
     """
-    return urlize(unicode(value), trim_url_limit, nofollow)
+    return urlize(soft_unicode(value), trim_url_limit, nofollow)
 
 
 def do_indent(s, width=4, indentfirst=False):
@@ -339,7 +329,7 @@ def do_indent(s, width=4, indentfirst=False):
     """
     indention = ' ' * width
     if indentfirst:
-        return u'\n'.join([indention + line for line in s.splitlines()])
+        return u'\n'.join(indention + line for line in s.splitlines())
     return s.replace('\n', '\n' + indention)
 
 
@@ -386,8 +376,10 @@ def do_wordwrap(s, pos=79, hard=False):
     if len(s) < pos:
         return s
     if hard:
-        return u'\n'.join([s[idx:idx + pos] for idx in
-                          xrange(0, len(s), pos)])
+        return u'\n'.join(s[idx:idx + pos] for idx in
+                          xrange(0, len(s), pos))
+
+    # TODO: switch to wordwrap.wrap
     # code from http://aspn.activestate.com/ASPN/Cookbook/Python/Recipe/148061
     return reduce(lambda line, word, pos=pos: u'%s%s%s' %
                   (line, u' \n'[(len(line)-line.rfind('\n') - 1 +
@@ -396,53 +388,12 @@ def do_wordwrap(s, pos=79, hard=False):
 
 
 def do_wordcount(s):
-    """
-    Count the words in that string.
-    """
-    return len([x for x in s.split() if x])
-
-
-def do_textile(s):
-    """
-    Prase the string using textile.
-
-    requires the `PyTextile`_ library.
-
-    .. _PyTextile: http://dealmeida.net/projects/textile/
-    """
-    from textile import textile
-    return textile(s.encode('utf-8')).decode('utf-8')
-
-
-def do_markdown(s):
-    """
-    Parse the string using markdown.
-
-    requires the `Python-markdown`_ library.
-
-    .. _Python-markdown: http://www.freewisdom.org/projects/python-markdown/
-    """
-    from markdown import markdown
-    return markdown(s.encode('utf-8')).decode('utf-8')
-
-
-def do_rst(s):
-    """
-    Parse the string using the reStructuredText parser from the
-    docutils package.
-
-    requires `docutils`_.
-
-    .. _docutils: http://docutils.sourceforge.net/
-    """
-    from docutils.core import publish_parts
-    parts = publish_parts(source=s, writer_name='html4css1')
-    return parts['fragment']
+    """Count the words in that string."""
+    return len(x for x in s.split() if x)
 
 
 def do_int(value, default=0):
-    """
-    Convert the value into an integer. If the
+    """Convert the value into an integer. If the
     conversion doesn't work it will return ``0``. You can
     override this default using the first parameter.
     """
@@ -456,8 +407,7 @@ def do_int(value, default=0):
 
 
 def do_float(value, default=0.0):
-    """
-    Convert the value into a floating point number. If the
+    """Convert the value into a floating point number. If the
     conversion doesn't work it will return ``0.0``. You can
     override this default using the first parameter.
     """
@@ -468,10 +418,8 @@ def do_float(value, default=0.0):
 
 
 def do_string(value):
-    """
-    Convert the value into an string.
-    """
-    return unicode(value)
+    """Convert the value into an string."""
+    return soft_unicode(value)
 
 
 def do_format(value, *args, **kwargs):
@@ -486,28 +434,22 @@ def do_format(value, *args, **kwargs):
     if kwargs:
         kwargs.update(idx, arg in enumerate(args))
         args = kwargs
-    return unicode(value) % args
+    return soft_unicode(value) % args
 
 
 def do_trim(value):
-    """
-    Strip leading and trailing whitespace.
-    """
-    return value.strip()
+    """Strip leading and trailing whitespace."""
+    return soft_unicode(value).strip()
 
 
 def do_striptags(value):
-    """
-    Strip SGML/XML tags and replace adjacent whitespace by one space.
-
-    *new in Jinja 1.1*
+    """Strip SGML/XML tags and replace adjacent whitespace by one space.
     """
     return ' '.join(_striptags_re.sub('', value).split())
 
 
 def do_slice(value, slices, fill_with=None):
-    """
-    Slice an iterator and return a list of lists containing
+    """Slice an iterator and return a list of lists containing
     those items. Useful if you want to create a div containing
     three div tags that represent columns:
 
@@ -525,8 +467,6 @@ def do_slice(value, slices, fill_with=None):
 
     If you pass it a second argument it's used to fill missing
     values on the last iteration.
-
-    *new in Jinja 1.1*
     """
     result = []
     seq = list(value)
@@ -564,8 +504,6 @@ def do_batch(value, linecount, fill_with=None):
           </tr>
         {%- endfor %}
         </table>
-
-    *new in Jinja 1.1*
     """
     result = []
     tmp = []
@@ -581,9 +519,8 @@ def do_batch(value, linecount, fill_with=None):
     return result
 
 
-def do_round(precision=0, method='common'):
-    """
-    Round the number to a given precision. The first
+def do_round(value, precision=0, method='common'):
+    """Round the number to a given precision. The first
     parameter specifies the precision (default is ``0``), the
     second the rounding method:
 
@@ -599,41 +536,31 @@ def do_round(precision=0, method='common'):
             -> 43
         {{ 42.55|round(1, 'floor') }}
             -> 42.5
-
-    *new in Jinja 1.1*
     """
     if not method in ('common', 'ceil', 'floor'):
         raise FilterArgumentError('method must be common, ceil or floor')
     if precision < 0:
         raise FilterArgumentError('precision must be a postive integer '
                                   'or zero.')
-    def wrapped(env, context, value):
-        if method == 'common':
-            return round(value, precision)
-        import math
-        func = getattr(math, method)
-        if precision:
-            return func(value * 10 * precision) / (10 * precision)
-        else:
-            return func(value)
-    return wrapped
+    if method == 'common':
+        return round(value, precision)
+    func = getattr(math, method)
+    if precision:
+        return func(value * 10 * precision) / (10 * precision)
+    else:
+        return func(value)
 
 
-def do_sort(reverse=False):
-    """
-    Sort a sequence. Per default it sorts ascending, if you pass it
+def do_sort(value, reverse=False):
+    """Sort a sequence. Per default it sorts ascending, if you pass it
     `True` as first argument it will reverse the sorting.
-
-    *new in Jinja 1.1*
     """
-    def wrapped(env, context, value):
-        return sorted(value, reverse=reverse)
-    return wrapped
+    return sorted(value, reverse=reverse)
 
 
-def do_groupby(attribute):
-    """
-    Group a sequence of objects by a common attribute.
+@environmentfilter
+def do_groupby(environment, value, attribute):
+    """Group a sequence of objects by a common attribute.
 
     If you for example have a list of dicts or objects that represent persons
     with `gender`, `first_name` and `last_name` attributes and you want to
@@ -654,17 +581,13 @@ def do_groupby(attribute):
     As you can see the item we're grouping by is stored in the `grouper`
     attribute and the `list` contains all the objects that have this grouper
     in common.
-
-    *New in Jinja 1.2*
     """
-    def wrapped(env, context, value):
-        expr = lambda x: env.get_attribute(x, attribute)
-        return sorted([{
-            'grouper':  a,
-            'list':     list(b)
-        } for a, b in groupby(sorted(value, key=expr), expr)],
-            key=itemgetter('grouper'))
-    return wrapped
+    expr = lambda x: environment.subscribe(x, attribute)
+    return sorted([{
+        'grouper':  a,
+        'list':     b
+    } for a, b in groupby(sorted(value, key=expr), expr)],
+        key=itemgetter('grouper'))
 
 
 FILTERS = {
@@ -688,16 +611,12 @@ FILTERS = {
     'first':                do_first,
     'last':                 do_last,
     'random':               do_random,
-    'jsonencode':           do_jsonencode,
     'filesizeformat':       do_filesizeformat,
     'pprint':               do_pprint,
     'indent':               do_indent,
     'truncate':             do_truncate,
     'wordwrap':             do_wordwrap,
     'wordcount':            do_wordcount,
-    'textile':              do_textile,
-    'markdown':             do_markdown,
-    'rst':                  do_rst,
     'int':                  do_int,
     'float':                do_float,
     'string':               do_string,
index 51c30396419d4313bf4eece44cc2f6a4b4d69999..c57939594ae39f3b6d3b14e9702b09aaa628007f 100644 (file)
@@ -419,6 +419,8 @@ class Filter(Expr):
         if obj is None:
             obj = self.node.as_const()
         args = [x.as_const() for x in self.args]
+        if getattr(filter, 'environmentfilter', False):
+            args.insert(0, self.environment)
         kwargs = dict(x.as_const() for x in self.kwargs)
         if self.dyn_args is not None:
             try:
index af1066c74f22f67254772133a330f657bd17c7ae..2e64fe2f762374dc5efbe75e96bc1a37da7ede62 100644 (file)
@@ -26,6 +26,15 @@ def escape(obj, attribute=False):
     )
 
 
+def soft_unicode(s):
+    """Make a string unicode if it isn't already.  That way a markup
+    string is not converted back to unicode.
+    """
+    if not isinstance(s, unicode):
+        s = unicode(s)
+    return s
+
+
 def pformat(obj, verbose=False):
     """
     Prettyprint an object.  Either use the `pretty` library or the