Merged pull request #12 from bhy/T423.
[cython.git] / Cython / Compiler / Nodes.py
index a8ed64e67f119c4df0b99c675a079633662f8bcb..4405fefa46fc911c2d57d0aeb6096787ab1f4a4a 100644 (file)
@@ -548,7 +548,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
                 other_type = type_node.analyse_as_type(env)
                 if other_type is None:
                     error(type_node.pos, "Not a type")
-                elif (type is not PyrexTypes.py_object_type 
+                elif (type is not PyrexTypes.py_object_type
                       and not type.same_as(other_type)):
                     error(self.base.pos, "Signature does not agree with previous declaration")
                     error(type_node.pos, "Previous declaration here")
@@ -1099,7 +1099,7 @@ class CEnumDefNode(StatNode):
     #  api            boolean
     #  in_pxd         boolean
     #  entry          Entry
-    
+
     child_attrs = ["items"]
 
     def analyse_declarations(self, env):
@@ -1146,7 +1146,7 @@ class CEnumDefItemNode(StatNode):
             if not self.value.type.is_int:
                 self.value = self.value.coerce_to(PyrexTypes.c_int_type, env)
                 self.value.analyse_const_expression(env)
-        entry = env.declare_const(self.name, enum_entry.type, 
+        entry = env.declare_const(self.name, enum_entry.type,
             self.value, self.pos, cname = self.cname,
             visibility = enum_entry.visibility, api = enum_entry.api)
         enum_entry.enum_values.append(entry)
@@ -1170,7 +1170,7 @@ class CTypeDefNode(StatNode):
             cname = cname, visibility = self.visibility, api = self.api)
         if self.in_pxd and not env.in_cinclude:
             entry.defined_in_pxd = 1
-    
+
     def analyse_expressions(self, env):
         pass
     def generate_execution_code(self, code):
@@ -1221,7 +1221,7 @@ class FuncDefNode(StatNode, BlockNode):
             other_type = type_node.analyse_as_type(env)
             if other_type is None:
                 error(type_node.pos, "Not a type")
-            elif (type is not PyrexTypes.py_object_type 
+            elif (type is not PyrexTypes.py_object_type
                     and not type.same_as(other_type)):
                 error(arg.base_type.pos, "Signature does not agree with previous declaration")
                 error(type_node.pos, "Previous declaration here")
@@ -1623,20 +1623,24 @@ class FuncDefNode(StatNode, BlockNode):
         info = self.local_scope.arg_entries[1].cname
         # Python 3.0 betas have a bug in memoryview which makes it call
         # getbuffer with a NULL parameter. For now we work around this;
-        # the following line should be removed when this bug is fixed.
-        code.putln("if (%s == NULL) return 0;" % info)
+        # the following block should be removed when this bug is fixed.
+        code.putln("if (%s != NULL) {" % info)
         code.putln("%s->obj = Py_None; __Pyx_INCREF(Py_None);" % info)
         code.put_giveref("%s->obj" % info) # Do not refnanny object within structs
+        code.putln("}")
 
     def getbuffer_error_cleanup(self, code):
         info = self.local_scope.arg_entries[1].cname
+        code.putln("if (%s != NULL && %s->obj != NULL) {"
+                   % (info, info))
         code.put_gotref("%s->obj" % info)
-        code.putln("__Pyx_DECREF(%s->obj); %s->obj = NULL;" %
-                   (info, info))
+        code.putln("__Pyx_DECREF(%s->obj); %s->obj = NULL;"
+                   % (info, info))
+        code.putln("}")
 
     def getbuffer_normal_cleanup(self, code):
         info = self.local_scope.arg_entries[1].cname
-        code.putln("if (%s->obj == Py_None) {" % info)
+        code.putln("if (%s != NULL && %s->obj == Py_None) {" % (info, info))
         code.put_gotref("Py_None")
         code.putln("__Pyx_DECREF(Py_None); %s->obj = NULL;" % info)
         code.putln("}")
@@ -1673,7 +1677,12 @@ class CFuncDefNode(FuncDefNode):
         self.directive_locals.update(env.directives['locals'])
         base_type = self.base_type.analyse(env)
         # The 2 here is because we need both function and argument names.
-        name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
+        if isinstance(self.declarator, CFuncDeclaratorNode):
+            name_declarator, type = self.declarator.analyse(base_type, env,
+                                                            nonempty = 2 * (self.body is not None),
+                                                            directive_locals = self.directive_locals)
+        else:
+            name_declarator, type = self.declarator.analyse(base_type, env, nonempty = 2 * (self.body is not None))
         if not type.is_cfunction:
             error(self.pos,
                 "Suite attached to non-function declaration")
@@ -1806,7 +1815,7 @@ class CFuncDefNode(FuncDefNode):
             self.modifiers[self.modifiers.index('inline')] = 'cython_inline'
         if self.modifiers:
             modifiers = "%s " % ' '.join(self.modifiers).upper()
-        
+
         header = self.return_type.declaration_code(entity, dll_linkage=dll_linkage)
         #print (storage_class, modifiers, header)
         code.putln("%s%s%s {" % (storage_class, modifiers, header))
@@ -2023,6 +2032,17 @@ class DefNode(FuncDefNode):
                             api = False,
                             directive_locals = getattr(cfunc, 'directive_locals', {}))
 
+    def is_cdef_func_compatible(self):
+        """Determines if the function's signature is compatible with a
+        cdef function.  This can be used before calling
+        .as_cfunction() to see if that will be successful.
+        """
+        if self.needs_closure:
+            return False
+        if self.star_arg or self.starstar_arg:
+            return False
+        return True
+
     def analyse_declarations(self, env):
         self.is_classmethod = self.is_staticmethod = False
         if self.decorators:
@@ -2212,14 +2232,7 @@ class DefNode(FuncDefNode):
             entry.doc = None
 
     def declare_lambda_function(self, env):
-        name = self.name
-        prefix = env.scope_prefix
-        func_cname = \
-            Naming.lambda_func_prefix + u'funcdef' + prefix + self.lambda_name
-        entry = env.declare_lambda_function(func_cname, self.pos)
-        entry.pymethdef_cname = \
-            Naming.lambda_func_prefix + u'methdef' + prefix + self.lambda_name
-        entry.qualified_name = env.qualify_name(self.lambda_name)
+        entry = env.declare_lambda_function(self.lambda_name, self.pos)
         entry.doc = None
         self.entry = entry
 
@@ -3433,7 +3446,7 @@ class CClassDefNode(ClassDefNode):
             visibility = self.visibility,
             typedef_flag = self.typedef_flag,
             api = self.api,
-            buffer_defaults = buffer_defaults, 
+            buffer_defaults = buffer_defaults,
             shadow = self.shadow)
         if self.shadow:
             home_scope.lookup(self.class_name).as_variable = self.entry
