Better exception info reading for with statement
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 19 Jun 2008 00:49:58 +0000 (17:49 -0700)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Thu, 19 Jun 2008 00:49:58 +0000 (17:49 -0700)
Cython/CodeWriter.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Symtab.py
Cython/Compiler/Tests/TestParseTreeTransforms.py
Cython/Compiler/TreeFragment.py
Cython/Compiler/Visitor.py
Cython/TestUtils.py
tests/run/withstat.pyx

index aad485c6a3c2568cfe3a0177b35f99f0320df00d..bf9e3cc6e8190551e023cf44209d70b5d9637822 100644 (file)
@@ -1,7 +1,6 @@
-from Cython.Compiler.Visitor import TreeVisitor
+from Cython.Compiler.Visitor import TreeVisitor, get_temp_name_handle_desc
 from Cython.Compiler.Nodes import *
 from Cython.Compiler.ExprNodes import *
-from Cython.Compiler.Symtab import TempName
 
 """
 Serializes a Cython code tree to Cython code. This is primarily useful for
@@ -62,8 +61,9 @@ class CodeWriter(TreeVisitor):
         self.endline()
 
     def putname(self, name):
-        if isinstance(name, TempName):
-            name = self.tempnames.setdefault(name, u"$" + name.description)
+        tmpdesc = get_temp_name_handle_desc(name)
+        if tmpdesc is not None:
+            name = self.tempnames.setdefault(tmpdesc, u"$" +tmpdesc)
         self.put(name)
     
     def comma_seperated_list(self, items, output_rhs=False):
index 758de510a9e6adfecc08aa9cd303968b5664769c..c61f4a28b165538dc3f825fc2bf4e8730a340f38 100644 (file)
@@ -1060,7 +1060,6 @@ class NameNode(AtomicExprNode):
             else:
                 code.annotate(pos, AnnotationItem('c_call', 'c function', size=len(self.name)))
             
-
 class BackquoteNode(ExprNode):
     #  `expr`
     #
@@ -1212,6 +1211,9 @@ class ExcValueNode(AtomicExprNode):
     def generate_result_code(self, code):
         pass
 
+    def analyse_types(self, env):
+        pass
+
 
 class TempNode(AtomicExprNode):
     #  Node created during analyse_types phase
index dd88dd6bc4233963828f66bb3324f61faf1cd1c1..c985e0d1954dd9c93fce936f9d8a0b906c8e4213 100644 (file)
@@ -3329,18 +3329,24 @@ class ExceptClauseNode(Node):
     #  pattern        ExprNode
     #  target         ExprNode or None
     #  body           StatNode
+    #  excinfo_target NameNode or None   optional target for exception info
+    #  excinfo_target NameNode or None   used internally
     #  match_flag     string             result of exception match
     #  exc_value      ExcValueNode       used internally
     #  function_name  string             qualified name of enclosing function
     #  exc_vars       (string * 3)       local exception variables
     
-    child_attrs = ["pattern", "target", "body", "exc_value"]
+    child_attrs = ["pattern", "target", "body", "exc_value", "excinfo_target"]
 
     exc_value = None
+    excinfo_target = None
+    excinfo_assignment = None
 
     def analyse_declarations(self, env):
         if self.target:
             self.target.analyse_target_declaration(env)
+        if self.excinfo_target is not None:
+            self.excinfo_target.analyse_target_declaration(env)
         self.body.analyse_declarations(env)
     
     def analyse_expressions(self, env):
@@ -3358,6 +3364,17 @@ class ExceptClauseNode(Node):
             self.exc_value = ExprNodes.ExcValueNode(self.pos, env, self.exc_vars[1])
             self.exc_value.allocate_temps(env)
             self.target.analyse_target_expression(env, self.exc_value)
+        if self.excinfo_target is not None:
+            import ExprNodes
+            self.excinfo_tuple = ExprNodes.TupleNode(pos=self.pos, args=[
+                ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[0]),
+                ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[1]),
+                ExprNodes.ExcValueNode(pos=self.pos, env=env, var=self.exc_vars[2])
+            ])
+            self.excinfo_tuple.analyse_expressions(env)
+            self.excinfo_tuple.allocate_temps(env)
+            self.excinfo_target.analyse_target_expression(env, self.excinfo_tuple)
+
         self.body.analyse_expressions(env)
         for var in self.exc_vars:
             env.release_temp(var)
@@ -3387,6 +3404,10 @@ class ExceptClauseNode(Node):
         if self.target:
             self.exc_value.generate_evaluation_code(code)
             self.target.generate_assignment_code(self.exc_value, code)
+        if self.excinfo_target is not None:
+            self.excinfo_tuple.generate_evaluation_code(code)
+            self.excinfo_target.generate_assignment_code(self.excinfo_tuple, code)
+
         old_exc_vars = code.exc_vars
         code.exc_vars = self.exc_vars
         self.body.generate_execution_code(code)
@@ -4497,6 +4518,7 @@ bad:
     Py_XDECREF(*tb);
     return -1;
 }
+
 """]
 
 #------------------------------------------------------------------------------------
