Merged pull request #12 from bhy/T423.
[cython.git] / Cython / Compiler / Nodes.py
index 51ef88f3f0be72446ed67045d2afc8d274834a52..4405fefa46fc911c2d57d0aeb6096787ab1f4a4a 100644 (file)
@@ -4154,8 +4154,9 @@ class RaiseStatNode(StatNode):
     #  exc_type    ExprNode or None
     #  exc_value   ExprNode or None
     #  exc_tb      ExprNode or None
+    #  cause       ExprNode or None
 
-    child_attrs = ["exc_type", "exc_value", "exc_tb"]
+    child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
 
     def analyse_expressions(self, env):
         if self.exc_type:
@@ -4167,6 +4168,9 @@ class RaiseStatNode(StatNode):
         if self.exc_tb:
             self.exc_tb.analyse_types(env)
             self.exc_tb = self.exc_tb.coerce_to_pyobject(env)
+        if self.cause:
+            self.cause.analyse_types(env)
+            self.cause = self.cause.coerce_to_pyobject(env)
         # special cases for builtin exceptions
         self.builtin_exc_name = None
         if self.exc_type and not self.exc_value and not self.exc_tb:
@@ -4204,13 +4208,19 @@ class RaiseStatNode(StatNode):
             tb_code = self.exc_tb.py_result()
         else:
             tb_code = "0"
+        if self.cause:
+            self.cause.generate_evaluation_code(code)
+            cause_code = self.cause.py_result()
+        else:
+            cause_code = "0"
         code.globalstate.use_utility_code(raise_utility_code)
         code.putln(
-            "__Pyx_Raise(%s, %s, %s);" % (
+            "__Pyx_Raise(%s, %s, %s, %s);" % (
                 type_code,
                 value_code,
-                tb_code))
-        for obj in (self.exc_type, self.exc_value, self.exc_tb):
+                tb_code,
+                cause_code))
+        for obj in (self.exc_type, self.exc_value, self.exc_tb, self.cause):
             if obj:
                 obj.generate_disposal_code(code)
                 obj.free_temps(code)
@@ -4224,6 +4234,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.generate_function_definitions(env, code)
         if self.exc_tb is not None:
             self.exc_tb.generate_function_definitions(env, code)
+        if self.cause is not None:
+            self.cause.generate_function_definitions(env, code)
 
     def annotate(self, code):
         if self.exc_type:
@@ -4232,6 +4244,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.annotate(code)
         if self.exc_tb:
             self.exc_tb.annotate(code)
+        if self.cause:
+            self.cause.annotate(code)
 
 
 class ReraiseStatNode(StatNode):
@@ -6003,11 +6017,12 @@ static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyOb
 
 raise_utility_code = UtilityCode(
 proto = """
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb); /*proto*/
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); /*proto*/
 """,
 impl = """
 #if PY_MAJOR_VERSION < 3
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
+    /* cause is unused */
     Py_XINCREF(type);
     Py_XINCREF(value);
     Py_XINCREF(tb);
@@ -6074,7 +6089,7 @@ raise_error:
 
 #else /* Python 3+ */
 
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
     if (tb == Py_None) {
         tb = 0;
     } else if (tb && !PyTraceBack_Check(tb)) {
@@ -6099,6 +6114,29 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
         goto bad;
     }
 
+    if (cause) {
+        PyObject *fixed_cause;
+        if (PyExceptionClass_Check(cause)) {
+            fixed_cause = PyObject_CallObject(cause, NULL);
+            if (fixed_cause == NULL)
+                goto bad;
+        }
+        else if (PyExceptionInstance_Check(cause)) {
+            fixed_cause = cause;
+            Py_INCREF(fixed_cause);
+        }
+        else {
+            PyErr_SetString(PyExc_TypeError,
+                            "exception causes must derive from "
+                            "BaseException");
+            goto bad;
+        }
+        if (!value) {
+            value = PyObject_CallObject(type, NULL);
+        }
+        PyException_SetCause(value, fixed_cause);
+    }
+
     PyErr_SetObject(type, value);
 
     if (tb) {