work on the macro stuff
authorArmin Ronacher <armin.ronacher@active-4.com>
Tue, 8 Apr 2008 16:09:13 +0000 (18:09 +0200)
committerArmin Ronacher <armin.ronacher@active-4.com>
Tue, 8 Apr 2008 16:09:13 +0000 (18:09 +0200)
--HG--
branch : trunk

jinja2/compiler.py
jinja2/runtime.py
test.py

index 41577c8c6bbcbe14bf82488ccb91212513c4360b..5c64699d3b8add67806da9f5b2449aadfd1699cb 100644 (file)
@@ -61,10 +61,13 @@ class Identifiers(object):
         self.undeclared.discard(name)
         self.declared.add(name)
 
-    def is_declared(self, name):
+    def is_declared(self, name, local_only=False):
         """Check if a name is declared in this or an outer scope."""
-        return name in self.declared or name in self.declared_locally or \
-               name in self.declared_parameter
+        if name in self.declared_locally or name in self.declared_parameter:
+            return True
+        if local_only:
+            return False
+        return name in self.declared
 
     def find_shadowed(self):
         """Find all the shadowed names."""
@@ -81,7 +84,6 @@ class Frame(object):
         if parent is not None:
             self.identifiers.declared.update(
                 parent.identifiers.declared |
-                parent.identifiers.undeclared |
                 parent.identifiers.declared_locally |
                 parent.identifiers.declared_parameter
             )
@@ -92,9 +94,12 @@ class Frame(object):
         rv.identifiers = copy(self)
         return rv
 
-    def inspect(self, nodes):
-        """Walk the node and check for identifiers."""
-        visitor = FrameIdentifierVisitor(self.identifiers)
+    def inspect(self, nodes, hard_scope=False):
+        """Walk the node and check for identifiers.  If the scope
+        is hard (eg: enforce on a python level) overrides from outer
+        scopes are tracked differently.
+        """
+        visitor = FrameIdentifierVisitor(self.identifiers, hard_scope)
         for node in nodes:
             visitor.visit(node)
 
@@ -106,15 +111,16 @@ class Frame(object):
 class FrameIdentifierVisitor(NodeVisitor):
     """A visitor for `Frame.inspect`."""
 
-    def __init__(self, identifiers):
+    def __init__(self, identifiers, hard_scope):
         self.identifiers = identifiers
+        self.hard_scope = hard_scope
 
     def visit_Name(self, node):
         """All assignments to names go through this function."""
         if node.ctx in ('store', 'param'):
             self.identifiers.declared_locally.add(node.name)
         elif node.ctx == 'load':
-            if not self.identifiers.is_declared(node.name):
+            if not self.identifiers.is_declared(node.name, self.hard_scope):
                 self.identifiers.undeclared.add(node.name)
 
     def visit_Macro(self, node):
@@ -326,7 +332,32 @@ class CodeGenerator(NodeVisitor):
 
     def visit_Macro(self, node, frame):
         macro_frame = frame.inner()
-        macro_frame.inspect(node.body)
+        macro_frame.inspect(node.iter_child_nodes(), hard_scope=True)
+
+        # variables that are undeclared (accessed before declaration) and
+        # declared locally *and* part of an outside scope raise a template
+        # assertion error. Reason: we can't generate reasonable code from
+        # it without aliasing all the variables.  XXX: alias them ^^
+        overriden_closure_vars = (
+            macro_frame.identifiers.undeclared &
+            macro_frame.identifiers.declared &
+            (macro_frame.identifiers.declared_locally |
+             macro_frame.identifiers.declared_parameter)
+        )
+        if overriden_closure_vars:
+            vars = ', '.join(sorted(overriden_closure_vars))
+            raise TemplateAssertionError('It\'s not possible to set and '
+                                         'access variables derived from '
+                                         'an outer scope! (affects: %s' %
+                                         vars, node.lineno, self.filename)
+
+        # remove variables from a closure from the frame's undeclared
+        # identifiers.
+        macro_frame.identifiers.undeclared -= (
+            macro_frame.identifiers.undeclared &
+            macro_frame.identifiers.declared
+        )
+
         args = ['l_' + x.name for x in node.args]
         if 'arguments' in macro_frame.identifiers.undeclared:
             accesses_arguments = True
@@ -337,17 +368,20 @@ class CodeGenerator(NodeVisitor):
         self.indent()
         self.writeline('if 0: yield None')
         self.outdent()
-        self.blockvisit(node.body, frame)
+        self.pull_locals(macro_frame)
+        self.blockvisit(node.body, macro_frame)
         self.newline()
         if frame.toplevel:
             self.write('context[%r] = ' % node.name)
         arg_tuple = ', '.join(repr(x.name) for x in node.args)
         if len(node.args) == 1:
             arg_tuple += ','
-        self.write('l_%s = Macro(macro, %r, (%s), %s)' % (
-            node.name, node.name,
-            arg_tuple, accesses_arguments
-        ))
+        self.write('l_%s = Macro(macro, %r, (%s), (' % (node.name, node.name,
+                                                       arg_tuple))
+        for arg in node.defaults:
+            self.visit(arg)
+            self.write(', ')
+        self.write('), %r)' % accesses_arguments)
 
     def visit_ExprStmt(self, node, frame):
         self.newline(node)
index 8a1bdbacc14d491ae5115fa879e89404f1429962..7af5c4abbe31b989b3d66fcbe94e11bf8dda4033 100644 (file)
@@ -56,3 +56,21 @@ class TemplateContext(dict):
                 return self.globals[key]
             except:
                 return self.undefined_factory(key)
+
+
+class Macro(object):
+
+    def __init__(self, func, name, arguments, defaults, catch_all):
+        self.func = func
+        self.name = name
+        self.arguments = arguments
+        self.defaults = defaults
+        self.catch_all = catch_all
+
+    def __call__(self, *args, **kwargs):
+        if len(args) > len(self.arguments):
+            raise TypeError('macro %r takes not more than %d argument(s).' %
+                            (self.name, len(self.arguments)))
+        arguments = {}
+        # XXX: assemble arguments
+        return u''.join(self.func(*args, **kwargs))
diff --git a/test.py b/test.py
index 7b5d78ea73e1fe354cc7f288fab2c6cc1481c2a7..3aa44e120d28f310f347318f6efdecdd09b88e6b 100644 (file)
--- a/test.py
+++ b/test.py
@@ -4,12 +4,12 @@ from jinja2.compiler import generate
 
 env = Environment()
 ast = env.parse("""
-{% (a, b), c = foo() %}
-{% macro foo(a, b, c=42) %}
-  42 {{ arguments }}
-{% endmacro %}
 {% block body %}
-    {% bar = 23 %}
+    {% b = 23 %}
+    {% macro foo(a) %}[{{ a }}|{{ b }}|{{ c }}]{% endmacro %}
+    {% for item in seq %}
+      {{ foo(item) }}
+    {% endfor %}
 {% endblock %}
 """)
 print ast