moved `IncludedTemplate` into the regular template API, fixed more unittests
authorArmin Ronacher <armin.ronacher@active-4.com>
Fri, 25 Apr 2008 09:44:59 +0000 (11:44 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Fri, 25 Apr 2008 09:44:59 +0000 (11:44 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/environment.py
jinja2/filters.py
jinja2/loaders.py
jinja2/runtime.py
jinja2/visitor.py
tests/loaderres/templates/brokenimport.html [deleted file]
tests/test_loaders.py
tests/test_macros.py

index 9bf1e4d61a69984a64b5b7cbfd83f7cf9e5fec7b..8282bc6c58a0e6c4dc89e4cc9413afb7a8f37461 100644 (file)
@@ -441,14 +441,19 @@ class CodeGenerator(NodeVisitor):
             func_frame.identifiers.declared
         )
 
-        func_frame.accesses_arguments = False
+        func_frame.accesses_kwargs = False
+        func_frame.accesses_varargs = False
         func_frame.accesses_caller = False
         func_frame.arguments = args = ['l_' + x.name for x in node.args]
 
-        if 'arguments' in func_frame.identifiers.undeclared:
-            func_frame.accesses_arguments = True
-            func_frame.identifiers.add_special('arguments')
-            args.append('l_arguments')
+        if 'kwargs' in func_frame.identifiers.undeclared:
+            func_frame.accesses_kwargs = True
+            func_frame.identifiers.add_special('kwargs')
+            args.append('l_kwargs')
+        if 'varargs' in func_frame.identifiers.undeclared:
+            func_frame.accesses_varargs = True
+            func_frame.identifiers.add_special('varargs')
+            args.append('l_varargs')
         if 'caller' in func_frame.identifiers.undeclared:
             func_frame.accesses_caller = True
             func_frame.identifiers.add_special('caller')
@@ -598,14 +603,14 @@ class CodeGenerator(NodeVisitor):
             self.writeline('l_%s = ' % node.target, node)
             if frame.toplevel:
                 self.write('context[%r] = ' % node.target)
-            self.write('IncludedTemplate(environment, context, ')
+            self.write('environment.get_template(')
             self.visit(node.template, frame)
-            self.write(')')
+            self.write(', %r).include(context)' % self.name)
             return
 
         self.writeline('included_template = environment.get_template(', node)
         self.visit(node.template, frame)
-        self.write(')')
+        self.write(', %r)' % self.name)
         if frame.toplevel:
             self.writeline('included_context = included_template.new_context('
                            'context.get_root())')
@@ -729,8 +734,9 @@ class CodeGenerator(NodeVisitor):
         for arg in node.defaults:
             self.visit(arg, macro_frame)
             self.write(', ')
-        self.write('), %s, %s)' % (
-            macro_frame.accesses_arguments and '1' or '0',
+        self.write('), %s, %s, %s)' % (
+            macro_frame.accesses_kwargs and '1' or '0',
+            macro_frame.accesses_varargs and '1' or '0',
             macro_frame.accesses_caller and '1' or '0'
         ))
 
@@ -753,7 +759,10 @@ class CodeGenerator(NodeVisitor):
         for arg in node.defaults:
             self.visit(arg)
             self.write(', ')
-        self.write('), %s, 0)' % (call_frame.accesses_arguments and '1' or '0'))
+        self.write('), %s, %s, 0)' % (
+            call_frame.accesses_kwargs and '1' or '0',
+            call_frame.accesses_varargs and '1' or '0'
+        ))
         if frame.buffer is None:
             self.writeline('yield ', node)
         else:
index fa19b1b7e54805abee6268685d168f5faf0eb29a..a982e8e8234a26a284b2c5e98b3564e2e978e455 100644 (file)
@@ -347,6 +347,16 @@ class Template(object):
         return TemplateContext(self.environment, dict(self.globals, **vars),
                                self.name, self.blocks)
 
+    def include(self, context=None):
+        """Include this template."""
+        if context is None:
+            context = self.new_context({})
+        elif isinstance(context, TemplateContext):
+            context = self.new_context(context.get_root())
+        else:
+            context = self.new_context(context)
+        return IncludedTemplate(self, context)
+
     def get_corresponding_lineno(self, lineno):
         """Return the source line number of a line number in the
         generated bytecode as they are not in sync.
@@ -376,6 +386,25 @@ class Template(object):
         )
 
 
+class IncludedTemplate(object):
+    """Represents an included template."""
+
+    def __init__(self, template, context):
+        self._template = template
+        self._name = template.name
+        self._rendered_body = u''.join(template.root_render_func(context))
+        self._context = context.get_exported()
+
+    __getitem__ = lambda x, n: x._context[n]
+    __html__ = __unicode__ = lambda x: x._rendered_body
+
+    def __repr__(self):
+        return '<%s %r>' % (
+            self.__class__.__name__,
+            self._name
+        )
+
+
 class TemplateStream(object):
     """This class wraps a generator returned from `Template.generate` so that
     it's possible to buffer multiple elements so that it's possible to return
index d48ac940dd15ce9a7f240ddf6b394a4b9a858c8e..68f9b5f254b4022e013403bcd7fd0ed9d5892c14 100644 (file)
@@ -378,10 +378,12 @@ def do_wordwrap(s, pos=79, hard=False):
                                 len(word.split('\n', 1)[0]) >= pos)],
                    word), s.split(' '))
 
