[svn] some changes in jinja. added recursion support for {% for %}, pos -> lineno...
authorArmin Ronacher <armin.ronacher@active-4.com>
Tue, 27 Feb 2007 18:40:14 +0000 (19:40 +0100)
committerArmin Ronacher <armin.ronacher@active-4.com>
Tue, 27 Feb 2007 18:40:14 +0000 (19:40 +0100)
--HG--
branch : trunk

jinja/datastructure.py
jinja/defaults.py
jinja/environment.py
jinja/exceptions.py
jinja/lexer.py
jinja/nodes.py
jinja/parser.py
jinja/tests.py [new file with mode: 0644]
jinja/translators/python.py

index f1db28d6029acc0da0c25dff28d17c4735ea5a19..c5fe44d01d60de4e8fd120aa4ab787685c3d4af5 100644 (file)
@@ -22,6 +22,7 @@ class UndefinedType(object):
     """
     An object that does not exist.
     """
+    __slots__ = ()
 
     def __init__(self):
         try:
@@ -68,8 +69,6 @@ class Context(object):
     Dict like object.
     """
 
-    __slots__ = ('stack')
-
     def __init__(*args, **kwargs):
         try:
             self = args[0]
@@ -80,16 +79,23 @@ class Context(object):
                             'The rest of the arguments are forwarded to '
                             'the default dict constructor.')
         self._stack = [initial, {}]
+        self.globals, self.current = self._stack
 
     def pop(self):
         if len(self._stack) <= 2:
             raise ValueError('cannot pop initial layer')
-        return self._stack.pop()
+        rv = self._stack.pop()
+        self.current = self._stack[-1]
+        return rv
 
     def push(self, data=None):
         self._stack.append(data or {})
+        self.current = self._stack[-1]
 
     def __getitem__(self, name):
+        # don't give access to jinja internal variables
+        if name.startswith('::'):
+            return Undefined
         for d in _reversed(self._stack):
             if name in d:
                 return d[name]
@@ -116,36 +122,66 @@ class LoopContext(object):
     Used by `Environment.iterate`.
     """
 
-    def __init__(self, index, length):
-        self.index = 0
-        self.length = length
-        try:
-            self.length = len(seq)
-        except TypeError:
-            self.seq = list(seq)
-            self.length = len(self.seq)
-        else:
-            self.seq = seq
+    jinja_allowed_attributes = ['index', 'index0', 'length', 'parent',
+                                'even', 'odd']
+
+    def __init__(self, seq, parent, loop_function):
+        self.loop_function = loop_function
+        self.parent = parent
+        self._stack = []
+        if seq is not None:
+            self.push(seq)
+
+    def push(self, seq):
+        self._stack.append({
+            'index':            -1,
+            'seq':              seq,
+            'length':           len(seq)
+        })
 
-    def revindex(self):
-        return self.length - self.index + 1
-    revindex = property(revindex)
+    def pop(self):
+        return self._stack.pop()
 
-    def revindex0(self):
-        return self.length - self.index
-    revindex0 = property(revindex0)
+    iterated = property(lambda s: s._stack[-1]['index'] > -1)
+    index0 = property(lambda s: s._stack[-1]['index'])
+    index = property(lambda s: s._stack[-1]['index'] + 1)
+    length = property(lambda s: s._stack[-1]['length'])
+    even = property(lambda s: s._stack[-1]['index'] % 2 == 0)
+    odd = property(lambda s: s._stack[-1]['index'] % 2 == 1)
 
-    def index0(self):
-        return self.index - 1
-    index0 = property(index0)
+    def __iter__(self):
+        s = self._stack[-1]
+        for idx, item in enumerate(s['seq']):
+            s['index'] = idx
+            yield item
+
+    def __call__(self, seq):
+        if self.loop_function is not None:
+            return self.loop_function(seq)
+        return Undefined
+
+
+class CycleContext(object):
+    """
+    Helper class used for cycling.
+    """
+
+    def __init__(self, seq=None):
+        self.lineno = -1
+        if seq is not None:
+            self.seq = seq
+            self.length = len(seq)
+            self.cycle = self.cycle_static
+        else:
+            self.cycle = self.cycle_dynamic
 
-    def even(self):
-        return self.index % 2 == 0
-    even = property(even)
+    def cycle_static(self):
+        self.lineno = (self.lineno + 1) % self.length
+        return self.seq[self.lineno]
 
-    def odd(self):
-        return self.index % 2 == 1
-    odd = property(odd)
+    def cycle_dynamic(self, seq):
+        self.lineno = (self.lineno + 1) % len(seq)
+        return seq[self.lineno]
 
 
 class TokenStream(object):
@@ -203,6 +239,6 @@ class TokenStream(object):
         except StopIteration:
             raise IndexError('end of stream reached')
 
-    def push(self, pos, token, data):
+    def push(self, lineno, token, data):
         """Push an yielded token back to the stream."""
-        self._pushed.append((pos, token, data))
+        self._pushed.append((lineno, token, data))
index 0d16b2f221cbecaf11a1743c949abede36bb6606..9cb52919fafebf1bbab2b80eb19960a883fad290 100644 (file)
@@ -9,3 +9,4 @@
     :license: BSD, see LICENSE for more details.
 """
 from jinja.filters import FILTERS as DEFAULT_FILTERS
+from jinja.tests import TESTS as DEFAULT_TESTS
index 6b0867508425b7a7ae8f038ea0203ef5090573a3..46cede0819e82c109a484cdfb609cf35fe468d91 100644 (file)
@@ -10,9 +10,9 @@
 """
 from jinja.lexer import Lexer
 from jinja.parser import Parser
