Merge remote branch 'upstream/master'
[cython.git] / Cython / Compiler / Nodes.py
index ff130a487d178785f7fd4e0acf850434eb58ae1f..c09a7d47b9c3fd316d7224407bd6c201cc847c67 100644 (file)
@@ -23,7 +23,7 @@ from PyrexTypes import py_object_type, error_type, CFuncType
 from Symtab import ModuleScope, LocalScope, ClosureScope, \
     StructOrUnionScope, PyClassScope, CClassScope, CppClassScope
 from Cython.Utils import open_new_file, replace_suffix
-from Code import UtilityCode
+from Code import UtilityCode, ClosureTempAllocator
 from StringEncoding import EncodedString, escape_byte_string, split_string_literal
 import Options
 import ControlFlow
@@ -1166,6 +1166,7 @@ class FuncDefNode(StatNode, BlockNode):
     assmt = None
     needs_closure = False
     needs_outer_scope = False
+    is_generator = False
     modifiers = []
 
     def analyse_default_values(self, env):
@@ -1295,7 +1296,9 @@ class FuncDefNode(StatNode, BlockNode):
                     (self.return_type.declaration_code(Naming.retval_cname),
                      init))
         tempvardecl_code = code.insertion_point()
-        self.generate_keyword_list(code)
+        code.put_declare_refcount_context()
+        if not self.is_generator:
+            self.generate_keyword_list(code)
         if profile:
             code.put_trace_declarations()
         # ----- Extern library function declarations
@@ -1314,7 +1317,14 @@ class FuncDefNode(StatNode, BlockNode):
         if is_getbuffer_slot:
             self.getbuffer_init(code)
         # ----- Create closure scope object
-        if self.needs_closure:
+        if self.is_generator:
+            code.putln("%s = (%s) %s;" % (
+                Naming.cur_scope_cname,
+                lenv.scope_class.type.declaration_code(''),
+                Naming.self_cname))
+            gotref_code = code.insertion_point()
+
+        elif self.needs_closure:
             code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
                 Naming.cur_scope_cname,
                 lenv.scope_class.type.declaration_code(''),
@@ -1331,7 +1341,7 @@ class FuncDefNode(StatNode, BlockNode):
             code.putln("}")
             code.put_gotref(Naming.cur_scope_cname)
             # Note that it is unsafe to decref the scope at this point.
