clean up comprehensions to bring them closer to generator expressions, make their...
authorStefan Behnel <scoder@users.berlios.de>
Thu, 27 May 2010 13:34:15 +0000 (15:34 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 27 May 2010 13:34:15 +0000 (15:34 +0200)
remove optimisations for set([...]) and dict([...]) as they do not take side-effects into account: unhashable items lead to pre-mature exit from the loop
instead, transform set(genexp), list(genexp) and dict(genexp) into inlined comprehensions that do not leak loop variables

Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/Parsing.py
Cython/Compiler/Visitor.py
tests/run/dictcomp.pyx
tests/run/listcomp.pyx
tests/run/setcomp.pyx

index 3267022f0db85d19d48fb5b1e6259e34d92c282b..7f2da37ea22d33d509c6eb18f64b8a2bb988ea5c 100755 (executable)
@@ -3919,27 +3919,45 @@ class ScopedExprNode(ExprNode):
         # this is called with the expr_scope as env
         pass
 
+    def init_scope(self, outer_scope, expr_scope=None):
+        self.expr_scope = expr_scope
 
-class ComprehensionNode(ExprNode):    # (ScopedExprNode)
+
+class ComprehensionNode(ScopedExprNode):
     subexprs = ["target"]
     child_attrs = ["loop", "append"]
 
+    # different behaviour in Py2 and Py3: leak loop variables or not?
+    has_local_scope = False # Py2 behaviour as default
+
     def infer_type(self, env):
         return self.target.infer_type(env)
 
     def analyse_declarations(self, env):
         self.append.target = self # this is used in the PyList_Append of the inner loop
-        self.loop.analyse_declarations(env)
-#        self.expr_scope = Symtab.GeneratorExpressionScope(env)
-#        self.loop.analyse_declarations(self.expr_scope)
+        self.init_scope(env)
+        if self.expr_scope is not None:
+            self.loop.analyse_declarations(self.expr_scope)
+        else:
+            self.loop.analyse_declarations(env)
+
+    def init_scope(self, outer_scope, expr_scope=None):
+        if expr_scope is not None:
+            self.expr_scope = expr_scope
+        elif self.has_local_scope:
+            self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
+        else:
+            self.expr_scope = None
 
     def analyse_types(self, env):
         self.target.analyse_expressions(env)
         self.type = self.target.type
-        self.loop.analyse_expressions(env)
+        if not self.has_local_scope:
+            self.loop.analyse_expressions(env)
 
-#    def analyse_scoped_expressions(self, env):
-#        self.loop.analyse_expressions(env)
+    def analyse_scoped_expressions(self, env):
+        if self.has_local_scope:
+            self.loop.analyse_expressions(env)
 
     def may_be_none(self):
         return False
@@ -3957,20 +3975,20 @@ class ComprehensionNode(ExprNode):    # (ScopedExprNode)
         self.loop.annotate(code)
 
 
-class ComprehensionAppendNode(ExprNode):
+class ComprehensionAppendNode(Node):
     # Need to be careful to avoid infinite recursion:
     # target must not be in child_attrs/subexprs
-    subexprs = ['expr']
+
+    child_attrs = ['expr']
 
     type = PyrexTypes.c_int_type
     
-    def analyse_types(self, env):
-        self.expr.analyse_types(env)
+    def analyse_expressions(self, env):
+        self.expr.analyse_expressions(env)
         if not self.expr.type.is_pyobject:
             self.expr = self.expr.coerce_to_pyobject(env)
-        self.is_temp = 1
 
-    def generate_result_code(self, code):
+    def generate_execution_code(self, code):
         if self.target.type is list_type:
             function = "PyList_Append"
         elif self.target.type is set_type:
@@ -3978,33 +3996,53 @@ class ComprehensionAppendNode(ExprNode):
         else:
             raise InternalError(
                 "Invalid type for comprehension node: %s" % self.target.type)
-            
-        code.putln("%s = %s(%s, (PyObject*)%s); %s" %
-            (self.result(),
-             function,
-             self.target.result(),
-             self.expr.result(),
-             code.error_goto_if(self.result(), self.pos)))
+
+        self.expr.generate_evaluation_code(code)
+        code.putln(code.error_goto_if("%s(%s, (PyObject*)%s)" % (
+            function,
+            self.target.result(),
+            self.expr.result()
+            ), self.pos))
+        self.expr.generate_disposal_code(code)
+        self.expr.free_temps(code)
+
+    def generate_function_definitions(self, env, code):
+        self.expr.generate_function_definitions(env, code)
+
+    def annotate(self, code):
+        self.expr.annotate(code)
 
 class DictComprehensionAppendNode(ComprehensionAppendNode):
-    subexprs = ['key_expr', 'value_expr']
+    child_attrs = ['key_expr', 'value_expr']
 
-    def analyse_types(self, env):
-        self.key_expr.analyse_types(env)
+    def analyse_expressions(self, env):
+        self.key_expr.analyse_expressions(env)
         if not self.key_expr.type.is_pyobject:
             self.key_expr = self.key_expr.coerce_to_pyobject(env)
-        self.value_expr.analyse_types(env)
+        self.value_expr.analyse_expressions(env)
         if not self.value_expr.type.is_pyobject:
             self.value_expr = self.value_expr.coerce_to_pyobject(env)
-        self.is_temp = 1
 
-    def generate_result_code(self, code):
-        code.putln("%s = PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s); %s" %
-            (self.result(),
-             self.target.result(),
-             self.key_expr.result(),
-             self.value_expr.result(),
-             code.error_goto_if(self.result(), self.pos)))
+    def generate_execution_code(self, code):
+        self.key_expr.generate_evaluation_code(code)
+        self.value_expr.generate_evaluation_code(code)
+        code.putln(code.error_goto_if("PyDict_SetItem(%s, (PyObject*)%s, (PyObject*)%s)" % (
+            self.target.result(),
+            self.key_expr.result(),
+            self.value_expr.result()
+            ), self.pos))
+        self.key_expr.generate_disposal_code(code)
+        self.key_expr.free_temps(code)
+        self.value_expr.generate_disposal_code(code)
+        self.value_expr.free_temps(code)
+
+    def generate_function_definitions(self, env, code):
+        self.key_expr.generate_function_definitions(env, code)
+        self.value_expr.generate_function_definitions(env, code)
+
+    def annotate(self, code):
+        self.key_expr.annotate(code)
+        self.value_expr.annotate(code)
 
 
 class GeneratorExpressionNode(ScopedExprNode):
@@ -4019,9 +4057,15 @@ class GeneratorExpressionNode(ScopedExprNode):
     type = py_object_type
 
     def analyse_declarations(self, env):
-        self.expr_scope = Symtab.GeneratorExpressionScope(env)
+        self.init_scope(env)
         self.loop.analyse_declarations(self.expr_scope)
 
+    def init_scope(self, outer_scope, expr_scope=None):
+        if expr_scope is not None:
+            self.expr_scope = expr_scope
+        else:
+            self.expr_scope = Symtab.GeneratorExpressionScope(outer_scope)
+
     def analyse_types(self, env):
         self.is_temp = True
 
index b8e9f74791dedbffbc32d3a13b083c8ce62206d6..051a0b05dab3dc555b0ee2a91f9dd029291c8e79 100644 (file)
@@ -1022,52 +1022,6 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
 
     # specific handlers for simple call nodes
 
-    def _handle_simple_function_set(self, node, pos_args):
-        """Replace set([a,b,...]) by a literal set {a,b,...} and
-        set([ x for ... ]) by a literal { x for ... }.
-        """
-        arg_count = len(pos_args)
-        if arg_count == 0:
-            return ExprNodes.SetNode(node.pos, args=[],
-                                     type=Builtin.set_type)
-        if arg_count > 1:
-            return node
-        iterable = pos_args[0]
-        if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
-            return ExprNodes.SetNode(node.pos, args=iterable.args)
-        elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
-                 isinstance(iterable.target, (ExprNodes.ListNode,
-                                              ExprNodes.SetNode)):
-            iterable.target = ExprNodes.SetNode(node.pos, args=[])
-            iterable.pos = node.pos
-            return iterable
-        else:
-            return node
-
-    def _handle_simple_function_dict(self, node, pos_args):
-        """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
-        """
-        if len(pos_args) != 1:
-            return node
-        arg = pos_args[0]
-        if isinstance(arg, ExprNodes.ComprehensionNode) and \
-               isinstance(arg.target, (ExprNodes.ListNode,
-                                       ExprNodes.SetNode)):
-            append_node = arg.append
-            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
-                   len(append_node.expr.args) == 2:
-                key_node, value_node = append_node.expr.args
-                target_node = ExprNodes.DictNode(
-                    pos=arg.target.pos, key_value_pairs=[])
-                new_append_node = ExprNodes.DictComprehensionAppendNode(
-                    append_node.pos, target=target_node,
-                    key_expr=key_node, value_expr=value_node)
-                arg.target = target_node
-                arg.type = target_node.type
-                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
-                return replace_in(arg)
-        return node
-
     def _handle_simple_function_float(self, node, pos_args):
         if len(pos_args) == 0:
             return ExprNodes.FloatNode(node.pos, value='0.0')
@@ -1182,7 +1136,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             rhs = ExprNodes.BoolNode(yield_node.pos, value = not is_any,
                                      constant_result = not is_any))
 
-        Visitor.RecursiveNodeReplacer(yield_node, test_node).visitchildren(loop_node)
+        Visitor.recursively_replace_node(loop_node, yield_node, test_node)
 
         return ExprNodes.InlinedGeneratorExpressionNode(
             gen_expr_node.pos, loop = loop_node, result_node = result_ref,
@@ -1215,7 +1169,7 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             rhs = ExprNodes.binop_node(node.pos, '+', result_ref, yield_expression)
             )
 
-        Visitor.RecursiveNodeReplacer(yield_node, add_node).visitchildren(loop_node)
+        Visitor.recursively_replace_node(loop_node, yield_node, add_node)
 
         exec_code = Nodes.StatListNode(
             node.pos,
@@ -1232,6 +1186,113 @@ class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
             gen_expr_node.pos, loop = exec_code, result_node = result_ref,
             expr_scope = gen_expr_node.expr_scope, orig_func = 'sum')
 
+    def _handle_simple_function_list(self, node, pos_args):
+        if len(pos_args) == 0:
+            return ExprNodes.ListNode(node.pos, args=[], constant_result=[])
+        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.ListNode)
+
+    def _handle_simple_function_set(self, node, pos_args):
+        if len(pos_args) == 0:
+            return ExprNodes.SetNode(node.pos, args=[], constant_result=set())
+        return self._transform_list_set_genexpr(node, pos_args, ExprNodes.SetNode)
+
+    def _transform_list_set_genexpr(self, node, pos_args, container_node_class):
+        """Replace set(genexpr) and list(genexpr) by a literal comprehension.
+        """
+        if len(pos_args) > 1:
+            return node
+        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+            return node
+        gen_expr_node = pos_args[0]
+        loop_node = gen_expr_node.loop
+
+        yield_node = self._find_single_yield_node(loop_node)
+        if yield_node is None:
+            return node
+        yield_expression = yield_node.arg
+
+        target_node = container_node_class(node.pos, args=[])
+        append_node = ExprNodes.ComprehensionAppendNode(
+            yield_node.pos,
+            expr = yield_expression,
+            target = ExprNodes.CloneNode(target_node),
+            is_temp = 1) # FIXME: why is this an ExprNode?
+
+        Visitor.recursively_replace_node(loop_node, yield_node, append_node)
+
+        setcomp = ExprNodes.ComprehensionNode(
+            node.pos,
+            has_local_scope = True,
+            expr_scope = gen_expr_node.expr_scope,
+            loop = loop_node,
+            append = append_node,
+            target = target_node)
+        append_node.target = setcomp
+        return setcomp
+
+    def _handle_simple_function_dict(self, node, pos_args):
+        """Replace dict( (a,b) for ... ) by a literal { a:b for ... }.
+        """
+        if len(pos_args) == 0:
+            return ExprNodes.DictNode(node.pos, key_value_pairs=[], constant_result={})
+        if len(pos_args) > 1:
+            return node
+        if not isinstance(pos_args[0], ExprNodes.GeneratorExpressionNode):
+            return node
+        gen_expr_node = pos_args[0]
+        loop_node = gen_expr_node.loop
+
+        yield_node = self._find_single_yield_node(loop_node)
+        if yield_node is None:
+            return node
+        yield_expression = yield_node.arg
+
+        if not isinstance(yield_expression, ExprNodes.TupleNode):
+            return node
+        if len(yield_expression.args) != 2:
+            return node
+
+        target_node = ExprNodes.DictNode(node.pos, key_value_pairs=[])
+        append_node = ExprNodes.DictComprehensionAppendNode(
+            yield_node.pos,
+            key_expr = yield_expression.args[0],
+            value_expr = yield_expression.args[1],
+            target = ExprNodes.CloneNode(target_node),
+            is_temp = 1) # FIXME: why is this an ExprNode?
+
+        Visitor.recursively_replace_node(loop_node, yield_node, append_node)
+
+        dictcomp = ExprNodes.ComprehensionNode(
+            node.pos,
+            has_local_scope = True,
+            expr_scope = gen_expr_node.expr_scope,
+            loop = loop_node,
+            append = append_node,
+            target = target_node)
+        append_node.target = dictcomp
+        return dictcomp
+
+
+
+        arg = pos_args[0]
+        if isinstance(arg, ExprNodes.ComprehensionNode) and \
+               isinstance(arg.target, (ExprNodes.ListNode,
+                                       ExprNodes.SetNode)):
+            append_node = arg.append
+            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
+                   len(append_node.expr.args) == 2:
+                key_node, value_node = append_node.expr.args
+                target_node = ExprNodes.DictNode(
+                    pos=arg.target.pos, key_value_pairs=[])
+                new_append_node = ExprNodes.DictComprehensionAppendNode(
+                    append_node.pos, target=target_node,
+                    key_expr=key_node, value_expr=value_node)
+                arg.target = target_node
+                arg.type = target_node.type
+                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
+                return replace_in(arg)
+        return node
+
     # specific handlers for general call nodes
 
     def _handle_general_function_dict(self, node, pos_args, kwargs):
index 7b67d485fc2632631624f3b1d7c7e9cb3ddff13f..35ebc59ae56a349c4f24719df5b1d258d30f7bb0 100644 (file)
@@ -1142,8 +1142,9 @@ class AnalyseExpressionsTransform(CythonTransform):
         return node
 
     def visit_ScopedExprNode(self, node):
-        node.expr_scope.infer_types()
-        node.analyse_scoped_expressions(node.expr_scope)
+        if node.expr_scope is not None:
+            node.expr_scope.infer_types()
+            node.analyse_scoped_expressions(node.expr_scope)
         self.visitchildren(node)
         return node
         
index 30d907eb24bd23e0697f377d3b944601a11cc61a..149841f06a81e575a8c06f6acb95b1b95b67d414 100644 (file)
@@ -777,7 +777,7 @@ def p_list_maker(s):
         target = ExprNodes.ListNode(pos, args = [])
         append = ExprNodes.ComprehensionAppendNode(
             pos, expr=expr, target=ExprNodes.CloneNode(target))
-        loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+        loop = p_comp_for(s, append)
         s.expect(']')
         return ExprNodes.ComprehensionNode(
             pos, loop=loop, append=append, target=target)
@@ -843,7 +843,7 @@ def p_dict_or_set_maker(s):
         target = ExprNodes.SetNode(pos, args=[])
         append = ExprNodes.ComprehensionAppendNode(
             item.pos, expr=item, target=ExprNodes.CloneNode(target))
-        loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+        loop = p_comp_for(s, append)
         s.expect('}')
         return ExprNodes.ComprehensionNode(
             pos, loop=loop, append=append, target=target)
@@ -858,7 +858,7 @@ def p_dict_or_set_maker(s):
             append = ExprNodes.DictComprehensionAppendNode(
                 item.pos, key_expr=key, value_expr=value,
                 target=ExprNodes.CloneNode(target))
-            loop = p_comp_for(s, Nodes.ExprStatNode(append.pos, expr=append))
+            loop = p_comp_for(s, append)
             s.expect('}')
             return ExprNodes.ComprehensionNode(
                 pos, loop=loop, append=append, target=target)
index 8cd0f001da9fd88d340d5b560be689a696a05ab9..224dbc42e6f948e26599c9b631f8c729a78f526c 100644 (file)
@@ -352,7 +352,9 @@ class RecursiveNodeReplacer(VisitorTransform):
         else:
             return node
 
-
+def recursively_replace_node(tree, old_node, new_node):
+    replace_in = RecursiveNodeReplacer(old_node, new_node)
+    replace_in(tree)
 
 
 # Utils
index 02a2cbfeecdcaeaba3683d106a3abd13f25932e1..d53310f31ea860885727cb0bbeeb0b516bd67947 100644 (file)
@@ -1,48 +1,58 @@
-__doc__ = u"""
->>> type(smoketest_dict()) is dict
-True
->>> type(smoketest_list()) is dict
-True
-
->>> sorted(smoketest_dict().items())
-[(2, 0), (4, 4), (6, 8)]
->>> sorted(smoketest_list().items())
-[(2, 0), (4, 4), (6, 8)]
-
->>> list(typed().items())
-[(A, 1), (A, 1), (A, 1)]
->>> sorted(iterdict().items())
-[(1, 'a'), (2, 'b'), (3, 'c')]
-"""
 
 cimport cython
 
-def smoketest_dict():
-    return { x+2:x*2
-             for x in range(5)
-             if x % 2 == 0 }
+def dictcomp():
+    """
+    >>> sorted(dictcomp().items())
+    [(2, 0), (4, 4), (6, 8)]
+    >>> sorted(dictcomp().items())
+    [(2, 0), (4, 4), (6, 8)]
+    """
+    x = 'abc'
+    result = { x+2:x*2
+               for x in range(5)
+               if x % 2 == 0 }
+    assert x != 'abc'
+    return result
 
 @cython.test_fail_if_path_exists(
-    "//ComprehensionNode//ComprehensionAppendNode",
-    "//SimpleCallNode//ComprehensionNode")
+    "//GeneratorExpressionNode",
+    "//SimpleCallNode")
 @cython.test_assert_path_exists(
     "//ComprehensionNode",
     "//ComprehensionNode//DictComprehensionAppendNode")
-def smoketest_list():
-    return dict([ (x+2,x*2)
-                  for x in range(5)
-                  if x % 2 == 0 ])
+def genexpr():
+    """
+    >>> type(genexpr()) is dict
+    True
+    >>> type(genexpr()) is dict
+    True
+    """
+    x = 'abc'
+    result = dict( (x+2,x*2)
+                   for x in range(5)
+                   if x % 2 == 0 )
+    assert x == 'abc'
+    return result
 
 cdef class A:
     def __repr__(self): return u"A"
     def __richcmp__(one, other, op): return one is other
     def __hash__(self): return id(self) % 65536
 
-def typed():
+def typed_dictcomp():
+    """
+    >>> list(typed_dictcomp().items())
+    [(A, 1), (A, 1), (A, 1)]
+    """
     cdef A obj
     return {obj:1 for obj in [A(), A(), A()]}
 
-def iterdict():
+def iterdict_dictcomp():
+    """
+    >>> sorted(iterdict_dictcomp().items())
+    [(1, 'a'), (2, 'b'), (3, 'c')]
+    """
     cdef dict d = dict(a=1,b=2,c=3)
     return {d[key]:key for key in d}
 
index 471971219f5119466e86254d43f8749a8bdea366..0680aa2aa332dd924c626d26f4f21d58d770e647 100644 (file)
@@ -3,7 +3,20 @@ def smoketest():
     >>> smoketest()
     [0, 4, 8]
     """
-    print [x*2 for x in range(5) if x % 2 == 0]
+    x = 'abc'
+    result = [x*2 for x in range(5) if x % 2 == 0]
+    assert x != 'abc'
+    return result
+
+def list_genexp():
+    """
+    >>> list_genexp()
+    [0, 4, 8]
+    """
+    x = 'abc'
+    result = list(x*2 for x in range(5) if x % 2 == 0)
+    assert x == 'abc'
+    return result
 
 def int_runvar():
     """
index 38398c104938cff12f9540e72e1cef1a6f115138..a89ca5154ff45e4e1d85d80c5a44544f181d6a43 100644 (file)
@@ -1,39 +1,38 @@
-__doc__ = u"""
->>> type(smoketest_set()) is not list
-True
->>> type(smoketest_set()) is _set
-True
->>> type(smoketest_list()) is _set
-True
-
->>> sorted(smoketest_set())
-[0, 4, 8]
->>> sorted(smoketest_list())
-[0, 4, 8]
-
->>> list(typed())
-[A, A, A]
->>> sorted(iterdict())
-[1, 2, 3]
-"""
 
 cimport cython
 
 # Py2.3 doesn't have the set type, but Cython does :)
 _set = set
 
-def smoketest_set():
+def setcomp():
+    """
+    >>> type(setcomp()) is not list
+    True
+    >>> type(setcomp()) is _set
+    True
+    >>> sorted(setcomp())
+    [0, 4, 8]
+    """
     return { x*2
              for x in range(5)
              if x % 2 == 0 }
 
-@cython.test_fail_if_path_exists("//SimpleCallNode//ComprehensionNode")
-@cython.test_assert_path_exists("//ComprehensionNode",
-                                "//ComprehensionNode//ComprehensionAppendNode")
-def smoketest_list():
-    return set([ x*2
+@cython.test_fail_if_path_exists(
+    "//GeneratorExpressionNode",
+    "//SimpleCallNode")
+@cython.test_assert_path_exists(
+    "//ComprehensionNode",
+    "//ComprehensionNode//ComprehensionAppendNode")
+def genexp_set():
+    """
+    >>> type(genexp_set()) is _set
+    True
+    >>> sorted(genexp_set())
+    [0, 4, 8]
+    """
+    return set( x*2
                  for x in range(5)
-                 if x % 2 == 0 ])
+                 if x % 2 == 0 )
 
 cdef class A:
     def __repr__(self): return u"A"
@@ -41,10 +40,18 @@ cdef class A:
     def __hash__(self): return id(self) % 65536
 
 def typed():
+    """
+    >>> list(typed())
+    [A, A, A]
+    """
     cdef A obj
     return {obj for obj in {A(), A(), A()}}
 
 def iterdict():
+    """
+    >>> sorted(iterdict())
+    [1, 2, 3]
+    """
     cdef dict d = dict(a=1,b=2,c=3)
     return {d[key] for key in d}