-from jinja.datastructure import LoopContext, Undefined
-from jinja.exceptions import FilterNotFound
-from jinja.defaults import DEFAULT_FILTERS
+from jinja.datastructure import Undefined
+from jinja.exceptions import FilterNotFound, TestNotFound
+from jinja.defaults import DEFAULT_FILTERS, DEFAULT_TESTS
 
 
 class Environment(object):
@@ -30,7 +30,8 @@ class Environment(object):
                  template_charset='utf-8',
                  charset='utf-8',
                  loader=None,
-                 filters=None):
+                 filters=None,
+                 tests=None):
 
         # lexer / parser information
         self.block_start_string = block_start_string
@@ -45,6 +46,7 @@ class Environment(object):
         self.charset = charset
         self.loader = loader
         self.filters = filters or DEFAULT_FILTERS.copy()
+        self.tests = filters or DEFAULT_TESTS.copy()
 
         # create lexer
         self.lexer = Lexer(self)
@@ -66,21 +68,6 @@ class Environment(object):
             except UnicodeError:
                 return str(value).decode(self.charset, 'ignore')
 
-    def iterate(self, seq):
-        """
-        Helper function used by the python translator runtime code to
-        iterate over a sequence.
-        """
-        try:
-            length = len(seq)
-        except TypeError:
-            seq = list(seq)
-            length = len(seq)
-        loop_data = LoopContext(0, length)
-        for item in seq:
-            loop_data.index += 1
-            yield loop_data, item
-
     def prepare_filter(self, name, *args):
         """
         Prepare a filter.
@@ -98,13 +85,50 @@ class Environment(object):
             value = f(self, context, value)
         return value
 
+    def perform_test(self, value, context, testname):
+        """
+        Perform a test on a variable.
+        """
+        try:
+            test = self.tests[testname]
+        except KeyError:
+            raise TestNotFound(testname)
+        return bool(test(self, context, value))
+
     def get_attribute(self, obj, name):
         """
         Get the attribute name from obj.
         """
         try:
-            return getattr(obj, name)
+            rv = getattr(obj, name)
+            r = getattr(obj, 'jinja_allowed_attributes', None)
+            if r is not None:
+                if name not in r:
+                    raise AttributeError()
+            return rv
         except AttributeError:
             return obj[name]
         except:
             return Undefined
+
+    def call_function(self, f, args, kwargs, dyn_args, dyn_kwargs):
+        """
+        Function call helper
+        """
+        if dyn_args is not None:
+            args += dyn_args
+        elif dyn_kwargs is not None:
+            kwargs.update(dyn_kwargs)
+        return f(*args, **kwargs)
+
+    def finish_var(self, value):
+        """
+        As long as no write_var function is passed to the template
+        evaluator the source generated by the python translator will
+        call this function for all variables. You can use this to
+        enable output escaping etc or just ensure that None and
+        Undefined values are rendered as empty strings.
+        """
+        if value is None or value is Undefined:
+            return u''
+        return unicode(value)
index 265c2da0532ebadfa76f59c1d1d8d9f7c3cdbbd5..ec052ea230d34ad244d8b4606077f5915f249d62 100644 (file)
@@ -23,14 +23,23 @@ class FilterNotFound(KeyError, TemplateError):
         KeyError.__init__(self, message)
 
 
+class TestNotFound(KeyError, TemplateError):
+    """
+    Raised if a test does not exist.
+    """
+
+    def __init__(self, message):
+        KeyError.__init__(self, message)
+
+
 class TemplateSyntaxError(SyntaxError, TemplateError):
     """
     Raised to tell the user that there is a problem with the template.
     """
 
-    def __init__(self, message, pos):
+    def __init__(self, message, lineno):
         SyntaxError.__init__(self, message)
-        self.pos = pos
+        self.lineno = lineno
 
 
 class TemplateRuntimeError(TemplateError):
@@ -39,6 +48,6 @@ class TemplateRuntimeError(TemplateError):
     rendering.
     """
 
-    def __init__(self, message, pos):
+    def __init__(self, message, lineno):
         RuntimeError.__init__(self, message)
-        self.pos = pos
+        self.lineno = lineno
index e8f4ffc9f8fefc100467a23c45f1bb922093c027..55a4497bd53de22e35cfda49d371b41ff5ba18f7 100644 (file)
@@ -23,7 +23,7 @@ operator_re = re.compile('(%s)' % '|'.join(
     '[', ']', '(', ')', '{', '}',
     # attribute access and comparison / logical operators
     '.', ':', ',', '|', '==', '<', '>', '<=', '>=', '!=', '=',
-    ur'or\b', ur'and\b', ur'not\b'
+    ur'or\b', ur'and\b', ur'not\b', ur'in\b', ur'is'
 ]))
 
 
@@ -37,8 +37,8 @@ class Failure(object):
         self.message = message
         self.error_class = cls
 
-    def __call__(self, position):
-        raise self.error_class(self.message, position)
+    def __call__(self, lineno):
+        raise self.error_class(self.message, lineno)
 
 
 class Lexer(object):