-        if self.needs_outer_scope:
+        if self.needs_outer_scope and not self.is_generator:
             code.putln("%s = (%s)%s;" % (
                             outer_scope_cname,
                             cenv.scope_class.type.declaration_code(''),
@@ -1348,34 +1358,24 @@ class FuncDefNode(StatNode, BlockNode):
             # fatal error before hand, it's not really worth tracing
             code.put_trace_call(self.entry.name, self.pos)
         # ----- Fetch arguments
-        self.generate_argument_parsing_code(env, code)
-        # If an argument is assigned to in the body, we must
-        # incref it to properly keep track of refcounts.
-        for entry in lenv.arg_entries:
-            if entry.type.is_pyobject:
-                if entry.assignments and not entry.in_closure:
-                    code.put_var_incref(entry)
-        # ----- Initialise local variables
-        for entry in lenv.var_entries:
-            if entry.type.is_pyobject and entry.init_to_none and entry.used:
-                code.put_init_var_to_py_none(entry)
-        # ----- Initialise local buffer auxiliary variables
-        for entry in lenv.var_entries + lenv.arg_entries:
-            if entry.type.is_buffer and entry.buffer_aux.buffer_info_var.used:
-                code.putln("%s.buf = NULL;" %
-                           entry.buffer_aux.buffer_info_var.cname)
-        # ----- Check and convert arguments
-        self.generate_argument_type_tests(code)
-        # ----- Acquire buffer arguments
-        for entry in lenv.arg_entries:
-            if entry.type.is_buffer:
-                Buffer.put_acquire_arg_buffer(entry, code, self.pos)
-
+        if not self.is_generator:
+            self.generate_preamble(env, code)
+        if self.is_generator:
+            code.funcstate.init_closure_temps(lenv.scope_class.type.scope)
+            resume_code = code.insertion_point()
+            first_run_label = code.new_label('first_run')
+            code.use_label(first_run_label)
+            code.put_label(first_run_label)
+            code.putln('%s' %
+                       (code.error_goto_if_null(Naming.sent_value_cname, self.pos)))
         # -------------------------
         # ----- Function body -----
         # -------------------------
         self.body.generate_execution_code(code)
 
+        if self.is_generator:
+            code.putln('PyErr_SetNone(PyExc_StopIteration); %s' % code.error_goto(self.pos))
+
         # ----- Default return value
         code.putln("")
         if self.return_type.is_pyobject:
@@ -1456,17 +1456,15 @@ class FuncDefNode(StatNode, BlockNode):
             if entry.type.is_pyobject:
                 if entry.used and not entry.in_closure:
                     code.put_var_decref(entry)
-                elif entry.in_closure and self.needs_closure:
-                    code.put_giveref(entry.cname)
         # Decref any increfed args
         for entry in lenv.arg_entries:
             if entry.type.is_pyobject:
-                if entry.in_closure:
-                    code.put_var_giveref(entry)
-                elif entry.assignments:
+                if entry.assignments and not entry.in_closure:
                     code.put_var_decref(entry)
-        if self.needs_closure:
+        if self.needs_closure and not self.is_generator:
             code.put_decref(Naming.cur_scope_cname, lenv.scope_class.type)
+        if self.is_generator:
+            code.putln('%s->%s.resume_label = -1;' % (Naming.cur_scope_cname, Naming.obj_base_cname))
 
         # ----- Return
         # This code is duplicated in ModuleNode.generate_module_init_func
@@ -1504,15 +1502,55 @@ class FuncDefNode(StatNode, BlockNode):
 
         if preprocessor_guard:
             code.putln("#endif /*!(%s)*/" % preprocessor_guard)
-
         # ----- Go back and insert temp variable declarations
         tempvardecl_code.put_temp_declarations(code.funcstate)
+        # ----- Generator resume code
+        if self.is_generator:
+            resume_code.putln("switch (%s->%s.resume_label) {" % (Naming.cur_scope_cname, Naming.obj_base_cname));
+            resume_code.putln("case 0: goto %s;" % first_run_label)
+            for yield_expr in self.yields:
+                resume_code.putln("case %d: goto %s;" % (yield_expr.label_num, yield_expr.label_name));
+            resume_code.putln("default: /* CPython raises the right error here */");
+            resume_code.putln("return NULL;");
+            resume_code.putln("}");
         # ----- Python version
         code.exit_cfunc_scope()
         if self.py_func:
             self.py_func.generate_function_definitions(env, code)
         self.generate_wrapper_functions(code)
 
+        if self.is_generator:
+            self.generator.generate_function_body(self.local_scope, code)
+
+    def generate_preamble(self, env, code):
+        """Parse arguments and prepare scope"""
+        import Buffer
+
+        lenv = self.local_scope
+
+        self.generate_argument_parsing_code(env, code)
+        # If an argument is assigned to in the body, we must
+        # incref it to properly keep track of refcounts.
+        for entry in lenv.arg_entries:
+            if entry.type.is_pyobject:
+                if entry.assignments and not entry.in_closure:
+                    code.put_var_incref(entry)
+        # ----- Initialise local variables
+        for entry in lenv.var_entries:
+            if entry.type.is_pyobject and entry.init_to_none and entry.used:
+                code.put_init_var_to_py_none(entry)
+        # ----- Initialise local buffer auxiliary variables
+        for entry in lenv.var_entries + lenv.arg_entries:
+            if entry.type.is_buffer and entry.buffer_aux.buffer_info_var.used:
+                code.putln("%s.buf = NULL;" %
+                           entry.buffer_aux.buffer_info_var.cname)
+        # ----- Check and convert arguments
+        self.generate_argument_type_tests(code)
+        # ----- Acquire buffer arguments
+        for entry in lenv.arg_entries:
+            if entry.type.is_buffer:
+                Buffer.put_acquire_arg_buffer(entry, code, self.pos)
+
     def declare_argument(self, env, arg):
         if arg.type.is_void:
             error(arg.pos, "Invalid use of 'void'")
@@ -1863,6 +1901,61 @@ class DecoratorNode(Node):
     child_attrs = ['decorator']
 
 
+class GeneratorWrapperNode(object):
+    # Wrapper
+    def __init__(self, def_node, func_cname=None, body_cname=None, header=None):
+        self.def_node = def_node
+        self.func_cname = func_cname
+        self.body_cname = body_cname
+        self.header = header
+
+    def generate_function_body(self, env, code):
+        code.mark_pos(self.def_node.pos)
+        cenv = env.outer_scope # XXX: correct?
+        while cenv.is_py_class_scope or cenv.is_c_class_scope:
+            cenv = cenv.outer_scope
+        lenv = self.def_node.local_scope
+        code.enter_cfunc_scope()
+        code.putln()
+        code.putln('%s {' % self.header)
+        code.put_declare_refcount_context()
+        self.def_node.generate_keyword_list(code)
+        code.put(lenv.scope_class.type.declaration_code(Naming.cur_scope_cname))
+        code.putln(";")
+        code.put_setup_refcount_context(self.def_node.entry.name)
+        code.putln("%s = (%s)%s->tp_new(%s, %s, NULL);" % (
+            Naming.cur_scope_cname,
+            lenv.scope_class.type.declaration_code(''),
+            lenv.scope_class.type.typeptr_cname,
+            lenv.scope_class.type.typeptr_cname,
+            Naming.empty_tuple))
+        code.putln("if (unlikely(!%s)) {" % Naming.cur_scope_cname)
+        code.put_finish_refcount_context()
+        code.putln("return NULL;");
+        code.putln("}");
+        code.put_gotref(Naming.cur_scope_cname)
+
+        if self.def_node.needs_outer_scope:
+            outer_scope_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.outer_scope_cname)
+            code.putln("%s = (%s)%s;" % (
+                            outer_scope_cname,
+                            cenv.scope_class.type.declaration_code(''),
+                            Naming.self_cname))
+            code.put_incref(outer_scope_cname, cenv.scope_class.type)
+            code.put_giveref(outer_scope_cname)
+
+        self.def_node.generate_preamble(env, code)
+
+        generator_cname = '%s->%s' % (Naming.cur_scope_cname, Naming.obj_base_cname)
+
+        code.putln('%s.resume_label = 0;' % generator_cname)
+        code.putln('%s.body = %s;' % (generator_cname, self.body_cname))
+        code.put_giveref(Naming.cur_scope_cname)
+        code.put_finish_refcount_context()
+        code.putln("return (PyObject *) %s;" % Naming.cur_scope_cname);
+        code.putln('}\n')
+        code.exit_cfunc_scope()
+
 class DefNode(FuncDefNode):
     # A Python function definition.
     #
@@ -2156,6 +2249,10 @@ class DefNode(FuncDefNode):
             Naming.pyfunc_prefix + prefix + name
         entry.pymethdef_cname = \
             Naming.pymethdef_prefix + prefix + name
+
+        if self.is_generator:
+            self.generator_body_cname = Naming.genbody_prefix + env.next_id(env.scope_prefix) + name
+
         if Options.docstrings:
             entry.doc = embed_position(self.pos, self.doc)
             entry.doc_cname = \
@@ -2303,7 +2400,15 @@ class DefNode(FuncDefNode):
                 "static PyMethodDef %s = " %
                     self.entry.pymethdef_cname)
             code.put_pymethoddef(self.entry, ";", allow_skip=False)
