From ab3d64a207f2975d8ce8100c6a90ef8c5b66f260 Mon Sep 17 00:00:00 2001
From: Stefan Behnel <scoder@users.berlios.de>
Date: Mon, 25 Apr 2011 17:08:26 +0200
Subject: [PATCH] enable forked test runs for pyregr test suite

---
 runtests.py | 139 ++++++++++++++++++++++++++++------------------------
 1 file changed, 74 insertions(+), 65 deletions(-)

diff --git a/runtests.py b/runtests.py
index 9f3259f8..4d2c1bb2 100644
--- a/runtests.py
+++ b/runtests.py
@@ -588,64 +588,70 @@ class CythonRunTestCase(CythonCompileTestCase):
             self.run_doctests(self.module, result)
 
     def run_doctests(self, module_name, result):
-        if sys.version_info[0] >= 3 or not hasattr(os, 'fork') or not self.fork:
-            doctest.DocTestSuite(module_name).run(result)
-            gc.collect()
-            return
-
-        # fork to make sure we do not keep the tested module loaded
-        result_handle, result_file = tempfile.mkstemp()
-        os.close(result_handle)
-        child_id = os.fork()
-        if not child_id:
-            result_code = 0
-            try:
-                try:
-                    tests = None
-                    try:
-                        partial_result = PartialTestResult(result)
-                        tests = doctest.DocTestSuite(module_name)
-                        tests.run(partial_result)
-                        gc.collect()
-                    except Exception:
-                        if tests is None:
-                            # importing failed, try to fake a test class
-                            tests = _FakeClass(
-                                failureException=sys.exc_info()[1],
-                                _shortDescription=self.shortDescription(),
-                                module_name=None)
-                        partial_result.addError(tests, sys.exc_info())
-                        result_code = 1
-                    output = open(result_file, 'wb')
-                    pickle.dump(partial_result.data(), output)
-                except:
-                    traceback.print_exc()
-            finally:
-                try: output.close()
-                except: pass
-                os._exit(result_code)
+        def run_test(result):
+            tests = doctest.DocTestSuite(module_name)
+            tests.run(result)
+        run_forked_test(result, run_test, self.shortDescription(), self.fork)
 
+
+def run_forked_test(result, run_func, test_name, fork=True):
+    if sys.version_info[0] >= 3 or not hasattr(os, 'fork') or not fork:
+        run_test(result)
+        gc.collect()
+        return
+
+    # fork to make sure we do not keep the tested module loaded
+    result_handle, result_file = tempfile.mkstemp()
+    os.close(result_handle)
+    child_id = os.fork()
+    if not child_id:
+        result_code = 0
         try:
-            cid, result_code = os.waitpid(child_id, 0)
-            # os.waitpid returns the child's result code in the
-            # upper byte of result_code, and the signal it was
-            # killed by in the lower byte
-            if result_code & 255:
-                raise Exception("Tests in module '%s' were unexpectedly killed by signal %d"%
-                                (module_name, result_code & 255))
-            result_code = result_code >> 8
-            if result_code in (0,1):
-                input = open(result_file, 'rb')
+            try:
+                tests = None
                 try:
-                    PartialTestResult.join_results(result, pickle.load(input))
-                finally:
-                    input.close()
-            if result_code:
-                raise Exception("Tests in module '%s' exited with status %d" %
-                                (module_name, result_code))
+                    partial_result = PartialTestResult(result)
+                    run_func(partial_result)
+                    gc.collect()
+                except Exception:
+                    if tests is None:
+                        # importing failed, try to fake a test class
+                        tests = _FakeClass(
+                            failureException=sys.exc_info()[1],
+                            _shortDescription=test_name,
+                            module_name=None)
+                    partial_result.addError(tests, sys.exc_info())
+                    result_code = 1
+                output = open(result_file, 'wb')
+                pickle.dump(partial_result.data(), output)
+            except:
+                traceback.print_exc()
         finally:
-            try: os.unlink(result_file)
+            try: output.close()
             except: pass
+            os._exit(result_code)
+
+    try:
+        cid, result_code = os.waitpid(child_id, 0)
+        # os.waitpid returns the child's result code in the
+        # upper byte of result_code, and the signal it was
+        # killed by in the lower byte
+        if result_code & 255:
+            raise Exception("Tests in module '%s' were unexpectedly killed by signal %d"%
+                            (module_name, result_code & 255))
+        result_code = result_code >> 8
+        if result_code in (0,1):
+            input = open(result_file, 'rb')
+            try:
+                PartialTestResult.join_results(result, pickle.load(input))
+            finally:
+                input.close()
+        if result_code:
+            raise Exception("Tests in module '%s' exited with status %d" %
+                            (module_name, result_code))
+    finally:
+        try: os.unlink(result_file)
+        except: pass
 
 class PureDoctestTestCase(unittest.TestCase):
     def __init__(self, module_name, module_path):
@@ -773,20 +779,23 @@ class CythonPyregrTestCase(CythonRunTestCase):
         except ImportError: # Py3k
             from test import support
 
-        def run_unittest(*classes):
-            return self._run_unittest(result, *classes)
-        def run_doctest(module, verbosity=None):
-            return self._run_doctest(result, module)
+        def run_test(result):
+            def run_unittest(*classes):
+                return self._run_unittest(result, *classes)
+            def run_doctest(module, verbosity=None):
+                return self._run_doctest(result, module)
 
-        support.run_unittest = run_unittest
-        support.run_doctest = run_doctest
+            support.run_unittest = run_unittest
+            support.run_doctest = run_doctest
 
-        try:
-            module = __import__(self.module)
-            if hasattr(module, 'test_main'):
-                module.test_main()
-        except (unittest.SkipTest, support.ResourceDenied):
-            result.addSkip(self, 'ok')
+            try:
+                module = __import__(self.module)
+                if hasattr(module, 'test_main'):
+                    module.test_main()
+            except (unittest.SkipTest, support.ResourceDenied):
+                result.addSkip(self, 'ok')
+
+        run_forked_test(result, run_test, self.shortDescription(), self.fork)
 
 include_debugger = sys.version_info[:2] > (2, 5)
 
-- 
2.26.2