@@ -103,7 +103,8 @@ class Lexer(object):
         returns a `TokenStream` but in some situations it can be useful
         to use this function since it can be marginally faster.
         """
-        pos = 0
+        source = type(source)('\n').join(source.splitlines())
+        pos = lineno = 0
         stack = ['root']
         statetokens = self.rules['root']
         source_length = len(source)
@@ -118,6 +119,9 @@ class Lexer(object):
                         for idx, token in enumerate(tokens):
                             # hidden group
                             if token is None:
+                                g += m.group(idx)
+                                if g:
+                                    lineno += g.count('\n')
                                 continue
                             # failure group
                             elif isinstance(token, Failure):
@@ -128,7 +132,8 @@ class Lexer(object):
                             elif token == '#bygroup':
                                 for key, value in m.groupdict().iteritems():
                                     if value is not None:
-                                        yield m.start(key), key, value
+                                        yield lineno, key, value
+                                        lineno += value.count('\n')
                                         break
                                 else:
                                     raise RuntimeError('%r wanted to resolve '
@@ -139,14 +144,16 @@ class Lexer(object):
                             else:
                                 data = m.group(idx + 1)
                                 if data:
-                                    yield m.start(idx + 1), token, data
+                                    yield lineno, token, data
+                                lineno += data.count('\n')
                     # strings as token just are yielded as it, but just
                     # if the data is not empty
                     else:
                         data = m.group()
                         if tokens is not None:
                             if data:
-                                yield pos, tokens, data
+                                yield lineno, tokens, data
+                        lineno += data.count('\n')
                     # fetch new position into new variable so that we can check
                     # if there is a internal parsing error which would result
                     # in an infinite loop
@@ -188,4 +195,4 @@ class Lexer(object):
                     return
                 # something went wrong
                 raise TemplateSyntaxError('unexpected char %r at %d' %
-                                          (source[pos], pos), pos)
+                                          (source[pos], pos), lineno)
index c10479bde2c7796f197341dd8313aa2628732d08..7d75c58cb0b596aa0edc3cc6726fc3b9fa6c7dc7 100644 (file)
@@ -8,7 +8,31 @@
     :copyright: 2006 by Armin Ronacher.
     :license: BSD, see LICENSE for more details.
 """
-from compiler.ast import Node
+from compiler import ast
+from compiler.misc import set_filename
+
+
+def inc_lineno(offset, tree):
+    """
+    Increment the linenumbers of all nodes in tree with offset.
+    """
+    todo = [tree]
+    while todo:
+        node = todo.pop()
+        node.lineno = (node.lineno or 0) + offset
+        todo.extend(node.getChildNodes())
+
+
+class Node(ast.Node):
+    """
+    jinja node.
+    """
+
+    def getChildren(self):
+        return self.get_items()
+
+    def getChildNodes(self):
+        return [x for x in self.get_items() if isinstance(x, ast.Node)]
 
 
 class Text(Node):
@@ -16,10 +40,13 @@ class Text(Node):
     Node that represents normal text.
     """
 
-    def __init__(self, pos, text):
-        self.pos = pos
+    def __init__(self, lineno, text):
+        self.lineno = lineno
         self.text = text
 
+    def get_items(self):
+        return [self.text]
+
     def __repr__(self):
         return 'Text(%r)' % (self.text,)
 
@@ -29,12 +56,29 @@ class NodeList(list, Node):
     A node that stores multiple childnodes.
     """
 
-    def __init__(self, pos, data=None):
-        self.pos = pos
+    def __init__(self, lineno, data=None):
+        self.lineno = lineno
         list.__init__(self, data or ())
 
+    getChildren = getChildNodes = lambda s: list(s)
+
     def __repr__(self):
-        return 'NodeList(%s)' % list.__repr__(self)
+        return '%s(%s)' % (
+            self.__class__.__name__,
+            list.__repr__(self)
+        )
+
+
+class Template(NodeList):
+    """
+    A template.
+    """
+
+    def __init__(self, filename, node):
+        if node.__class__ is not NodeList:
+            node = (node,)
+        NodeList.__init__(self, 0, node)
+        set_filename(filename, self)
 
 
 class ForLoop(Node):
@@ -42,19 +86,24 @@ class ForLoop(Node):
     A node that represents a for loop
     """
 
-    def __init__(self, pos, item, seq, body, else_):
-        self.pos = pos
+    def __init__(self, lineno, item, seq, body, else_, recursive):
+        self.lineno = lineno
         self.item = item
         self.seq = seq
         self.body = body
         self.else_ = else_
+        self.recursive = recursive
+
+    def get_items(self):
+        return [self.item, self.seq, self.body, self.else_, self.recursive]
 
     def __repr__(self):
-        return 'ForLoop(%r, %r, %r, %r)' % (
+        return 'ForLoop(%r, %r, %r, %r, %r)' % (
             self.item,
             self.seq,
             self.body,
-            self.else_
+            self.else_,
+            self.recursive
         )
 
 
@@ -63,16 +112,21 @@ class IfCondition(Node):
     A node that represents an if condition.
     """
 
-    def __init__(self, pos, test, body, else_):
-        self.pos = pos
-        self.test = test
-        self.body = body
+    def __init__(self, lineno, tests, else_):
+        self.lineno = lineno
+        self.tests = tests
         self.else_ = else_
 
+    def get_items(self):
+        result = []
+        for test in tests:
+            result.extend(test)
+        result.append(self._else)
+        return result
+
     def __repr__(self):
