Merged pull request #12 from bhy/T423.
authorrobertwb <robertwb@gmail.com>
Tue, 26 Apr 2011 19:25:23 +0000 (12:25 -0700)
committerrobertwb <robertwb@gmail.com>
Tue, 26 Apr 2011 19:25:23 +0000 (12:25 -0700)
T423 explicit execption chaining

Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
runtests.py
tests/run/funcexceptraisefrom.pyx [new file with mode: 0644]
tests/run/test_raisefrom.pyx [new file with mode: 0644]

index 51ef88f3f0be72446ed67045d2afc8d274834a52..4405fefa46fc911c2d57d0aeb6096787ab1f4a4a 100644 (file)
@@ -4154,8 +4154,9 @@ class RaiseStatNode(StatNode):
     #  exc_type    ExprNode or None
     #  exc_value   ExprNode or None
     #  exc_tb      ExprNode or None
+    #  cause       ExprNode or None
 
-    child_attrs = ["exc_type", "exc_value", "exc_tb"]
+    child_attrs = ["exc_type", "exc_value", "exc_tb", "cause"]
 
     def analyse_expressions(self, env):
         if self.exc_type:
@@ -4167,6 +4168,9 @@ class RaiseStatNode(StatNode):
         if self.exc_tb:
             self.exc_tb.analyse_types(env)
             self.exc_tb = self.exc_tb.coerce_to_pyobject(env)
+        if self.cause:
+            self.cause.analyse_types(env)
+            self.cause = self.cause.coerce_to_pyobject(env)
         # special cases for builtin exceptions
         self.builtin_exc_name = None
         if self.exc_type and not self.exc_value and not self.exc_tb:
@@ -4204,13 +4208,19 @@ class RaiseStatNode(StatNode):
             tb_code = self.exc_tb.py_result()
         else:
             tb_code = "0"
+        if self.cause:
+            self.cause.generate_evaluation_code(code)
+            cause_code = self.cause.py_result()
+        else:
+            cause_code = "0"
         code.globalstate.use_utility_code(raise_utility_code)
         code.putln(
-            "__Pyx_Raise(%s, %s, %s);" % (
+            "__Pyx_Raise(%s, %s, %s, %s);" % (
                 type_code,
                 value_code,
-                tb_code))
-        for obj in (self.exc_type, self.exc_value, self.exc_tb):
+                tb_code,
+                cause_code))
+        for obj in (self.exc_type, self.exc_value, self.exc_tb, self.cause):
             if obj:
                 obj.generate_disposal_code(code)
                 obj.free_temps(code)
@@ -4224,6 +4234,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.generate_function_definitions(env, code)
         if self.exc_tb is not None:
             self.exc_tb.generate_function_definitions(env, code)
+        if self.cause is not None:
+            self.cause.generate_function_definitions(env, code)
 
     def annotate(self, code):
         if self.exc_type:
@@ -4232,6 +4244,8 @@ class RaiseStatNode(StatNode):
             self.exc_value.annotate(code)
         if self.exc_tb:
             self.exc_tb.annotate(code)
+        if self.cause:
+            self.cause.annotate(code)
 
 
 class ReraiseStatNode(StatNode):
