reimplement 'with' statement using dedicated nodes, instead of a generic transform
authorStefan Behnel <scoder@users.berlios.de>
Sun, 24 Apr 2011 22:15:34 +0000 (00:15 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 24 Apr 2011 22:15:34 +0000 (00:15 +0200)
--HG--
rename : tests/run/withstat.pyx => tests/run/withstat_py.py

Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
runtests.py
tests/run/withstat.pyx
tests/run/withstat_py.py [new file with mode: 0644]

index 8c2010611dbb1af9985b114b285a710f7f9a55cd..5a5f0c2c07280a33d5870abc0f8e918315140a40 100755 (executable)
@@ -1910,6 +1910,40 @@ class NextNode(AtomicExprNode):
         code.putln("}")
 
 
+class WithExitCallNode(ExprNode):
+    # The __exit__() call of a 'with' statement.  Used in both the
+    # except and finally clauses.
+
+    # with_stat  WithStatNode                the surrounding 'with' statement
+    # args       TupleNode or ResultStatNode the exception info tuple
+
+    subexprs = ['args']
+
+    def analyse_types(self, env):
+        self.args.analyse_types(env)
+        self.type = PyrexTypes.c_bint_type
+        self.is_temp = True
+
+    def generate_result_code(self, code):
+        if isinstance(self.args, TupleNode):
+            # call only if it was not already called (and decref-cleared)
+            code.putln("if (%s) {" % self.with_stat.exit_var)
+        result_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False)
+        code.putln("%s = PyObject_Call(%s, %s, NULL);" % (
+            result_var,
+            self.with_stat.exit_var,
+            self.args.result()))
+        code.put_decref_clear(self.with_stat.exit_var, type=py_object_type)
+        code.putln(code.error_goto_if_null(result_var, self.pos))
+        code.put_gotref(result_var)
+        code.putln("%s = __Pyx_PyObject_IsTrue(%s);" % (self.result(), result_var))
+        code.put_decref_clear(result_var, type=py_object_type)
+        code.putln(code.error_goto_if_neg(self.result(), self.pos))
+        code.funcstate.release_temp(result_var)
+        if isinstance(self.args, TupleNode):
+            code.putln("}")
+
+
 class ExcValueNode(AtomicExprNode):
     #  Node created during analyse_types phase
     #  of an ExceptClauseNode to fetch the current
index fb8d6fcb02d5d9ffdc7fea9b4774e763164ac3e4..34d8b22f922b05c422e242044731f15c289d7820 100644 (file)
@@ -102,7 +102,7 @@ class Context(object):
 
     def create_pipeline(self, pxd, py=False):
         from Visitor import PrintTree
-        from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, PxdPostParse
+        from ParseTreeTransforms import NormalizeTree, PostParse, PxdPostParse
         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
@@ -139,7 +139,6 @@ class Context(object):
             _align_function_definitions,
             ConstantFolding(),
             FlattenInListTransform(),
-            WithTransform(self),
             DecoratorTransform(self),
             AnalyseDeclarationsTransform(self),
             AutoTestDictTransform(self),
index 5358a3c3972554ef13b1c5b2e0d68ff2869b3571..8007d83a98d9afc553f36f9d0cba01672d20a1ab 100644 (file)
@@ -4835,13 +4835,177 @@ class WithStatNode(StatNode):
     """
     Represents a Python with statement.
 
-    This is only used at parse tree level; and is not present in
-    analysis or generation phases.
+    Implemented as follows:
+
+        MGR = EXPR
+        EXIT = MGR.__exit__
+        VALUE = MGR.__enter__()
+        EXC = True
+        try:
+            try:
+                TARGET = VALUE  # optional
+                BODY
+            except:
+                EXC = False
+                if not EXIT(*EXCINFO):
+                    raise
+        finally:
+            if EXC:
+                EXIT(None, None, None)
+            MGR = EXIT = VALUE = None
     """
     #  manager          The with statement manager object
-    #  target            Node (lhs expression)
     #  body             StatNode
-    child_attrs = ["manager", "target", "body"]
+
+    child_attrs = ["manager", "body"]
+
+    has_target = False
+
+    def __init__(self, pos, manager, target, body):
+        StatNode.__init__(self, pos, manager = manager)
+
+        import ExprNodes
+        self.target_temp = ExprNodes.TempNode(pos, type=py_object_type)
+        if target is not None:
+            self.has_target = True
+            body = StatListNode(
+                pos, stats = [
+                    WithTargetAssignmentStatNode(
+                        pos, lhs = target, rhs = self.target_temp),
+                    body
+                    ])
+
+        import UtilNodes
+        excinfo_target = UtilNodes.ResultRefNode(
+            pos=pos, type=Builtin.tuple_type, may_hold_none=False)
+        except_clause = ExceptClauseNode(
+            pos, body = IfStatNode(
+                pos, if_clauses = [
+                    IfClauseNode(
+                        pos, condition = ExprNodes.NotNode(
+                            pos, operand = ExprNodes.WithExitCallNode(
+                                pos, with_stat = self,
+                                args = excinfo_target)),
+                        body = ReraiseStatNode(pos),
+                        ),
+                    ],
+                else_clause = None),
+            pattern = None,
+            target = None,
+            excinfo_target = excinfo_target,
+            )
+
+        self.body = TryFinallyStatNode(
+            pos, body = TryExceptStatNode(
+                pos, body = body,
+                except_clauses = [except_clause],
+                else_clause = None,
+                ),
+            finally_clause = ExprStatNode(
+                pos, expr = ExprNodes.WithExitCallNode(
+                    pos, with_stat = self,
+                    args = ExprNodes.TupleNode(
+                        pos, args = [ExprNodes.NoneNode(pos) for _ in range(3)]
+                        ))),
+            handle_error_case = False,
+            )
+
+    def analyse_declarations(self, env):
+        self.manager.analyse_declarations(env)
+        self.body.analyse_declarations(env)
+
+    def analyse_expressions(self, env):
+        self.manager.analyse_types(env)
+        self.body.analyse_expressions(env)
+
+    def generate_execution_code(self, code):
+        code.putln("/*with:*/ {")
+        self.manager.generate_evaluation_code(code)
+        self.exit_var = code.funcstate.allocate_temp(py_object_type, manage_ref=False)
+        code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (
+            self.exit_var,
+            self.manager.py_result(),
+            code.get_py_string_const(EncodedString('__exit__'), identifier=True),
+            code.error_goto_if_null(self.exit_var, self.pos),
+            ))
+        code.put_gotref(self.exit_var)
+
+        # need to free exit_var in the face of exceptions during setup
+        old_error_label = code.new_error_label()
+        intermediate_error_label = code.error_label
+
+        enter_func = code.funcstate.allocate_temp(py_object_type, manage_ref=True)
+        code.putln("%s = PyObject_GetAttr(%s, %s); %s" % (
+            enter_func,
+            self.manager.py_result(),
+            code.get_py_string_const(EncodedString('__enter__'), identifier=True),
+            code.error_goto_if_null(enter_func, self.pos),
+            ))
+        code.put_gotref(enter_func)
+        self.manager.generate_disposal_code(code)
+        self.manager.free_temps(code)
+        self.target_temp.allocate(code)
+        code.putln('%s = PyObject_Call(%s, ((PyObject *)%s), NULL); %s' % (
+            self.target_temp.result(),
+            enter_func,
+            Naming.empty_tuple,
+            code.error_goto_if_null(self.target_temp.result(), self.pos),
+            ))
+        code.put_gotref(self.target_temp.result())
+        code.put_decref_clear(enter_func, py_object_type)
+        code.funcstate.release_temp(enter_func)
+        if not self.has_target:
+            code.put_decref_clear(self.target_temp.result(), type=py_object_type)
+            self.target_temp.release(code)
+            # otherwise, WithTargetAssignmentStatNode will do it for us
+
+        code.error_label = old_error_label
+        self.body.generate_execution_code(code)
+
+        step_over_label = code.new_label()
+        code.put_goto(step_over_label)
+        code.put_label(intermediate_error_label)
+        code.put_decref_clear(self.exit_var, py_object_type)
+        code.put_goto(old_error_label)
+        code.put_label(step_over_label)
+
+        code.funcstate.release_temp(self.exit_var)
+        code.putln('}')
+
+class WithTargetAssignmentStatNode(AssignmentNode):
+    # The target assignment of the 'with' statement value (return
+    # value of the __enter__() call).
+    #
+    # This is a special cased assignment that steals the RHS reference
+    # and frees its temp.
+    #
+    # lhs  ExprNode  the assignment target
+    # rhs  TempNode  the return value of the __enter__() call
+
+    child_attrs = ["lhs", "rhs"]
+
+    def analyse_declarations(self, env):
+        self.lhs.analyse_target_declaration(env)
+
+    def analyse_types(self, env):
+        self.rhs.analyse_types(env)
+        self.lhs.analyse_target_types(env)
+        self.lhs.gil_assignment_check(env)
+        self.orig_rhs = self.rhs
+        self.rhs = self.rhs.coerce_to(self.lhs.type, env)
+
+    def generate_execution_code(self, code):
+        self.rhs.generate_evaluation_code(code)
+        self.lhs.generate_assignment_code(self.rhs, code)
+        self.orig_rhs.release(code)
+
+    def generate_function_definitions(self, env, code):
+        self.rhs.generate_function_definitions(env, code)
+
+    def annotate(self, code):
+        self.lhs.annotate(code)
+        self.rhs.annotate(code)
+
 
 class TryExceptStatNode(StatNode):
     #  try .. except statement