-        return 'IfCondition(%r, %r, %r)' % (
-            self.test,
-            self.body,
+        return 'IfCondition(%r, %r)' % (
+            self.tests,
             self.else_
         )
 
@@ -82,22 +136,54 @@ class Cycle(Node):
     A node that represents the cycle statement.
     """
 
-    def __init__(self, pos, seq):
-        self.pos = pos
+    def __init__(self, lineno, seq):
+        self.lineno = lineno
         self.seq = seq
 
+    def get_items(self):
+        return [self.seq]
+
     def __repr__(self):
         return 'Cycle(%r)' % (self.seq,)
 
 
 class Print(Node):
     """
-    A node that represents variable tags and print calls
+    A node that represents variable tags and print calls.
     """
 
-    def __init__(self, pos, variable):
-        self.pos = pos
+    def __init__(self, lineno, variable):
+        self.lineno = lineno
         self.variable = variable
 
+    def get_items(self):
+        return [self.variable]
+
     def __repr__(self):
         return 'Print(%r)' % (self.variable,)
+
+
+class Macro(Node):
+    """
+    A node that represents a macro.
+    """
+
+    def __init__(self, lineno, name, arguments, body):
+        self.lineno = lineno
+        self.name = name
+        self.arguments = arguments
+        self.body = body
+
+    def get_items(self):
+        result = [self.name]
+        for item in self.arguments:
+            result.extend(item)
+        result.append(self.body)
+        return result
+
+    def __repr__(self):
+        return 'Macro(%r, %r, %r)' % (
+            self.name,
+            self.arguments,
+            self.body
+        )
index e2455542f0cb5e6e6cc1132621a1fceb1dd5f474..34add4f5361ca5add023713e17ff63fdeb874d95 100644 (file)
@@ -10,6 +10,7 @@
 """
 import re
 from compiler import ast, parse
+from compiler.misc import set_filename
 from jinja import nodes
 from jinja.datastructure import TokenStream
 from jinja.exceptions import TemplateSyntaxError
@@ -20,8 +21,9 @@ end_of_block = lambda p, t, d: t == 'block_end'
 end_of_variable = lambda p, t, d: t == 'variable_end'
 switch_for = lambda p, t, d: t == 'name' and d in ('else', 'endfor')
 end_of_for = lambda p, t, d: t == 'name' and d == 'endfor'
-switch_if = lambda p, t, d: t == 'name' and d in ('else', 'endif')
+switch_if = lambda p, t, d: t == 'name' and d in ('else', 'elif', 'endif')
 end_of_if = lambda p, t, d: t == 'name' and d == 'endif'
+end_of_macro = lambda p, t, d: t == 'name' and d == 'endmacro'
 
 
 class Parser(object):
@@ -54,14 +56,31 @@ class Parser(object):
             'for':          self.handle_for_directive,
             'if':           self.handle_if_directive,
             'cycle':        self.handle_cycle_directive,
-            'print':        self.handle_print_directive
+            'print':        self.handle_print_directive,
+            'macro':        self.handle_macro_directive
         }
 
-    def handle_for_directive(self, pos, gen):
+    def handle_for_directive(self, lineno, gen):
         """
         Handle a for directive and return a ForLoop node
         """
-        ast = self.parse_python(pos, gen, 'for %s:pass\nelse:pass')
+        #XXX: maybe we could make the "recurse" part optional by using
+        #     a static analysis later.
+        recursive = []
+        def wrapgen():
+            """Wrap the generator to check if we have a recursive for loop."""
+            for token in gen:
+                if token[1:] == ('name', 'recursive'):
+                    try:
+                        item = gen.next()
+                    except StopIteration:
+                        recursive.append(True)
+                        return
+                    yield token
+                    yield item
+                else:
+                    yield token
+        ast = self.parse_python(lineno, wrapgen(), 'for %s:pass')
         body = self.subparse(switch_for)
 
         # do we have an else section?
@@ -72,39 +91,47 @@ class Parser(object):
             else_ = None
         self.close_remaining_block()
 
-        return nodes.ForLoop(pos, ast.assign, ast.list, body, else_)
+        return nodes.ForLoop(lineno, ast.assign, ast.list, body, else_, bool(recursive))
 
-    def handle_if_directive(self, pos, gen):
+    def handle_if_directive(self, lineno, gen):
         """
-        Handle if/else blocks. elif is not supported by now.
+        Handle if/else blocks.
         """
-        ast = self.parse_python(pos, gen, 'if %s:pass\nelse:pass')
-        body = self.subparse(switch_if)
+        ast = self.parse_python(lineno, gen, 'if %s:pass')
+        tests = [(ast.tests[0][0], self.subparse(switch_if))]
 
         # do we have an else section?
-        if self.tokenstream.next()[2] == 'else':
-            self.close_remaining_block()
-            else_ = self.subparse(end_of_if, True)
-        else:
-            else_ = None
+        while True:
+            lineno, token, needle = self.tokenstream.next()
+            if needle == 'else':
+                self.close_remaining_block()
+                else_ = self.subparse(end_of_if, True)
+                break
+            elif needle == 'elif':
+                gen = self.tokenstream.fetch_until(end_of_block, True)
+                ast = self.parse_python(lineno, gen, 'if %s:pass')
+                tests.append((ast.tests[0][0], self.subparse(switch_if)))
+            else:
+                else_ = None
+                break
         self.close_remaining_block()
 