@@ -6003,11 +6017,12 @@ static CYTHON_INLINE void __Pyx_ErrFetch(PyObject **type, PyObject **value, PyOb
 
 raise_utility_code = UtilityCode(
 proto = """
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb); /*proto*/
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause); /*proto*/
 """,
 impl = """
 #if PY_MAJOR_VERSION < 3
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
+    /* cause is unused */
     Py_XINCREF(type);
     Py_XINCREF(value);
     Py_XINCREF(tb);
@@ -6074,7 +6089,7 @@ raise_error:
 
 #else /* Python 3+ */
 
-static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
+static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb, PyObject *cause) {
     if (tb == Py_None) {
         tb = 0;
     } else if (tb && !PyTraceBack_Check(tb)) {
@@ -6099,6 +6114,29 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
         goto bad;
     }
 
+    if (cause) {
+        PyObject *fixed_cause;
+        if (PyExceptionClass_Check(cause)) {
+            fixed_cause = PyObject_CallObject(cause, NULL);
+            if (fixed_cause == NULL)
+                goto bad;
+        }
+        else if (PyExceptionInstance_Check(cause)) {
+            fixed_cause = cause;
+            Py_INCREF(fixed_cause);
+        }
+        else {
+            PyErr_SetString(PyExc_TypeError,
+                            "exception causes must derive from "
+                            "BaseException");
+            goto bad;
+        }
+        if (!value) {
+            value = PyObject_CallObject(type, NULL);
+        }
+        PyException_SetCause(value, fixed_cause);
+    }
+
     PyErr_SetObject(type, value);
 
     if (tb) {
index e111bc13862a6703c6198408f62a8dd1c1cfa358..4e2c55ad2c2aee350487bbac841dac8b4e0695cc 100644 (file)
@@ -1172,6 +1172,7 @@ def p_raise_statement(s):
     exc_type = None
     exc_value = None
     exc_tb = None
+    cause = None
     if s.sy not in statement_terminators:
         exc_type = p_test(s)
         if s.sy == ',':
@@ -1180,11 +1181,15 @@ def p_raise_statement(s):
             if s.sy == ',':
                 s.next()
                 exc_tb = p_test(s)
+        elif s.sy == 'from':
+            s.next()
+            cause = p_test(s)
     if exc_type or exc_value or exc_tb:
         return Nodes.RaiseStatNode(pos,
             exc_type = exc_type,
             exc_value = exc_value,
-            exc_tb = exc_tb)
+            exc_tb = exc_tb,
+            cause = cause)
     else:
         return Nodes.ReraiseStatNode(pos)
 
index a0e256100f91c8ba732f0129bdaa86cd80cd9d93..e9d09f1562f12ea5d12ce986601e06812ffa9f19 100644 (file)
@@ -107,7 +107,9 @@ VER_DEP_MODULES = {
     # we can only have one (3,) key.  Since 2.7 is supposed to be the
     # last 2.x release, things would have to change drastically for this
     # to be unsafe...
-    (2,999): (operator.lt, lambda x: x in ['run.special_methods_T561_py3']),
+    (2,999): (operator.lt, lambda x: x in ['run.special_methods_T561_py3',
+                                           'run.test_raisefrom',
+                                           ]),
     (3,): (operator.ge, lambda x: x in ['run.non_future_division',
                                         'compile.extsetslice',
                                         'compile.extdelslice',
diff --git a/tests/run/funcexceptraisefrom.pyx b/tests/run/funcexceptraisefrom.pyx
new file mode 100644 (file)
index 0000000..ea83add
--- /dev/null
@@ -0,0 +1,51 @@
+__doc__ = u"""
+>>> def bar():
+...     try:
+...         foo()
+...     except ValueError:
+...         if IS_PY3:
+...             print(isinstance(sys.exc_info()[1].__cause__, TypeError))
+...         else:
+...             print(True)
+
+>>> bar()
+True
+
+>>> print(sys.exc_info())
+(None, None, None)
+
+>>> def bar2():
+...     try:
+...         foo2()
+...     except ValueError:
+...         if IS_PY3:
+...             cause = sys.exc_info()[1].__cause__
+...             print(isinstance(cause, TypeError))
+...             print(cause.args==('value',))
+...             pass
+...         else:
+...             print(True)
+...             print(True)
+
+>>> bar2()
+True
+True
+"""
+
+import sys
+IS_PY3 = sys.version_info[0] >= 3
+if not IS_PY3:
+    sys.exc_clear()
+
+def foo():
+    try:
+        raise TypeError
+    except TypeError:
+        raise ValueError from TypeError
+
+def foo2():
+    try:
+        raise TypeError
+    except TypeError:
+        raise ValueError() from TypeError('value')
+
diff --git a/tests/run/test_raisefrom.pyx b/tests/run/test_raisefrom.pyx
new file mode 100644 (file)
index 0000000..3a6110f
--- /dev/null
@@ -0,0 +1,39 @@
+import unittest
+# adapted from pyregr
+class TestCause(unittest.TestCase):
+    def test_invalid_cause(self):
+        try:
+            raise IndexError from 5
+        except TypeError as e:
+            self.assertIn("exception cause", str(e))
+        else:
+            self.fail("No exception raised")
+
+    def test_class_cause(self):
+        try:
+            raise IndexError from KeyError
+        except IndexError as e:
+            self.assertIsInstance(e.__cause__, KeyError)
+        else:
+            self.fail("No exception raised")
+
+    def test_instance_cause(self):
+        cause = KeyError()
+        try:
+            raise IndexError from cause
+        except IndexError as e:
+            self.assertTrue(e.__cause__ is cause)
+        else:
+            self.fail("No exception raised")
+
+    def test_erroneous_cause(self):
+        class MyException(Exception):
+            def __init__(self):
+                raise RuntimeError()
+
+        try:
+            raise IndexError from MyException
+        except RuntimeError:
+            pass
+        else:
+            self.fail("No exception raised")