import ExprNodes
import PyrexTypes
import Visitor
+import Builtin
+import UtilNodes
+import TypeSlots
+import Symtab
+from StringEncoding import EncodedString
def unwrap_node(node):
while isinstance(node, ExprNodes.PersistentNode):
return False
+class DictIterTransform(Visitor.VisitorTransform):
+ """Transform a for-in-dict loop into a while loop calling PyDict_Next().
+ """
+ PyDict_Next_func_type = PyrexTypes.CFuncType(
+ PyrexTypes.c_bint_type, [
+ PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
+ PyrexTypes.CFuncTypeArg("pos", PyrexTypes.c_py_ssize_t_ptr_type, None),
+ PyrexTypes.CFuncTypeArg("key", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None),
+ PyrexTypes.CFuncTypeArg("value", PyrexTypes.CPtrType(PyrexTypes.py_object_type), None)
+ ])
+
+ PyDict_Next_name = EncodedString("PyDict_Next")
+
+ PyDict_Next_entry = Symtab.Entry(
+ PyDict_Next_name, PyDict_Next_name, PyDict_Next_func_type)
+
+ def visit_ForInStatNode(self, node):
+ self.visitchildren(node)
+ iterator = node.iterator.sequence
+ if not isinstance(iterator, ExprNodes.SimpleCallNode):
+ return node
+ function = iterator.function
+ if not isinstance(function, ExprNodes.AttributeNode):
+ return node
+ if function.obj.type != Builtin.dict_type:
+ return node
+ dict_obj = function.obj
+ method = function.attribute
+ env = self.env_stack[-1]
+
+ keys = values = False
+ if method == 'iterkeys':
+ keys = True
+ elif method == 'itervalues':
+ values = True
+ elif method == 'iteritems':
+ keys = values = True
+ else:
+ return node
+
+ py_object_ptr = PyrexTypes.c_void_ptr_type
+
+ temps = []
+ pos_temp = node.iterator.counter
+ pos_temp_addr = ExprNodes.AmpersandNode(
+ node.pos, operand=pos_temp,
+ type=PyrexTypes.c_ptr_type(PyrexTypes.c_py_ssize_t_type))
+ if keys:
+ temp = UtilNodes.TempHandle(py_object_ptr)
+ temps.append(temp)
+ key_temp = temp.ref(node.target.pos)
+ key_temp_addr = ExprNodes.AmpersandNode(
+ node.target.pos, operand=key_temp,
+ type=PyrexTypes.c_ptr_type(py_object_ptr))
+ else:
+ key_temp_addr = key_temp = ExprNodes.NullNode(
+ pos=node.target.pos)
+ if values:
+ temp = UtilNodes.TempHandle(py_object_ptr)
+ temps.append(temp)
+ value_temp = temp.ref(node.target.pos)
+ value_temp_addr = ExprNodes.AmpersandNode(
+ node.target.pos, operand=value_temp,
+ type=PyrexTypes.c_ptr_type(py_object_ptr))
+ else:
+ value_temp_addr = value_temp = ExprNodes.NullNode(
+ pos=node.target.pos)
+
+ key_target = value_target = node.target
+ tuple_target = None
+ if keys and values:
+ if node.target.is_sequence_constructor:
+ if len(node.target.args) == 2:
+ key_target, value_target = node.target.args
+ else:
+ # FIXME ...
+ return node
+ else:
+ tuple_target = node.target
+
+ if keys:
+ key_cast = ExprNodes.TypecastNode(
+ pos = key_target.pos,
+ operand = key_temp,
+ type = key_target.type)
+ if values:
+ value_cast = ExprNodes.TypecastNode(
+ pos = value_target.pos,
+ operand = value_temp,
+ type = value_target.type)
+
+ if isinstance(node.body, Nodes.StatListNode):
+ body = node.body
+ else:
+ body = Nodes.StatListNode(pos = node.body.pos,
+ stats = [node.body])
+
+ if tuple_target:
+ tuple_result = ExprNodes.TupleNode(
+ pos = tuple_target.pos,
+ args = [key_cast, value_cast]
+ )
+ tuple_result.analyse_types(env)
+ tuple_result.allocate_temps(env)
+ body.stats.insert(0, Nodes.SingleAssignmentNode(
+ pos = tuple_target.pos,
+ lhs = tuple_target,
+ rhs = tuple_result))
+ else:
+ if values:
+ body.stats.insert(
+ 0, Nodes.SingleAssignmentNode(
+ pos = value_target.pos,
+ lhs = value_target,
+ rhs = value_cast))
+ if keys:
+ body.stats.insert(
+ 0, Nodes.SingleAssignmentNode(
+ pos = key_target.pos,
+ lhs = key_target,
+ rhs = key_cast))
+
+ result_code = [
+ Nodes.SingleAssignmentNode(
+ pos = node.pos,
+ lhs = pos_temp,
+ rhs = ExprNodes.IntNode(node.pos, value=0)),
+ Nodes.WhileStatNode(
+ pos = node.pos,
+ condition = ExprNodes.SimpleCallNode(
+ pos = dict_obj.pos,
+ type = PyrexTypes.c_bint_type,
+ function = ExprNodes.NameNode(
+ pos=dict_obj.pos, name=self.PyDict_Next_name,
+ type = self.PyDict_Next_func_type,
+ entry = self.PyDict_Next_entry),
+ args = [dict_obj, pos_temp_addr,
+ key_temp_addr, value_temp_addr]
+ ),
+ body = body,
+ else_clause = node.else_clause
+ )
+ ]
+
+ return UtilNodes.TempsBlockNode(
+ node.pos, temps=temps,
+ body=Nodes.StatListNode(
+ pos = node.pos,
+ stats = result_code
+ ))
+
+ def visit_ModuleNode(self, node):
+ self.env_stack = [node.scope]
+ self.visitchildren(node)
+ return node
+
+ def visit_FuncDefNode(self, node):
+ self.env_stack.append(node.local_scope)
+ self.visitchildren(node)
+ self.env_stack.pop()
+ return node
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+ return node
+
+
class SwitchTransform(Visitor.VisitorTransform):
"""
This transformation tries to turn long if statements into C switch statements.