replace set([...]) by a literal set {...}
authorStefan Behnel <scoder@users.berlios.de>
Sun, 14 Dec 2008 21:15:02 +0000 (22:15 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 14 Dec 2008 21:15:02 +0000 (22:15 +0100)
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py
tests/run/set.pyx

index 1df959b5a56dd0b9bed04661610865880e8ed1bf..2a0729cb04c84de3ad31c4d3ec45708f4d17b2ed 100644 (file)
@@ -123,8 +123,8 @@ class Context:
             IntroduceBufferAuxiliaryVars(self),
             _check_c_declarations,
             AnalyseExpressionsTransform(self),
-            ConstantFolding(),
             FlattenBuiltinTypeCreation(),
+            ConstantFolding(),
             DictIterTransform(),
             SwitchTransform(),
             FinalOptimizePhase(self),
index df97cd3ec11bc9349e871b62453ca06865d11479..15b2d4ac60878e3c0d1cd349baa3b43756289349 100644 (file)
@@ -354,24 +354,64 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
     """Optimise some common instantiation patterns for builtin types.
     """
     def visit_GeneralCallNode(self, node):
+        self.visitchildren(node)
+        handler = self._find_handler('general', node.function)
+        if handler is not None:
+            node = handler(node, node.positional_args, node.keyword_args)
+        return node
+
+    def visit_SimpleCallNode(self, node):
+        self.visitchildren(node)
+        handler = self._find_handler('simple', node.function)
+        if handler is not None:
+            node = handler(node, node.arg_tuple, None)
+        return node
+
+    def _find_handler(self, call_type, function):
+        if not function.type.is_builtin_type:
+            return None
+        handler = getattr(self, '_handle_%s_%s' % (call_type, function.name), None)
+        if handler is None:
+            handler = getattr(self, '_handle_any_%s' % function.name, None)
+        return handler
+
+    def _handle_general_dict(self, node, pos_args, kwargs):
         """Replace dict(a=b,c=d,...) by the underlying keyword dict
         construction which is done anyway.
         """
-        self.visitchildren(node)
-        if not node.function.type.is_builtin_type:
+        if not isinstance(pos_args, ExprNodes.TupleNode):
             return node
-        if node.function.name != 'dict':
+        if len(pos_args.args) > 0:
             return node
-        if not isinstance(node.positional_args, ExprNodes.TupleNode):
-            return node
-        if len(node.positional_args.args) > 0:
-            return node
-        if not isinstance(node.keyword_args, ExprNodes.DictNode):
+        if not isinstance(kwargs, ExprNodes.DictNode):
             return node
         if node.starstar_arg:
             # we could optimise this by updating the kw dict instead
             return node
-        return node.keyword_args
+        return kwargs
+
+    def _handle_simple_set(self, node, pos_args, kwargs):
+        """Replace set([a,b,...]) by a literal set {a,b,...}.
+        """
+        if not isinstance(pos_args, ExprNodes.TupleNode):
+            return node
+        arg_count = len(pos_args.args)
+        if arg_count == 0:
+            return ExprNodes.SetNode(node.pos, args=[],
+                                     type=Builtin.set_type, is_temp=1)
+        if arg_count > 1:
+            return node
+        iterable = pos_args.args[0]
+        if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
+            return ExprNodes.SetNode(node.pos, args=iterable.args,
+                                     type=Builtin.set_type, is_temp=1)
+        elif isinstance(iterable, ExprNodes.ListComprehensionNode):
+            iterable.__class__ = ExprNodes.SetComprehensionNode
+            iterable.append.__class__ = ExprNodes.SetComprehensionAppendNode
+            iterable.pos = node.pos
+            return iterable
+        else:
+            return node
 
     def visit_PyTypeTestNode(self, node):
         """Flatten redundant type checks after tree changes.
index 8b78330f9a45423e1beeeeec16f0fd47265cd643..24e853ee549337af0c0059c6fb0bcced2090838a 100644 (file)
@@ -9,7 +9,12 @@ True
 >>> sorted(test_set_add())
 ['a', 1]
 
->>> type(test_set_add()) is _set
+>>> type(test_set_list_comp()) is _set
+True
+>>> sorted(test_set_list_comp())
+[0, 1, 2]
+
+>>> type(test_set_clear()) is _set
 True
 >>> list(test_set_clear())
 []
@@ -46,6 +51,11 @@ def test_set_clear():
     s1.clear()
     return s1
 
+def test_set_list_comp():
+    cdef set s1
+    s1 = set([i%3 for i in range(5)])
+    return s1
+
 def test_set_pop():
     cdef set s1
     s1 = set()