Turn YieldNodeCollector into TreeVisitor
authorVitja Makarov <vitja.makarov@gmail.com>
Sun, 12 Dec 2010 13:55:45 +0000 (16:55 +0300)
committerVitja Makarov <vitja.makarov@gmail.com>
Sun, 12 Dec 2010 13:55:45 +0000 (16:55 +0300)
Cython/Compiler/ParseTreeTransforms.py

index f43e8aff718d04e79d33799af7eb0d442e42fe15..d21858852a57249c5c81649203fd0fff5e77b182 100644 (file)
@@ -1372,40 +1372,53 @@ class ClosureTempAllocator(object):
             if entry.type.is_pyobject:
                 code.put_xgiveref('%s->%s' % (Naming.cur_scope_cname, entry.cname))
 
-class YieldCollector(object):
-    def __init__(self, node):
-        self.node = node
+class YieldNodeCollector(TreeVisitor):
+
+    def __init__(self):
+        super(YieldNodeCollector, self).__init__()
         self.yields = []
-        self.returns = []
+        self.has_return = False
+
+    visit_Node = TreeVisitor.visitchildren
+
+    def visit_YieldExprNode(self, node):
+        if self.has_return:
+            error(node.pos, "'yield' outside function")
+        else:
+            self.yields.append(node)
+            node.label_num = len(self.yields)
+
+    def visit_ReturnStatNode(self, node):
+        if self.yields:
+            error(collector.returns[0].pos, "'return' with argument inside generator")
+        else:
+            self.has_return = True
+
+    def visit_ClassDefNode(self, node):
+        pass
+
+    def visit_DefNode(self, node):
+        pass
 
 class MarkGeneratorVisitor(CythonTransform):
     """XXX: merge me with MarkClosureVisitor"""
     def __init__(self, context):
         super(MarkGeneratorVisitor, self).__init__(context)
-        self.allow_yield = False
-        self.path = []
 
     def visit_ModuleNode(self, node):
         self.visitchildren(node)
         return node
 
     def visit_ClassDefNode(self, node):
-        saved = self.allow_yield
-        self.allow_yield = False
         self.visitchildren(node)
-        self.allow_yield = saved
         return node
 
     def visit_FuncDefNode(self, node):
-        saved = self.allow_yield
-        self.allow_yield = True
-        self.path.append(YieldCollector(node))
+        collector = YieldNodeCollector()
+        collector.visitchildren(node)
         self.visitchildren(node)
-        self.allow_yield = saved
-        collector = self.path.pop()
-        if collector.yields and collector.returns:
-            error(collector.returns[0].pos, "'return' with argument inside generator")
-        elif collector.yields:
+
+        if collector.yields:
             allocator = ClosureTempAllocator()
             # XXX: move allocator inside local scope
             for y in collector.yields:
@@ -1416,19 +1429,6 @@ class MarkGeneratorVisitor(CythonTransform):
             node.yields = collector.yields
         return node
 
-    def visit_YieldExprNode(self, node):
-        if not self.allow_yield:
-            error(node.pos, "'yield' outside function")
-            return node
-        collector = self.path[-1]
-        collector.yields.append(node)
-        node.label_num = len(collector.yields)
-        return node
-
-    def visit_ReturnStatNode(self, node):
-        if self.path:
-            self.path[-1].returns.append(node)
-        return node
 
 class CreateClosureClasses(CythonTransform):
     # Output closure classes in module scope for all functions