From 15e5567faca2749aee795beec5bc32ca576740d2 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 9 Sep 2010 00:41:41 -0700 Subject: [PATCH] in and not in operators for C arrays and sliced pointers --- Cython/Compiler/ExprNodes.py | 33 +++++++++++++++++----- Cython/Compiler/Nodes.py | 3 ++ Cython/Compiler/Optimize.py | 53 +++++++++++++++++++++++++++++++++++- 3 files changed, 81 insertions(+), 8 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 94dd8cf7..1ab47bf9 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -2011,7 +2011,8 @@ class IndexNode(ExprNode): # Handle the case where base is a literal char* (and we expect a string, not an int) if isinstance(self.base, BytesNode) or is_slice: - self.base = self.base.coerce_to_pyobject(env) + if not (self.base.type.is_ptr or self.base.type.is_array): + self.base = self.base.coerce_to_pyobject(env) skip_child_analysis = False buffer_access = False @@ -2092,7 +2093,7 @@ class IndexNode(ExprNode): if self.index.type.is_pyobject: self.index = self.index.coerce_to( PyrexTypes.c_py_ssize_t_type, env) - if not self.index.type.is_int: + elif not self.index.type.is_int: error(self.pos, "Invalid index type '%s'" % self.index.type) @@ -5995,10 +5996,11 @@ class CmpNode(object): (op, operand1.type, operand2.type)) def is_python_comparison(self): - return not self.is_c_string_contains() and ( - self.has_python_operands() - or (self.cascade and self.cascade.is_python_comparison()) - or self.operator in ('in', 'not_in')) + return (not self.is_ptr_contains() + and not self.is_c_string_contains() + and (self.has_python_operands() + or (self.cascade and self.cascade.is_python_comparison()) + or self.operator in ('in', 'not_in'))) def coerce_operands_to(self, dst_type, env): operand2 = self.operand2 @@ -6010,7 +6012,8 @@ class CmpNode(object): def is_python_result(self): return ((self.has_python_operands() and self.operator not in ('is', 'is_not', 'in', 'not_in') and - not self.is_c_string_contains()) + not self.is_c_string_contains() and + not self.is_ptr_contains()) or (self.cascade and self.cascade.is_python_result())) def is_c_string_contains(self): @@ -6019,6 +6022,16 @@ class CmpNode(object): and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or (self.operand1.type is PyrexTypes.c_py_unicode_type and self.operand2.type is unicode_type)) + + def is_ptr_contains(self): + if self.operator in ('in', 'not_in'): + iterator = self.operand2 + if iterator.type.is_ptr or iterator.type.is_array: + return iterator.type.base_type is not PyrexTypes.c_char_type + if (isinstance(iterator, IndexNode) and + isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and + (iterator.base.type.is_array or iterator.base.type.is_ptr)): + return iterator.base.type.base_type is not PyrexTypes.c_char_type def generate_operation_code(self, code, result_code, operand1, op , operand2): @@ -6214,6 +6227,12 @@ class PrimaryCmpNode(ExprNode, CmpNode): env.use_utility_code(char_in_bytes_utility_code) self.operand2 = self.operand2.as_none_safe_node( "argument of type 'NoneType' is not iterable") + elif self.is_ptr_contains(): + if self.cascade: + error(self.pos, "Cascading comparison not yet supported for 'val in sliced pointer'.") + self.type = PyrexTypes.c_bint_type + # Will be transformed by IterationTransform + return else: common_type = py_object_type self.is_pycmp = True diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index ecd4b8f7..e04c9f93 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -4295,6 +4295,9 @@ class ForInStatNode(LoopNode, StatNode): self.target.analyse_target_types(env) self.iterator.analyse_expressions(env) self.item = ExprNodes.NextNode(self.iterator, env) + if not self.target.type.assignable_from(self.item.type) and \ + (self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array): + self.item.type = self.iterator.sequence.type.base_type self.item = self.item.coerce_to(self.target.type, env) self.body.analyse_expressions(env) if self.else_clause: diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 79370bda..a329b96e 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -77,11 +77,62 @@ class IterationTransform(Visitor.VisitorTransform): self.visitchildren(node) self.current_scope = oldscope return node + + def visit_PrimaryCmpNode(self, node): + if node.is_ptr_contains(): + + # for t in operand2: + # if operand1 == t: + # res = True + # break + # else: + # res = False + + pos = node.pos + res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type) + res = res_handle.ref(pos) + result_ref = UtilNodes.ResultRefNode(node) + if isinstance(node.operand2, ExprNodes.IndexNode): + base_type = node.operand2.base.type.base_type + else: + base_type = node.operand2.type.base_type + target_handle = UtilNodes.TempHandle(base_type) + target = target_handle.ref(pos) + cmp_node = ExprNodes.PrimaryCmpNode( + pos, operator=u'==', operand1=node.operand1, operand2=target) + if_body = Nodes.StatListNode( + pos, + stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)), + Nodes.BreakStatNode(pos)]) + if_node = Nodes.IfStatNode( + pos, + if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)], + else_clause=None) + for_loop = UtilNodes.TempsBlockNode( + pos, + temps = [target_handle], + body = Nodes.ForInStatNode( + pos, + target=target, + iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2), + body=if_node, + else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0)))) + for_loop.analyse_expressions(self.current_scope) + for_loop = self(for_loop) + new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop) + + if node.operator == 'not_in': + new_node = ExprNodes.NotNode(pos, operand=new_node) + return new_node + + else: + self.visitchildren(node) + return node def visit_ForInStatNode(self, node): self.visitchildren(node) return self._optimise_for_loop(node) - + def _optimise_for_loop(self, node): iterator = node.iterator.sequence if iterator.type is Builtin.dict_type: -- 2.26.2