@@ -4141,8 +4154,9 @@ class RaiseStatNode(StatNode):
     #  exc_type    ExprNode or None
     #  exc_value   ExprNode or None
     #  exc_tb      ExprNode or None
+    #  cause       ExprNode or None
 
-    child_attrs = ["exc_type", "exc_value", "exc_tb"]
+    child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
 
     def analyse_expressions(self, env):
         if self.exc_type:
@@ -4154,13 +4168,16 @@ class RaiseStatNode(StatNode):
         if self.exc_tb:
             self.exc_tb.analyse_types(env)
             self.exc_tb = self.exc_tb.coerce_to_pyobject(env)
+        if self.cause:
+            self.cause.analyse_types(env)
+            self.cause = self.cause.coerce_to_pyobject(env)
         # special cases for builtin exceptions
         self.builtin_exc_name = None
         if self.exc_type and not self.exc_value and not self.exc_tb:
             exc = self.exc_type
             import ExprNodes
-            if (isinstance(exc, ExprNodes.SimpleCallNode) and 
-                not (exc.args or (exc.arg_tuple is not None and 
+            if (isinstance(exc, ExprNodes.SimpleCallNode) and
+                not (exc.args or (exc.arg_tuple is not None and
                                   exc.arg_tuple.args))):
                 exc = exc.function # extract the exception type
             if exc.is_name and exc.entry.is_builtin:
@@ -4191,13 +4208,19 @@ class RaiseStatNode(StatNode):
             tb_code = self.exc_tb.py_result()
         else:
             tb_code = "0"
+        if self.cause:
+            self.cause.generate_evaluation_code(code)
+            cause_code = self.cause.py_result()
+        else:
+            cause_code = "0"
         code.globalstate.use_utility_code(raise_utility_code)
         code.putln(
-            "__Pyx_Raise(%s, %s, %s);" % (
+            "__Pyx_Raise(%s, %s, %s, %s);" % (
                 type_code,
                 value_code,
-                tb_code))
-        for obj in (self.exc_type, self.exc_value, self.exc_tb):
+                tb_code,
+                cause_code))
+        for obj in (self.exc_type, self.exc_value, self.exc_tb, self.cause):
             if obj:
                 obj.generate_disposal_code(code)
                 obj.free_temps(code)
@@ -4211,6 +4234,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.generate_function_definitions(env, code)
         if self.exc_tb is not None:
             self.exc_tb.generate_function_definitions(env, code)
+        if self.cause is not None:
+            self.cause.generate_function_definitions(env, code)
 
     def annotate(self, code):
         if self.exc_type:
@@ -4219,6 +4244,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.annotate(code)
         if self.exc_tb:
             self.exc_tb.annotate(code)
+        if self.cause:
+            self.cause.annotate(code)
 
 
 class ReraiseStatNode(StatNode):
@@ -4822,14 +4849,134 @@ class WithStatNode(StatNode):
     """
     Represents a Python with statement.
 
-    This is only used at parse tree level; and is not present in
-    analysis or generation phases.
+    Implemented by the WithTransform as follows:
+
+        MGR = EXPR
+        EXIT = MGR.__exit__
+        VALUE = MGR.__enter__()
+        EXC = True
+        try:
+            try:
+                TARGET = VALUE  # optional
+                BODY
+            except:
+                EXC = False
+                if not EXIT(*EXCINFO):
+                    raise
+        finally:
+            if EXC:
+                EXIT(None, None, None)
+            MGR = EXIT = VALUE = None
     """
     #  manager          The with statement manager object
-    #  target            Node (lhs expression)
+    #  target           ExprNode  the target lhs of the __enter__() call
     #  body             StatNode
+
     child_attrs = ["manager", "target", "body"]
 
+    has_target = False
+
+    def analyse_declarations(self, env):
+        self.manager.analyse_declarations(env)
+        self.body.analyse_declarations(env)
+
+    def analyse_expressions(self, env):
+        self.manager.analyse_types(env)
+        self.body.analyse_expressions(env)
+
+    def generate_function_definitions(self, env, code):
+        self.manager.generate_function_definitions(env, code)
+        self.body.generate_function_definitions(env, code)
+
+    def generate_execution_code(self, code):
+        code.putln("/*with:*/ {")
+        self.manager.generate_evaluation_code(code)
+        self.exit_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False)
+        code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (
+            self.exit_var,
+            self.manager.py_result(),
+            code.get_py_string_const(EncodedString('__exit__'), identifier=True),
+            code.error_goto_if_null(self.exit_var, self.pos),
+            ))
+        code.put_gotref(self.exit_var)
+
+        # need to free exit_var in the face of exceptions during setup
+        old_error_label = code.new_error_label()
+        intermediate_error_label = code.error_label
+
+        enter_func = code.funcstate.allocate_temp(py_object_type, manage_ref=True)
+        code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (
+            enter_func,
+            self.manager.py_result(),
+            code.get_py_string_const(EncodedString('__enter__'), identifier=True),
+            code.error_goto_if_null(enter_func, self.pos),
+            ))
+        code.put_gotref(enter_func)
+        self.manager.generate_disposal_code(code)
+        self.manager.free_temps(code)
+        self.target_temp.allocate(code)
+        code.putln('%s = PyObject_Call(%s, ((PyObject *)%s), NULL); %s' % (
+            self.target_temp.result(),
+            enter_func,
+            Naming.empty_tuple,
+            code.error_goto_if_null(self.target_temp.result(), self.pos),
+            ))
+        code.put_gotref(self.target_temp.result())
+        code.put_decref_clear(enter_func, py_object_type)
+        code.funcstate.release_temp(enter_func)
+        if not self.has_target:
+            code.put_decref_clear(self.target_temp.result(), type=py_object_type)
+            self.target_temp.release(code)
+            # otherwise, WithTargetAssignmentStatNode will do it for us
+
+        code.error_label = old_error_label
+        self.body.generate_execution_code(code)
+
+        step_over_label = code.new_label()
+        code.put_goto(step_over_label)
+        code.put_label(intermediate_error_label)
+        code.put_decref_clear(self.exit_var, py_object_type)
+        code.put_goto(old_error_label)
+        code.put_label(step_over_label)
+
+        code.funcstate.release_temp(self.exit_var)
+        code.putln('}')
+
+class WithTargetAssignmentStatNode(AssignmentNode):
+    # The target assignment of the 'with' statement value (return
+    # value of the __enter__() call).
+    #
+    # This is a special cased assignment that steals the RHS reference
+    # and frees its temp.
+    #
+    # lhs  ExprNode  the assignment target
+    # rhs  TempNode  the return value of the __enter__() call
+
+    child_attrs = ["lhs", "rhs"]
+
+    def analyse_declarations(self, env):
+        self.lhs.analyse_target_declaration(env)
+
+    def analyse_types(self, env):
+        self.rhs.analyse_types(env)
+        self.lhs.analyse_target_types(env)
+        self.lhs.gil_assignment_check(env)
+        self.orig_rhs = self.rhs
+        self.rhs = self.rhs.coerce_to(self.lhs.type, env)
+
+    def generate_execution_code(self, code):
+        self.rhs.generate_evaluation_code(code)
+        self.lhs.generate_assignment_code(self.rhs, code)
+        self.orig_rhs.release(code)
+
+    def generate_function_definitions(self, env, code):
+        self.rhs.generate_function_definitions(env, code)
+
+    def annotate(self, code):
+        self.lhs.annotate(code)
+        self.rhs.annotate(code)
+
+
 class TryExceptStatNode(StatNode):
     #  try .. except statement
     #
