Make Refnanny happy, fix some errors. More testcases.
authorVitja Makarov <vitja.makarov@gmail.com>
Thu, 9 Dec 2010 20:39:40 +0000 (23:39 +0300)
committerVitja Makarov <vitja.makarov@gmail.com>
Thu, 9 Dec 2010 20:39:40 +0000 (23:39 +0300)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
tests/run/generators.pyx

index 8fedd8dbf575707bb012343663604c5bff9e1e55..582e34d22802c485c64f8f9f5b47a01209fc4fd2 100755 (executable)
@@ -4990,8 +4990,6 @@ class YieldExprNode(ExprNode):
             save_cname = self.temp_allocator.allocate_temp(type)
             saved.append((cname, save_cname, type))
             code.putln('%s->%s = %s;' % (Naming.cur_scope_cname, save_cname, cname))
-            if type.is_pyobject:
-                code.put_giveref(cname)
         self.label_name = code.new_label('resume_from_yield')
         code.use_label(self.label_name)
         self.allocate_temp_result(code)
@@ -5008,6 +5006,9 @@ class YieldExprNode(ExprNode):
         else:
             code.put_init_to_py_none(Naming.retval_cname, py_object_type)
 
+        # XXX: safe here as all used temps are handled but not clean
+        self.temp_allocator.put_giveref(code)
+        code.put_xgiveref(Naming.retval_cname)
         code.put_finish_refcount_context()
         code.putln("/* return from function, yielding value */")
         code.putln("%s->%s.resume_label = %d;" % (Naming.cur_scope_cname, Naming.obj_base_cname, self.label_num))
@@ -5018,18 +5019,21 @@ class YieldExprNode(ExprNode):
             code.putln('%s = %s->%s;' % (cname, Naming.cur_scope_cname, save_cname))
             if type.is_pyobject:
                 code.putln('%s->%s = 0;' % (Naming.cur_scope_cname, save_cname))
-                code.put_gotref(cname)
         code.putln('%s = __pyx_send_value;' % self.result())
         code.put_incref(self.result(), py_object_type)
 
-class StopIterationNode(YieldExprNode):
-    subexprs = []
+class StopIterationNode(Node):
+    # XXX: is it okay?
+    child_attrs = []
 
-    def generate_evaluation_code(self, code):
-        self.allocate_temp_result(code)
-        self.label_name = code.new_label('resume_from_yield')
-        code.use_label(self.label_name)
-        code.put_label(self.label_name)
+    def analyse_expressions(self, env):
+        pass
+
+    def generate_function_definitions(self, env, code):
+        pass
+
+    def generate_execution_code(self, code):
+        code.putln('/* Stop iteration */')
         code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
 
 #-------------------------------------------------------------------
@@ -8324,6 +8328,11 @@ static CYTHON_INLINE PyObject *__CyGenerator_SendEx(struct __CyGenerator *self,
         }
     }
 
+    if (self->resume_label == -1) {
+        PyErr_SetNone(PyExc_StopIteration);
+        return NULL;
+    }
+
     self->is_running = 1;
     retval = self->body((PyObject *) self, value, is_exc);
     self->is_running = 0;
index 0bc60d58d0c7d538dd6deb91056c6d8c6e7714ae..2abe6126604d05f5a220cb4691af75e4fbb2be3b 100644 (file)
@@ -1321,6 +1321,8 @@ class FuncDefNode(StatNode, BlockNode):
                 Naming.cur_scope_cname,
                 lenv.scope_class.type.declaration_code(''),
                 Naming.self_cname))
+            gotref_code = code.insertion_point()
+
         elif self.needs_closure:
             code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
                 Naming.cur_scope_cname,
@@ -1480,7 +1482,9 @@ class FuncDefNode(StatNode, BlockNode):
                     code.put_var_decref(entry)
         if self.needs_closure and not self.is_generator:
             code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
-                
+        if self.is_generator:
+            code.putln('%s->%s.resume_label = -1;' % (Naming.cur_scope_cname, Naming.obj_base_cname))
+
         # ----- Return
         # This code is duplicated in ModuleNode.generate_module_init_func
         if not lenv.nogil:
@@ -1535,6 +1539,8 @@ class FuncDefNode(StatNode, BlockNode):
         self.generate_wrapper_functions(code)
 
         if self.is_generator:
+            gotref_code.putln('/* Make refnanny happy */')
+            self.temp_allocator.put_gotref(gotref_code)
             self.generator.generate_function_body(self.local_scope, code)
 
     def declare_argument(self, env, arg):
@@ -1920,18 +1926,22 @@ class GeneratorWrapperNode(object):
         code.put_gotref(Naming.cur_scope_cname)
 
         if self.def_node.needs_outer_scope:
-            code.putln("%s->%s = (%s)%s;" % (
-                            Naming.cur_scope_cname,
-                            Naming.outer_scope_cname,
+            outer_scope_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.outer_scope_cname)
+            code.putln("%s = (%s)%s;" % (
+                            outer_scope_cname,
                             cenv.scope_class.type.declaration_code(''),
                             Naming.self_cname))
+            code.put_incref(outer_scope_cname, cenv.scope_class.type)
 
         self.def_node.generate_argument_parsing_code(env, code)
 
         generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname)
 
         code.putln('%s.resume_label = 0;' % generator_cname)
-        code.putln('%s.body = (void *) %s;' % (generator_cname, self.body_cname))
+        code.putln('%s.body = %s;' % (generator_cname, self.body_cname))
+        for entry in lenv.scope_class.type.scope.entries.values():
+            if entry.type.is_pyobject:
+                code.put_xgiveref('%s->%s' % (Naming.cur_scope_cname, entry.cname))
         code.put_giveref(Naming.cur_scope_cname)
         code.put_finish_refcount_context()
         code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname);
index e71edbf129b2e8cfa0eadaa828dba63436861b10..dd746a0fb9c7396342b2820011cd026ff59fa20e 100644 (file)
@@ -1343,21 +1343,35 @@ class ClosureTempAllocator(object):
         self.temps_count = 0
 
     def reset(self):
-        for type, cnames in self.temps_allocated:
+        for type, cnames in self.temps_allocated.items():
             self.temps_free[type] = list(cnames)
 
     def allocate_temp(self, type):
         if not type in self.temps_allocated:
             self.temps_allocated[type] = []
             self.temps_free[type] = []
-        if self.temps_free[type]:
+        elif self.temps_free[type]:
             return self.temps_free[type].pop(0)
-        cname = '%s_%d' % (Naming.codewriter_temp_prefix, self.temps_count)
+        cname = '%s%d' % (Naming.codewriter_temp_prefix, self.temps_count)
         self.klass.declare_var(pos=None, name=cname, cname=cname, type=type, is_cdef=True)
         self.temps_allocated[type].append(cname)
         self.temps_count += 1
         return cname
 
+    def put_gotref(self, code):
+        for entry in self.klass.entries.values():
+            if entry.cname == Naming.outer_scope_cname: # XXX
+                continue
+            if entry.type.is_pyobject:
+                code.put_xgotref('%s->%s' % (Naming.cur_scope_cname, entry.cname))
+
+    def put_giveref(self, code):
+        for entry in self.klass.entries.values():
+            if entry.cname == Naming.outer_scope_cname: # XXX
+                continue
+            if entry.type.is_pyobject:
+                code.put_xgiveref('%s->%s' % (Naming.cur_scope_cname, entry.cname))
+
 class YieldCollector(object):
     def __init__(self, node):
         self.node = node
@@ -1394,12 +1408,11 @@ class MarkGeneratorVisitor(CythonTransform):
         elif collector.yields:
             allocator = ClosureTempAllocator()
             stop_node = ExprNodes.StopIterationNode(node.pos, arg=None)
-            collector.yields.append(stop_node)
-            for y in collector.yields: # XXX: find a better way
+            # XXX: move allocator inside local scope
+            for y in collector.yields:
                 y.temp_allocator = allocator
             node.temp_allocator = allocator
-            stop_node.label_num = len(collector.yields)
-            node.body.stats.append(Nodes.ExprStatNode(node.pos, expr=stop_node))
+            node.body.stats.append(stop_node)
             node.is_generator = True
             node.needs_closure = True
             node.yields = collector.yields
index 5cdc40f0feed12fffbefa123e8673610c8396714..b6176f007b45f9114c1877ab84c9376a28dc6380 100644 (file)
@@ -1,3 +1,17 @@
+def very_simple():
+    """
+    >>> x = very_simple()
+    >>> next(x)
+    1
+    >>> next(x)
+    Traceback (most recent call last):
+    StopIteration
+    >>> next(x)
+    Traceback (most recent call last):
+    StopIteration
+    """
+    yield 1
+
 def simple():
     """
     >>> x = simple()
@@ -32,6 +46,18 @@ def simple_send():
     while True:
         i = yield i
 
+def raising():
+    """
+    >>> x = raising()
+    >>> next(x)
+    Traceback (most recent call last):
+    KeyError: 'foo'
+    >>> next(x)
+    Traceback (most recent call last):
+    StopIteration
+    """
+    yield {}['foo']
+
 def with_outer(*args):
     """
     >>> x = with_outer(1, 2, 3)
@@ -42,3 +68,15 @@ def with_outer(*args):
         for i in args:
             yield i
     return generator
+
+def with_outer_raising(*args):
+    """
+    >>> x = with_outer_raising(1, 2, 3)
+    >>> list(x())
+    [1, 2, 3]
+    """
+    def generator():
+        for i in args:
+            yield i
+        raise StopIteration
+    return generator