new transform that converts for-in-dict.iter*() into a while-loop over PyDict_Next...
authorStefan Behnel <scoder@users.berlios.de>
Sun, 16 Nov 2008 21:45:12 +0000 (22:45 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 16 Nov 2008 21:45:12 +0000 (22:45 +0100)
Cython/Compiler/Optimize.py

index e02f5b78753546d203636fc45821d3da3491e0be..b5f2efa244a2639abafd191f8844eeee1cf70f22 100644 (file)
@@ -2,6 +2,11 @@ import Nodes
 import ExprNodes
 import PyrexTypes
 import Visitor
+import Builtin
+import UtilNodes
+import TypeSlots
+import Symtab
+from StringEncoding import EncodedString
 
 def unwrap_node(node):
     while isinstance(node, ExprNodes.PersistentNode):
@@ -18,6 +23,173 @@ def is_common_value(a, b):
     return False
 
 
+class DictIterTransform(Visitor.VisitorTransform):
+    """Transform a for-in-dict loop into a while loop calling PyDict_Next().
+    """
+    PyDict_Next_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_bint_type, [
+            PyrexTypes.CFuncTypeArg("dict",  PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("pos",   PyrexTypes.c_py_ssize_t_ptr_type, None),
+            PyrexTypes.CFuncTypeArg("key",   PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
+            PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
+            ])
+
+    PyDict_Next_name = EncodedString("PyDict_Next")
+
+    PyDict_Next_entry = Symtab.Entry(
+        PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
+
+    def visit_ForInStatNode(self, node):
+        self.visitchildren(node)
+        iterator = node.iterator.sequence
+        if not isinstance(iterator, ExprNodes.SimpleCallNode):
+            return node
+        function = iterator.function
+        if not isinstance(function, ExprNodes.AttributeNode):
+            return node
+        if function.obj.type != Builtin.dict_type:
+            return node
+        dict_obj = function.obj
+        method = function.attribute
+        env = self.env_stack[-1]
+
+        keys = values = False
+        if method == 'iterkeys':
+            keys = True
+        elif method == 'itervalues':
+            values = True
+        elif method == 'iteritems':
+            keys = values = True
+        else:
+            return node
+
+        py_object_ptr = PyrexTypes.c_void_ptr_type
+
+        temps = []
+        pos_temp = node.iterator.counter
+        pos_temp_addr = ExprNodes.AmpersandNode(
+            node.pos, operand=pos_temp,
+            type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
+        if keys:
+            temp = UtilNodes.TempHandle(py_object_ptr)
+            temps.append(temp)
+            key_temp = temp.ref(node.target.pos)
+            key_temp_addr = ExprNodes.AmpersandNode(
+                node.target.pos, operand=key_temp,
+                type=PyrexTypes.c_ptr_type(py_object_ptr))
+        else:
+            key_temp_addr = key_temp = ExprNodes.NullNode(
+                pos=node.target.pos)
+        if values:
+            temp = UtilNodes.TempHandle(py_object_ptr)
+            temps.append(temp)
+            value_temp = temp.ref(node.target.pos)
+            value_temp_addr = ExprNodes.AmpersandNode(
+                node.target.pos, operand=value_temp,
+                type=PyrexTypes.c_ptr_type(py_object_ptr))
+        else:
+            value_temp_addr = value_temp = ExprNodes.NullNode(
+                pos=node.target.pos)
+
+        key_target = value_target = node.target
+        tuple_target = None
+        if keys and values:
+            if node.target.is_sequence_constructor:
+                if len(node.target.args) == 2:
+                    key_target, value_target = node.target.args
+                else:
+                    # FIXME ...
+                    return node
+            else:
+                tuple_target = node.target
+
+        if keys:
+            key_cast = ExprNodes.TypecastNode(
+                pos = key_target.pos,
+                operand = key_temp,
+                type = key_target.type)
+        if values:
+            value_cast = ExprNodes.TypecastNode(
+                pos = value_target.pos,
+                operand = value_temp,
+                type = value_target.type)
+
+        if isinstance(node.body, Nodes.StatListNode):
+            body = node.body
+        else:
+            body = Nodes.StatListNode(pos = node.body.pos,
+                                      stats = [node.body])
+
+        if tuple_target:
+            tuple_result = ExprNodes.TupleNode(
+                pos = tuple_target.pos,
+                args = [key_cast, value_cast]
+                )
+            tuple_result.analyse_types(env)
+            tuple_result.allocate_temps(env)
+            body.stats.insert(0, Nodes.SingleAssignmentNode(
+                    pos = tuple_target.pos,
+                    lhs = tuple_target,
+                    rhs = tuple_result))
+        else:
+            if values:
+                body.stats.insert(
+                    0, Nodes.SingleAssignmentNode(
+                        pos = value_target.pos,
+                        lhs = value_target,
+                        rhs = value_cast))
+            if keys:
+                body.stats.insert(
+                    0, Nodes.SingleAssignmentNode(
+                        pos = key_target.pos,
+                        lhs = key_target,
+                        rhs = key_cast))
+
+        result_code = [
+            Nodes.SingleAssignmentNode(
+                pos = node.pos,
+                lhs = pos_temp,
+                rhs = ExprNodes.IntNode(node.pos, value=0)),
+            Nodes.WhileStatNode(
+                pos = node.pos,
+                condition = ExprNodes.SimpleCallNode(
+                    pos = dict_obj.pos,
+                    type = PyrexTypes.c_bint_type,
+                    function = ExprNodes.NameNode(
+                        pos=dict_obj.pos, name=self.PyDict_Next_name,
+                        type = self.PyDict_Next_func_type,
+                        entry = self.PyDict_Next_entry),
+                    args = [dict_obj, pos_temp_addr,
+                            key_temp_addr, value_temp_addr]
+                    ),
+                body = body,
+                else_clause = node.else_clause
+                )
+            ]
+
+        return UtilNodes.TempsBlockNode(
+            node.pos, temps=temps,
+            body=Nodes.StatListNode(
+                pos = node.pos,
+                stats = result_code
+                ))
+
+    def visit_ModuleNode(self, node):
+        self.env_stack = [node.scope]
+        self.visitchildren(node)
+        return node
+
+    def visit_FuncDefNode(self, node):
+        self.env_stack.append(node.local_scope)
+        self.visitchildren(node)
+        self.env_stack.pop()
+        return node
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
+
 class SwitchTransform(Visitor.VisitorTransform):
     """
     This transformation tries to turn long if statements into C switch statements.