@@ -4995,7 +5142,7 @@ class ExceptClauseNode(Node):
     #  pattern        [ExprNode]
     #  target         ExprNode or None
     #  body           StatNode
-    #  excinfo_target NameNode or None   optional target for exception info
+    #  excinfo_target ResultRefNode or None   optional target for exception info
     #  match_flag     string             result of exception match
     #  exc_value      ExcValueNode       used internally
     #  function_name  string             qualified name of enclosing function
@@ -5013,8 +5160,6 @@ class ExceptClauseNode(Node):
     def analyse_declarations(self, env):
         if self.target:
             self.target.analyse_target_declaration(env)
-        if self.excinfo_target is not None:
-            self.excinfo_target.analyse_target_declaration(env)
         self.body.analyse_declarations(env)
 
     def analyse_expressions(self, env):
@@ -5035,7 +5180,6 @@ class ExceptClauseNode(Node):
             self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
                 ExprNodes.ExcValueNode(pos=self.pos, env=env) for x in range(3)])
             self.excinfo_tuple.analyse_expressions(env)
-            self.excinfo_target.analyse_target_expression(env, self.excinfo_tuple)
 
         self.body.analyse_expressions(env)
 
@@ -5090,7 +5234,7 @@ class ExceptClauseNode(Node):
             for tempvar, node in zip(exc_vars, self.excinfo_tuple.args):
                 node.set_var(tempvar)
             self.excinfo_tuple.generate_evaluation_code(code)
