Support for with statement
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 18 Jun 2008 06:22:49 +0000 (23:22 -0700)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Wed, 18 Jun 2008 06:22:49 +0000 (23:22 -0700)
Cython/Compiler/Future.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py [new file with mode: 0644]
Cython/Compiler/Parsing.py
Cython/Compiler/Tests/TestParseTreeTransforms.py [new file with mode: 0644]
Cython/TestUtils.py
Cython/Tests/TestCodeWriter.py
tests/run/withstat.pyx [new file with mode: 0644]

index 854849575bffffbcb37b97ab608e6ae52e5d0189..a517fea6ca8d97503259c1abc8136a83e2929343 100644 (file)
@@ -7,5 +7,6 @@ def _get_feature(name):
         return object()
 
 unicode_literals = _get_feature("unicode_literals")
+with_statement = _get_feature("with_statement")
 
 del _get_feature
index 5d45408741604ae2d7fe3720cbcd7cbd10f77454..4230814b39ee2387e27c0fee137fe1c66646439c 100644 (file)
@@ -232,6 +232,10 @@ class Context:
         errors_occurred = False
         try:
             tree = self.parse(source, scope.type_names, pxd = 0, full_module_name = full_module_name)
+            # This is of course going to change and be refactored real soon
+            from ParseTreeTransforms import WithTransform, PostParse
+            tree = PostParse()(tree)
+            tree = WithTransform()(tree)
             tree.process_implementation(scope, options, result)
         except CompileError:
             errors_occurred = True
index c99c9ba907b80cc39571f8c397f9de25b3512167..f9c99ae6388386f2e021f31069abe188b20b5329 100644 (file)
@@ -2414,7 +2414,7 @@ class InPlaceAssignmentNode(AssignmentNode):
     #  Fortunately, the type of the lhs node is fairly constrained 
     #  (it must be a NameNode, AttributeNode, or IndexNode).     
     
-    child_attrs = ["lhs", "rhs", "dup"]
+    child_attrs = ["lhs", "rhs"]
 
     def analyse_declarations(self, env):
         self.lhs.analyse_target_declaration(env)
@@ -2998,7 +2998,7 @@ class ForInStatNode(LoopNode, StatNode):
     #  else_clause   StatNode
     #  item          NextNode       used internally
     
-    child_attrs = ["target", "iterator", "body", "else_clause", "item"]
+    child_attrs = ["target", "iterator", "body", "else_clause"]
     
     def analyse_declarations(self, env):
         self.target.analyse_target_declaration(env)
@@ -3115,7 +3115,7 @@ class ForFromStatNode(LoopNode, StatNode):
     #  is_py_target       bool
     #  loopvar_name       string
     #  py_loopvar_node    PyTempNode or None
-    child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause", "py_loopvar_node"]
+    child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"]
     
     def analyse_declarations(self, env):
         self.target.analyse_target_declaration(env)
@@ -3231,6 +3231,18 @@ class ForFromStatNode(LoopNode, StatNode):
             self.else_clause.annotate(code)
 
 
+class WithStatNode(StatNode):
+    """
+    Represents a Python with statement.
+    
+    This is only used at parse tree level; and is not present in
+    analysis or generation phases.
+    """
+    #  manager          The with statement manager object
+    #  target            Node (lhs expression)
+    #  body             StatNode
+    child_attrs = ["manager", "target", "body"]
+
 class TryExceptStatNode(StatNode):
     #  try .. except statement
     #
@@ -3326,6 +3338,8 @@ class ExceptClauseNode(Node):
     
     child_attrs = ["pattern", "target", "body", "exc_value"]
 
+    exc_value = None
+
     def analyse_declarations(self, env):
         if self.target:
             self.target.analyse_target_declaration(env)
diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py
new file mode 100644 (file)
index 0000000..e014b5d
--- /dev/null
@@ -0,0 +1,149 @@
+from Cython.Compiler.Visitor import VisitorTransform
+from Cython.Compiler.Nodes import *
+from Cython.Compiler.TreeFragment import TreeFragment
+
+
+class PostParse(VisitorTransform):
+    """
+    This transform fixes up a few things after parsing
+    in order to make the parse tree more suitable for
+    transforms.
+
+    a) After parsing, blocks with only one statement will
+    be represented by that statement, not by a StatListNode.
+    When doing transforms this is annoying and inconsistent,
+    as one cannot in general remove a statement in a consistent
+    way and so on. This transform wraps any single statements
+    in a StatListNode containing a single statement.
+
+    b) The PassStatNode is a noop and serves no purpose beyond
+    plugging such one-statement blocks; i.e., once parsed a
+`    "pass" can just as well be represented using an empty
+    StatListNode. This means less special cases to worry about
+    in subsequent transforms (one always checks to see if a
+    StatListNode has no children to see if the block is empty).
+    """
+
+    def __init__(self):
+        super(PostParse, self).__init__()
+        self.is_in_statlist = False
+        self.is_in_expr = False
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
+    def visit_ExprNode(self, node):
+        stacktmp = self.is_in_expr
+        self.is_in_expr = True
+        self.visitchildren(node)
+        self.is_in_expr = stacktmp
+        return node
+
+    def visit_StatNode(self, node, is_listcontainer=False):
+        stacktmp = self.is_in_statlist
+        self.is_in_statlist = is_listcontainer
+        self.visitchildren(node)
+        self.is_in_statlist = stacktmp
+        if not self.is_in_statlist and not self.is_in_expr:
+            return StatListNode(pos=node.pos, stats=[node])
+        else:
+            return node
+
+    def visit_PassStatNode(self, node):
+        if not self.is_in_statlist:
+            return StatListNode(pos=node.pos, stats=[])
+        else:
+            return []
+
+    def visit_StatListNode(self, node):
+        self.is_in_statlist = True
+        self.visitchildren(node)
+        self.is_in_statlist = False
+        return node
+
+    def visit_ParallelAssignmentNode(self, node):
+        return self.visit_StatNode(node, True)
+    
+    def visit_CEnumDefNode(self, node):
+        return self.visit_StatNode(node, True)
+
+    def visit_CStructOrUnionDefNode(self, node):
+        return self.visit_StatNode(node, True)
+
+class WithTransform(VisitorTransform):
+
+    template_without_target = TreeFragment(u"""
+        import sys as SYS
+        MGR = EXPR
+        EXIT = MGR.__exit__
+        MGR.__enter__()
+        EXC = True
+        try:
+            try:
+                BODY
+            except:
+                EXC = False
+                if not EXIT(*SYS.exc_info()):
+                    raise
+        finally:
+            if EXC:
+                EXIT(None, None, None)
+    """, u"WithTransformFragment")
+
+    template_with_target = TreeFragment(u"""
+        import sys as SYS
+        MGR = EXPR
+        EXIT = MGR.__exit__
+        VALUE = MGR.__enter__()
+        EXC = True
+        try:
+            try:
+                TARGET = VALUE
+                BODY
+            except:
+                EXC = False
+                if not EXIT(*SYS.exc_info()):
+                    raise
+        finally:
+            if EXC:
+                EXIT(None, None, None)
+    """, u"WithTransformFragment")
+
+    def visit_Node(self, node):
+       self.visitchildren(node)
+       return node
+
+    def visit_WithStatNode(self, node):
+        if node.target is not None:
+            result = self.template_with_target.substitute({
+                u'EXPR' : node.manager,
+                u'BODY' : node.body,
+                u'TARGET' : node.target
+                }, temps=(u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"),
+                pos = node.pos)
+        else:
+            result = self.template_without_target.substitute({
+                u'EXPR' : node.manager,
+                u'BODY' : node.body,
+                }, temps=(u'MGR', u'EXC', u"EXIT", u"SYS"),
+                pos = node.pos)
+        
+        return result.body.stats
+
+
+class CallExitFuncNode(Node):
+    def analyse_types(self, env):
+        pass
+    def analyse_expressions(self, env):
+        self.exc_vars = [
+            env.allocate_temp(PyrexTypes.py_object_type)
+            for x in xrange(3)
+        ]
+        
+        
+    def generate_result(self, code):
+        code.putln("""{
+        PyObject* type; PyObject* value; PyObject* tb;
+        __Pyx_GetException(
+        }""")
index 7bdfa4d2d690371f258a36558184bbddc289de1a..985b95b0dcffa75e41a954f2bdce7cbd24c6af32 100644 (file)
@@ -1134,13 +1134,13 @@ def p_for_from_step(s):
 
 inequality_relations = ('<', '<=', '>', '>=')
 
-def p_for_target(s):
+def p_target(s, terminator):
     pos = s.position()
     expr = p_bit_expr(s)
     if s.sy == ',':
         s.next()
         exprs = [expr]
-        while s.sy != 'in':
+        while s.sy != terminator:
             exprs.append(p_bit_expr(s))
             if s.sy != ',':
                 break
@@ -1149,6 +1149,9 @@ def p_for_target(s):
     else:
         return expr
 
+def p_for_target(s):
+    return p_target(s, 'in')
+
 def p_for_iterator(s):
     pos = s.position()
     expr = p_testlist(s)
@@ -1227,8 +1230,17 @@ def p_with_statement(s):
         body = p_suite(s)
         return Nodes.GILStatNode(pos, state = state, body = body)
     else:
-        s.error("Only 'with gil' and 'with nogil' implemented",
-                pos = pos)
+        manager = p_expr(s)
+        target = None
+        if s.sy == 'IDENT' and s.systring == 'as':
+            s.next()
+            allow_multi = (s.sy == '(')
+            target = p_target(s, ':')
+            if not allow_multi and isinstance(target, ExprNodes.TupleNode):
+                s.error("Multiple with statement target values not allowed without paranthesis")
+        body = p_suite(s)
+       return Nodes.WithStatNode(pos, manager = manager, 
+                                      target = target, body = body)
     
 def p_simple_statement(s, first_statement = 0):
     #print "p_simple_statement:", s.sy, s.systring ###
diff --git a/Cython/Compiler/Tests/TestParseTreeTransforms.py b/Cython/Compiler/Tests/TestParseTreeTransforms.py
new file mode 100644 (file)
index 0000000..cbd60ef
--- /dev/null
@@ -0,0 +1,149 @@
+from Cython.TestUtils import TransformTest
+from Cython.Compiler.ParseTreeTransforms import *
+from Cython.Compiler.Nodes import *
+
+class TestPostParse(TransformTest):
+    def test_parserbehaviour_is_what_we_coded_for(self):
+        t = self.fragment(u"if x: y").root
+        self.assertLines(u"""
+(root): ModuleNode
+  body: IfStatNode
+    if_clauses[0]: IfClauseNode
+      condition: NameNode
+      body: ExprStatNode
+        expr: NameNode
+""", self.treetypes(t))
+        
+    def test_wrap_singlestat(self):
+       t = self.run_pipeline([PostParse()], u"if x: y")
+        self.assertLines(u"""
+(root): ModuleNode
+  body: StatListNode
+    stats[0]: IfStatNode
+      if_clauses[0]: IfClauseNode
+        condition: NameNode
+        body: StatListNode
+          stats[0]: ExprStatNode
+            expr: NameNode
+""", self.treetypes(t))
+
+    def test_wrap_multistat(self):
+        t = self.run_pipeline([PostParse()], u"""
+            if z:
+                x
+                y
+        """)
+        self.assertLines(u"""
+(root): ModuleNode
+  body: StatListNode
+    stats[0]: IfStatNode
+      if_clauses[0]: IfClauseNode
+        condition: NameNode
+        body: StatListNode
+          stats[0]: ExprStatNode
+            expr: NameNode
+          stats[1]: ExprStatNode
+            expr: NameNode
+""", self.treetypes(t))
+
+    def test_statinexpr(self):
+        t = self.run_pipeline([PostParse()], u"""
+            a, b = x, y
+        """)
+        self.assertLines(u"""
+(root): ModuleNode
+  body: StatListNode
+    stats[0]: ParallelAssignmentNode
+      stats[0]: SingleAssignmentNode
+        lhs: NameNode
+        rhs: NameNode
+      stats[1]: SingleAssignmentNode
+        lhs: NameNode
+        rhs: NameNode
+""", self.treetypes(t))
+
+    def test_wrap_offagain(self):
+        t = self.run_pipeline([PostParse()], u"""
+            x
+            y
+            if z:
+                x
+        """)
+        self.assertLines(u"""
+(root): ModuleNode
+  body: StatListNode
+    stats[0]: ExprStatNode
+      expr: NameNode
+    stats[1]: ExprStatNode
+      expr: NameNode
+    stats[2]: IfStatNode
+      if_clauses[0]: IfClauseNode
+        condition: NameNode
+        body: StatListNode
+          stats[0]: ExprStatNode
+            expr: NameNode
+""", self.treetypes(t))
+        
+
+    def test_pass_eliminated(self):
+        t = self.run_pipeline([PostParse()], u"pass")
+        self.assert_(len(t.body.stats) == 0)
+
+class TestWithTransform(TransformTest):
+
+    def test_simplified(self):
+        t = self.run_pipeline([WithTransform()], u"""
+        with x:
+            y = z ** 3
+        """)
+        
+        self.assertCode(u"""
+
+        $SYS = (import sys)
+        $MGR = x
+        $EXIT = $MGR.__exit__
+        $MGR.__enter__()
+        $EXC = True
+        try:
+            try:
+                y = z ** 3
+            except:
+                $EXC = False
+                if (not $EXIT($SYS.exc_info())):
+                    raise
+        finally:
+            if $EXC:
+                $EXIT(None, None, None)
+
+        """, t)
+
+    def test_basic(self):
+        t = self.run_pipeline([WithTransform()], u"""
+        with x as y:
+            y = z ** 3
+        """)
+        self.assertCode(u"""
+
+        $SYS = (import sys)
+        $MGR = x
+        $EXIT = $MGR.__exit__
+        $VALUE = $MGR.__enter__()
+        $EXC = True
+        try:
+            try:
+                y = $VALUE
+                y = z ** 3
+            except:
+                $EXC = False
+                if (not $EXIT($SYS.exc_info())):
+                    raise
+        finally:
+            if $EXC:
+                $EXIT(None, None, None)
+
+        """, t)
+                          
+
+if __name__ == "__main__":
+    import unittest
+    unittest.main()
index 8681ee39ab37c73d73e710851356d1eb9ad5a2ca..9fe9d29bfcdb26bb54bd15e4308d2d2e9e9cf0dc 100644 (file)
@@ -28,6 +28,16 @@ class NodeTypeWriter(TreeVisitor):
         self._indents -= 1
 
 class CythonTest(unittest.TestCase):
+
+    def assertLines(self, expected, result):
+        "Checks that the given strings or lists of strings are equal line by line"
+        if not isinstance(expected, list): expected = expected.split(u"\n")
+        if not isinstance(result, list): result = result.split(u"\n")
+        for idx, (expected_line, result_line) in enumerate(zip(expected, result)):
+            self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line))
+        self.assertEqual(len(expected), len(result),
+            "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result)))
+
     def assertCode(self, expected, result_tree):
         writer = CodeWriter()
         writer.write(result_tree)
index 25fc2d477a4815732571a770416fe033ed8233ef..030730728ed955306a7ee7d6b2c836599e6ee19b 100644 (file)
@@ -72,6 +72,9 @@ class TestCodeWriter(CythonTest):
 
     def test_inplace_assignment(self):
         self.t(u"x += 43")
+
+    def test_attribute(self):
+        self.t(u"a.x")
     
 if __name__ == "__main__":
     import unittest
diff --git a/tests/run/withstat.pyx b/tests/run/withstat.pyx
new file mode 100644 (file)
index 0000000..7203009
--- /dev/null
@@ -0,0 +1,58 @@
+from __future__ import with_statement
+
+__doc__ = u"""
+>>> basic()
+enter
+value
+exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+>>> with_exception(None)
+enter
+value
+exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
+outer except
+>>> with_exception(True)
+enter
+value
+exit <type 'type'> <type 'exceptions.Exception'> <type 'traceback'>
+>>> multitarget()
+enter
+1 2 3 4 5
+exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+>>> tupletarget()
+enter
+(1, 2, (3, (4, 5)))
+exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+"""
+
+class ContextManager:
+    def __init__(self, value, exit_ret = None):
+        self.value = value
+        self.exit_ret = exit_ret
+
+    def __exit__(self, a, b, c):
+        print "exit", type(a), type(b), type(c)
+        return self.exit_ret
+        
+    def __enter__(self):
+        print "enter"
+        return self.value
+        
+def basic():
+    with ContextManager("value") as x:
+        print x
+
+def with_exception(exit_ret):
+    try:
+        with ContextManager("value", exit_ret=exit_ret) as value:
+            print value
+            raise Exception()
+    except:
+        print "outer except"
+
+def multitarget():
+    with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
+        print a, b, c, d, e
+
+def tupletarget():
+    with ContextManager((1, 2, (3, (4, 5)))) as t:
+        print t