split builtin type call optimisations into pre and post type analysis phase
authorStefan Behnel <scoder@users.berlios.de>
Mon, 7 Dec 2009 23:59:22 +0000 (00:59 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Mon, 7 Dec 2009 23:59:22 +0000 (00:59 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py

index ae90a98e5ac8801d78ffce5aef381b71324fa4eb..ce9f383c51019d0c408379a10d9414cf79651e2f 100644 (file)
@@ -2641,28 +2641,32 @@ class SimpleCallNode(CallNode):
 
 class PythonCapiFunctionNode(ExprNode):
     subexprs = []
-    def __init__(self, pos, name, func_type, utility_code = None):
+    def __init__(self, pos, py_name, cname, func_type, utility_code = None):
         self.pos = pos
-        self.name = name
+        self.name = py_name
+        self.cname = cname
         self.type = func_type
         self.utility_code = utility_code
 
+    def analyse_types(self, env):
+        pass
+
     def generate_result_code(self, code):
         if self.utility_code:
             code.globalstate.use_utility_code(self.utility_code)
 
     def calculate_result_code(self):
-        return self.name
+        return self.cname
 
 class PythonCapiCallNode(SimpleCallNode):
     # Python C-API Function call (only created in transforms)
 
     def __init__(self, pos, function_name, func_type,
-                 utility_code = None, **kwargs):
+                 utility_code = None, py_name=None, **kwargs):
         self.type = func_type.return_type
         self.result_ctype = self.type
         self.function = PythonCapiFunctionNode(
-            pos, function_name, func_type,
+            pos, py_name, function_name, func_type,
             utility_code = utility_code)
         # call this last so that we can override the constructed
         # attributes above with explicit keyword arguments if required
index 1bc90028c078a5a69f789d5a446906628545acf2..e869e9816ecf90489ca5da23a82aa96a8591512a 100644 (file)
@@ -92,7 +92,8 @@ class Context(object):
         from AnalysedTreeTransforms import AutoTestDictTransform
         from AutoDocTransforms import EmbedSignature
         from Optimize import FlattenInListTransform, SwitchTransform, IterationTransform
-        from Optimize import OptimizeBuiltinCalls, ConstantFolding, FinalOptimizePhase
+        from Optimize import EarlyReplaceBuiltinCalls, OptimizeBuiltinCalls
+        from Optimize import ConstantFolding, FinalOptimizePhase
         from Optimize import DropRefcountingTransform
         from Buffer import IntroduceBufferAuxiliaryVars
         from ModuleNode import check_c_declarations, check_c_declarations_pxd
@@ -131,6 +132,7 @@ class Context(object):
             AnalyseDeclarationsTransform(self),
             AutoTestDictTransform(self),
             EmbedSignature(self),
+            EarlyReplaceBuiltinCalls(self),
             MarkAssignments(self),
             TransformBuiltinMethods(self),
             IntroduceBufferAuxiliaryVars(self),
index d602c102c7ec5787e889af402060e5d65514ab0d..9af0561ab7066c5ae9a7210caa4f5e2b4da8e6fd 100644 (file)
@@ -723,9 +723,190 @@ class DropRefcountingTransform(Visitor.VisitorTransform):
         return (base.name, index_val)
 
 
+class EarlyReplaceBuiltinCalls(Visitor.EnvTransform):
+    """Optimize some common calls to builtin types *before* the type
+    analysis phase and *after* the declarations analysis phase.
+
+    This transform cannot make use of any argument types, but it can
+    restructure the tree in a way that the type analysis phase can
+    respond to.
+    """
+    # only intercept on call nodes
+    visit_Node = Visitor.VisitorTransform.recurse_to_children
+
+    def visit_SimpleCallNode(self, node):
+        self.visitchildren(node)
+        function = node.function
+        if not self._function_is_builtin_name(function):
+            return node
+        return self._dispatch_to_handler(node, function, node.args)
+
+    def visit_GeneralCallNode(self, node):
+        self.visitchildren(node)
+        function = node.function
+        if not self._function_is_builtin_name(function):
+            return node
+        arg_tuple = node.positional_args
+        if not isinstance(arg_tuple, ExprNodes.TupleNode):
+            return node
+        args = arg_tuple.args
+        return self._dispatch_to_handler(
+            node, function, args, node.keyword_args)
+
+    def _function_is_builtin_name(self, function):
+        if not function.is_name:
+            return False
+        entry = self.env_stack[-1].lookup(function.name)
+        if not entry or getattr(entry, 'scope', None) is not Builtin.builtin_scope:
+            return False
+        return True
+
+    def _dispatch_to_handler(self, node, function, args, kwargs=None):
+        if kwargs is None:
+            handler_name = '_handle_simple_function_%s' % function.name
+        else:
+            handler_name = '_handle_general_function_%s' % function.name
+        handle_call = getattr(self, handler_name, None)
+        if handle_call is not None:
+            if kwargs is None:
+                return handle_call(node, args)
+            else:
+                return handle_call(node, args, kwargs)
+        return node
+
+    def _inject_capi_function(self, node, cname, func_type, utility_code=None):
+        node.function = ExprNodes.PythonCapiFunctionNode(
+            node.function.pos, node.function.name, cname, func_type,
+            utility_code = utility_code)
+
+    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
+        if not expected: # None or 0
+            arg_str = ''
+        elif isinstance(expected, basestring) or expected > 1:
+            arg_str = '...'
+        elif expected == 1:
+            arg_str = 'x'
+        else:
+            arg_str = ''
+        if expected is not None:
+            expected_str = 'expected %s, ' % expected
+        else:
+            expected_str = ''
+        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
+            function_name, arg_str, expected_str, len(args)))
+
+    # specific handlers for simple call nodes
+
+    def _handle_simple_function_set(self, node, pos_args):
+        """Replace set([a,b,...]) by a literal set {a,b,...} and
+        set([ x for ... ]) by a literal { x for ... }.
+        """
+        arg_count = len(pos_args)
+        if arg_count == 0:
+            return ExprNodes.SetNode(node.pos, args=[],
+                                     type=Builtin.set_type)
+        if arg_count > 1:
+            return node
+        iterable = pos_args[0]
+        if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
+            return ExprNodes.SetNode(node.pos, args=iterable.args)
+        elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
+                 isinstance(iterable.target, (ExprNodes.ListNode,
+                                              ExprNodes.SetNode)):
+            iterable.target = ExprNodes.SetNode(node.pos, args=[])
+            iterable.pos = node.pos
+            return iterable
+        else:
+            return node
+
+    def _handle_simple_function_dict(self, node, pos_args):
+        """Replace dict([ (a,b) for ... ]) by a literal { a:b for ... }.
+        """
+        if len(pos_args) != 1:
+            return node
+        arg = pos_args[0]
+        if isinstance(arg, ExprNodes.ComprehensionNode) and \
+               isinstance(arg.target, (ExprNodes.ListNode,
+                                       ExprNodes.SetNode)):
+            append_node = arg.append
+            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
+                   len(append_node.expr.args) == 2:
+                key_node, value_node = append_node.expr.args
+                target_node = ExprNodes.DictNode(
+                    pos=arg.target.pos, key_value_pairs=[])
+                new_append_node = ExprNodes.DictComprehensionAppendNode(
+                    append_node.pos, target=target_node,
+                    key_expr=key_node, value_expr=value_node)
+                arg.target = target_node
+                arg.type = target_node.type
+                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
+                return replace_in(arg)
+        return node
+
+    PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.py_object_type, [
+            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
+            ])
+
+    PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.py_object_type, [
+            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
+            ])
+
+    def _handle_simple_function_getattr(self, node, pos_args):
+        if len(pos_args) == 2:
+            self._inject_capi_function(
+                node, "PyObject_GetAttr",
+                self.PyObject_GetAttr2_func_type)
+        elif len(pos_args) == 3:
+            self._inject_capi_function(
+                node, "__Pyx_GetAttr3",
+                self.PyObject_GetAttr3_func_type,
+                Builtin.getattr3_utility_code)
+        else:
+            self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
+        return node
+
+    Pyx_Type_func_type = PyrexTypes.CFuncType(
+        Builtin.type_type, [
+            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
+            ])
+
+    def _handle_simple_function_type(self, node, pos_args):
+        if len(pos_args) != 1:
+            return node
+        self._inject_capi_function(
+            node, "__Pyx_Type",
+            self.Pyx_Type_func_type,
+            pytype_utility_code)
+        return node
+
+    # specific handlers for general call nodes
+
+    def _handle_general_function_dict(self, node, pos_args, kwargs):
+        """Replace dict(a=b,c=d,...) by the underlying keyword dict
+        construction which is done anyway.
+        """
+        if len(pos_args) > 0:
+            return node
+        if not isinstance(kwargs, ExprNodes.DictNode):
+            return node
+        if node.starstar_arg:
+            # we could optimize this by updating the kw dict instead
+            return node
+        return kwargs
+
+
 class OptimizeBuiltinCalls(Visitor.EnvTransform):
     """Optimize some common methods calls and instantiation patterns
-    for builtin types.
+    for builtin types *after* the type analysis phase.
+
+    Running after type analysis, this transform can only perform
+    function replacements that do not alter the function return type
+    in a way that was not anticipated by the type analysis.
     """
     # only intercept on call nodes
     visit_Node = Visitor.VisitorTransform.recurse_to_children
