From: Robert Bradshaw Date: Sat, 11 Sep 2010 21:46:28 +0000 (-0700) Subject: Cleanup slice iteration code. X-Git-Tag: 0.14.alpha0~327 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=082aee713f52ef4b49f028f385a4d56905a62821;p=cython.git Cleanup slice iteration code. --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index b16f86bb..fd30a8dd 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1667,11 +1667,10 @@ class IteratorNode(ExprNode): def analyse_types(self, env): self.sequence.analyse_types(env) - if isinstance(self.sequence, SliceIndexNode) and \ - (self.sequence.base.type.is_array or self.sequence.base.type.is_ptr) \ - or self.sequence.type.is_array and self.sequence.type.size is not None: + if (self.sequence.type.is_array or self.sequence.type.is_ptr) and \ + not self.sequence.type.is_string: # C array iteration will be transformed later on - pass + self.type = self.sequence.type else: self.sequence = self.sequence.coerce_to_pyobject(env) self.is_temp = 1 @@ -1686,6 +1685,8 @@ class IteratorNode(ExprNode): code.funcstate.release_temp(self.counter_cname) def generate_result_code(self, code): + if self.sequence.type.is_array or self.sequence.type.is_ptr: + raise InternalError("for in carray slice not transformed") is_builtin_sequence = self.sequence.type is list_type or \ self.sequence.type is tuple_type may_be_a_sequence = is_builtin_sequence or not self.sequence.type.is_builtin_type @@ -1733,6 +1734,8 @@ class NextNode(AtomicExprNode): def __init__(self, iterator, env): self.pos = iterator.pos self.iterator = iterator + if iterator.type.is_ptr or iterator.type.is_array: + self.type = iterator.type.base_type self.is_temp = 1 def generate_result_code(self, code): @@ -2008,6 +2011,7 @@ class IndexNode(ExprNode): return is_slice = isinstance(self.index, SliceNode) + # Potentially overflowing index value. if not is_slice and isinstance(self.index, IntNode) and Utils.long_literal(self.index.value): self.index = self.index.coerce_to_pyobject(env) @@ -2092,7 +2096,9 @@ class IndexNode(ExprNode): else: if base_type.is_ptr or base_type.is_array: self.type = base_type.base_type - if self.index.type.is_pyobject: + if is_slice: + self.type = base_type + elif self.index.type.is_pyobject: self.index = self.index.coerce_to( PyrexTypes.c_py_ssize_t_type, env) elif not self.index.type.is_int: @@ -2147,6 +2153,8 @@ class IndexNode(ExprNode): return "PyTuple_GET_ITEM(%s, %s)" % (self.base.result(), self.index.result()) elif self.base.type is unicode_type and self.type is PyrexTypes.c_py_unicode_type: return "PyUnicode_AS_UNICODE(%s)[%s]" % (self.base.result(), self.index.result()) + elif (self.type.is_ptr or self.type.is_array) and self.type == self.base.type: + error(self.pos, "Invalid use of pointer slice") else: return "(%s[%s])" % ( self.base.result(), self.index.result()) @@ -2401,7 +2409,9 @@ class SliceIndexNode(ExprNode): base_type = self.base.type if base_type.is_string: self.type = bytes_type - elif base_type.is_array or base_type.is_ptr: + elif base_type.is_ptr: + self.type = base_type + elif base_type.is_array: # we need a ptr type here instead of an array type, as # array types can result in invalid type casts in the C # code @@ -6027,13 +6037,9 @@ class CmpNode(object): def is_ptr_contains(self): if self.operator in ('in', 'not_in'): - iterator = self.operand2 - if iterator.type.is_ptr or iterator.type.is_array: - return iterator.type.base_type is not PyrexTypes.c_char_type - if (isinstance(iterator, IndexNode) and - isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and - (iterator.base.type.is_array or iterator.base.type.is_ptr)): - return iterator.base.type.base_type is not PyrexTypes.c_char_type + container_type = self.operand2.type + return (container_type.is_ptr or container_type.is_array) \ + and not container_type.is_string def generate_operation_code(self, code, result_code, operand1, op , operand2): diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index e04c9f93..ecd4b8f7 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -4295,9 +4295,6 @@ class ForInStatNode(LoopNode, StatNode): self.target.analyse_target_types(env) self.iterator.analyse_expressions(env) self.item = ExprNodes.NextNode(self.iterator, env) - if not self.target.type.assignable_from(self.item.type) and \ - (self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array): - self.item.type = self.iterator.sequence.type.base_type self.item = self.item.coerce_to(self.target.type, env) self.body.analyse_expressions(env) if self.else_clause: diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 4fab6178..87cf35a9 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -146,16 +146,13 @@ class IterationTransform(Visitor.VisitorTransform): node, dict_obj=iterator, keys=True, values=False) # C array (slice) iteration? - plain_iterator = unwrap_coerced_node(iterator) - if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \ - (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): - return self._transform_carray_iteration(node, plain_iterator) - if isinstance(plain_iterator, ExprNodes.IndexNode) and \ - isinstance(plain_iterator.index, (ExprNodes.SliceNode, ExprNodes.CoerceFromPyTypeNode)): - iterator_base = unwrap_coerced_node(plain_iterator.base) - if iterator_base.type.is_array or iterator_base.type.is_ptr: + if False: + plain_iterator = unwrap_coerced_node(iterator) + if isinstance(plain_iterator, ExprNodes.SliceIndexNode) and \ + (plain_iterator.base.type.is_array or plain_iterator.base.type.is_ptr): return self._transform_carray_iteration(node, plain_iterator) - if iterator.type.is_array: + + if iterator.type.is_ptr or iterator.type.is_array: return self._transform_carray_iteration(node, iterator) if iterator.type in (Builtin.bytes_type, Builtin.unicode_type): return self._transform_string_iteration(node, iterator) @@ -220,7 +217,7 @@ class IterationTransform(Visitor.VisitorTransform): def _transform_string_iteration(self, node, slice_node): if not node.target.type.is_int: - return node + return self._transform_carray_iteration(node, slice_node) if slice_node.type is Builtin.unicode_type: unpack_func = "PyUnicode_AS_UNICODE" len_func = "PyUnicode_GET_SIZE" @@ -270,11 +267,13 @@ class IterationTransform(Visitor.VisitorTransform): stop = slice_node.stop step = None if not stop: + if not slice_base.type.is_pyobject: + error(slice_node.pos, "C array iteration requires known end index") return node elif isinstance(slice_node, ExprNodes.IndexNode): # slice_node.index must be a SliceNode - slice_base = unwrap_coerced_node(slice_node.base) - index = unwrap_coerced_node(slice_node.index) + slice_base = slice_node.base + index = slice_node.index start = index.start stop = index.stop step = index.step @@ -285,7 +284,8 @@ class IterationTransform(Visitor.VisitorTransform): or step.constant_result == 0 \ or step.constant_result > 0 and not stop \ or step.constant_result < 0 and not start: - error(step.pos, "C array iteration requires known step size and end index") + if not slice_base.type.is_pyobject: + error(step.pos, "C array iteration requires known step size and end index") return node else: # step sign is handled internally by ForFromStatNode @@ -293,14 +293,20 @@ class IterationTransform(Visitor.VisitorTransform): step = ExprNodes.IntNode(step.pos, type=PyrexTypes.c_py_ssize_t_type, value=abs(step.constant_result), constant_result=abs(step.constant_result)) - elif slice_node.type.is_array and slice_node.type.size is not None: + elif slice_node.type.is_array: + if slice_node.type.size is None: + error(step.pos, "C array iteration requires known end index") + return node slice_base = slice_node start = None stop = ExprNodes.IntNode( slice_node.pos, value=str(slice_node.type.size), type=PyrexTypes.c_py_ssize_t_type, constant_result=slice_node.type.size) step = None + else: + if not slice_node.type.is_pyobject: + error(slice_node.pos, "Invalid C array iteration") return node if start: diff --git a/tests/run/slice_ptr.pyx b/tests/run/slice_ptr.pyx index e46bd73d..d682f8b9 100644 --- a/tests/run/slice_ptr.pyx +++ b/tests/run/slice_ptr.pyx @@ -44,7 +44,7 @@ def void_ptr_slice(py_x, L, int a, int b): L_c[i] = L[i] assert (x in L_c[:b]) == (py_x in L[:b]) assert (x in L_c[a:b]) == (py_x in L[a:b]) -# assert (x in L_c[a:b:2]) == (py_x in L[a:b:2]) + assert (x in L_c[a:b:2]) == (py_x in L[a:b:2]) finally: free(L_c)