-        code.putln("%s {" % header)
+        if self.is_generator:
+            code.putln("static PyObject *%s(PyObject *%s, PyObject *%s) /* generator body */\n{" %
+                       (self.generator_body_cname, Naming.self_cname, Naming.sent_value_cname))
+            self.generator = GeneratorWrapperNode(self,
+                                                  func_cname=self.entry.func_cname,
+                                                  body_cname=self.generator_body_cname,
+                                                  header=header)
+        else:
+            code.putln("%s {" % header)
 
     def generate_argument_declarations(self, env, code):
         for arg in self.args:
@@ -2326,8 +2431,8 @@ class DefNode(FuncDefNode):
             code.putln("0};")
 
     def generate_argument_parsing_code(self, env, code):
-        # Generate PyArg_ParseTuple call for generic
-        # arguments, if any.
+        # Generate fast equivalent of PyArg_ParseTuple call for
+        # generic arguments, if any, including args/kwargs
         if self.entry.signature.has_dummy_arg and not self.self_in_stararg:
             # get rid of unused argument warning
             code.putln("%s = %s;" % (Naming.self_cname, Naming.self_cname))
@@ -2383,9 +2488,9 @@ class DefNode(FuncDefNode):
                 self.generate_arg_decref(self.star_arg, code)
                 if self.starstar_arg:
                     if self.starstar_arg.entry.xdecref_cleanup:
-                        code.put_var_xdecref(self.starstar_arg.entry)
+                        code.put_var_xdecref_clear(self.starstar_arg.entry)
                     else:
-                        code.put_var_decref(self.starstar_arg.entry)
+                        code.put_var_decref_clear(self.starstar_arg.entry)
             code.putln('__Pyx_AddTraceback("%s");' % self.entry.qualified_name)
             # The arguments are put into the closure one after the
             # other, so when type errors are found, all references in
@@ -2404,14 +2509,24 @@ class DefNode(FuncDefNode):
         if code.label_used(end_label):
             code.put_label(end_label)
 
+        # fix refnanny view on closure variables here, instead of
+        # doing it separately for each arg parsing special case
+        if self.star_arg and self.star_arg.entry.in_closure:
+            code.put_var_giveref(self.star_arg.entry)
+        if self.starstar_arg and self.starstar_arg.entry.in_closure:
+            code.put_var_giveref(self.starstar_arg.entry)
+        for arg in self.args:
+            if arg.type.is_pyobject and arg.entry.in_closure:
+                code.put_var_giveref(arg.entry)
+
     def generate_arg_assignment(self, arg, item, code):
         if arg.type.is_pyobject:
             if arg.is_generic:
                 item = PyrexTypes.typecast(arg.type, PyrexTypes.py_object_type, item)
             entry = arg.entry
-            code.putln("%s = %s;" % (entry.cname, item))
             if entry.in_closure:
-                code.put_var_incref(entry)
+                code.put_incref(item, PyrexTypes.py_object_type)
+            code.putln("%s = %s;" % (entry.cname, item))
         else:
             func = arg.type.from_py_function
             if func:
@@ -2425,11 +2540,11 @@ class DefNode(FuncDefNode):
 
     def generate_arg_xdecref(self, arg, code):
         if arg:
-            code.put_var_xdecref(arg.entry)
+            code.put_var_xdecref_clear(arg.entry)
 
     def generate_arg_decref(self, arg, code):
         if arg:
-            code.put_var_decref(arg.entry)
+            code.put_var_decref_clear(arg.entry)
 
     def generate_stararg_copy_code(self, code):
         if not self.star_arg:
@@ -2632,19 +2747,18 @@ class DefNode(FuncDefNode):
             code.putln('if (PyTuple_GET_SIZE(%s) > %d) {' % (
                     Naming.args_cname,
                     max_positional_args))