index e014b5df103b76c84e56f640865b66b9fba5b74f..598b19cb7d9aedbc7375a2e89fdea5f39a5b05f5 100644 (file)
@@ -1,5 +1,6 @@
-from Cython.Compiler.Visitor import VisitorTransform
+from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle
 from Cython.Compiler.Nodes import *
+from Cython.Compiler.ExprNodes import *
 from Cython.Compiler.TreeFragment import TreeFragment
 
 
@@ -71,10 +72,13 @@ class PostParse(VisitorTransform):
     def visit_CStructOrUnionDefNode(self, node):
         return self.visit_StatNode(node, True)
 
+
 class WithTransform(VisitorTransform):
 
+    # EXCINFO is manually set to a variable that contains
+    # the exc_info() tuple that can be generated by the enclosing except
+    # statement.
     template_without_target = TreeFragment(u"""
-        import sys as SYS
         MGR = EXPR
         EXIT = MGR.__exit__
         MGR.__enter__()
@@ -84,15 +88,15 @@ class WithTransform(VisitorTransform):
                 BODY
             except:
                 EXC = False
-                if not EXIT(*SYS.exc_info()):
+                if not EXIT(*EXCINFO):
                     raise
         finally:
             if EXC:
                 EXIT(None, None, None)
-    """, u"WithTransformFragment")
+    """, temps=[u'MGR', u'EXC', u"EXIT", u"SYS)"],
+    pipeline=[PostParse()])
 
     template_with_target = TreeFragment(u"""
-        import sys as SYS
         MGR = EXPR
         EXIT = MGR.__exit__
         VALUE = MGR.__enter__()
@@ -103,47 +107,38 @@ class WithTransform(VisitorTransform):
                 BODY
             except:
                 EXC = False
-                if not EXIT(*SYS.exc_info()):
+                if not EXIT(*EXCINFO):
                     raise
         finally:
             if EXC:
                 EXIT(None, None, None)
-    """, u"WithTransformFragment")
+    """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"],
+    pipeline=[PostParse()])
 
     def visit_Node(self, node):
        self.visitchildren(node)
        return node
 
     def visit_WithStatNode(self, node):
+        excinfo_name = temp_name_handle('EXCINFO')
+        excinfo_namenode = NameNode(pos=node.pos, name=excinfo_name)
+        excinfo_target = NameNode(pos=node.pos, name=excinfo_name)
         if node.target is not None:
             result = self.template_with_target.substitute({
                 u'EXPR' : node.manager,
                 u'BODY' : node.body,
-                u'TARGET' : node.target
-                }, temps=(u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"),
-                pos = node.pos)
+                u'TARGET' : node.target,
+                u'EXCINFO' : excinfo_namenode
+                }, pos = node.pos)
+            # Set except excinfo target to EXCINFO
+            result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
         else:
             result = self.template_without_target.substitute({
                 u'EXPR' : node.manager,
                 u'BODY' : node.body,
-                }, temps=(u'MGR', u'EXC', u"EXIT", u"SYS"),
-                pos = node.pos)
-        
-        return result.body.stats
-
-
-class CallExitFuncNode(Node):
-    def analyse_types(self, env):
-        pass
-    def analyse_expressions(self, env):
-        self.exc_vars = [
-            env.allocate_temp(PyrexTypes.py_object_type)
-            for x in xrange(3)
-        ]
-        
+                u'EXCINFO' : excinfo_namenode
+                }, pos = node.pos)
+            # Set except excinfo target to EXCINFO
+            result.stats[4].body.stats[0].except_clauses[0].excinfo_target = excinfo_target
         
