From 1e0ab7ca198716ab4fb6c4327ae54ed0046ac5a7 Mon Sep 17 00:00:00 2001
From: Stefan Behnel <scoder@users.berlios.de>
Date: Tue, 25 Nov 2008 18:24:52 +0100
Subject: [PATCH] handle value coercion correctly in dict iteration

---
 Cython/Compiler/ExprNodes.py |  4 +--
 Cython/Compiler/Optimize.py  | 67 +++++++++++++++++++++++-------------
 tests/run/iterdict.pyx       | 40 +++++++++++++++++++++
 3 files changed, 86 insertions(+), 25 deletions(-)

diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py
index 9612dd86..9b4eb84c 100644
--- a/Cython/Compiler/ExprNodes.py
+++ b/Cython/Compiler/ExprNodes.py
@@ -4504,8 +4504,6 @@ class PyTypeTestNode(CoercionNode):
         self.type = dst_type
         self.gil_check(env)
         self.result_ctype = arg.ctype()
-        if not dst_type.is_builtin_type:
-            env.use_utility_code(type_test_utility_code)
 
     gil_message = "Python type test"
     
@@ -4523,6 +4521,8 @@ class PyTypeTestNode(CoercionNode):
     
     def generate_result_code(self, code):
         if self.type.typeobj_is_available():
+            if not dst_type.is_builtin_type:
+                code.globalstate.use_utility_code(type_test_utility_code)
             code.putln(
                 "if (!(%s)) %s" % (
                     self.type.type_test_code(self.arg.py_result()),
diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py
index 3f4c5cf2..fcf07ff4 100644
--- a/Cython/Compiler/Optimize.py
+++ b/Cython/Compiler/Optimize.py
@@ -111,16 +111,25 @@ class DictIterTransform(Visitor.VisitorTransform):
             else:
                 tuple_target = node.target
 
-        if keys:
-            key_cast = ExprNodes.TypecastNode(
-                pos = key_target.pos,
-                operand = key_temp,
-                type = key_target.type)
-        if values:
-            value_cast = ExprNodes.TypecastNode(
-                pos = value_target.pos,
-                operand = value_temp,
-                type = value_target.type)
+        def coerce_object_to(obj_node, dest_type):
+            class FakeEnv(object):
+                nogil = False
+            if dest_type.is_pyobject:
+                if dest_type.is_extension_type or dest_type.is_builtin_type:
+                    return (obj_node, ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv()))
+                else:
+                    return (obj_node, None)
+            else:
+                temp = UtilNodes.TempHandle(dest_type)
+                temps.append(temp)
+                temp_result = temp.ref(obj_node.pos)
+                class CoercedTempNode(ExprNodes.CoerceFromPyTypeNode):
+                    # FIXME: remove this after result-code refactoring
+                    def result(self):
+                        return temp_result.result()
+                    def generate_execution_code(self, code):
+                        self.generate_result_code(code)
+                return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
 
         if isinstance(node.body, Nodes.StatListNode):
             body = node.body
@@ -129,7 +138,7 @@ class DictIterTransform(Visitor.VisitorTransform):
                                       stats = [node.body])
 
         if tuple_target:
-            temp = UtilNodes.TempHandle(py_object_ptr)
+            temp = UtilNodes.TempHandle(PyrexTypes.py_object_type)
             temps.append(temp)
             temp_tuple = temp.ref(tuple_target.pos)
             class TempTupleNode(ExprNodes.TupleNode):
@@ -139,7 +148,7 @@ class DictIterTransform(Visitor.VisitorTransform):
 
             tuple_result = TempTupleNode(
                 pos = tuple_target.pos,
-                args = [key_cast, value_cast],
+                args = [key_temp, value_temp],
                 is_temp = 1,
                 type = Builtin.tuple_type,
                 )
@@ -148,18 +157,30 @@ class DictIterTransform(Visitor.VisitorTransform):
                     lhs = tuple_target,
                     rhs = tuple_result))
         else:
-            if values:
-                body.stats.insert(
-                    0, Nodes.SingleAssignmentNode(
-                        pos = value_target.pos,
-                        lhs = value_target,
-                        rhs = value_cast))
+            # execute all coercions before the assignments
+            coercion_stats = []
+            assign_stats = []
             if keys:
-                body.stats.insert(
-                    0, Nodes.SingleAssignmentNode(
-                        pos = key_target.pos,
-                        lhs = key_target,
-                        rhs = key_cast))
+                temp_result, coercion = coerce_object_to(
+                    key_temp, key_target.type)
+                if coercion:
+                    coercion_stats.append(coercion)
+                assign_stats.append(
+                    Nodes.SingleAssignmentNode(
+                        pos = key_temp.pos,
+                        rhs = temp_result,
+                        lhs = key_target))
+            if values:
+                temp_result, coercion = coerce_object_to(
+                    value_temp, value_target.type)
+                if coercion:
+                    coercion_stats.append(coercion)
+                assign_stats.append(
+                    Nodes.SingleAssignmentNode(
+                        pos = value_temp.pos,
+                        rhs = temp_result,
+                        lhs = value_target))
+            body.stats[0:0] = coercion_stats + assign_stats
 
         result_code = [
             Nodes.SingleAssignmentNode(
diff --git a/tests/run/iterdict.pyx b/tests/run/iterdict.pyx
index 5c253bc6..14d3ed7d 100644
--- a/tests/run/iterdict.pyx
+++ b/tests/run/iterdict.pyx
@@ -6,14 +6,22 @@ __doc__ = u"""
 [(10, 0), (11, 1), (12, 2), (13, 3)]
 >>> iteritems(d)
 [(10, 0), (11, 1), (12, 2), (13, 3)]
+>>> iteritems_int(d)
+[(10, 0), (11, 1), (12, 2), (13, 3)]
 >>> iteritems_tuple(d)
 [(10, 0), (11, 1), (12, 2), (13, 3)]
 >>> iterkeys(d)
 [10, 11, 12, 13]
+>>> iterkeys_int(d)
+[10, 11, 12, 13]
 >>> iterdict(d)
 [10, 11, 12, 13]
+>>> iterdict_int(d)
+[10, 11, 12, 13]
 >>> itervalues(d)
 [0, 1, 2, 3]
+>>> itervalues_int(d)
+[0, 1, 2, 3]
 """
 
 def items(dict d):
@@ -30,6 +38,14 @@ def iteritems(dict d):
     l.sort()
     return l
 
+def iteritems_int(dict d):
+    cdef int k,v
+    l = []
+    for k,v in d.iteritems():
+        l.append((k,v))
+    l.sort()
+    return l
+
 def iteritems_tuple(dict d):
     l = []
     for t in d.iteritems():
@@ -44,6 +60,14 @@ def iterkeys(dict d):
     l.sort()
     return l
 
+def iterkeys_int(dict d):
+    cdef int k
+    l = []
+    for k in d.iterkeys():
+        l.append(k)
+    l.sort()
+    return l
+
 def iterdict(dict d):
     l = []
     for k in d:
@@ -51,9 +75,25 @@ def iterdict(dict d):
     l.sort()
     return l
 
+def iterdict_int(dict d):
+    cdef int k
+    l = []
+    for k in d:
+        l.append(k)
+    l.sort()
+    return l
+
 def itervalues(dict d):
     l = []
     for v in d.itervalues():
         l.append(v)
     l.sort()
     return l
+
+def itervalues_int(dict d):
+    cdef int v
+    l = []
+    for v in d.itervalues():
+        l.append(v)
+    l.sort()
+    return l
-- 
2.26.2