fix type inference for sliced builtins
authorStefan Behnel <scoder@users.berlios.de>
Wed, 8 Sep 2010 09:31:54 +0000 (11:31 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 8 Sep 2010 09:31:54 +0000 (11:31 +0200)
Cython/Compiler/ExprNodes.py
tests/run/type_inference.pyx

index d12f48a80ba2f0eb8cc5b4170c97a0aa8ccb8f4b..ab2671262af833ac6baf92daa7c5550e57a4003c 100755 (executable)
@@ -1942,29 +1942,39 @@ class IndexNode(ExprNode):
         return self.base.type_dependencies(env)
     
     def infer_type(self, env):
-        is_slice = isinstance(self.index, SliceNode)
-        if isinstance(self.base, BytesNode):
-            if is_slice:
+        base_type = self.base.infer_type(env)
+        if isinstance(self.index, SliceNode):
+            # slicing!
+            if base_type.is_string:
                 return bytes_type
+            elif base_type in (unicode_type, bytes_type, str_type, list_type, tuple_type):
+                # slicing these returns the same type
+                return base_type
             else:
-                return py_object_type # Py2/3 return different types
-        base_type = self.base.infer_type(env)
-        if base_type.is_ptr or base_type.is_array:
-            return base_type.base_type
-        elif base_type is unicode_type and self.index.infer_type(env).is_int:
-            # Py_UNICODE will automatically coerce to a unicode string
-            # if required, so this is safe. We only infer Py_UNICODE
-            # when the index is a C integer type. Otherwise, we may
-            # need to use normal Python item access, in which case
-            # it's faster to return the one-char unicode string than
-            # to receive it, throw it away, and potentially rebuild it
-            # on a subsequent PyObject coercion.
-            return PyrexTypes.c_py_unicode_type
-        elif base_type in (str_type, unicode_type):
-            # these types will always return their own type on Python indexing/slicing
-            return base_type
-        elif is_slice and base_type in (bytes_type, list_type, tuple_type):
-            # slicing these returns the same type
+                # TODO: Handle buffers (hopefully without too much redundancy).
+                return py_object_type
+
+        if isinstance(self.base, BytesNode):
+            # Py2/3 return different types on indexing bytes objects
+            # and we can't be sure if we are slicing, so we can't do
+            # any better than this:
+            return py_object_type
+
+        if self.index.infer_type(env).is_int or isinstance(self.index, (IntNode, LongNode)):
+            # indexing!
+            if base_type is unicode_type:
+                # Py_UNICODE will automatically coerce to a unicode string
+                # if required, so this is safe. We only infer Py_UNICODE
+                # when the index is a C integer type. Otherwise, we may
+                # need to use normal Python item access, in which case
+                # it's faster to return the one-char unicode string than
+                # to receive it, throw it away, and potentially rebuild it
+                # on a subsequent PyObject coercion.
+                return PyrexTypes.c_py_unicode_type
+            elif base_type.is_ptr or base_type.is_array:
+                return base_type.base_type
+        if base_type is unicode_type:
+            # this type always returns its own type on Python indexing/slicing
             return base_type
         else:
             # TODO: Handle buffers (hopefully without too much redundancy).
@@ -1993,11 +2003,12 @@ class IndexNode(ExprNode):
             self.type = PyrexTypes.error_type
             return
         
-        if isinstance(self.index, IntNode) and Utils.long_literal(self.index.value):
+        is_slice = isinstance(self.index, SliceNode)
+        if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value):
             self.index = self.index.coerce_to_pyobject(env)
-        
+
         # Handle the case where base is a literal char* (and we expect a string, not an int)
-        if isinstance(self.base, BytesNode):
+        if isinstance(self.base, BytesNode) or is_slice:
             self.base = self.base.coerce_to_pyobject(env)
 
         skip_child_analysis = False
@@ -2069,6 +2080,8 @@ class IndexNode(ExprNode):
                     # Py_UNICODE will automatically coerce to a unicode string
                     # if required, so this is fast and safe
                     self.type = PyrexTypes.c_py_unicode_type
+                elif is_slice and base_type in (bytes_type, str_type, unicode_type, list_type, tuple_type):
+                    self.type = base_type
                 else:
                     self.type = py_object_type
             else:
index 495498711ae056ff69fa4b427546d01d756e281a..b78f70e2bfef9357b550c7761aa240958f40c9f1 100644 (file)
@@ -60,10 +60,20 @@ def slicing():
     assert typeof(b) == "char *", typeof(b)
     b1 = b[1:2]
     assert typeof(b1) == "bytes object", typeof(b1)
+    b2 = b[1:2:2]
+    assert typeof(b2) == "bytes object", typeof(b2)
     u = u"xyz"
     assert typeof(u) == "unicode object", typeof(u)
     u1 = u[1:2]
     assert typeof(u1) == "unicode object", typeof(u1)
+    u2 = u[1:2:2]
+    assert typeof(u2) == "unicode object", typeof(u2)
+    s = "xyz"
+    assert typeof(s) == "str object", typeof(s)
+    s1 = s[1:2]
+    assert typeof(s1) == "str object", typeof(s1)
+    s2 = s[1:2:2]
+    assert typeof(s2) == "str object", typeof(s2)
     L = [1,2,3]
     assert typeof(L) == "list object", typeof(L)
     L1 = L[1:2]
@@ -84,11 +94,15 @@ def indexing():
     b = b"abc"
     assert typeof(b) == "char *", typeof(b)
     b1 = b[1]
-    assert typeof(b1) == "char", typeof(b1)  # FIXME: bytes object ??
+    assert typeof(b1) == "char", typeof(b1)  # FIXME: Python object ??
     u = u"xyz"
     assert typeof(u) == "unicode object", typeof(u)
     u1 = u[1]
     assert typeof(u1) == "Py_UNICODE", typeof(u1)
+    s = "xyz"
+    assert typeof(s) == "str object", typeof(s)
+    s1 = s[1]
+    assert typeof(s1) == "Python object", typeof(s1)
     L = [1,2,3]
     assert typeof(L) == "list object", typeof(L)
     L1 = L[1]
@@ -267,7 +281,7 @@ def loop_over_bytes():
 def loop_over_str():
     """
     >>> print( loop_over_str() )
-    str object
+    Python object
     """
     cdef str string = 'abcdefg'
     for c in string: