implement "for int_var in bytes_string" and "for int_var in unicode_string"
authorStefan Behnel <scoder@users.berlios.de>
Sun, 18 Apr 2010 21:06:11 +0000 (23:06 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 18 Apr 2010 21:06:11 +0000 (23:06 +0200)
Cython/Compiler/Optimize.py
tests/run/for_in_string.pyx [new file with mode: 0644]

index 69a8ebf9ad11556f187b95874491902895e4fb81..72f5f701cceab7736bf89a115e496bcdbc23694c 100644 (file)
@@ -95,7 +95,11 @@ class IterationTransform(Visitor.VisitorTransform):
             return self._transform_carray_iteration(node, iterator)
         elif iterator.type.is_array:
             return self._transform_carray_iteration(node, iterator)
-        elif not isinstance(iterator, ExprNodes.SimpleCallNode):
+        elif iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
+            return self._transform_string_iteration(node, iterator)
+
+        # the rest is based on function calls
+        if not isinstance(iterator, ExprNodes.SimpleCallNode):
             return node
 
         function = iterator.function
@@ -132,6 +136,71 @@ class IterationTransform(Visitor.VisitorTransform):
 
         return node
 
+    PyUnicode_AS_UNICODE_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_int_ptr_type, [ # FIXME: return type is actually Py_UNICODE*
+            PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
+            ])
+
+    PyUnicode_GET_SIZE_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_py_ssize_t_type, [
+            PyrexTypes.CFuncTypeArg("s", Builtin.unicode_type, None)
+            ])
+
+    PyBytes_AS_STRING_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_char_ptr_type, [
+            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
+            ])
+
+    PyBytes_GET_SIZE_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_py_ssize_t_type, [
+            PyrexTypes.CFuncTypeArg("s", Builtin.bytes_type, None)
+            ])
+
+    def _transform_string_iteration(self, node, slice_node):
+        if not node.target.type.is_int:
+            return node
+        if slice_node.type is Builtin.unicode_type:
+            unpack_func = "PyUnicode_AS_UNICODE"
+            len_func = "PyUnicode_GET_SIZE"
+            unpack_func_type = self.PyUnicode_AS_UNICODE_func_type
+            len_func_type = self.PyUnicode_GET_SIZE_func_type
+        elif slice_node.type is Builtin.bytes_type:
+            unpack_func = "PyBytes_AS_STRING"
+            unpack_func_type = self.PyBytes_AS_STRING_func_type
+            len_func = "PyBytes_GET_SIZE"
+            len_func_type = self.PyBytes_GET_SIZE_func_type
+        else:
+            return node
+
+        unpack_temp_node = UtilNodes.LetRefNode(
+            ExprNodes.NoneCheckNode(
+                slice_node, "PyExc_TypeError", "'NoneType' is not iterable"))
+
+        slice_base_node = ExprNodes.PythonCapiCallNode(
+            slice_node.pos, unpack_func, unpack_func_type,
+            args = [unpack_temp_node],
+            is_temp = 0,
+            )
+        len_node = ExprNodes.PythonCapiCallNode(
+            slice_node.pos, len_func, len_func_type,
+            args = [unpack_temp_node],
+            is_temp = 0,
+            )
+
+        return UtilNodes.LetNode(
+            unpack_temp_node,
+            self._transform_carray_iteration(
+                node,
+                ExprNodes.SliceIndexNode(
+                    slice_node.pos,
+                    base = slice_base_node,
+                    start = None,
+                    step = None,
+                    stop = len_node,
+                    type = slice_base_node.type,
+                    is_temp = 1,
+                    )))
+
     def _transform_carray_iteration(self, node, slice_node):
         if isinstance(slice_node, ExprNodes.SliceIndexNode):
             slice_base = slice_node.base
@@ -166,7 +235,7 @@ class IterationTransform(Visitor.VisitorTransform):
 
         stop_ptr_node = ExprNodes.AddNode(
             stop.pos,
-            operand1=carray_ptr,
+            operand1=ExprNodes.CloneNode(carray_ptr),
             operator='+',
             operand2=stop,
             type=ptr_type
diff --git a/tests/run/for_in_string.pyx b/tests/run/for_in_string.pyx
new file mode 100644 (file)
index 0000000..9e920df
--- /dev/null
@@ -0,0 +1,48 @@
+
+bytes_abc = b'abc'
+bytes_ABC = b'ABC'
+
+unicode_abc = u'abc'
+unicode_ABC = u'ABC'
+
+
+def for_in_bytes(bytes s):
+    """
+    >>> for_in_bytes(bytes_abc)
+    'X'
+    >>> for_in_bytes(bytes_ABC)
+    'C'
+    """
+    for c in s:
+        if c == 'C':
+            return 'C'
+    else:
+        return 'X'
+
+def for_char_in_bytes(bytes s):
+    """
+    >>> for_char_in_bytes(bytes_abc)
+    'X'
+    >>> for_char_in_bytes(bytes_ABC)
+    'C'
+    """
+    cdef char c
+    for c in s:
+        if c == 'C':
+            return 'C'
+    else:
+        return 'X'
+
+def for_int_in_unicode(unicode s):
+    """
+    >>> for_int_in_unicode(unicode_abc)
+    'X'
+    >>> for_int_in_unicode(unicode_ABC)
+    'C'
+    """
+    cdef int c
+    for c in s:
+        if c == 'C':
+            return 'C'
+    else:
+        return 'X'