-    def generate_result(self, code):
-        code.putln("""{
-        PyObject* type; PyObject* value; PyObject* tb;
-        __Pyx_GetException(
-        }""")
+        return result.stats
index 0adb29543ad7eb778c8061a586731df84c9972a2..529deadeb0b1d8ffb44a703389af5855f467eca8 100644 (file)
@@ -16,29 +16,6 @@ from TypeSlots import \
 import ControlFlow
 import __builtin__
 
-class TempName(object):
-    """
-    Use instances of this class in order to provide a name for
-    anonymous, temporary functions. Each instance is considered
-    a seperate name, which are guaranteed not to clash with one
-    another or with names explicitly given as strings.
-
-    The argument to the constructor is simply a describing string
-    for debugging purposes and does not affect name clashes at all.
-
-    NOTE: Support for these TempNames are introduced on an as-needed
-    basis and will not "just work" everywhere. Places where they work:
-    - (none)
-    """
-    def __init__(self, description):
-        self.description = description
-
-    # Spoon-feed operators for documentation purposes
-    def __hash__(self):
-        return id(self)
-    def __cmp__(self, other):
-        return cmp(id(self), id(other))
-
 possible_identifier = re.compile(ur"(?![0-9])\w+$", re.U).match
 nice_identifier = re.compile('^[a-zA-Z0-0_]+$').match
 
@@ -1098,20 +1075,13 @@ class ModuleScope(Scope):
         var_entry.is_readonly = 1
         entry.as_variable = var_entry
         
-tempctr = 0
-
 class LocalScope(Scope):    
 
     def __init__(self, name, outer_scope):
         Scope.__init__(self, name, outer_scope, outer_scope)
     
     def mangle(self, prefix, name):
-        if isinstance(name, TempName):
-            global tempctr
-            tempctr += 1
-            return u"%s%s%d" % (Naming.temp_prefix, name.description, tempctr)
-        else:
-            return prefix + name
+        return prefix + name
 
     def declare_arg(self, name, type, pos):
         # Add an entry for an argument of a function.
index cbd60eff12b91acb88198274f2adc263067cbce5..3455dbc483e86330dbba6f20565c3f0b1515c15a 100644 (file)
@@ -6,8 +6,8 @@ class TestPostParse(TransformTest):
     def test_parserbehaviour_is_what_we_coded_for(self):
         t = self.fragment(u"if x: y").root
         self.assertLines(u"""
-(root): ModuleNode
-  body: IfStatNode
+(root): StatListNode
+  stats[0]: IfStatNode
     if_clauses[0]: IfClauseNode
       condition: NameNode
       body: ExprStatNode
@@ -17,14 +17,13 @@ class TestPostParse(TransformTest):
     def test_wrap_singlestat(self):
        t = self.run_pipeline([PostParse()], u"if x: y")
         self.assertLines(u"""
-(root): ModuleNode
-  body: StatListNode
-    stats[0]: IfStatNode
-      if_clauses[0]: IfClauseNode
-        condition: NameNode
-        body: StatListNode
-          stats[0]: ExprStatNode
-            expr: NameNode
+(root): StatListNode
+  stats[0]: IfStatNode
+    if_clauses[0]: IfClauseNode
+      condition: NameNode
+      body: StatListNode
+        stats[0]: ExprStatNode
+          expr: NameNode
 """, self.treetypes(t))
 
     def test_wrap_multistat(self):
@@ -34,16 +33,15 @@ class TestPostParse(TransformTest):
                 y
         """)
         self.assertLines(u"""
-(root): ModuleNode
-  body: StatListNode
-    stats[0]: IfStatNode
-      if_clauses[0]: IfClauseNode
-        condition: NameNode
-        body: StatListNode
-          stats[0]: ExprStatNode
-            expr: NameNode
-          stats[1]: ExprStatNode
-            expr: NameNode
+(root): StatListNode
+  stats[0]: IfStatNode
+    if_clauses[0]: IfClauseNode
+      condition: NameNode
+      body: StatListNode
+        stats[0]: ExprStatNode
+          expr: NameNode
+        stats[1]: ExprStatNode
+          expr: NameNode
 """, self.treetypes(t))
 
     def test_statinexpr(self):
@@ -51,15 +49,14 @@ class TestPostParse(TransformTest):
             a, b = x, y
         """)
         self.assertLines(u"""
