infer type of loop variable when for-in-looping over pointers, C arrays, unicode...
authorStefan Behnel <scoder@users.berlios.de>
Fri, 14 May 2010 20:03:16 +0000 (22:03 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 14 May 2010 20:03:16 +0000 (22:03 +0200)
Cython/Compiler/TypeInference.py
tests/run/type_inference.pyx

index d58d375a050800f9a0dbaf1401bed041381a5861..b3ddeaf969e038e51c0816021fb44969f2bdbdc9 100644 (file)
@@ -70,7 +70,15 @@ class MarkAssignments(CythonTransform):
                                                  sequence.args[0],
                                                  sequence.args[2]))
         if not is_special:
-            self.mark_assignment(node.target, object_expr)
+            # A for-loop basically translates to subsequent calls to
+            # __getitem__(), so using an IndexNode here allows us to
+            # naturally infer the base type of pointers, C arrays,
+            # Python strings, etc., while correctly falling back to an
+            # object type when the base type cannot be handled.
+            self.mark_assignment(node.target, ExprNodes.IndexNode(
+                node.pos,
+                base = sequence,
+                index = ExprNodes.IntNode(node.pos, value = '0')))
         self.visitchildren(node)
         return node
 
index 7386a0bdb755cdea53b2da8e80b299b495405c25..0abc0526b2fb6672b073de2136c16754b8a7d1a0 100644 (file)
@@ -187,6 +187,46 @@ def loop():
         pass
     assert typeof(a) == "long"
 
+def loop_over_charptr():
+    """
+    >>> print( loop_over_charptr() )
+    char
+    """
+    cdef char* char_ptr_string = 'abcdefg'
+    for c in char_ptr_string:
+        pass
+    return typeof(c)
+
+def loop_over_bytes():
+    """
+    >>> print( loop_over_bytes() )
+    Python object
+    """
+    cdef bytes bytes_string = b'abcdefg'
+    for c in bytes_string:
+        pass
+    return typeof(c)
+
+def loop_over_unicode():
+    """
+    >>> print( loop_over_unicode() )
+    Py_UNICODE
+    """
+    cdef unicode ustring = u'abcdefg'
+    for uchar in ustring:
+        pass
+    return typeof(uchar)
+
+def loop_over_int_array():
+    """
+    >>> print( loop_over_int_array() )
+    int
+    """
+    cdef int[10] int_array
+    for i in int_array:
+        pass
+    return typeof(i)
+
 cdef unicode retu():
     return u"12345"