-        return nodes.IfCondition(pos, ast.tests[0][0], body, else_)
+        return nodes.IfCondition(lineno, tests, else_)
 
-    def handle_cycle_directive(self, pos, gen):
+    def handle_cycle_directive(self, lineno, gen):
         """
         Handle {% cycle foo, bar, baz %}.
         """
-        ast = self.parse_python(pos, gen, '_cycle((%s))')
+        ast = self.parse_python(lineno, gen, '_cycle((%s))')
         # ast is something like Discard(CallFunc(Name('_cycle'), ...))
         # skip that.
-        return nodes.Cycle(pos, ast.expr.args[0])
+        return nodes.Cycle(lineno, ast.expr.args[0])
 
-    def handle_print_directive(self, pos, gen):
+    def handle_print_directive(self, lineno, gen):
         """
         Handle {{ foo }} and {% print foo %}.
         """
-        ast = self.parse_python(pos, gen, 'print_(%s)')
+        ast = self.parse_python(lineno, gen, 'print_(%s)')
         # ast is something like Discard(CallFunc(Name('print_'), ...))
         # so just use the args
         arguments = ast.expr.args
@@ -112,17 +139,38 @@ class Parser(object):
         if len(arguments) != 1:
             raise TemplateSyntaxError('invalid argument count for print; '
                                       'print requires exactly one argument, '
-                                      'got %d.' % len(arguments), pos)
-        return nodes.Print(pos, arguments[0])
+                                      'got %d.' % len(arguments), lineno)
+        return nodes.Print(lineno, arguments[0])
+
+    def handle_macro_directive(self, lineno, gen):
+        """
+        Handle {% macro foo(bar, baz) %}.
+        """
+        try:
+            macro_name = gen.next()
+        except StopIteration:
+            raise TemplateSyntaxError('macro requires a name', lineno)
+        if macro_name[1] != 'name':
+            raise TemplateSyntaxError('expected \'name\', got %r' %
+                                      macro_name[1], lineno)
+        ast = self.parse_python(lineno, gen, 'def %s(%%s):pass' % str(macro_name[2]))
+        body = self.subparse(end_of_macro, True)
+        self.close_remaining_block()
+
+        if ast.varargs or ast.kwargs:
+            raise TemplateSyntaxError('variable length macro signature '
+                                      'not allowed.', lineno)
+        defaults = [None] * (len(ast.argnames) - len(ast.defaults)) + ast.defaults
+        return nodes.Macro(lineno, ast.name, zip(ast.argnames, defaults), body)
 
-    def parse_python(self, pos, gen, template='%s'):
+    def parse_python(self, lineno, gen, template='%s'):
         """
         Convert the passed generator into a flat string representing
         python sourcecode and return an ast node or raise a
         TemplateSyntaxError.
         """
         tokens = []
-        for t_pos, t_token, t_data in gen:
+        for t_lineno, t_token, t_data in gen:
             if t_token == 'string':
                 tokens.append('u' + t_data)
             else:
@@ -131,20 +179,22 @@ class Parser(object):
         try:
             ast = parse(source, 'exec')
         except SyntaxError, e:
-            raise TemplateSyntaxError(str(e), pos + e.offset - 1)
+            raise TemplateSyntaxError(str(e), lineno + e.lineno - 1)
         assert len(ast.node.nodes) == 1, 'get %d nodes, 1 expected' % len(ast.node.nodes)
-        return ast.node.nodes[0]
+        result = ast.node.nodes[0]
+        nodes.inc_lineno(lineno, result)
+        return result
 
     def parse(self):
         """
-        Parse the template and return a nodelist.
+        Parse the template and return a Template.
         """
-        return self.subparse(None)
+        return nodes.Template(self.filename, self.subparse(None))
 
     def subparse(self, test, drop_needle=False):
         """
         Helper function used to parse the sourcecode until the test
-        function which is passed a tuple in the form (pos, token, data)
+        function which is passed a tuple in the form (lineno, token, data)
         returns True. In that case the current token is pushed back to
         the tokenstream and the generator ends.
 
@@ -160,57 +210,57 @@ class Parser(object):
                 return result[0]
             return result
 
-        pos = self.tokenstream.last[0]
-        result = nodes.NodeList(pos)
-        for pos, token, data in self.tokenstream:
+        lineno = self.tokenstream.last[0]
+        result = nodes.NodeList(lineno)
+        for lineno, token, data in self.tokenstream:
             # this token marks the begin or a variable section.
             # parse everything till the end of it.
             if token == 'variable_begin':
                 gen = self.tokenstream.fetch_until(end_of_variable, True)
-                result.append(self.directives['print'](pos, gen))
+                result.append(self.directives['print'](lineno, gen))
 
             # this token marks the start of a block. like for variables
             # just parse everything until the end of the block
             elif token == 'block_begin':
                 gen = self.tokenstream.fetch_until(end_of_block, True)
                 try:
-                    pos, token, data = gen.next()
+                    lineno, token, data = gen.next()
                 except StopIteration:
-                    raise TemplateSyntaxError('unexpected end of block', pos)
+                    raise TemplateSyntaxError('unexpected end of block', lineno)
 
                 # first token *must* be a name token
                 if token != 'name':
-                    raise TemplateSyntaxError('unexpected %r token' % token, pos)
+                    raise TemplateSyntaxError('unexpected %r token' % token, lineno)
 
                 # if a test function is passed to subparse we check if we
                 # reached the end of such a requested block.
-                if test is not None and test(pos, token, data):
+                if test is not None and test(lineno, token, data):
                     if not drop_needle:
-                        self.tokenstream.push(pos, token, data)
+                        self.tokenstream.push(lineno, token, data)
                     return finish()
 
                 # the first token tells us which directive we want to call.
                 # if if doesn't match any existing directive it's like a
                 # template syntax error.
                 if data in self.directives:
-                    node = self.directives[data](pos, gen)
+                    node = self.directives[data](lineno, gen)
                 else:
-                    raise TemplateSyntaxError('unknown directive %r' % data, pos)
+                    raise TemplateSyntaxError('unknown directive %r' % data, lineno)
                 result.append(node)
 
             # here the only token we should get is "data". all other
             # tokens just exist in block or variable sections. (if the
             # tokenizer is not brocken)
             elif token == 'data':
-                result.append(nodes.Text(pos, data))
+                result.append(nodes.Text(lineno, data))
 
             # so this should be unreachable code
             else:
-                raise AssertionError('unexpected token %r' % token)
+                raise AssertionError('unexpected token %r(%r)' % (token, data))
 
         # still here and a test function is provided? raise and error
         if test is not None:
-            raise TemplateSyntaxError('unexpected end of template', pos)
+            raise TemplateSyntaxError('unexpected end of template', lineno)
         return finish()
 
     def close_remaining_block(self):
@@ -220,10 +270,10 @@ class Parser(object):
         the stream. If the next token isn't the block end we throw an
         error.
         """