@@ -5175,6 +5339,9 @@ class TryFinallyStatNode(StatNode):
 
     preserve_exception = 1
 
+    # handle exception case, in addition to return/break/continue
+    handle_error_case = True
+
     disallow_continue_in_try_finally = 0
     # There doesn't seem to be any point in disallowing
     # continue in the try block, since we have no problem
@@ -5208,6 +5375,8 @@ class TryFinallyStatNode(StatNode):
         old_labels = code.all_new_labels()
         new_labels = code.get_all_labels()
         new_error_label = code.error_label
+        if not self.handle_error_case:
+            code.error_label = old_error_label
         catch_label = code.new_label()
         code.putln(
             "/*try:*/ {")
index 79d6e3690e135c3a361fc92b446f8b98336c8185..8283094a8530934d0b49fb30f3249a51e01907bb 100644 (file)
@@ -897,76 +897,6 @@ class InterpretCompilerDirectives(CythonTransform, SkipDeclarations):
             return self.visit_with_directives(node.body, directive_dict)
         return self.visit_Node(node)
 
-class WithTransform(CythonTransform, SkipDeclarations):
-
-    # EXCINFO is manually set to a variable that contains
-    # the exc_info() tuple that can be generated by the enclosing except
-    # statement.
-    template_without_target = TreeFragment(u"""
-        MGR = EXPR
-        EXIT = MGR.__exit__
-        MGR.__enter__()
-        EXC = True
-        try:
-            try:
-                BODY
-            except:
-                EXC = False
-                if not EXIT(*EXCINFO):
-                    raise
-        finally:
-            if EXC:
-                EXIT(None, None, None)
-    """, temps=[u'MGR', u'EXC', u"EXIT"],
-    pipeline=[NormalizeTree(None)])
-
-    template_with_target = TreeFragment(u"""
-        MGR = EXPR
-        EXIT = MGR.__exit__
-        VALUE = MGR.__enter__()
-        EXC = True
-        try:
-            try:
-                TARGET = VALUE
-                BODY
-            except:
-                EXC = False
-                if not EXIT(*EXCINFO):
-                    raise
-        finally:
-            if EXC:
-                EXIT(None, None, None)
-            MGR = EXIT = VALUE = None
-    """, temps=[u'MGR', u'EXC', u"EXIT", u"VALUE"],
-    pipeline=[NormalizeTree(None)])
-
-    def visit_WithStatNode(self, node):
-        exc_info = ResultRefNode(pos=node.pos, type=Builtin.tuple_type, may_hold_none=False)
-        self.visitchildren(node, ['body'])
-        if node.target is not None:
-            result = self.template_with_target.substitute({
-                u'EXPR' : node.manager,
-                u'BODY' : node.body,
-                u'TARGET' : node.target,
-                u'EXCINFO' : exc_info,
-                }, pos=node.pos)
-        else:
-            result = self.template_without_target.substitute({
-                u'EXPR' : node.manager,
-                u'BODY' : node.body,
-                u'EXCINFO' : exc_info,
-                }, pos=node.pos)
-
-        # Set except excinfo target to EXCINFO
-        try_except = result.body.stats[-1].body.stats[-1]
-        try_except.except_clauses[0].excinfo_target = exc_info
-
-        return result
-
-    def visit_ExprNode(self, node):
-        # With statements are never inside expressions.
-        return node
-
 
 class DecoratorTransform(CythonTransform, SkipDeclarations):
 
