From 5d470de568833c6a0fca3a4e76d097f8c4c972e9 Mon Sep 17 00:00:00 2001
From: Stefan Behnel <scoder@users.berlios.de>
Date: Sun, 31 Aug 2008 13:08:38 +0200
Subject: [PATCH] faster exception handling/try-finally/etc. by inlining exc
 fetch/restore code

---
 Cython/Compiler/Nodes.py | 60 +++++++++++++++++++++++++++++++++-------
 1 file changed, 50 insertions(+), 10 deletions(-)

diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py
index 24d4f7ba..b8a29c74 100644
--- a/Cython/Compiler/Nodes.py
+++ b/Cython/Compiler/Nodes.py
@@ -951,12 +951,13 @@ class FuncDefNode(StatNode, BlockNode):
             # so need to save and restore error state
             buffers_present = len(lenv.buffer_entries) > 0
             if buffers_present:
+                code.globalstate.use_utility_code(restore_exception_utility_code)
                 code.putln("{ PyObject *__pyx_type, *__pyx_value, *__pyx_tb;")
-                code.putln("PyErr_Fetch(&__pyx_type, &__pyx_value, &__pyx_tb);")
+                code.putln("__Pyx_ErrFetch(&__pyx_type, &__pyx_value, &__pyx_tb);")
                 for entry in lenv.buffer_entries:
                     code.putln("%s;" % Buffer.get_release_buffer_code(entry))
                     #code.putln("%s = 0;" % entry.cname)
-                code.putln("PyErr_Restore(__pyx_type, __pyx_value, __pyx_tb);}")
+                code.putln("__Pyx_ErrRestore(__pyx_type, __pyx_value, __pyx_tb);}")
 
             err_val = self.error_value()
             exc_check = self.caller_will_check_exceptions()
@@ -969,6 +970,7 @@ class FuncDefNode(StatNode, BlockNode):
                     '__Pyx_WriteUnraisable("%s");' % 
                         self.entry.qualified_name)
                 env.use_utility_code(unraisable_exception_utility_code)
+                env.use_utility_code(restore_exception_utility_code)
             default_retval = self.return_type.default_value
             if err_val is None and default_retval:
                 err_val = default_retval
@@ -2928,6 +2930,7 @@ class RaiseStatNode(StatNode):
         if self.exc_tb:
             self.exc_tb.release_temp(env)
         env.use_utility_code(raise_utility_code)
+        env.use_utility_code(restore_exception_utility_code)
         self.gil_check(env)
 
     gil_message = "Raising exception"
@@ -2982,6 +2985,7 @@ class ReraiseStatNode(StatNode):
     def analyse_expressions(self, env):
         self.gil_check(env)
         env.use_utility_code(raise_utility_code)
+        env.use_utility_code(restore_exception_utility_code)
 
     gil_message = "Raising exception"
 
@@ -3641,6 +3645,7 @@ class ExceptClauseNode(Node):
         for var in self.exc_vars:
             env.release_temp(var)
         env.use_utility_code(get_exception_utility_code)
+        env.use_utility_code(restore_exception_utility_code)
     
     def generate_handling_code(self, code, end_label):
         code.mark_pos(self.pos)
@@ -3825,6 +3830,7 @@ class TryFinallyStatNode(StatNode):
             "}")
 
     def put_error_catcher(self, code, error_label, i, catch_label):
+        code.globalstate.use_utility_code(restore_exception_utility_code)
         code.putln(
             "%s: {" %
                 error_label)