-        pos = self.tokenstream.last[0]
+        lineno = self.tokenstream.last[0]
         try:
-            pos, token, data = self.tokenstream.next()
+            lineno, token, data = self.tokenstream.next()
         except StopIteration:
-            raise TemplateSyntaxError('missing closing tag', pos)
+            raise TemplateSyntaxError('missing closing tag', lineno)
         if token != 'block_end':
-            raise TemplateSyntaxError('expected close tag, found %r' % token, pos)
+            raise TemplateSyntaxError('expected close tag, found %r' % token, lineno)
diff --git a/jinja/tests.py b/jinja/tests.py
new file mode 100644 (file)
index 0000000..a8692eb
--- /dev/null
@@ -0,0 +1,123 @@
+# -*- coding: utf-8 -*-
+"""
+    jinja.tests
+    ~~~~~~~~~~~
+
+    Jinja test functions. Used with the "is" operator.
+
+    :copyright: 2006 by Armin Ronacher.
+    :license: BSD, see LICENSE for more details.
+"""
+import re
+from jinja.datastructure import Undefined
+
+
+number_re = re.compile(r'^-?\d+(\.\d+)$')
+
+regex_type = type(number_re)
+
+
+def test_odd():
+    """
+    {{ var is odd }}
+
+    Return True if the variable is odd.
+    """
+    return lambda e, c, v: v % 2 == 1
+
+
+def test_even():
+    """
+    {{ var is even }}
+
+    Return True of the variable is even.
+    """
+    return lambda e, c, v: v % 2 == 0
+
+
+def test_defined():
+    """
+    {{ var is defined }}
+
+    Return True if the variable is defined.
+    """
+    return lambda e, c, v: v is not Undefined
+
+
+def test_lower():
+    """
+    {{ var is lower }}
+
+    Return True if the variable is lowercase.
+    """
+    return lambda e, c, v: isinstance(v, basestring) and v.islower()
+
+
+def test_upper():
+    """
+    {{ var is upper }}
+
+    Return True if the variable is uppercase.
+    """
+    return lambda e, c, v: isinstance(v, basestring) and v.isupper()
+
+
+def test_numeric():
+    """
+    {{ var is numeric }}
+
+    Return True if the variable is numeric.
+    """
+    return lambda e, c, v: isinstance(v, (int, long, float)) or (
+                           isinstance(v, basestring) and
+                               number_re.match(v) is not None)
+
+
+def test_sequence():
+    """
+    {{ var is sequence }}
+
+    Return True if the variable is a sequence.
+    """
+    def wrapped(environment, context, value):
+        try:
+            len(value)
+            value.__getitem__
+        except:
+            return False
+        return True
+    return wrapped
+
+
+def test_matching(regex):
+    """
+    {{ var is matching('\d+$') }}
+
+    Test if the variable matches the regular expression
+    given. If the regular expression is a string additional
+    slashes are automatically added, if it's a compiled regex
+    it's used without any modifications.
+    """
+    if isinstance(regex, unicode):
+        regex = re.compile(regex.encode('unicode-escape'), re.U)
+    elif isinstance(regex, unicode):
+        regex = re.compile(regex.encode('string-escape'))
+    elif type(regex) is not regex_type:
+        regex = None
+    def wrapped(environment, context, value):
+        if regex is None:
+            return False
+        else:
+            return regex.match(value)
+    return wrapped
+
+TESTS = {
+    'odd':              test_odd,
+    'even':             test_even,
+    'defined':          test_defined,
+    'lower':            test_lower,
+    'upper':            test_upper,
+    'numeric':          test_numeric,
+    'sequence':         test_sequence,
+    'matching':         test_matching
+}
index 127a7835fd9f810dc319f04e3dc1c43887852d24..b8f52eb496195cd9ff941a4567218baaf0c6ff2b 100644 (file)
@@ -21,8 +21,6 @@ class PythonTranslator(object):
     def __init__(self, environment, node):
         self.environment = environment
         self.node = node
