From f982ac038ed4282fa9dbf77bcebfa717e449b696 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Mon, 25 Apr 2011 00:15:34 +0200 Subject: [PATCH] reimplement 'with' statement using dedicated nodes, instead of a generic transform --HG-- rename : tests/run/withstat.pyx => tests/run/withstat_py.py --- Cython/Compiler/ExprNodes.py | 34 ++++ Cython/Compiler/Main.py | 3 +- Cython/Compiler/Nodes.py | 177 ++++++++++++++++- Cython/Compiler/ParseTreeTransforms.py | 70 ------- runtests.py | 1 + tests/run/withstat.pyx | 11 -- tests/run/withstat_py.py | 252 +++++++++++++++++++++++++ 7 files changed, 461 insertions(+), 87 deletions(-) create mode 100644 tests/run/withstat_py.py diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 8c201061..5a5f0c2c 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1910,6 +1910,40 @@ class NextNode(AtomicExprNode): code.putln("}") +class WithExitCallNode(ExprNode): + # The __exit__() call of a 'with' statement. Used in both the + # except and finally clauses. + + # with_stat WithStatNode the surrounding 'with' statement + # args TupleNode or ResultStatNode the exception info tuple + + subexprs = ['args'] + + def analyse_types(self, env): + self.args.analyse_types(env) + self.type = PyrexTypes.c_bint_type + self.is_temp = True + + def generate_result_code(self, code): + if isinstance(self.args, TupleNode): + # call only if it was not already called (and decref-cleared) + code.putln("if (%s) {" % self.with_stat.exit_var) + result_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False) + code.putln("%s = PyObject_Call(%s, %s, NULL);" % ( + result_var, + self.with_stat.exit_var, + self.args.result())) + code.put_decref_clear(self.with_stat.exit_var, type=py_object_type) + code.putln(code.error_goto_if_null(result_var, self.pos)) + code.put_gotref(result_var) + code.putln("%s = __Pyx_PyObject_IsTrue(%s);" % (self.result(), result_var)) + code.put_decref_clear(result_var, type=py_object_type) + code.putln(code.error_goto_if_neg(self.result(), self.pos)) + code.funcstate.release_temp(result_var) + if isinstance(self.args, TupleNode): + code.putln("}") + + class ExcValueNode(AtomicExprNode): # Node created during analyse_types phase # of an ExceptClauseNode to fetch the current diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index fb8d6fcb..34d8b22f 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -102,7 +102,7 @@ class Context(object): def create_pipeline(self, pxd, py=False): from Visitor import PrintTree - from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse + from ParseTreeTransforms import NormalizeTree, PostParse, PxdPostParse from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods @@ -139,7 +139,6 @@ class Context(object): _align_function_definitions, ConstantFolding(), FlattenInListTransform(), - WithTransform(self), DecoratorTransform(self), AnalyseDeclarationsTransform(self), AutoTestDictTransform(self), diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 5358a3c3..8007d83a 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -4835,13 +4835,177 @@ class WithStatNode(StatNode): """ Represents a Python with statement. - This is only used at parse tree level; and is not present in - analysis or generation phases. + Implemented as follows: + + MGR = EXPR + EXIT = MGR.__exit__ + VALUE = MGR.__enter__() + EXC = True + try: + try: + TARGET = VALUE # optional + BODY + except: + EXC = False + if not EXIT(*EXCINFO): + raise + finally: + if EXC: + EXIT(None, None, None) + MGR = EXIT = VALUE = None """ # manager The with statement manager object - # target Node (lhs expression) # body StatNode - child_attrs = ["manager", "target", "body"] + + child_attrs = ["manager", "body"] + + has_target = False + + def __init__(self, pos, manager, target, body): + StatNode.__init__(self, pos, manager = manager) + + import ExprNodes + self.target_temp = ExprNodes.TempNode(pos, type=py_object_type) + if target is not None: + self.has_target = True + body = StatListNode( + pos, stats = [ + WithTargetAssignmentStatNode( + pos, lhs = target, rhs = self.target_temp), + body + ]) + + import UtilNodes + excinfo_target = UtilNodes.ResultRefNode( + pos=pos, type=Builtin.tuple_type, may_hold_none=False) + except_clause = ExceptClauseNode( + pos, body = IfStatNode( + pos, if_clauses = [ + IfClauseNode( + pos, condition = ExprNodes.NotNode( + pos, operand = ExprNodes.WithExitCallNode( + pos, with_stat = self, + args = excinfo_target)), + body = ReraiseStatNode(pos), + ), + ], + else_clause = None), + pattern = None, + target = None, + excinfo_target = excinfo_target, + ) + + self.body = TryFinallyStatNode( + pos, body = TryExceptStatNode( + pos, body = body, + except_clauses = [except_clause], + else_clause = None, + ), + finally_clause = ExprStatNode( + pos, expr = ExprNodes.WithExitCallNode( + pos, with_stat = self, + args = ExprNodes.TupleNode( + pos, args = [ExprNodes.NoneNode(pos) for _ in range(3)] + ))), + handle_error_case = False, + ) + + def analyse_declarations(self, env): + self.manager.analyse_declarations(env) + self.body.analyse_declarations(env) + + def analyse_expressions(self, env): + self.manager.analyse_types(env) + self.body.analyse_expressions(env) + + def generate_execution_code(self, code): + code.putln("/*with:*/ {") + self.manager.generate_evaluation_code(code) + self.exit_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False) + code.putln("%s = PyObject_GetAttr(%s, %s); %s" % ( + self.exit_var, + self.manager.py_result(), + code.get_py_string_const(EncodedString('__exit__'), identifier=True), + code.error_goto_if_null(self.exit_var, self.pos), + )) + code.put_gotref(self.exit_var) + + # need to free exit_var in the face of exceptions during setup + old_error_label = code.new_error_label() + intermediate_error_label = code.error_label + + enter_func = code.funcstate.allocate_temp(py_object_type, manage_ref=True) + code.putln("%s = PyObject_GetAttr(%s, %s); %s" % ( + enter_func, + self.manager.py_result(), + code.get_py_string_const(EncodedString('__enter__'), identifier=True), + code.error_goto_if_null(enter_func, self.pos), + )) + code.put_gotref(enter_func) + self.manager.generate_disposal_code(code) + self.manager.free_temps(code) + self.target_temp.allocate(code) + code.putln('%s = PyObject_Call(%s, ((PyObject *)%s), NULL); %s' % ( + self.target_temp.result(), + enter_func, + Naming.empty_tuple, + code.error_goto_if_null(self.target_temp.result(), self.pos), + )) + code.put_gotref(self.target_temp.result()) + code.put_decref_clear(enter_func, py_object_type) + code.funcstate.release_temp(enter_func) + if not self.has_target: + code.put_decref_clear(self.target_temp.result(), type=py_object_type) + self.target_temp.release(code) + # otherwise, WithTargetAssignmentStatNode will do it for us + + code.error_label = old_error_label + self.body.generate_execution_code(code) + + step_over_label = code.new_label() + code.put_goto(step_over_label) + code.put_label(intermediate_error_label) + code.put_decref_clear(self.exit_var, py_object_type) + code.put_goto(old_error_label) + code.put_label(step_over_label) + + code.funcstate.release_temp(self.exit_var) + code.putln('}') + +class WithTargetAssignmentStatNode(AssignmentNode): + # The target assignment of the 'with' statement value (return + # value of the __enter__() call). + # + # This is a special cased assignment that steals the RHS reference + # and frees its temp. + # + # lhs ExprNode the assignment target + # rhs TempNode the return value of the __enter__() call + + child_attrs = ["lhs", "rhs"] + + def analyse_declarations(self, env): + self.lhs.analyse_target_declaration(env) + + def analyse_types(self, env): + self.rhs.analyse_types(env) + self.lhs.analyse_target_types(env) + self.lhs.gil_assignment_check(env) + self.orig_rhs = self.rhs + self.rhs = self.rhs.coerce_to(self.lhs.type, env) + + def generate_execution_code(self, code): + self.rhs.generate_evaluation_code(code) + self.lhs.generate_assignment_code(self.rhs, code) + self.orig_rhs.release(code) + + def generate_function_definitions(self, env, code): + self.rhs.generate_function_definitions(env, code) + + def annotate(self, code): + self.lhs.annotate(code) + self.rhs.annotate(code) + class TryExceptStatNode(StatNode): # try .. except statement @@ -5175,6 +5339,9 @@ class TryFinallyStatNode(StatNode): preserve_exception = 1 + # handle exception case, in addition to return/break/continue + handle_error_case = True + disallow_continue_in_try_finally = 0 # There doesn't seem to be any point in disallowing # continue in the try block, since we have no problem @@ -5208,6 +5375,8 @@ class TryFinallyStatNode(StatNode): old_labels = code.all_new_labels() new_labels = code.get_all_labels() new_error_label = code.error_label + if not self.handle_error_case: + code.error_label = old_error_label catch_label = code.new_label() code.putln( "/*try:*/ {") diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 79d6e369..8283094a 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -897,76 +897,6 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations): return self.visit_with_directives(node.body, directive_dict) return self.visit_Node(node) -class WithTransform(CythonTransform, SkipDeclarations): - - # EXCINFO is manually set to a variable that contains - # the exc_info() tuple that can be generated by the enclosing except - # statement. - template_without_target = TreeFragment(u""" - MGR = EXPR - EXIT = MGR.__exit__ - MGR.__enter__() - EXC = True - try: - try: - BODY - except: - EXC = False - if not EXIT(*EXCINFO): - raise - finally: - if EXC: - EXIT(None, None, None) - """, temps=[u'MGR', u'EXC', u"EXIT"], - pipeline=[NormalizeTree(None)]) - - template_with_target = TreeFragment(u""" - MGR = EXPR - EXIT = MGR.__exit__ - VALUE = MGR.__enter__() - EXC = True - try: - try: - TARGET = VALUE - BODY - except: - EXC = False - if not EXIT(*EXCINFO): - raise - finally: - if EXC: - EXIT(None, None, None) - MGR = EXIT = VALUE = None - """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"], - pipeline=[NormalizeTree(None)]) - - def visit_WithStatNode(self, node): - exc_info = ResultRefNode(pos=node.pos, type=Builtin.tuple_type, may_hold_none=False) - self.visitchildren(node, ['body']) - if node.target is not None: - result = self.template_with_target.substitute({ - u'EXPR' : node.manager, - u'BODY' : node.body, - u'TARGET' : node.target, - u'EXCINFO' : exc_info, - }, pos=node.pos) - else: - result = self.template_without_target.substitute({ - u'EXPR' : node.manager, - u'BODY' : node.body, - u'EXCINFO' : exc_info, - }, pos=node.pos) - - # Set except excinfo target to EXCINFO - try_except = result.body.stats[-1].body.stats[-1] - try_except.except_clauses[0].excinfo_target = exc_info - - return result - - def visit_ExprNode(self, node): - # With statements are never inside expressions. - return node - class DecoratorTransform(CythonTransform, SkipDeclarations): diff --git a/runtests.py b/runtests.py index 1aed0fcc..791f7156 100644 --- a/runtests.py +++ b/runtests.py @@ -98,6 +98,7 @@ VER_DEP_MODULES = { ]), (2,6) : (operator.lt, lambda x: x in ['run.print_function', 'run.cython3', + 'run.withstat_py', 'run.generators_py', # generators, with statement 'run.pure_py', # decorators, with statement ]), diff --git a/tests/run/withstat.pyx b/tests/run/withstat.pyx index ad037e28..67bef982 100644 --- a/tests/run/withstat.pyx +++ b/tests/run/withstat.pyx @@ -56,17 +56,6 @@ def with_pass(): with ContextManager(u"value") as x: pass -def with_return(): - """ - >>> with_return() - enter - exit - """ - with ContextManager(u"value") as x: - # FIXME: DISABLED - currently crashes!! - # return x - pass - def with_exception(exit_ret): """ >>> with_exception(None) diff --git a/tests/run/withstat_py.py b/tests/run/withstat_py.py new file mode 100644 index 00000000..3df6ad63 --- /dev/null +++ b/tests/run/withstat_py.py @@ -0,0 +1,252 @@ +import sys + +def typename(t): + name = type(t).__name__ + if sys.version_info < (2,5): + if name == 'classobj' and issubclass(t, MyException): + name = 'type' + elif name == 'instance' and isinstance(t, MyException): + name = 'MyException' + return "" % name + +class MyException(Exception): + pass + +class ContextManager(object): + def __init__(self, value, exit_ret = None): + self.value = value + self.exit_ret = exit_ret + + def __exit__(self, a, b, tb): + print("exit %s %s %s" % (typename(a), typename(b), typename(tb))) + return self.exit_ret + + def __enter__(self): + print("enter") + return self.value + +def no_as(): + """ + >>> no_as() + enter + hello + exit + """ + with ContextManager("value"): + print("hello") + +def basic(): + """ + >>> basic() + enter + value + exit + """ + with ContextManager("value") as x: + print(x) + +def with_pass(): + """ + >>> with_pass() + enter + exit + """ + with ContextManager("value") as x: + pass + +def with_return(): + """ + >>> print(with_return()) + enter + exit + value + """ + with ContextManager("value") as x: + return x + +def with_break(): + """ + >>> print(with_break()) + enter + exit + a + """ + for c in list("abc"): + with ContextManager("value") as x: + break + print("FAILED") + return c + +def with_continue(): + """ + >>> print(with_continue()) + enter + exit + enter + exit + enter + exit + c + """ + for c in list("abc"): + with ContextManager("value") as x: + continue + print("FAILED") + return c + +def with_exception(exit_ret): + """ + >>> with_exception(None) + enter + value + exit + outer except + >>> with_exception(True) + enter + value + exit + """ + try: + with ContextManager("value", exit_ret=exit_ret) as value: + print(value) + raise MyException() + except: + print("outer except") + +def multitarget(): + """ + >>> multitarget() + enter + 1 2 3 4 5 + exit + """ + with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))): + print('%s %s %s %s %s' % (a, b, c, d, e)) + +def tupletarget(): + """ + >>> tupletarget() + enter + (1, 2, (3, (4, 5))) + exit + """ + with ContextManager((1, 2, (3, (4, 5)))) as t: + print(t) + +def multimanager(): + """ + >>> multimanager() + enter + enter + enter + enter + enter + enter + 2 + value + 1 2 3 4 5 + nested + exit + exit + exit + exit + exit + exit + """ + with ContextManager(1), ContextManager(2) as x, ContextManager('value') as y,\ + ContextManager(3), ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))): + with ContextManager('nested') as nested: + print(x) + print(y) + print('%s %s %s %s %s' % (a, b, c, d, e)) + print(nested) + +# Tests borrowed from pyregr test_with.py, +# modified to follow the constraints of Cython. +import unittest + +class Dummy(object): + def __init__(self, value=None, gobble=False): + if value is None: + value = self + self.value = value + self.gobble = gobble + self.enter_called = False + self.exit_called = False + + def __enter__(self): + self.enter_called = True + return self.value + + def __exit__(self, *exc_info): + self.exit_called = True + self.exc_info = exc_info + if self.gobble: + return True + +class InitRaises(object): + def __init__(self): raise RuntimeError() + +class EnterRaises(object): + def __enter__(self): raise RuntimeError() + def __exit__(self, *exc_info): pass + +class ExitRaises(object): + def __enter__(self): pass + def __exit__(self, *exc_info): raise RuntimeError() + +class NestedWith(unittest.TestCase): + """ + >>> NestedWith().runTest() + """ + + def runTest(self): + self.testNoExceptions() + self.testExceptionInExprList() + self.testExceptionInEnter() + self.testExceptionInExit() + self.testEnterReturnsTuple() + + def testNoExceptions(self): + with Dummy() as a, Dummy() as b: + self.assertTrue(a.enter_called) + self.assertTrue(b.enter_called) + self.assertTrue(a.exit_called) + self.assertTrue(b.exit_called) + + def testExceptionInExprList(self): + try: + with Dummy() as a, InitRaises(): + pass + except: + pass + self.assertTrue(a.enter_called) + self.assertTrue(a.exit_called) + + def testExceptionInEnter(self): + try: + with Dummy() as a, EnterRaises(): + self.fail('body of bad with executed') + except RuntimeError: + pass + else: + self.fail('RuntimeError not reraised') + self.assertTrue(a.enter_called) + self.assertTrue(a.exit_called) + + def testExceptionInExit(self): + body_executed = False + with Dummy(gobble=True) as a, ExitRaises(): + body_executed = True + self.assertTrue(a.enter_called) + self.assertTrue(a.exit_called) + self.assertTrue(body_executed) + self.assertNotEqual(a.exc_info[0], None) + + def testEnterReturnsTuple(self): + with Dummy(value=(1,2)) as (a1, a2), \ + Dummy(value=(10, 20)) as (b1, b2): + self.assertEquals(1, a1) + self.assertEquals(2, a2) + self.assertEquals(10, b1) + self.assertEquals(20, b2) -- 2.26.2