@@ -3833,7 +3839,7 @@ class TryFinallyStatNode(StatNode):
                     i)
         code.put_var_xdecrefs_clear(self.cleanup_list)
         code.putln(
-                "PyErr_Fetch(&%s, &%s, &%s);" %
+                "__Pyx_ErrFetch(&%s, &%s, &%s);" %
                     Naming.exc_vars)
         code.putln(
                 "%s = %s;" % (
@@ -3846,11 +3852,12 @@ class TryFinallyStatNode(StatNode):
             "}")
             
     def put_error_uncatcher(self, code, i, error_label):
+        code.globalstate.use_utility_code(restore_exception_utility_code)
         code.putln(
             "case %s: {" %
                 i)
         code.putln(
-                "PyErr_Restore(%s, %s, %s);" %
+                "__Pyx_ErrRestore(%s, %s, %s);" %
                     Naming.exc_vars)
         code.putln(
                 "%s = %s;" % (
@@ -4285,7 +4292,7 @@ static void __Pyx_Raise(PyObject *type, PyObject *value, PyObject *tb) {
             }
         #endif
     }
-    PyErr_Restore(type, value, tb);
+    __Pyx_ErrRestore(type, value, tb);
     return;
 raise_error:
     Py_XDECREF(value);
@@ -4309,7 +4316,7 @@ static void __Pyx_ReRaise(void) {
     Py_XINCREF(type);
     Py_XINCREF(value);
     Py_XINCREF(tb);
-    PyErr_Restore(type, value, tb);
+    __Pyx_ErrRestore(type, value, tb);
 }
 """]
 
@@ -4559,13 +4566,13 @@ static void __Pyx_WriteUnraisable(const char *name); /*proto*/
 static void __Pyx_WriteUnraisable(const char *name) {
     PyObject *old_exc, *old_val, *old_tb;
     PyObject *ctx;
-    PyErr_Fetch(&old_exc, &old_val, &old_tb);
+    __Pyx_ErrFetch(&old_exc, &old_val, &old_tb);
     #if PY_MAJOR_VERSION < 3
     ctx = PyString_FromString(name);
     #else
     ctx = PyUnicode_FromString(name);
     #endif
-    PyErr_Restore(old_exc, old_val, old_tb);
+    __Pyx_ErrRestore(old_exc, old_val, old_tb);
     if (!ctx)
         ctx = Py_None;
     PyErr_WriteUnraisable(ctx);
@@ -4664,6 +4671,39 @@ bad:
     'EMPTY_TUPLE' : Naming.empty_tuple,
 }]
 
+restore_exception_utility_code = [
+"""
+void INLINE __Pyx_ErrRestore(PyObject *type, PyObject *value, PyObject *tb); /*proto*/
+void INLINE __Pyx_ErrFetch(PyObject **type, PyObject **value, PyObject **tb); /*proto*/
+""","""
+void INLINE __Pyx_ErrRestore(PyObject *type, PyObject *value, PyObject *tb) {
+    PyObject *tmp_type, *tmp_value, *tmp_tb;
+    PyThreadState *tstate = PyThreadState_GET();
+
+    tmp_type = tstate->curexc_type;
+    tmp_value = tstate->curexc_value;
+    tmp_tb = tstate->curexc_traceback;
+    tstate->curexc_type = type;
+    tstate->curexc_value = value;
+    tstate->curexc_traceback = tb;
+    Py_XDECREF(tmp_type);
+    Py_XDECREF(tmp_value);
+    Py_XDECREF(tmp_tb);
+}
+
+void INLINE __Pyx_ErrFetch(PyObject **type, PyObject **value, PyObject **tb) {
+    PyThreadState *tstate = PyThreadState_GET();
+    *type = tstate->curexc_type;
+    *value = tstate->curexc_value;
+    *tb = tstate->curexc_traceback;
+
+    tstate->curexc_type = 0;
+    tstate->curexc_value = 0;
+    tstate->curexc_traceback = 0;
+}
+
+"""]
+
 #------------------------------------------------------------------------------------
 
 set_vtable_utility_code = [
@@ -4758,8 +4798,8 @@ static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb);
 ""","""
 static int __Pyx_GetException(PyObject **type, PyObject **value, PyObject **tb) {
     PyObject *tmp_type, *tmp_value, *tmp_tb;
-    PyThreadState *tstate = PyThreadState_Get();
-    PyErr_Fetch(type, value, tb);
+    PyThreadState *tstate = PyThreadState_GET();
+    __Pyx_ErrFetch(type, value, tb);
     PyErr_NormalizeException(type, value, tb);
     if (PyErr_Occurred())
         goto bad;
-- 
2.26.2