-(root): ModuleNode
-  body: StatListNode
-    stats[0]: ParallelAssignmentNode
-      stats[0]: SingleAssignmentNode
-        lhs: NameNode
-        rhs: NameNode
-      stats[1]: SingleAssignmentNode
-        lhs: NameNode
-        rhs: NameNode
+(root): StatListNode
+  stats[0]: ParallelAssignmentNode
+    stats[0]: SingleAssignmentNode
+      lhs: NameNode
+      rhs: NameNode
+    stats[1]: SingleAssignmentNode
+      lhs: NameNode
+      rhs: NameNode
 """, self.treetypes(t))
 
     def test_wrap_offagain(self):
@@ -70,24 +67,23 @@ class TestPostParse(TransformTest):
                 x
         """)
         self.assertLines(u"""
-(root): ModuleNode
-  body: StatListNode
-    stats[0]: ExprStatNode
-      expr: NameNode
-    stats[1]: ExprStatNode
-      expr: NameNode
-    stats[2]: IfStatNode
-      if_clauses[0]: IfClauseNode
-        condition: NameNode
-        body: StatListNode
-          stats[0]: ExprStatNode
-            expr: NameNode
+(root): StatListNode
+  stats[0]: ExprStatNode
+    expr: NameNode
+  stats[1]: ExprStatNode
+    expr: NameNode
+  stats[2]: IfStatNode
+    if_clauses[0]: IfClauseNode
+      condition: NameNode
+      body: StatListNode
+        stats[0]: ExprStatNode
+          expr: NameNode
 """, self.treetypes(t))
         
 
     def test_pass_eliminated(self):
         t = self.run_pipeline([PostParse()], u"pass")
-        self.assert_(len(t.body.stats) == 0)
+        self.assert_(len(t.stats) == 0)
 
 class TestWithTransform(TransformTest):
 
@@ -99,7 +95,6 @@ class TestWithTransform(TransformTest):
         
         self.assertCode(u"""
 
-        $SYS = (import sys)
         $MGR = x
         $EXIT = $MGR.__exit__
         $MGR.__enter__()
@@ -109,7 +104,7 @@ class TestWithTransform(TransformTest):
                 y = z ** 3
             except:
                 $EXC = False
-                if (not $EXIT($SYS.exc_info())):
+                if (not $EXIT($EXCINFO)):
                     raise
         finally:
             if $EXC:
@@ -124,7 +119,6 @@ class TestWithTransform(TransformTest):
         """)
         self.assertCode(u"""
 
-        $SYS = (import sys)
         $MGR = x
         $EXIT = $MGR.__exit__
         $VALUE = $MGR.__enter__()
@@ -135,7 +129,7 @@ class TestWithTransform(TransformTest):
                 y = z ** 3
             except:
                 $EXC = False
-                if (not $EXIT($SYS.exc_info())):
+                if (not $EXIT($EXCINFO)):
                     raise
         finally:
             if $EXC:
index c5a574bc09082565c86d337cf5d47e46d689a9e3..0eff1b82fbe1e88040169060eeb05582b8a08ce7 100644 (file)
@@ -6,9 +6,8 @@ import re
 from cStringIO import StringIO
 from Scanning import PyrexScanner, StringSourceDescriptor
 from Symtab import BuiltinScope, ModuleScope
-from Visitor import VisitorTransform
-from Nodes import Node
-from Symtab import TempName
+from Visitor import VisitorTransform, temp_name_handle
+from Nodes import Node, StatListNode
 from ExprNodes import NameNode
 import Parsing
 import Main
@@ -109,7 +108,7 @@ class TemplateTransform(VisitorTransform):
         self.substitutions = substitutions
         tempdict = {}
         for key in temps:
-            tempdict[key] = TempName(key)
+            tempdict[key] = temp_name_handle(key)
         self.temps = tempdict
         self.pos = pos
         return super(TemplateTransform, self).__call__(node)
@@ -164,7 +163,7 @@ def strip_common_indent(lines):
     return lines
     
 class TreeFragment(object):
-    def __init__(self, code, name, pxds={}):
+    def __init__(self, code, name="(tree fragment)", pxds={}, temps=[], pipeline=[]):
         if isinstance(code, unicode):
             def fmt(x): return u"\n".join(strip_common_indent(x.split(u"\n"))) 
             
@@ -173,12 +172,20 @@ class TreeFragment(object):
             for key, value in pxds.iteritems():
                 fmt_pxds[key] = fmt(value)
                 
-            self.root = parse_from_strings(name, fmt_code, fmt_pxds)
+            t = parse_from_strings(name, fmt_code, fmt_pxds)
+            mod = t
+            t = t.body # Make sure a StatListNode is at the top
+            if not isinstance(t, StatListNode):
+                t = StatListNode(pos=mod.pos, stats=[t])
+            for transform in pipeline:
+                t = transform(t)
+            self.root = t
         elif isinstance(code, Node):
             if pxds != {}: raise NotImplementedError()
             self.root = code
         else:
             raise ValueError("Unrecognized code format (accepts unicode and Node)")
+        self.temps = temps
 
     def copy(self):
         return copy_code_tree(self.root)
@@ -186,7 +193,7 @@ class TreeFragment(object):
     def substitute(self, nodes={}, temps=[], pos = None):
         return TemplateTransform()(self.root,
                                    substitutions = nodes,
-                                   temps = temps, pos = pos)
+                                   temps = self.temps + temps, pos = pos)
 
 
 
index 6bbcd1e3bcea251f91864c01c65e83ded09226d0..44ad2ad794a7dc57c643db7c1c741e1c14d0e6cd 100644 (file)
@@ -166,6 +166,19 @@ def replace_node(ptr, value):
     else:
         getattr(parent, attrname)[listidx] = value
 
+tmpnamectr = 0
+def temp_name_handle(description):
+    global tmpnamectr
+    tmpnamectr += 1
+    return u"__cyt_%d_%s" % (tmpnamectr, description)
+
+def get_temp_name_handle_desc(handle):
+    if not handle.startswith(u"__cyt_"):
+        return None
+    else:
+        idx = handle.find(u"_", 6)
+        return handle[idx+1:]
+    
 class PrintTree(TreeVisitor):
     """Prints a representation of the tree to standard output.
     Subclass and override repr_of to provide more information
index 9fe9d29bfcdb26bb54bd15e4308d2d2e9e9cf0dc..87f072f811466f1f66834b9614ebb0280d5b96b3 100644 (file)
@@ -77,8 +77,8 @@ class TransformTest(CythonTest):
     To create a test case:
      - Call run_pipeline. The pipeline should at least contain the transform you
        are testing; pyx should be either a string (passed to the parser to
-       create a post-parse tree) or a ModuleNode representing input to pipeline.
-       The result will be a transformed result (usually a ModuleNode).
+       create a post-parse tree) or a node representing input to pipeline.
+       The result will be a transformed result.
        
      - Check that the tree is correct. If wanted, assertCode can be used, which
        takes a code string as expected, and a ModuleNode in result_tree
