Cleanup slice iteration code.
authorRobert Bradshaw <robertwb@math.washington.edu>
Sat, 11 Sep 2010 21:46:28 +0000 (14:46 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sat, 11 Sep 2010 21:46:28 +0000 (14:46 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/Optimize.py
tests/run/slice_ptr.pyx

index b16f86bb9b16904f71cf5d128357018d079612a9..fd30a8ddaeb2cd1c9ed9fc3104e43967d8cef92f 100755 (executable)
@@ -1667,11 +1667,10 @@ class IteratorNode(ExprNode):
     
     def analyse_types(self, env):
         self.sequence.analyse_types(env)
-        if isinstance(self.sequence, SliceIndexNode) and \
-               (self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \
-               or self.sequence.type.is_array and self.sequence.type.size is not None:
+        if (self.sequence.type.is_array or self.sequence.type.is_ptr) and \
+                not self.sequence.type.is_string:
             # C array iteration will be transformed later on
-            pass
+            self.type = self.sequence.type
         else:
             self.sequence = self.sequence.coerce_to_pyobject(env)
         self.is_temp = 1
@@ -1686,6 +1685,8 @@ class IteratorNode(ExprNode):
         code.funcstate.release_temp(self.counter_cname)
 
     def generate_result_code(self, code):
+        if self.sequence.type.is_array or self.sequence.type.is_ptr:
+            raise InternalError("for in carray slice not transformed")
         is_builtin_sequence = self.sequence.type is list_type or \
                               self.sequence.type is tuple_type
         may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type
@@ -1733,6 +1734,8 @@ class NextNode(AtomicExprNode):
     def __init__(self, iterator, env):
         self.pos = iterator.pos
         self.iterator = iterator
+        if iterator.type.is_ptr or iterator.type.is_array:
+            self.type = iterator.type.base_type
         self.is_temp = 1
     
     def generate_result_code(self, code):
@@ -2008,6 +2011,7 @@ class IndexNode(ExprNode):
             return
         
         is_slice = isinstance(self.index, SliceNode)
+        # Potentially overflowing index value.
         if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value):
             self.index = self.index.coerce_to_pyobject(env)
 
@@ -2092,7 +2096,9 @@ class IndexNode(ExprNode):
             else:
                 if base_type.is_ptr or base_type.is_array:
                     self.type = base_type.base_type
-                    if self.index.type.is_pyobject:
+                    if is_slice:
+                        self.type = base_type
+                    elif self.index.type.is_pyobject:
                         self.index = self.index.coerce_to(
                             PyrexTypes.c_py_ssize_t_type, env)
                     elif not self.index.type.is_int:
@@ -2147,6 +2153,8 @@ class IndexNode(ExprNode):
             return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result())
         elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type:
             return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result())
+        elif (self.type.is_ptr or self.type.is_array) and self.type == self.base.type:
+            error(self.pos, "Invalid use of pointer slice")
         else:
             return "(%s[%s])" % (
                 self.base.result(), self.index.result())
@@ -2401,7 +2409,9 @@ class SliceIndexNode(ExprNode):
         base_type = self.base.type
         if base_type.is_string:
             self.type = bytes_type
-        elif base_type.is_array or base_type.is_ptr:
+        elif base_type.is_ptr:
+            self.type = base_type
+        elif base_type.is_array:
             # we need a ptr type here instead of an array type, as
             # array types can result in invalid type casts in the C
             # code
@@ -6027,13 +6037,9 @@ class CmpNode(object):
     
     def is_ptr_contains(self):
         if self.operator in ('in', 'not_in'):
-            iterator = self.operand2
-            if iterator.type.is_ptr or iterator.type.is_array:
-                return iterator.type.base_type is not PyrexTypes.c_char_type
-            if (isinstance(iterator, IndexNode) and
-                isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
-                (iterator.base.type.is_array or iterator.base.type.is_ptr)):
-                return iterator.base.type.base_type is not PyrexTypes.c_char_type
+            container_type = self.operand2.type
+            return (container_type.is_ptr or container_type.is_array) \
+                and not container_type.is_string
 
     def generate_operation_code(self, code, result_code, 
             operand1, op , operand2):
index e04c9f9349eaa5a84c465df2485abafea557bd55..ecd4b8f7b9b83ba720d0b4509e25a5a1c8e4cb7e 100644 (file)
@@ -4295,9 +4295,6 @@ class ForInStatNode(LoopNode, StatNode):
         self.target.analyse_target_types(env)
         self.iterator.analyse_expressions(env)
         self.item = ExprNodes.NextNode(self.iterator, env)