-        self.indention = 0
-        self.last_pos = 0
 
         self.constants = {
             'true':                 'True',
@@ -33,12 +31,14 @@ class PythonTranslator(object):
 
         self.handlers = {
             # jinja nodes
+            nodes.Template:         self.handle_template,
             nodes.Text:             self.handle_template_text,
             nodes.NodeList:         self.handle_node_list,
             nodes.ForLoop:          self.handle_for_loop,
             nodes.IfCondition:      self.handle_if_condition,
             nodes.Cycle:            self.handle_cycle,
             nodes.Print:            self.handle_print,
+            nodes.Macro:            self.handle_macro,
             # used python nodes
             ast.Name:               self.handle_name,
             ast.AssName:            self.handle_name,
@@ -77,6 +77,8 @@ class PythonTranslator(object):
                 ast.GenExpr:        'generator expressions'
             })
 
+        self.reset()
+
     def indent(self, text):
         """
         Indent the current text.
@@ -92,19 +94,31 @@ class PythonTranslator(object):
         elif node.__class__ in self.unsupported:
             raise TemplateSyntaxError('unsupported syntax element %r found.'
                                       % self.unsupported[node.__class__],
-                                      self.last_pos)
+                                      node.lineno)
         else:
             raise AssertionError('unhandled node %r' % node.__class__)
         return out
 
     # -- jinja nodes
 
+    def handle_template(self, node):
+        """
+        Handle a template node. Basically do nothing but calling the
+        handle_node_list function.
+        """
+        return self.handle_node_list(node)
+
+    def handle_template_text(self, node):
+        """
+        Handle data around nodes.
+        """
+        return self.indent('write(%r)' % node.text)
+
     def handle_node_list(self, node):
         """
         In some situations we might have a node list. It's just
         a collection of multiple statements.
         """
-        self.last_pos = node.pos
         buf = []
         for n in node:
             buf.append(self.handle_node(n))
@@ -115,26 +129,46 @@ class PythonTranslator(object):
         Handle a for loop. Pretty basic, just that we give the else
         clause a different behavior.
         """
-        self.last_pos = node.pos
         buf = []
         write = lambda x: buf.append(self.indent(x))
         write('context.push()')
-        write('parent_loop = context[\'loop\']')
-        write('loop_data = None')
-        write('for (loop_data, %s) in environment.iterate(%s):' % (
-            self.handle_node(node.item),
-            self.handle_node(node.seq)
-        ))
+
+        # recursive loops
+        if node.recursive:
+            write('def forloop(seq):')
+            self.indention += 1
+            write('context[\'loop\'].push(seq)')
+            write('for %s in context[\'loop\']:' %
+                self.handle_node(node.item),
+            )
+
+        # simple loops
+        else:
+            write('context[\'loop\'] = LoopContext(%s, context[\'loop\'], None)' %
+                  self.handle_node(node.seq))
+            write('for %s in context[\'loop\']:' %
+                self.handle_node(node.item)
+            )
+
+        # handle real loop code
         self.indention += 1
-        write('loop_data.parent = parent_loop')
-        write('context[\'loop\'] = loop_data')
         buf.append(self.handle_node(node.body))
         self.indention -= 1
+
+        # else part of loop
         if node.else_ is not None:
-            write('if loop_data is None:')
+            write('if not context[\'loop\'].iterated:')
             self.indention += 1
             buf.append(self.handle_node(node.else_))
             self.indention -= 1
+
+        # call recursive for loop!
+        if node.recursive:
+            write('context[\'loop\'].pop()')
+            self.indention -= 1
+            write('context[\'loop\'] = LoopContext(None, context[\'loop\'], forloop)')
+            write('forloop(%s)' % self.handle_node(node.seq))
+
         write('context.pop()')
         return '\n'.join(buf)
 
@@ -142,13 +176,16 @@ class PythonTranslator(object):
         """
         Handle an if condition node.
         """
-        self.last_pos = node.pos
         buf = []
         write = lambda x: buf.append(self.indent(x))
-        write('if %s:' % self.handle_node(node.test))
-        self.indention += 1
-        buf.append(self.handle_node(node.body))
-        self.indention -= 1
+        for idx, (test, body) in enumerate(node.tests):
+            write('%sif %s:' % (
+                idx and 'el' or '',
+                self.handle_node(test)
+            ))
+            self.indention += 1
+            buf.append(self.handle_node(body))
+            self.indention -= 1
         if node.else_ is not None:
             write('else:')
             self.indention += 1
@@ -160,25 +197,59 @@ class PythonTranslator(object):
         """
         Handle the cycle tag.
         """
+        name = '::cycle_%x' % self.last_cycle_id
+        self.last_cycle_id += 1
         buf = []
         write = lambda x: buf.append(self.indent(x))
-        write('# XXX: add some code here')
-        self.last_pos = node.pos
+
+        write('if not %r in context.current:' % name)
+        self.indention += 1
+        if node.seq.__class__ in (ast.Tuple, ast.List):
+            write('context.current[%r] = CycleContext([%s])' % (
+                name,
+                ', '.join([self.handle_node(n) for n in node.seq.nodes])
+            ))
+            hardcoded = True
+        else:
+            write('context.current[%r] = CycleContext()' % name)
+            hardcoded = False
+        self.indention -= 1
+
+        if hardcoded:
+            write('write_var(context.current[%r].cycle())' % name)
+        else:
+            write('write_var(context.current[%r].cycle(%s))' % (
+                name,
+                self.handle_node(node.seq)
+            ))
+
         return '\n'.join(buf)
 
     def handle_print(self, node):
         """
         Handle a print statement.
         """
-        self.last_pos = node.pos
         return self.indent('write_var(%s)' % self.handle_node(node.variable))
 
-    def handle_template_text(self, node):
+    def handle_macro(self, node):
         """
-        Handle data around nodes.
+        Handle macro declarations.
         """
-        self.last_pos = node.pos
-        return self.indent('write(%r)' % node.text)
+        buf = []
+
+        args = []
+        for name, n in node.arguments:
+            if n is None:
+                args.append('%s=Undefined' % name)
+            else:
+                args.append('%s=%s' % (name, self.handle_node(n)))
+        buf.append(self.indent('def macro(%s):' % ', '.join(args)))
+        self.indention += 1
+        buf.append(self.handle_node(node.body))
+        self.indention -= 1
+        buf.append(self.indent('context[%r] = macro' % node.name))
+
+        return '\n'.join(buf)
 
     # -- python nodes
 
@@ -194,9 +265,28 @@ class PythonTranslator(object):
         """
         Any sort of comparison
         """
+        # the semantic for the is operator is different.
+        # for jinja the is operator performs tests and must
+        # be the only operator
+        if node.ops[0][0] == 'is':
+            if len(node.ops) > 1:
+                raise TemplateSyntaxError('is operator must not be chained',
+                                          node.lineno)
+            elif node.ops[0][1].__class__ is not ast.Name:
+                raise TemplateSyntaxError('is operator requires a test name',
+                                          ' as operand', node.lineno)
+            return 'environment.perform_test(%s, context, %r)' % (
+                self.handle_node(node.expr),
+                node.ops[0][1].name
+            )
+
+        # normal operators
         buf = []
         buf.append(self.handle_node(node.expr))
         for op, n in node.ops:
+            if op == 'is':
+                raise TemplateSyntaxError('is operator must not be chained',
+                                          node.lineno)
             buf.append(op)
             buf.append(self.handle_node(n))
         return ' '.join(buf)
@@ -213,7 +303,7 @@ class PythonTranslator(object):
         """
         if len(node.subs) != 1:
             raise TemplateSyntaxError('attribute access requires one argument',
-                                      self.last_pos)
+                                      node.lineno)
         assert node.flags != 'OP_DELETE', 'wtf? do we support that?'
         if node.subs[0].__class__ is ast.Sliceobj:
             return '%s[%s]' % (
@@ -247,22 +337,26 @@ class PythonTranslator(object):
         filters = []
         for n in node.nodes[1:]:
             if n.__class__ is ast.CallFunc:
+                if n.node.__class__ is not ast.Name:
+                    raise TemplateSyntaxError('invalid filter. filter must '
+                                              'be a hardcoded function name '
+                                              'from the filter namespace',
+                                              n.lineno)
                 args = []
                 for arg in n.args:
                     if arg.__class__ is ast.Keyword:
                         raise TemplateSyntaxError('keyword arguments for '
                                                   'filters are not supported.',
-                                                  self.last_pos)
+                                                  n.lineno)
                     args.append(self.handle_node(arg))
                 if n.star_args is not None or n.dstar_args is not None:
                     raise TemplateSynaxError('*args / **kwargs is not supported '
-                                             'for filters', self.last_pos)
-                args = ', '.join(args)
+                                             'for filters', n.lineno)
                 if args:
-                    args = ', ' + args
-                filters.append('environment.prepare_filter(%s%s)' % (
-                    self.handle_node(n.node),
-                    args
+                    args = ', ' + ', '.join(args)
+                filters.append('environment.prepare_filter(%r%s)' % (
+                    n.node.name,
+                    args or ''
                 ))
             elif n.__class__ is ast.Name:
                 filters.append('environment.prepare_filter(%s)' %
@@ -271,7 +365,7 @@ class PythonTranslator(object):
                 raise TemplateSyntaxError('invalid filter. filter must be a '
                                           'hardcoded function name from the '
                                           'filter namespace',
-                                          self.last_pos)
+                                          n.lineno)
         return 'environment.apply_filters(%s, context, [%s])' % (
             self.handle_node(node.nodes[0]),
             ', '.join(filters)
@@ -436,19 +530,18 @@ class PythonTranslator(object):
             args.append(self.handle_node(n))
         return '[%s]' % ':'.join(args)
 
-    def translate(self):
+    def reset(self):
         self.indention = 1
-        self.last_pos = 0
-        lines = [
-            'from jinja.datastructures import Undefined',
-            'def generate(environment, context, write, write_var=None):',
-            '    """This function was automatically generated by',
-            '    the jinja python translator. do not edit."""',
-            '    if write_var is None:',
-            '        write_var = write'
-        ]
-        lines.append(self.handle_node(self.node))
-        return '\n'.join(lines)
+        self.last_cycle_id = 0
+
+    def translate(self):
+        return (
+            'from jinja.datastructures import Undefined, LoopContext, CycleContext\n'
+            'def generate(context, write, write_var=None):\n'
+            '    environment = context.environment\n'
+            '    if write_var is None:\n'
+            '        write_var = lambda x: write(environment.finish_var(x))\n'
+        ) + self.handle_node(self.node)
 
 
 def translate(environment, node):