-            code.put('%s = PyTuple_GetSlice(%s, %d, PyTuple_GET_SIZE(%s)); ' % (
+            code.putln('%s = PyTuple_GetSlice(%s, %d, PyTuple_GET_SIZE(%s));' % (
                     self.star_arg.entry.cname, Naming.args_cname,
                     max_positional_args, Naming.args_cname))
-            code.put_gotref(self.star_arg.entry.cname)
+            code.putln("if (unlikely(!%s)) {" % self.star_arg.entry.cname)
             if self.starstar_arg:
-                code.putln("")
-                code.putln("if (unlikely(!%s)) {" % self.star_arg.entry.cname)
                 code.put_decref(self.starstar_arg.entry.cname, py_object_type)
-                code.putln('return %s;' % self.error_value())
-                code.putln('}')
-            else:
-                code.putln("if (unlikely(!%s)) return %s;" % (
-                        self.star_arg.entry.cname, self.error_value()))
+            if self.needs_closure:
+                code.put_decref(Naming.cur_scope_cname, self.local_scope.scope_class.type)
+            code.put_finish_refcount_context()
+            code.putln('return %s;' % self.error_value())
+            code.putln('}')
+            code.put_gotref(self.star_arg.entry.cname)
             code.putln('} else {')
             code.put("%s = %s; " % (self.star_arg.entry.cname, Naming.empty_tuple))
             code.put_incref(Naming.empty_tuple, py_object_type)
@@ -2813,9 +2927,9 @@ class DefNode(FuncDefNode):
             if arg.needs_conversion:
                 self.generate_arg_conversion(arg, code)
             elif arg.entry.in_closure:
-                code.putln('%s = %s;' % (arg.entry.cname, arg.hdr_cname))
                 if arg.type.is_pyobject:
-                    code.put_var_incref(arg.entry)
+                    code.put_incref(arg.hdr_cname, py_object_type)
+                code.putln('%s = %s;' % (arg.entry.cname, arg.hdr_cname))
 
     def generate_arg_conversion(self, arg, code):
         # Generate conversion code for one argument.
@@ -3320,6 +3434,24 @@ class GlobalNode(StatNode):
         pass
 
 
+class NonlocalNode(StatNode):
+    # Nonlocal variable declaration via the 'nonlocal' keyword.
+    #
+    # names    [string]
+
+    child_attrs = []
+
+    def analyse_declarations(self, env):
+        for name in self.names:
+            env.declare_nonlocal(name, self.pos)
+
+    def analyse_expressions(self, env):
+        pass
+
+    def generate_execution_code(self, code):
+        pass
+
+
 class ExprStatNode(StatNode):
     #  Expression used as a statement.
     #
@@ -3344,6 +3476,7 @@ class ExprStatNode(StatNode):
                 self.__class__ = PassStatNode
 
     def analyse_expressions(self, env):
+        self.expr.result_is_used = False # hint that .result() may safely be left empty
         self.expr.analyse_expressions(env)
 
     def generate_execution_code(self, code):
@@ -4647,12 +4780,12 @@ class TryExceptStatNode(StatNode):
         try_continue_label = code.new_label('try_continue')
         try_end_label = code.new_label('try_end')
 
+        exc_save_vars = [code.funcstate.allocate_temp(py_object_type, False)
+                         for i in xrange(3)]
         code.putln("{")
-        code.putln("PyObject %s;" %
-                   ', '.join(['*%s' % var for var in Naming.exc_save_vars]))
         code.putln("__Pyx_ExceptionSave(%s);" %
-                   ', '.join(['&%s' % var for var in Naming.exc_save_vars]))
-        for var in Naming.exc_save_vars:
+                   ', '.join(['&%s' % var for var in exc_save_vars]))
+        for var in exc_save_vars:
             code.put_xgotref(var)
         code.putln(
             "/*try:*/ {")
@@ -4671,14 +4804,14 @@ class TryExceptStatNode(StatNode):
             self.else_clause.generate_execution_code(code)
             code.putln(
                 "}")
-        for var in Naming.exc_save_vars:
+        for var in exc_save_vars:
             code.put_xdecref_clear(var, py_object_type)
         code.put_goto(try_end_label)
         if code.label_used(try_return_label):
             code.put_label(try_return_label)
-            for var in Naming.exc_save_vars: code.put_xgiveref(var)
+            for var in exc_save_vars: code.put_xgiveref(var)
             code.putln("__Pyx_ExceptionReset(%s);" %
-                       ', '.join(Naming.exc_save_vars))
+                       ', '.join(exc_save_vars))
             code.put_goto(old_return_label)
         code.put_label(our_error_label)
         for temp_name, type in temps_to_clean_up:
@@ -4690,9 +4823,9 @@ class TryExceptStatNode(StatNode):
         if error_label_used or not self.has_default_clause:
             if error_label_used:
                 code.put_label(except_error_label)
-            for var in Naming.exc_save_vars: code.put_xgiveref(var)
+            for var in exc_save_vars: code.put_xgiveref(var)
             code.putln("__Pyx_ExceptionReset(%s);" %
-                       ', '.join(Naming.exc_save_vars))
+                       ', '.join(exc_save_vars))
             code.put_goto(old_error_label)
 
         for exit_label, old_label in zip(
@@ -4701,19 +4834,22 @@ class TryExceptStatNode(StatNode):
 
             if code.label_used(exit_label):
                 code.put_label(exit_label)
-                for var in Naming.exc_save_vars: code.put_xgiveref(var)
+                for var in exc_save_vars: code.put_xgiveref(var)
                 code.putln("__Pyx_ExceptionReset(%s);" %
-                           ', '.join(Naming.exc_save_vars))
+                           ', '.join(exc_save_vars))
                 code.put_goto(old_label)
 
         if code.label_used(except_end_label):
             code.put_label(except_end_label)
-            for var in Naming.exc_save_vars: code.put_xgiveref(var)
+            for var in exc_save_vars: code.put_xgiveref(var)
             code.putln("__Pyx_ExceptionReset(%s);" %
-                       ', '.join(Naming.exc_save_vars))
+                       ', '.join(exc_save_vars))
         code.put_label(try_end_label)
         code.putln("}")
 
+        for cname in exc_save_vars:
+            code.funcstate.release_temp(cname)
+
         code.return_label = old_return_label
         code.break_label = old_break_label
         code.continue_label = old_continue_label