-        if not self.target.type.assignable_from(self.item.type) and \
-            (self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array):
-            self.item.type = self.iterator.sequence.type.base_type
         self.item = self.item.coerce_to(self.target.type, env)
         self.body.analyse_expressions(env)
         if self.else_clause:
index 4fab617825ad4bf947d0f60514f1efb9033b1d4b..87cf35a94acf51924bcc6cbdd48dc34ac1796eb8 100644 (file)
@@ -146,16 +146,13 @@ class IterationTransform(Visitor.VisitorTransform):
                 node, dict_obj=iterator, keys=True, values=False)
 
         # C array (slice) iteration?
-        plain_iterator = unwrap_coerced_node(iterator)
-        if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
-               (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
-            return self._transform_carray_iteration(node, plain_iterator)
-        if isinstance(plain_iterator, ExprNodes.IndexNode) and \
-               isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)):
-            iterator_base = unwrap_coerced_node(plain_iterator.base)
-            if iterator_base.type.is_array or iterator_base.type.is_ptr:
+        if False:
+            plain_iterator = unwrap_coerced_node(iterator)
+            if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \
+                   (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr):
                 return self._transform_carray_iteration(node, plain_iterator)
-        if iterator.type.is_array:
+
+        if iterator.type.is_ptr or iterator.type.is_array:
             return self._transform_carray_iteration(node, iterator)
         if iterator.type in (Builtin.bytes_type, Builtin.unicode_type):
             return self._transform_string_iteration(node, iterator)
@@ -220,7 +217,7 @@ class IterationTransform(Visitor.VisitorTransform):
 
     def _transform_string_iteration(self, node, slice_node):
         if not node.target.type.is_int:
-            return node
+            return self._transform_carray_iteration(node, slice_node)
         if slice_node.type is Builtin.unicode_type:
             unpack_func = "PyUnicode_AS_UNICODE"
             len_func = "PyUnicode_GET_SIZE"
@@ -270,11 +267,13 @@ class IterationTransform(Visitor.VisitorTransform):
             stop = slice_node.stop
             step = None
             if not stop:
+                if not slice_base.type.is_pyobject:
+                    error(slice_node.pos, "C array iteration requires known end index")
                 return node
         elif isinstance(slice_node, ExprNodes.IndexNode):
             # slice_node.index must be a SliceNode
-            slice_base = unwrap_coerced_node(slice_node.base)
-            index = unwrap_coerced_node(slice_node.index)
+            slice_base = slice_node.base
+            index = slice_node.index
             start = index.start
             stop = index.stop
             step = index.step
@@ -285,7 +284,8 @@ class IterationTransform(Visitor.VisitorTransform):
                        or step.constant_result == 0 \
                        or step.constant_result > 0 and not stop \
                        or step.constant_result < 0 and not start:
-                    error(step.pos, "C array iteration requires known step size and end index")
+                    if not slice_base.type.is_pyobject:
+                        error(step.pos, "C array iteration requires known step size and end index")
                     return node
                 else:
                     # step sign is handled internally by ForFromStatNode
@@ -293,14 +293,20 @@ class IterationTransform(Visitor.VisitorTransform):
                     step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type,
                                              value=abs(step.constant_result),
                                              constant_result=abs(step.constant_result))
-        elif slice_node.type.is_array and slice_node.type.size is not None:
+        elif slice_node.type.is_array:
+            if slice_node.type.size is None:
+                error(step.pos, "C array iteration requires known end index")
+                return node
             slice_base = slice_node
             start = None
             stop = ExprNodes.IntNode(
                 slice_node.pos, value=str(slice_node.type.size),
                 type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size)
             step = None
+
         else:
+            if not slice_node.type.is_pyobject:
+                error(slice_node.pos, "Invalid C array iteration")
             return node
 
         if start:
index e46bd73dc65bc7017f47d518d4ef9bcd51bae587..d682f8b93dd0204f5b160dc61ee02eadc0348415 100644 (file)
@@ -44,7 +44,7 @@ def void_ptr_slice(py_x, L, int a, int b):
             L_c[i] = <void*>L[i]
         assert (x in L_c[:b]) == (py_x in L[:b])
         assert (x in L_c[a:b]) == (py_x in L[a:b])
-#        assert (x in L_c[a:b:2]) == (py_x in L[a:b:2])
+        assert (x in L_c[a:b:2]) == (py_x in L[a:b:2])
     finally:
         free(L_c)