index 1aed0fcc7f81208276431aa930480f66bb945700..791f7156a49f9f7fd307a38696befb40dbe9a90d 100644 (file)
@@ -98,6 +98,7 @@ VER_DEP_MODULES = {
                                           ]),
     (2,6) : (operator.lt, lambda x: x in ['run.print_function',
                                           'run.cython3',
+                                          'run.withstat_py',
                                           'run.generators_py', # generators, with statement
                                           'run.pure_py', # decorators, with statement
                                           ]),
index ad037e2880bd8664d1fb4fae2f2fff6200cb0ae9..67bef98276f630231f111c8ceef7882fa1f50518 100644 (file)
@@ -56,17 +56,6 @@ def with_pass():
     with ContextManager(u"value") as x:
         pass
 
-def with_return():
-    """
-    >>> with_return()
-    enter
-    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
-    """
-    with ContextManager(u"value") as x:
-        # FIXME: DISABLED - currently crashes!!
-        # return x
-        pass
-
 def with_exception(exit_ret):
     """
     >>> with_exception(None)
diff --git a/tests/run/withstat_py.py b/tests/run/withstat_py.py
new file mode 100644 (file)
index 0000000..3df6ad6
--- /dev/null
@@ -0,0 +1,252 @@
+import sys
+
+def typename(t):
+    name = type(t).__name__
+    if sys.version_info < (2,5):
+        if name == 'classobj' and issubclass(t, MyException):
+            name = 'type'
+        elif name == 'instance' and isinstance(t, MyException):
+            name = 'MyException'
+    return "<type '%s'>" % name
+
+class MyException(Exception):
+    pass
+
+class ContextManager(object):
+    def __init__(self, value, exit_ret = None):
+        self.value = value
+        self.exit_ret = exit_ret
+
+    def __exit__(self, a, b, tb):
+        print("exit %s %s %s" % (typename(a), typename(b), typename(tb)))
+        return self.exit_ret
+
+    def __enter__(self):
+        print("enter")
+        return self.value
+
+def no_as():
+    """
+    >>> no_as()
+    enter
+    hello
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    """
+    with ContextManager("value"):
+        print("hello")
+
+def basic():
+    """
+    >>> basic()
+    enter
+    value
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    """
+    with ContextManager("value") as x:
+        print(x)
+
+def with_pass():
+    """
+    >>> with_pass()
+    enter
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    """
+    with ContextManager("value") as x:
+        pass
+
+def with_return():
+    """
+    >>> print(with_return())
+    enter
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    value
+    """
+    with ContextManager("value") as x:
+        return x
+
+def with_break():
+    """
+    >>> print(with_break())
+    enter
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    a
+    """
+    for c in list("abc"):
+        with ContextManager("value") as x:
+            break
+        print("FAILED")
+    return c
+
+def with_continue():
+    """
+    >>> print(with_continue())
+    enter
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    enter
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    enter
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    c
+    """
+    for c in list("abc"):
+        with ContextManager("value") as x:
+            continue
+        print("FAILED")
+    return c
+
+def with_exception(exit_ret):
+    """
+    >>> with_exception(None)
+    enter
+    value
+    exit <type 'type'> <type 'MyException'> <type 'traceback'>
+    outer except
+    >>> with_exception(True)
+    enter
+    value
+    exit <type 'type'> <type 'MyException'> <type 'traceback'>
+    """
+    try:
+        with ContextManager("value", exit_ret=exit_ret) as value:
+            print(value)
+            raise MyException()
+    except:
+        print("outer except")
+
+def multitarget():
+    """
+    >>> multitarget()
+    enter
+    1 2 3 4 5
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    """
+    with ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
+        print('%s %s %s %s %s' % (a, b, c, d, e))
+
+def tupletarget():
+    """
+    >>> tupletarget()
+    enter
+    (1, 2, (3, (4, 5)))
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    """
+    with ContextManager((1, 2, (3, (4, 5)))) as t:
+        print(t)
+
+def multimanager():
+    """
+    >>> multimanager()
+    enter
+    enter
+    enter
+    enter
+    enter
+    enter
+    2
+    value
+    1 2 3 4 5
+    nested
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    exit <type 'NoneType'> <type 'NoneType'> <type 'NoneType'>
+    """
+    with ContextManager(1), ContextManager(2) as x, ContextManager('value') as y,\
+            ContextManager(3), ContextManager((1, 2, (3, (4, 5)))) as (a, b, (c, (d, e))):
+        with ContextManager('nested') as nested:
+            print(x)
+            print(y)
+            print('%s %s %s %s %s' % (a, b, c, d, e))
+            print(nested)
+
+# Tests borrowed from pyregr test_with.py,
+# modified to follow the constraints of Cython.
+import unittest
+
+class Dummy(object):
+    def __init__(self, value=None, gobble=False):
+        if value is None:
+            value = self
+        self.value = value
+        self.gobble = gobble
+        self.enter_called = False
+        self.exit_called = False
+
+    def __enter__(self):
+        self.enter_called = True
+        return self.value
+
+    def __exit__(self, *exc_info):
+        self.exit_called = True
+        self.exc_info = exc_info
+        if self.gobble:
+            return True
+
+class InitRaises(object):
+    def __init__(self): raise RuntimeError()
+
+class EnterRaises(object):
+    def __enter__(self): raise RuntimeError()
+    def __exit__(self, *exc_info): pass
+
+class ExitRaises(object):
+    def __enter__(self): pass
+    def __exit__(self, *exc_info): raise RuntimeError()
+
+class NestedWith(unittest.TestCase):
+    """
+    >>> NestedWith().runTest()
+    """
+
+    def runTest(self):
+        self.testNoExceptions()
+        self.testExceptionInExprList()
+        self.testExceptionInEnter()
+        self.testExceptionInExit()
+        self.testEnterReturnsTuple()
+
+    def testNoExceptions(self):
+        with Dummy() as a, Dummy() as b:
+            self.assertTrue(a.enter_called)
+            self.assertTrue(b.enter_called)
+        self.assertTrue(a.exit_called)
+        self.assertTrue(b.exit_called)
+
+    def testExceptionInExprList(self):
+        try:
+            with Dummy() as a, InitRaises():
+                pass
+        except:
+            pass
+        self.assertTrue(a.enter_called)
+        self.assertTrue(a.exit_called)
+
+    def testExceptionInEnter(self):
+        try:
+            with Dummy() as a, EnterRaises():
+                self.fail('body of bad with executed')
+        except RuntimeError:
+            pass
+        else:
+            self.fail('RuntimeError not reraised')
+        self.assertTrue(a.enter_called)
+        self.assertTrue(a.exit_called)
+
+    def testExceptionInExit(self):
+        body_executed = False
+        with Dummy(gobble=True) as a, ExitRaises():
+            body_executed = True
+        self.assertTrue(a.enter_called)
+        self.assertTrue(a.exit_called)
+        self.assertTrue(body_executed)
+        self.assertNotEqual(a.exc_info[0], None)
+
+    def testEnterReturnsTuple(self):
+        with Dummy(value=(1,2)) as (a1, a2), \
+             Dummy(value=(10, 20)) as (b1, b2):
+            self.assertEquals(1, a1)
+            self.assertEquals(2, a2)
+            self.assertEquals(10, b1)
+            self.assertEquals(20, b2)