-            self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code)
+            self.excinfo_target.result_code = self.excinfo_tuple.result()
 
         old_break_label, old_continue_label = code.break_label, code.continue_label
         code.break_label = code.new_label('except_break')
@@ -5100,24 +5244,32 @@ class ExceptClauseNode(Node):
         code.funcstate.exc_vars = exc_vars
         self.body.generate_execution_code(code)
         code.funcstate.exc_vars = old_exc_vars
+        if self.excinfo_target is not None:
+            self.excinfo_tuple.generate_disposal_code(code)
         for var in exc_vars:
-            code.putln("__Pyx_DECREF(%s); %s = 0;" % (var, var))
+            code.put_decref_clear(var, py_object_type)
         code.put_goto(end_label)
 
         if code.label_used(code.break_label):
             code.put_label(code.break_label)
+            if self.excinfo_target is not None:
+                self.excinfo_tuple.generate_disposal_code(code)
             for var in exc_vars:
-                code.putln("__Pyx_DECREF(%s); %s = 0;" % (var, var))
+                code.put_decref_clear(var, py_object_type)
             code.put_goto(old_break_label)
         code.break_label = old_break_label
 
         if code.label_used(code.continue_label):
             code.put_label(code.continue_label)
+            if self.excinfo_target is not None:
+                self.excinfo_tuple.generate_disposal_code(code)
             for var in exc_vars:
-                code.putln("__Pyx_DECREF(%s); %s = 0;" % (var, var))
+                code.put_decref_clear(var, py_object_type)
             code.put_goto(old_continue_label)
         code.continue_label = old_continue_label
 
+        if self.excinfo_target is not None:
+            self.excinfo_tuple.free_temps(code)
         for temp in exc_vars:
             code.funcstate.release_temp(temp)
 
@@ -5157,6 +5309,9 @@ class TryFinallyStatNode(StatNode):
 
     preserve_exception = 1
 
+    # handle exception case, in addition to return/break/continue
+    handle_error_case = True
+
     disallow_continue_in_try_finally = 0
     # There doesn't seem to be any point in disallowing
     # continue in the try block, since we have no problem
@@ -5190,6 +5345,8 @@ class TryFinallyStatNode(StatNode):
         old_labels = code.all_new_labels()
         new_labels = code.get_all_labels()
         new_error_label = code.error_label
+        if not self.handle_error_case:
+            code.error_label = old_error_label
         catch_label = code.new_label()
         code.putln(
             "/*try:*/ {")
@@ -5860,11 +6017,12 @@ static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyOb
 
 raise_utility_code = UtilityCode(
 proto = """
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb); /*proto*/
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); /*proto*/
 """,
 impl = """
 #if PY_MAJOR_VERSION < 3
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
+    /* cause is unused */
     Py_XINCREF(type);
     Py_XINCREF(value);
     Py_XINCREF(tb);
@@ -5931,7 +6089,7 @@ raise_error:
 
 #else /* Python 3+ */
 
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
     if (tb == Py_None) {
         tb = 0;
     } else if (tb && !PyTraceBack_Check(tb)) {
@@ -5956,6 +6114,29 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
         goto bad;
     }
 
+    if (cause) {
+        PyObject *fixed_cause;
+        if (PyExceptionClass_Check(cause)) {
+            fixed_cause = PyObject_CallObject(cause, NULL);
+            if (fixed_cause == NULL)
+                goto bad;
+        }
+        else if (PyExceptionInstance_Check(cause)) {
+            fixed_cause = cause;
+            Py_INCREF(fixed_cause);
+        }
+        else {
+            PyErr_SetString(PyExc_TypeError,
+                            "exception causes must derive from "
+                            "BaseException");
+            goto bad;
+        }
+        if (!value) {
+            value = PyObject_CallObject(type, NULL);
+        }
+        PyException_SetCause(value, fixed_cause);
+    }
+
     PyErr_SetObject(type, value);
 
     if (tb) {
@@ -6092,6 +6273,31 @@ static void __Pyx_ExceptionReset(PyObject *type, PyObject *value, PyObject *tb)
 
 #------------------------------------------------------------------------------------
 
+swap_exception_utility_code = UtilityCode(
+proto = """
+static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb); /*proto*/
+""",
+impl = """
+static CYTHON_INLINE void __Pyx_ExceptionSwap(PyObject **type, PyObject **value, PyObject **tb) {
+    PyObject *tmp_type, *tmp_value, *tmp_tb;
+    PyThreadState *tstate = PyThreadState_GET();
+
+    tmp_type = tstate->exc_type;
+    tmp_value = tstate->exc_value;
+    tmp_tb = tstate->exc_traceback;
+
+    tstate->exc_type = *type;
+    tstate->exc_value = *value;
+    tstate->exc_traceback = *tb;
+
+    *type = tmp_type;
+    *value = tmp_value;
+    *tb = tmp_tb;
+}
+""")
+
+#------------------------------------------------------------------------------------
+
 arg_type_test_utility_code = UtilityCode(
 proto = """
 static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed,