revamped jinja2 import system. the behavior is less confusing now, but it's not...
authorArmin Ronacher <armin.ronacher@active-4.com>
Fri, 25 Apr 2008 21:44:14 +0000 (23:44 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Fri, 25 Apr 2008 21:44:14 +0000 (23:44 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/filters.py
jinja2/lexer.py
jinja2/nodes.py
jinja2/optimizer.py
jinja2/parser.py
jinja2/runtime.py
tests/test_syntax.py

index 542beed1c1a713ecb08b501b6529b45299196421..958b2c3e6e2243bb826fc647ef222e376e6a1eb2 100644 (file)
@@ -210,14 +210,15 @@ class FrameIdentifierVisitor(NodeVisitor):
         self.identifiers.tests.add(node.name)
 
     def visit_Macro(self, node):
-        """Macros set local."""
         self.identifiers.declared_locally.add(node.name)
 
-    def visit_Include(self, node):
-        """Some includes set local."""
+    def visit_Import(self, node):
         self.generic_visit(node)
-        if node.target is not None:
-            self.identifiers.declared_locally.add(node.target)
+        self.identifiers.declared_locally.add(node.target)
+
+    def visit_FromImport(self, node):
+        self.generic_visit(node)
+        self.identifiers.declared_locally.update(node.names)
 
     def visit_Assign(self, node):
         """Visit assignments in the correct order."""
@@ -232,7 +233,8 @@ class FrameIdentifierVisitor(NodeVisitor):
 class CompilerExit(Exception):
     """Raised if the compiler encountered a situation where it just
     doesn't make sense to further process the code.  Any block that
-    raises such an exception is not further processed."""
+    raises such an exception is not further processed.
+    """
 
 
 class CodeGenerator(NodeVisitor):
@@ -597,28 +599,11 @@ class CodeGenerator(NodeVisitor):
 
     def visit_Include(self, node, frame):
         """Handles includes."""
-        # simpled include is include into a variable.  This kind of
-        # include works the same on every level, so we handle it first.
-        if node.target is not None:
-            self.writeline('l_%s = ' % node.target, node)
-            if frame.toplevel:
-                self.write('context[%r] = ' % node.target)
-            self.write('environment.get_template(')
-            self.visit(node.template, frame)
-            self.write(', %r).include(context)' % self.name)
-            return
-
         self.writeline('included_template = environment.get_template(', node)
         self.visit(node.template, frame)
         self.write(', %r)' % self.name)
-        if frame.toplevel:
-            self.writeline('included_context = included_template.new_context('
-                           'context.get_root())')
-            self.writeline('for event in included_template.root_render_func('
-                           'included_context):')
-        else:
-            self.writeline('for event in included_template.root_render_func('
-                           'included_template.new_context(context.get_root())):')
+        self.writeline('for event in included_template.root_render_func('
+                       'included_template.new_context(context.get_root())):')
         self.indent()
         if frame.buffer is None:
             self.writeline('yield event')
@@ -626,11 +611,33 @@ class CodeGenerator(NodeVisitor):
             self.writeline('%s.append(event)' % frame.buffer)
         self.outdent()
 
-        # if we have a toplevel include the exported variables are copied
-        # into the current context without exporting them.  context.udpate
-        # does *not* mark the variables as exported
+    def visit_Import(self, node, frame):
+        """Visit regular imports."""
+        self.writeline('l_%s = ' % node.target, node)
         if frame.toplevel:
-            self.writeline('context.update(included_context.get_exported())')
+            self.write('context[%r] = ' % node.target)
+        self.write('environment.get_template(')
+        self.visit(node.template, frame)
+        self.write(', %r).include(context)' % self.name)
+
+    def visit_FromImport(self, node, frame):
+        """Visit named imports."""
+        self.newline(node)
+        self.write('included_template = environment.get_template(')
+        self.visit(node.template, frame)
+        self.write(', %r).include(context)' % self.name)
+        for name in node.names:
+            self.writeline('l_%s = getattr(included_template, '
+                           '%r, missing)' % (name, name))
+            self.writeline('if l_%s is missing:' % name)
+            self.indent()
+            self.writeline('l_%s = environment.undefined(%r %% '
+                           'included_template.name)' %
+                           (name, 'the template %r does not export '
+                            'the requested name ' + repr(name)))
+            self.outdent()
+            if frame.toplevel:
+                self.writeline('context[%r] = l_%s' % (name, name))
 
     def visit_For(self, node, frame):
         loop_frame = frame.inner()
@@ -1022,6 +1029,9 @@ class CodeGenerator(NodeVisitor):
     def visit_Filter(self, node, frame, initial=None):
         self.write('f_%s(' % node.name)
         func = self.environment.filters.get(node.name)
+        if func is None:
+            raise TemplateAssertionError('no filter named %r' % node.name,
+                                         node.lineno, self.filename)
         if getattr(func, 'contextfilter', False):
             self.write('context, ')
         elif getattr(func, 'environmentfilter', False):
@@ -1037,9 +1047,9 @@ class CodeGenerator(NodeVisitor):
 
     def visit_Test(self, node, frame):
         self.write('t_%s(' % node.name)
-        func = self.environment.tests.get(node.name)
-        if getattr(func, 'contexttest', False):
-            self.write('context, ')
+        if node.name not in self.environment.tests:
+            raise TemplateAssertionError('no test named %r' % node.name,
+                                         node.lineno, self.filename)
         self.visit(node.node, frame)
         self.signature(node, frame)
         self.write(')')
index 68f9b5f254b4022e013403bcd7fd0ed9d5892c14..801b3502658b675597dc3093ca929a0274f0ca88 100644 (file)
@@ -633,5 +633,6 @@ FILTERS = {
     'round':                do_round,
     'sort':                 do_sort,
     'groupby':              do_groupby,
-    'safe':                 Markup
+    'safe':                 Markup,
+    'xmlattr':              do_xmlattr
 }
index 5217c7d876de30bd1cf173f1b699b002cd991dd4..772dee27b726c072faf787da2d81c104ff37693b 100644 (file)
@@ -37,7 +37,8 @@ float_re = re.compile(r'\d+\.\d+')
 keywords = set(['and', 'block', 'elif', 'else', 'endblock', 'print',
                 'endfilter', 'endfor', 'endif', 'endmacro', 'endraw',
                 'extends', 'filter', 'for', 'if', 'in', 'include',
-                'is', 'macro', 'not', 'or', 'raw', 'call', 'endcall'])
+                'is', 'macro', 'not', 'or', 'raw', 'call', 'endcall',
+                'from', 'import'])
 
 # bind operators to token types
 operators = {
index 5f3aabb0235afe8f9196ec77f0ac321a82c2165c..3aed350ee061c2cb8f1f166b46033cd97fb860d6 100644 (file)
@@ -263,9 +263,26 @@ class Block(Stmt):
 
 class Include(Stmt):
     """A node that represents the include tag."""
+    fields = ('template',)
+
+
+class Import(Stmt):
+    """A node that represents the import tag."""
     fields = ('template', 'target')
 
 
+class FromImport(Stmt):
+    """A node that represents the from import tag.  It's important to not
+    pass unsafe names to the name attribute.  The compiler translates the
+    attribute lookups directly into getattr calls and does *not* use the
+    subscribe callback of the interface.  As exported variables may not
+    start with double underscores (which the parser asserts) this is not a
+    problem for regular Jinja code, but if this node is used in an extension
+    extra care must be taken.
+    """
+    fields = ('template', 'names')
+
+
 class Trans(Stmt):
     """A node for translatable sections."""
     fields = ('singular', 'plural', 'indicator', 'replacements')
index 4d7b9f5387f0cf632dd2f9c140ef76944b94b1d5..c432b3b1f895ff9bec69d1261a00458682375e48 100644 (file)
@@ -173,6 +173,17 @@ class Optimizer(NodeTransformer):
             return node
         return result
 
+    def visit_Import(self, node, context):
+        rv = self.generic_visit(node, context)
+        context.undef(node.target)
+        return rv
+
+    def visit_FromImport(self, node, context):
+        rv = self.generic_visit(node, context)
+        for name in node.names:
+            context.undef(name)
+        return rv
+
     def fold(self, node, context):
         """Do constant folding."""
         node = self.generic_visit(node, context)
index 8db62deaf0ea1d0479138d1c2e63d810ff387bfb..daa7a0d3ec2436a2bd586594156150ab789d56a4 100644 (file)
@@ -14,7 +14,7 @@ from jinja2.exceptions import TemplateSyntaxError
 
 
 _statement_keywords = frozenset(['for', 'if', 'block', 'extends', 'print',
-                                 'macro', 'include'])
+                                 'macro', 'include', 'from', 'import'])
 _compare_operators = frozenset(['eq', 'ne', 'lt', 'lteq', 'gt', 'gteq', 'in'])
 statement_end_tokens = set(['variable_end', 'block_end', 'in'])
 _tuple_edge_tokens = set(['rparen']) | statement_end_tokens
@@ -145,22 +145,47 @@ class Parser(object):
 
     def parse_include(self):
         node = nodes.Include(lineno=self.stream.expect('include').lineno)
-        expr = self.parse_expression()
-        if self.stream.current.type is 'assign':
+        node.template = self.parse_expression()
+        return node
+
+    def parse_import(self):
+        node = nodes.Import(lineno=self.stream.expect('import').lineno)
+        node.template = self.parse_expression()
+        self.stream.expect('name:as')
+        node.target = self.stream.expect('name').value
+        if not nodes.Name(node.target, 'store').can_assign():
+            raise TemplateSyntaxError('can\'t assign imported template '
+                                      'to %r' % node.target, node.lineno,
+                                      self.filename)
+        return node
+
+    def parse_from(self):
+        node = nodes.FromImport(lineno=self.stream.expect('from').lineno)
+        node.template = self.parse_expression()
+        self.stream.expect('import')
+        node.names = []
+        while 1:
+            if node.names:
+                self.stream.expect('comma')
+            if self.stream.current.type is 'name':
+                target = nodes.Name(self.stream.current.value, 'store')
+                if not target.can_assign():
+                    raise TemplateSyntaxError('can\'t import object named %r'
+                                              % target.name, target.lineno,
+                                              self.filename)
+                elif target.name.startswith('__'):
+                    raise TemplateAssertionError('names starting with two '
+                                                 'underscores can not be '
+                                                 'imported', target.lineno,
+                                                 self.filename)
+                node.names.append(target.name)
+                self.stream.next()
+                if self.stream.current.type is not 'comma':
+                    break
+            else:
+                break
+        if self.stream.current.type is 'comma':
             self.stream.next()
-            if not isinstance(expr, nodes.Name):
-                raise TemplateSyntaxError('must assign imported template to '
-                                          'variable or current scope',
-                                          expr.lineno, self.filename)
-            if not expr.can_assign():
-                raise TemplateSyntaxError('can\'t assign imported template '
-                                          'to %r' % expr, expr.lineno,
-                                          self.filename)
-            node.target = expr.name
-            node.template = self.parse_expression()
-        else:
-            node.target = None
-            node.template = expr
         return node
 
     def parse_signature(self, node):
@@ -568,8 +593,9 @@ class Parser(object):
                     self.stream.look().type is 'assign':
                     key = self.stream.current.value
                     self.stream.skip(2)
-                    kwargs.append(nodes.Keyword(key, self.parse_expression(),
-                                                lineno=key.lineno))
+                    value = self.parse_expression()
+                    kwargs.append(nodes.Keyword(key, value,
+                                                lineno=value.lineno))
                 else:
                     ensure(not kwargs)
                     args.append(self.parse_expression())
index 6b9abbd6b3805b15e2564e27bd021b8200a68dca..7860dcc0d83a396ede390f55d751a7488e05431e 100644 (file)
@@ -14,7 +14,11 @@ from jinja2.exceptions import UndefinedError
 
 
 __all__ = ['LoopContext', 'StaticLoopContext', 'TemplateContext',
-           'Macro', 'Markup']
+           'Macro', 'Markup', 'missing']
+
+
+# special singleton representing missing values for the runtime
+missing = object()
 
 
 class TemplateContext(object):
@@ -69,7 +73,8 @@ class TemplateContext(object):
 
     def get_exported(self):
         """Get a new dict with the exported variables."""
-        return dict((k, self.vars[k]) for k in self.exported_vars)
+        return dict((k, self.vars[k]) for k in self.exported_vars
+                    if not k.startswith('__'))
 
     def get_root(self):
         """Return a new dict with all the non local variables."""
index dc91990e1c6c26baf78567eeb751f08f0dcabb16..af2b0f59282d9ab3f7193a936c0b05f94716789d 100644 (file)
@@ -14,7 +14,6 @@ CALL = '''{{ foo('a', c='d', e='f', *['b'], **{'g': 'h'}) }}'''
 SLICING = '''{{ [1, 2, 3][:] }}|{{ [1, 2, 3][::-1] }}'''
 ATTR = '''{{ foo.bar }}|{{ foo['bar'] }}'''
 SUBSCRIPT = '''{{ foo[0] }}|{{ foo[-1] }}'''
-KEYATTR = '''{{ {'items': 'foo'}.items }}|{{ {}.items() }}'''
 TUPLE = '''{{ () }}|{{ (1,) }}|{{ (1, 2) }}'''
 MATH = '''{{ (1 + 1 * 2) - 3 / 2 }}|{{ 2**3 }}'''
 DIV = '''{{ 3 // 2 }}|{{ 3 / 2 }}|{{ 3 % 2 }}'''
@@ -64,11 +63,6 @@ def test_subscript(env):
     assert tmpl.render(foo=[0, 1, 2]) == '0|2'
 
 
-def test_keyattr(env):
-    tmpl = env.from_string(KEYATTR)
-    assert tmpl.render() == 'foo|[]'
-
-
 def test_tuple(env):
     tmpl = env.from_string(TUPLE)
     assert tmpl.render() == '()|(1,)|(1, 2)'