+
 def do_wordcount(s):
     """Count the words in that string."""
     return len(s.split())
 
+
 def do_int(value, default=0):
     """Convert the value into an integer. If the
     conversion doesn't work it will return ``0``. You can
@@ -563,16 +565,30 @@ def do_groupby(environment, value, attribute):
         {% endfor %}
         </ul>
 
+    Additionally it's possible to use tuple unpacking for the grouper and
+    list:
+
+    .. sourcecode:: html+jinja
+
+        <ul>
+        {% for grouper, list in persons|groupby('gender') %}
+            ...
+        {% endfor %}
+        </ul>
+
     As you can see the item we're grouping by is stored in the `grouper`
     attribute and the `list` contains all the objects that have this grouper
     in common.
     """
     expr = lambda x: environment.subscribe(x, attribute)
-    return sorted([{
-        'grouper':  a,
-        'list':     b
-    } for a, b in groupby(sorted(value, key=expr), expr)],
-        key=itemgetter('grouper'))
+    return sorted(map(_GroupTuple, groupby(sorted(value, key=expr), expr)),
+                  key=itemgetter('grouper'))
+
+
+class _GroupTuple(tuple):
+    __slots__ = ()
+    grouper = property(itemgetter(0))
+    list = property(itemgetter(1))
 
 
 FILTERS = {
index 395816920b90ae1d70284d636961d4d023697126..2bcb30d0872797e3f1d6d1123a6a471ee3a3112f 100644 (file)
@@ -126,7 +126,8 @@ class PackageLoader(BaseLoader):
         self.package_path = package_path
 
     def get_source(self, environment, template):
-        path = '/'.join(split_template_path(template))
+        pieces = split_template_path(template)
+        path = '/'.join((self.package_path,) + tuple(pieces))
         if not self._pkg.resource_exists(self.package_name, path):
             raise TemplateNotFound(template)
         return self._pkg.resource_string(self.package_name, path), None, None
@@ -147,9 +148,9 @@ class DictLoader(BaseLoader):
 
 class FunctionLoader(BaseLoader):
     """A loader that is passed a function which does the loading.  The
-    function has to work like a `get_source` method but the return value for
-    not existing templates may be `None` instead of a `TemplateNotFound`
-    exception.
+    function becomes the name of the template passed and has to return either
+    an unicode string with the template source, a tuple in the form ``(source,
+    filename, uptodatefunc)`` or `None` if the template does not exist.
     """
 
     def __init__(self, load_func, cache_size=50, auto_reload=True):
@@ -157,9 +158,11 @@ class FunctionLoader(BaseLoader):
         self.load_func = load_func
 
     def get_source(self, environment, template):
-        rv = self.load_func(environment, template)
+        rv = self.load_func(template)
         if rv is None:
             raise TemplateNotFound(template)
+        elif isinstance(rv, basestring):
+            return rv, None, None
         return rv
 
 
index 0c0458c5f6d9d7f365d3724a90db87ac1a0f56fc..6b9abbd6b3805b15e2564e27bd021b8200a68dca 100644 (file)
@@ -14,7 +14,7 @@ from jinja2.exceptions import UndefinedError
 
 
 __all__ = ['LoopContext', 'StaticLoopContext', 'TemplateContext',
-           'Macro', 'IncludedTemplate', 'Markup']
+           'Macro', 'Markup']
 
 
 class TemplateContext(object):
@@ -97,7 +97,7 @@ class TemplateContext(object):
     def __repr__(self):
         return '<%s %s of %r>' % (
             self.__class__.__name__,
-            dict.__repr__(self),
+            repr(self.get_all()),
             self.name
         )
 
@@ -120,26 +120,6 @@ class SuperBlock(object):
         )
 
 
-class IncludedTemplate(object):
-    """Represents an included template."""
-
-    def __init__(self, environment, context, template):
-        template = environment.get_template(template)
-        context = template.new_context(context.get_root())
-        self._name = template.name
-        self._rendered_body = u''.join(template.root_render_func(context))
-        self._context = context.get_exported()
-
-    __getitem__ = lambda x, n: x._context[n]
-    __html__ = __unicode__ = lambda x: x._rendered_body
-
-    def __repr__(self):
-        return '<%s %r>' % (
-            self.__class__.__name__,
-            self._name
-        )
-
-
 class LoopContextBase(object):
     """Helper for extended iteration."""
 
@@ -228,19 +208,21 @@ class StaticLoopContext(LoopContextBase):
 class Macro(object):
     """Wraps a macro."""
 
-    def __init__(self, environment, func, name, arguments, defaults, catch_all, caller):
+    def __init__(self, environment, func, name, arguments, defaults,
+                 catch_kwargs, catch_varargs, caller):
         self._environment = environment
         self._func = func
         self.name = name
         self.arguments = arguments
         self.defaults = defaults
-        self.catch_all = catch_all
+        self.catch_kwargs = catch_kwargs
+        self.catch_varargs = catch_varargs
         self.caller = caller
 
     def __call__(self, *args, **kwargs):
         arg_count = len(self.arguments)
-        if len(args) > arg_count:
-            raise TypeError('macro %r takes not more than %d argument(s).' %
+        if not self.catch_varargs and len(args) > arg_count:
+            raise TypeError('macro %r takes not more than %d argument(s)' %
                             (self.name, len(self.arguments)))
         arguments = {}
         for idx, name in enumerate(self.arguments):
@@ -261,8 +243,13 @@ class Macro(object):
             if caller is None:
                 caller = self._environment.undefined('No caller defined')
             arguments['l_caller'] = caller
-        if self.catch_all:
-            arguments['l_arguments'] = kwargs
+        if self.catch_kwargs:
+            arguments['l_kwargs'] = kwargs
+        elif kwargs:
+            raise TypeError('macro %r takes no keyword argument %r' %
+                            (self.name, iter(kwargs).next()))
+        if self.catch_varargs:
+            arguments['l_varargs'] = args[arg_count:]
         return self._func(**arguments)
 
     def __repr__(self):
index 895aa758446a81693c3b87a6fcc550d487b47588..8c94803f7a5bbb457baf845fa3c0e63f526f3485 100644 (file)
@@ -24,8 +24,7 @@ class NodeVisitor(object):
     """
 
     def get_visitor(self, node):
-        """
-        Return the visitor function for this node or `None` if no visitor
+        """Return the visitor function for this node or `None` if no visitor
         exists for this node.  In that case the generic visit function is
         used instead.
         """
diff --git a/tests/loaderres/templates/brokenimport.html b/tests/loaderres/templates/brokenimport.html
deleted file mode 100644 (file)
index e3c106e..0000000
+++ /dev/null
@@ -1 +0,0 @@
-{% extends "missing.html" %}
index 05698393b03b8146f237c9acc3a6a92a086a4e9f..fb5ca733f92f61fbff4cac274210c9e9fe55ffa0 100644 (file)
@@ -7,6 +7,7 @@
     :license: BSD, see LICENSE for more details.
 """
 
+from py.test import raises
 import time
 import tempfile
 from jinja2 import Environment, loaders
@@ -21,17 +22,11 @@ package_loader = loaders.PackageLoader('loaderres', 'templates')
 
 filesystem_loader = loaders.FileSystemLoader('loaderres/templates')
 
-memcached_loader = loaders.MemcachedFileSystemLoader('loaderres/templates')
-
 function_loader = loaders.FunctionLoader({'justfunction.html': 'FOO'}.get)
 
 choice_loader = loaders.ChoiceLoader([dict_loader, package_loader])
 
 
-class FakeLoader(loaders.BaseLoader):
-    local_attr = 42
-
-
 def test_dict_loader():
     env = Environment(loader=dict_loader)
     tmpl = env.get_template('justdict.html')
@@ -46,47 +41,18 @@ def test_dict_loader():
 
 def test_package_loader():
     env = Environment(loader=package_loader)
-    for x in xrange(2):
-        tmpl = env.get_template('test.html')
-        assert tmpl.render().strip() == 'BAR'
-        try:
-            env.get_template('missing.html')
-        except TemplateNotFound:
-            pass
-        else:
-            raise AssertionError('expected template exception')
-
-        # second run in native mode (no pkg_resources)
-        package_loader.force_native = True
-        del package_loader._load_func
-
-
-def test_filesystem_loader():
-    env = Environment(loader=filesystem_loader)
     tmpl = env.get_template('test.html')
     assert tmpl.render().strip() == 'BAR'
-    tmpl = env.get_template('foo/test.html')
-    assert tmpl.render().strip() == 'FOO'
-    try:
-        env.get_template('missing.html')
-    except TemplateNotFound:
-        pass
-    else:
-        raise AssertionError('expected template exception')
+    raises(TemplateNotFound, lambda: env.get_template('missing.html'))
 
 
-def test_memcached_loader():
-    env = Environment(loader=memcached_loader)
+def test_filesystem_loader():
+    env = Environment(loader=filesystem_loader)
     tmpl = env.get_template('test.html')
     assert tmpl.render().strip() == 'BAR'
     tmpl = env.get_template('foo/test.html')
     assert tmpl.render().strip() == 'FOO'
-    try:
-        env.get_template('missing.html')
-    except TemplateNotFound:
-        pass
-    else:
-        raise AssertionError('expected template exception')
+    raises(TemplateNotFound, lambda: env.get_template('missing.html'))
 
 
 def test_choice_loader():
@@ -102,15 +68,6 @@ def test_choice_loader():
     else:
         raise AssertionError('expected template exception')
 
-    # this should raise an TemplateNotFound error with the
-    # correct name
-    try:
-        env.get_template('brokenimport.html')
-    except TemplateNotFound, e:
-        assert e.name == 'missing.html'
-    else:
-        raise AssertionError('expected exception')
-
 
 def test_function_loader():
     env = Environment(loader=function_loader)
@@ -122,57 +79,3 @@ def test_function_loader():
         pass
     else:
         raise AssertionError('expected template exception')
-
-
-def test_loader_redirect():
-    env = Environment(loader=FakeLoader())
-    assert env.loader.local_attr == 42
-    assert env.loader.get_source
-    assert env.loader.load
-
-
-class MemcacheTestingLoader(loaders.CachedLoaderMixin, loaders.BaseLoader):
-
-    def __init__(self, enable):
-        loaders.CachedLoaderMixin.__init__(self, enable, 40, None, True, 'foo')
-        self.times = {}
-        self.idx = 0
-
-    def touch(self, name):
-        self.times[name] = time.time()
-
-    def get_source(self, environment, name, parent):
-        self.touch(name)
-        self.idx += 1
-        return 'Template %s (%d)' % (name, self.idx)
-
-    def check_source_changed(self, environment, name):
-        if name in self.times:
-            return self.times[name]
-        return -1
-
-
-memcache_env = Environment(loader=MemcacheTestingLoader(True))
-no_memcache_env = Environment(loader=MemcacheTestingLoader(False))
-
-
-test_memcaching = r'''
->>> not_caching = MODULE.no_memcache_env.loader
->>> caching = MODULE.memcache_env.loader
->>> touch = caching.touch
-
->>> tmpl1 = not_caching.load('test.html')
->>> tmpl2 = not_caching.load('test.html')
->>> tmpl1 == tmpl2
-False
-
->>> tmpl1 = caching.load('test.html')
->>> tmpl2 = caching.load('test.html')
->>> tmpl1 == tmpl2
-True
-
->>> touch('test.html')
->>> tmpl2 = caching.load('test.html')
->>> tmpl1 == tmpl2
-False
-'''
index f2277888e46605af111f928e88f8baf6b429a73d..aa3546c4533d873f7b87f3c98ba01dd61bac59e4 100644 (file)
@@ -16,7 +16,7 @@ SCOPING = '''\
 {% macro level1(data1) %}
 {% macro level2(data2) %}{{ data1 }}|{{ data2 }}{% endmacro %}
 {{ level2('bar') }}{% endmacro %}
-{{ level1('foo') }}|{{ level2('bar') }}\
+{{ level1('foo') }}\
 '''
 
 ARGUMENTS = '''\
@@ -55,7 +55,7 @@ def test_simple(env):
 
 def test_scoping(env):
     tmpl = env.from_string(SCOPING)
-    assert tmpl.render() == 'foo|bar|'
+    assert tmpl.render() == 'foo|bar'
 
 
 def test_arguments(env):