From: Dag Sverre Seljebotn Date: Wed, 18 Jun 2008 06:22:49 +0000 (-0700) Subject: Support for with statement X-Git-Tag: 0.9.8.1~49^2~126^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=6c1aa761d2b4128b587fc7c29e4c8cac651849df;p=cython.git Support for with statement --- diff --git a/Cython/Compiler/Future.py b/Cython/Compiler/Future.py index 85484957..a517fea6 100644 --- a/Cython/Compiler/Future.py +++ b/Cython/Compiler/Future.py @@ -7,5 +7,6 @@ def _get_feature(name): return object() unicode_literals = _get_feature("unicode_literals") +with_statement = _get_feature("with_statement") del _get_feature diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 5d454087..4230814b 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -232,6 +232,10 @@ class Context: errors_occurred = False try: tree = self.parse(source, scope.type_names, pxd = 0, full_module_name = full_module_name) + # This is of course going to change and be refactored real soon + from ParseTreeTransforms import WithTransform, PostParse + tree = PostParse()(tree) + tree = WithTransform()(tree) tree.process_implementation(scope, options, result) except CompileError: errors_occurred = True diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index c99c9ba9..f9c99ae6 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -2414,7 +2414,7 @@ class InPlaceAssignmentNode(AssignmentNode): # Fortunately, the type of the lhs node is fairly constrained # (it must be a NameNode, AttributeNode, or IndexNode). - child_attrs = ["lhs", "rhs", "dup"] + child_attrs = ["lhs", "rhs"] def analyse_declarations(self, env): self.lhs.analyse_target_declaration(env) @@ -2998,7 +2998,7 @@ class ForInStatNode(LoopNode, StatNode): # else_clause StatNode # item NextNode used internally - child_attrs = ["target", "iterator", "body", "else_clause", "item"] + child_attrs = ["target", "iterator", "body", "else_clause"] def analyse_declarations(self, env): self.target.analyse_target_declaration(env) @@ -3115,7 +3115,7 @@ class ForFromStatNode(LoopNode, StatNode): # is_py_target bool # loopvar_name string # py_loopvar_node PyTempNode or None - child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause", "py_loopvar_node"] + child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause"] def analyse_declarations(self, env): self.target.analyse_target_declaration(env) @@ -3231,6 +3231,18 @@ class ForFromStatNode(LoopNode, StatNode): self.else_clause.annotate(code) +class WithStatNode(StatNode): + """ + Represents a Python with statement. + + This is only used at parse tree level; and is not present in + analysis or generation phases. + """ + # manager The with statement manager object + # target Node (lhs expression) + # body StatNode + child_attrs = ["manager", "target", "body"] + class TryExceptStatNode(StatNode): # try .. except statement # @@ -3326,6 +3338,8 @@ class ExceptClauseNode(Node): child_attrs = ["pattern", "target", "body", "exc_value"] + exc_value = None + def analyse_declarations(self, env): if self.target: self.target.analyse_target_declaration(env) diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py new file mode 100644 index 00000000..e014b5df --- /dev/null +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -0,0 +1,149 @@ +from Cython.Compiler.Visitor import VisitorTransform +from Cython.Compiler.Nodes import * +from Cython.Compiler.TreeFragment import TreeFragment + + +class PostParse(VisitorTransform): + """ + This transform fixes up a few things after parsing + in order to make the parse tree more suitable for + transforms. + + a) After parsing, blocks with only one statement will + be represented by that statement, not by a StatListNode. + When doing transforms this is annoying and inconsistent, + as one cannot in general remove a statement in a consistent + way and so on. This transform wraps any single statements + in a StatListNode containing a single statement. + + b) The PassStatNode is a noop and serves no purpose beyond + plugging such one-statement blocks; i.e., once parsed a +` "pass" can just as well be represented using an empty + StatListNode. This means less special cases to worry about + in subsequent transforms (one always checks to see if a + StatListNode has no children to see if the block is empty). + """ + + def __init__(self): + super(PostParse, self).__init__() + self.is_in_statlist = False + self.is_in_expr = False + + def visit_Node(self, node): + self.visitchildren(node) + return node + + def visit_ExprNode(self, node): + stacktmp = self.is_in_expr + self.is_in_expr = True + self.visitchildren(node) + self.is_in_expr = stacktmp + return node + + def visit_StatNode(self, node, is_listcontainer=False): + stacktmp = self.is_in_statlist + self.is_in_statlist = is_listcontainer + self.visitchildren(node) + self.is_in_statlist = stacktmp + if not self.is_in_statlist and not self.is_in_expr: + return StatListNode(pos=node.pos, stats=[node]) + else: + return node + + def visit_PassStatNode(self, node): + if not self.is_in_statlist: + return StatListNode(pos=node.pos, stats=[]) + else: + return [] + + def visit_StatListNode(self, node): + self.is_in_statlist = True + self.visitchildren(node) + self.is_in_statlist = False + return node + + def visit_ParallelAssignmentNode(self, node): + return self.visit_StatNode(node, True) + + def visit_CEnumDefNode(self, node): + return self.visit_StatNode(node, True) + + def visit_CStructOrUnionDefNode(self, node): + return self.visit_StatNode(node, True) + +class WithTransform(VisitorTransform): + + template_without_target = TreeFragment(u""" + import sys as SYS + MGR = EXPR + EXIT = MGR.__exit__ + MGR.__enter__() + EXC = True + try: + try: + BODY + except: + EXC = False + if not EXIT(*SYS.exc_info()): + raise + finally: + if EXC: + EXIT(None, None, None) + """, u"WithTransformFragment") + + template_with_target = TreeFragment(u""" + import sys as SYS + MGR = EXPR + EXIT = MGR.__exit__ + VALUE = MGR.__enter__() + EXC = True + try: + try: + TARGET = VALUE + BODY + except: + EXC = False + if not EXIT(*SYS.exc_info()): + raise + finally: + if EXC: + EXIT(None, None, None) + """, u"WithTransformFragment") + + def visit_Node(self, node): + self.visitchildren(node) + return node + + def visit_WithStatNode(self, node): + if node.target is not None: + result = self.template_with_target.substitute({ + u'EXPR' : node.manager, + u'BODY' : node.body, + u'TARGET' : node.target + }, temps=(u'MGR', u'EXC', u"EXIT", u"VALUE", u"SYS"), + pos = node.pos) + else: + result = self.template_without_target.substitute({ + u'EXPR' : node.manager, + u'BODY' : node.body, + }, temps=(u'MGR', u'EXC', u"EXIT", u"SYS"), + pos = node.pos) + + return result.body.stats + + +class CallExitFuncNode(Node): + def analyse_types(self, env): + pass + def analyse_expressions(self, env): + self.exc_vars = [ + env.allocate_temp(PyrexTypes.py_object_type) + for x in xrange(3) + ] + + + def generate_result(self, code): + code.putln("""{ + PyObject* type; PyObject* value; PyObject* tb; + __Pyx_GetException( + }""") diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 7bdfa4d2..985b95b0 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -1134,13 +1134,13 @@ def p_for_from_step(s): inequality_relations = ('<', '<=', '>', '>=') -def p_for_target(s): +def p_target(s, terminator): pos = s.position() expr = p_bit_expr(s) if s.sy == ',': s.next() exprs = [expr] - while s.sy != 'in': + while s.sy != terminator: exprs.append(p_bit_expr(s)) if s.sy != ',': break @@ -1149,6 +1149,9 @@ def p_for_target(s): else: return expr +def p_for_target(s): + return p_target(s, 'in') + def p_for_iterator(s): pos = s.position() expr = p_testlist(s) @@ -1227,8 +1230,17 @@ def p_with_statement(s): body = p_suite(s) return Nodes.GILStatNode(pos, state = state, body = body) else: - s.error("Only 'with gil' and 'with nogil' implemented", - pos = pos) + manager = p_expr(s) + target = None + if s.sy == 'IDENT' and s.systring == 'as': + s.next() + allow_multi = (s.sy == '(') + target = p_target(s, ':') + if not allow_multi and isinstance(target, ExprNodes.TupleNode): + s.error("Multiple with statement target values not allowed without paranthesis") + body = p_suite(s) + return Nodes.WithStatNode(pos, manager = manager, + target = target, body = body) def p_simple_statement(s, first_statement = 0): #print "p_simple_statement:", s.sy, s.systring ### diff --git a/Cython/Compiler/Tests/TestParseTreeTransforms.py b/Cython/Compiler/Tests/TestParseTreeTransforms.py new file mode 100644 index 00000000..cbd60eff --- /dev/null +++ b/Cython/Compiler/Tests/TestParseTreeTransforms.py @@ -0,0 +1,149 @@ +from Cython.TestUtils import TransformTest +from Cython.Compiler.ParseTreeTransforms import * +from Cython.Compiler.Nodes import * + +class TestPostParse(TransformTest): + def test_parserbehaviour_is_what_we_coded_for(self): + t = self.fragment(u"if x: y").root + self.assertLines(u""" +(root): ModuleNode + body: IfStatNode + if_clauses[0]: IfClauseNode + condition: NameNode + body: ExprStatNode + expr: NameNode +""", self.treetypes(t)) + + def test_wrap_singlestat(self): + t = self.run_pipeline([PostParse()], u"if x: y") + self.assertLines(u""" +(root): ModuleNode + body: StatListNode + stats[0]: IfStatNode + if_clauses[0]: IfClauseNode + condition: NameNode + body: StatListNode + stats[0]: ExprStatNode + expr: NameNode +""", self.treetypes(t)) + + def test_wrap_multistat(self): + t = self.run_pipeline([PostParse()], u""" + if z: + x + y + """) + self.assertLines(u""" +(root): ModuleNode + body: StatListNode + stats[0]: IfStatNode + if_clauses[0]: IfClauseNode + condition: NameNode + body: StatListNode + stats[0]: ExprStatNode + expr: NameNode + stats[1]: ExprStatNode + expr: NameNode +""", self.treetypes(t)) + + def test_statinexpr(self): + t = self.run_pipeline([PostParse()], u""" + a, b = x, y + """) + self.assertLines(u""" +(root): ModuleNode + body: StatListNode + stats[0]: ParallelAssignmentNode + stats[0]: SingleAssignmentNode + lhs: NameNode + rhs: NameNode + stats[1]: SingleAssignmentNode + lhs: NameNode + rhs: NameNode +""", self.treetypes(t)) + + def test_wrap_offagain(self): + t = self.run_pipeline([PostParse()], u""" + x + y + if z: + x + """) + self.assertLines(u""" +(root): ModuleNode + body: StatListNode + stats[0]: ExprStatNode + expr: NameNode + stats[1]: ExprStatNode + expr: NameNode + stats[2]: IfStatNode + if_clauses[0]: IfClauseNode + condition: NameNode + body: StatListNode + stats[0]: ExprStatNode + expr: NameNode +""", self.treetypes(t)) + + + def test_pass_eliminated(self): + t = self.run_pipeline([PostParse()], u"pass") + self.assert_(len(t.body.stats) == 0) + +class TestWithTransform(TransformTest): + + def test_simplified(self): + t = self.run_pipeline([WithTransform()], u""" + with x: + y = z ** 3 + """) + + self.assertCode(u""" + + $SYS = (import sys) + $MGR = x + $EXIT = $MGR.__exit__ + $MGR.__enter__() + $EXC = True + try: + try: + y = z ** 3 + except: + $EXC = False + if (not $EXIT($SYS.exc_info())): + raise + finally: + if $EXC: + $EXIT(None, None, None) + + """, t) + + def test_basic(self): + t = self.run_pipeline([WithTransform()], u""" + with x as y: + y = z ** 3 + """) + self.assertCode(u""" + + $SYS = (import sys) + $MGR = x + $EXIT = $MGR.__exit__ + $VALUE = $MGR.__enter__() + $EXC = True + try: + try: + y = $VALUE + y = z ** 3 + except: + $EXC = False + if (not $EXIT($SYS.exc_info())): + raise + finally: + if $EXC: + $EXIT(None, None, None) + + """, t) + + +if __name__ == "__main__": + import unittest + unittest.main() diff --git a/Cython/TestUtils.py b/Cython/TestUtils.py index 8681ee39..9fe9d29b 100644 --- a/Cython/TestUtils.py +++ b/Cython/TestUtils.py @@ -28,6 +28,16 @@ class NodeTypeWriter(TreeVisitor): self._indents -= 1 class CythonTest(unittest.TestCase): + + def assertLines(self, expected, result): + "Checks that the given strings or lists of strings are equal line by line" + if not isinstance(expected, list): expected = expected.split(u"\n") + if not isinstance(result, list): result = result.split(u"\n") + for idx, (expected_line, result_line) in enumerate(zip(expected, result)): + self.assertEqual(expected_line, result_line, "Line %d:\nExp: %s\nGot: %s" % (idx, expected_line, result_line)) + self.assertEqual(len(expected), len(result), + "Unmatched lines. Got:\n%s\nExpected:\n%s" % ("\n".join(expected), u"\n".join(result))) + def assertCode(self, expected, result_tree): writer = CodeWriter() writer.write(result_tree) diff --git a/Cython/Tests/TestCodeWriter.py b/Cython/Tests/TestCodeWriter.py index 25fc2d47..03073072 100644 --- a/Cython/Tests/TestCodeWriter.py +++ b/Cython/Tests/TestCodeWriter.py @@ -72,6 +72,9 @@ class TestCodeWriter(CythonTest): def test_inplace_assignment(self): self.t(u"x += 43") + + def test_attribute(self): + self.t(u"a.x") if __name__ == "__main__": import unittest diff --git a/tests/run/withstat.pyx b/tests/run/withstat.pyx new file mode 100644 index 00000000..72030095 --- /dev/null +++ b/tests/run/withstat.pyx @@ -0,0 +1,58 @@ +from __future__ import with_statement + +__doc__ = u""" +>>> basic() +enter +value +exit +>>> with_exception(None) +enter +value +exit +outer except +>>> with_exception(True) +enter +value +exit +>>> multitarget() +enter +1 2 3 4 5 +exit +>>> tupletarget() +enter +(1, 2, (3, (4, 5))) +exit +""" + +class ContextManager: + def __init__(self, value, exit_ret = None): + self.value = value + self.exit_ret = exit_ret + + def __exit__(self, a, b, c): + print "exit", type(a), type(b), type(c) + return self.exit_ret + + def __enter__(self): + print "enter" + return self.value + +def basic(): + with ContextManager("value") as x: + print x + +def with_exception(exit_ret): + try: + with ContextManager("value", exit_ret=exit_ret) as value: + print value + raise Exception() + except: + print "outer except" + +def multitarget(): + with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))): + print a, b, c, d, e + +def tupletarget(): + with ContextManager((1, 2, (3, (4, 5)))) as t: + print t