explicit execption chaining
authorHaoyu Bai <baihaoyu@gmail.com>
Wed, 30 Mar 2011 16:41:22 +0000 (00:41 +0800)
committerHaoyu Bai <baihaoyu@gmail.com>
Wed, 30 Mar 2011 16:41:22 +0000 (00:41 +0800)
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
tests/run/funcexceptraisefrom.pyx [new file with mode: 0644]

index 609489f1758a09bde72f7a7476c64ed54c8d809f..4835f8fe45c1d0dc2e306c5a08840010883d441d 100644 (file)
@@ -3942,8 +3942,9 @@ class RaiseStatNode(StatNode):
     #  exc_type    ExprNode or None
     #  exc_value   ExprNode or None
     #  exc_tb      ExprNode or None
+    #  cause       ExprNode or None
 
-    child_attrs = ["exc_type", "exc_value", "exc_tb"]
+    child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
 
     def analyse_expressions(self, env):
         if self.exc_type:
@@ -3955,6 +3956,9 @@ class RaiseStatNode(StatNode):
         if self.exc_tb:
             self.exc_tb.analyse_types(env)
             self.exc_tb = self.exc_tb.coerce_to_pyobject(env)
+        if self.cause:
+            self.cause.analyse_types(env)
+            self.cause = self.cause.coerce_to_pyobject(env)
         # special cases for builtin exceptions
         self.builtin_exc_name = None
         if self.exc_type and not self.exc_value and not self.exc_tb:
@@ -3990,13 +3994,19 @@ class RaiseStatNode(StatNode):
             tb_code = self.exc_tb.py_result()
         else:
             tb_code = "0"
+        if self.cause:
+            self.cause.generate_evaluation_code(code)
+            cause_code = self.cause.py_result()
+        else:
+            cause_code = "0"
         code.globalstate.use_utility_code(raise_utility_code)
         code.putln(
-            "__Pyx_Raise(%s, %s, %s);" % (
+            "__Pyx_Raise(%s, %s, %s, %s);" % (
                 type_code,
                 value_code,
-                tb_code))
-        for obj in (self.exc_type, self.exc_value, self.exc_tb):
+                tb_code,
+                cause_code))
+        for obj in (self.exc_type, self.exc_value, self.exc_tb, self.cause):
             if obj:
                 obj.generate_disposal_code(code)
                 obj.free_temps(code)
@@ -4010,6 +4020,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.generate_function_definitions(env, code)
         if self.exc_tb is not None:
             self.exc_tb.generate_function_definitions(env, code)
+        if self.cause is not None:
+            self.cause.generate_function_definitions(env, code)
 
     def annotate(self, code):
         if self.exc_type:
@@ -4018,6 +4030,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.annotate(code)
         if self.exc_tb:
             self.exc_tb.annotate(code)
+        if self.cause:
+            self.cause.annotate(code)
 
 
 class ReraiseStatNode(StatNode):
@@ -5652,11 +5666,12 @@ static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyOb
 
 raise_utility_code = UtilityCode(
 proto = """
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb); /*proto*/
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); /*proto*/
 """,
 impl = """
 #if PY_MAJOR_VERSION < 3
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
+    /* cause is unused */
     Py_XINCREF(type);
     Py_XINCREF(value);
     Py_XINCREF(tb);
@@ -5723,7 +5738,7 @@ raise_error:
 
 #else /* Python 3+ */
 
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
     if (tb == Py_None) {
         tb = 0;
     } else if (tb && !PyTraceBack_Check(tb)) {
@@ -5748,6 +5763,29 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
         goto bad;
     }
 
+    if (cause) {
+        PyObject *fixed_cause;
+        if (PyExceptionClass_Check(cause)) {
+            fixed_cause = PyObject_CallObject(cause, NULL);
+            if (fixed_cause == NULL)
+                goto bad;
+        }
+        else if (PyExceptionInstance_Check(cause)) {
+            fixed_cause = cause;
+            Py_INCREF(fixed_cause);
+        }
+        else {
+            PyErr_SetString(PyExc_TypeError,
+                            "exception causes must derive from "
+                            "BaseException");
+            goto bad;
+        }
+        if (!value) {
+            value = PyObject_CallObject(type, NULL);
+        }
+        PyException_SetCause(value, fixed_cause);
+    }
+
     PyErr_SetObject(type, value);
 
     if (tb) {
index 91ddf6a7f9ae25881bb97947779054e8f70b5747..9caaa39920b6bdc6952890aad35c41b8b16680db 100644 (file)
@@ -1166,6 +1166,7 @@ def p_raise_statement(s):
     exc_type = None
     exc_value = None
     exc_tb = None
+    cause = None
     if s.sy not in statement_terminators:
         exc_type = p_test(s)
         if s.sy == ',':
@@ -1174,11 +1175,15 @@ def p_raise_statement(s):
             if s.sy == ',':
                 s.next()
                 exc_tb = p_test(s)
+        elif s.sy == 'from':
+            s.next()
+            cause = p_test(s)
     if exc_type or exc_value or exc_tb:
         return Nodes.RaiseStatNode(pos,
             exc_type = exc_type,
             exc_value = exc_value,
-            exc_tb = exc_tb)
+            exc_tb = exc_tb,
+            cause = cause)
     else:
         return Nodes.ReraiseStatNode(pos)
 
diff --git a/tests/run/funcexceptraisefrom.pyx b/tests/run/funcexceptraisefrom.pyx
new file mode 100644 (file)
index 0000000..ea83add
--- /dev/null
@@ -0,0 +1,51 @@
+__doc__ = u"""
+>>> def bar():
+...     try:
+...         foo()
+...     except ValueError:
+...         if IS_PY3:
+...             print(isinstance(sys.exc_info()[1].__cause__, TypeError))
+...         else:
+...             print(True)
+
+>>> bar()
+True
+
+>>> print(sys.exc_info())
+(None, None, None)
+
+>>> def bar2():
+...     try:
+...         foo2()
+...     except ValueError:
+...         if IS_PY3:
+...             cause = sys.exc_info()[1].__cause__
+...             print(isinstance(cause, TypeError))
+...             print(cause.args==('value',))
+...             pass
+...         else:
+...             print(True)
+...             print(True)
+
+>>> bar2()
+True
+True
+"""
+
+import sys
+IS_PY3 = sys.version_info[0] >= 3
+if not IS_PY3:
+    sys.exc_clear()
+
+def foo():
+    try:
+        raise TypeError
+    except TypeError:
+        raise ValueError from TypeError
+
+def foo2():
+    try:
+        raise TypeError
+    except TypeError:
+        raise ValueError() from TypeError('value')
+