@@ -93,7 +93,6 @@ class TransformTest(CythonTest):
     
     def run_pipeline(self, pipeline, pyx, pxds={}):
         tree = self.fragment(pyx, pxds).root
-        assert isinstance(tree, ModuleNode)
         # Run pipeline
         for T in pipeline:
             tree = T(tree)
index 720300950566aece96c2616af3051516f17005c2..eb23902d981c5428663020b6d1a6f2e7f0cd619a 100644 (file)
@@ -1,6 +1,10 @@
 from __future__ import with_statement
 
 __doc__ = u"""
+>>> no_as()
+enter
+hello
+exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
 >>> basic()
 enter
 value
@@ -8,12 +12,12 @@ exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
 >>> with_exception(None)
 enter
 value
-exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
+exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
 outer except
 >>> with_exception(True)
 enter
 value
-exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
+exit <type 'type'> <class 'withstat.MyException'> <type 'traceback'>
 >>> multitarget()
 enter
 1 2 3 4 5
@@ -24,18 +28,25 @@ enter
 exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
 """
 
+class MyException(Exception):
+    pass
+
 class ContextManager:
     def __init__(self, value, exit_ret = None):
         self.value = value
         self.exit_ret = exit_ret
 
-    def __exit__(self, a, b, c):
-        print "exit", type(a), type(b), type(c)
+    def __exit__(self, a, b, tb):
+        print "exit", type(a), type(b), type(tb)
         return self.exit_ret
         
     def __enter__(self):
         print "enter"
         return self.value
+
+def no_as():
+    with ContextManager("value"):
+        print "hello"
         
 def basic():
     with ContextManager("value") as x:
@@ -45,7 +56,7 @@ def with_exception(exit_ret):
     try:
         with ContextManager("value", exit_ret=exit_ret) as value:
             print value
-            raise Exception()
+            raise MyException()
     except:
         print "outer except"
 
@@ -56,3 +67,4 @@ def multitarget():
 def tupletarget():
     with ContextManager((1, 2, (3, (4, 5)))) as t:
         print t
+