@@ -896,27 +1077,13 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
 
     ### builtin types
 
-    def _handle_general_function_dict(self, node, pos_args, kwargs):
-        """Replace dict(a=b,c=d,...) by the underlying keyword dict
-        construction which is done anyway.
-        """
-        if len(pos_args) > 0:
-            return node
-        if not isinstance(kwargs, ExprNodes.DictNode):
-            return node
-        if node.starstar_arg:
-            # we could optimize this by updating the kw dict instead
-            return node
-        return kwargs
-
     PyDict_Copy_func_type = PyrexTypes.CFuncType(
         Builtin.dict_type, [
             PyrexTypes.CFuncTypeArg("dict", Builtin.dict_type, None)
             ])
 
     def _handle_simple_function_dict(self, node, pos_args):
-        """Replace dict(some_dict) by PyDict_Copy(some_dict) and
-        dict([ (a,b) for ... ]) by a literal { a:b for ... }.
+        """Replace dict(some_dict) by PyDict_Copy(some_dict).
         """
         if len(pos_args) != 1:
             return node
@@ -929,48 +1096,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
                 args = [arg],
                 is_temp = node.is_temp
                 )
-        elif isinstance(arg, ExprNodes.ComprehensionNode) and \
-                 arg.type is Builtin.list_type:
-            append_node = arg.append
-            if isinstance(append_node.expr, (ExprNodes.TupleNode, ExprNodes.ListNode)) and \
-                   len(append_node.expr.args) == 2:
-                key_node, value_node = append_node.expr.args
-                target_node = ExprNodes.DictNode(
-                    pos=arg.target.pos, key_value_pairs=[], is_temp=1)
-                new_append_node = ExprNodes.DictComprehensionAppendNode(
-                    append_node.pos, target=target_node,
-                    key_expr=key_node, value_expr=value_node,
-                    is_temp=1)
-                arg.target = target_node
-                arg.type = target_node.type
-                replace_in = Visitor.RecursiveNodeReplacer(append_node, new_append_node)
-                return replace_in(arg)
         return node
 
-    def _handle_simple_function_set(self, node, pos_args):
-        """Replace set([a,b,...]) by a literal set {a,b,...} and
-        set([ x for ... ]) by a literal { x for ... }.
-        """
-        arg_count = len(pos_args)
-        if arg_count == 0:
-            return ExprNodes.SetNode(node.pos, args=[],
-                                     type=Builtin.set_type, is_temp=1)
-        if arg_count > 1:
-            return node
-        iterable = pos_args[0]
-        if isinstance(iterable, (ExprNodes.ListNode, ExprNodes.TupleNode)):
-            return ExprNodes.SetNode(node.pos, args=iterable.args,
-                                     type=Builtin.set_type, is_temp=1)
-        elif isinstance(iterable, ExprNodes.ComprehensionNode) and \
-                iterable.type is Builtin.list_type:
-            iterable.target = ExprNodes.SetNode(
-                node.pos, args=[], type=Builtin.set_type, is_temp=1)
-            iterable.type = Builtin.set_type
-            iterable.pos = node.pos
-            return iterable
-        else:
-            return node
-
     PyList_AsTuple_func_type = PyrexTypes.CFuncType(
         Builtin.tuple_type, [
             PyrexTypes.CFuncTypeArg("list", Builtin.list_type, None)
@@ -998,53 +1125,6 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
 
     ### builtin functions
 
-    PyObject_GetAttr2_func_type = PyrexTypes.CFuncType(
-        PyrexTypes.py_object_type, [
-            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
-            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
-            ])
-
-    PyObject_GetAttr3_func_type = PyrexTypes.CFuncType(
-        PyrexTypes.py_object_type, [
-            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None),
-            PyrexTypes.CFuncTypeArg("attr_name", PyrexTypes.py_object_type, None),
-            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
-            ])
-
-    def _handle_simple_function_getattr(self, node, pos_args):
-        if len(pos_args) == 2:
-            node = ExprNodes.PythonCapiCallNode(
-                node.pos, "PyObject_GetAttr", self.PyObject_GetAttr2_func_type,
-                args = pos_args,
-                is_temp = node.is_temp
-                )
-        elif len(pos_args) == 3:
-            node = ExprNodes.PythonCapiCallNode(
-                node.pos, "__Pyx_GetAttr3", self.PyObject_GetAttr3_func_type,
-                utility_code = Builtin.getattr3_utility_code,
-                args = pos_args,
-                is_temp = node.is_temp
-                )
-        else:
-            self._error_wrong_arg_count('getattr', node, pos_args, '2 or 3')
-        return node
-
-    Pyx_Type_func_type = PyrexTypes.CFuncType(
-        Builtin.type_type, [
-            PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
-            ])
-
-    def _handle_simple_function_type(self, node, pos_args):
-        if len(pos_args) != 1:
-            return node
-        node = ExprNodes.PythonCapiCallNode(
-            node.pos, "__Pyx_Type", self.Pyx_Type_func_type,
-                args = pos_args,
-                is_temp = node.is_temp,
-                utility_code = pytype_utility_code,
-                )
-        return node
-
     Pyx_strlen_func_type = PyrexTypes.CFuncType(
         PyrexTypes.c_size_t_type, [
             PyrexTypes.CFuncTypeArg("bytes", PyrexTypes.c_char_ptr_type, None)
@@ -1065,10 +1145,10 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             return node
         node = ExprNodes.PythonCapiCallNode(
             node.pos, "strlen", self.Pyx_strlen_func_type,
-                args = [arg],
-                is_temp = node.is_temp,
-                utility_code = include_string_h_utility_code,
-                )
+            args = [arg],
+            is_temp = node.is_temp,
+            utility_code = include_string_h_utility_code
+            )
         return node
 
     ### special methods