in and not in operators for C arrays and sliced pointers
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 9 Sep 2010 07:41:41 +0000 (00:41 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 9 Sep 2010 07:41:41 +0000 (00:41 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/Optimize.py

index 94dd8cf72ffbd7fb265d6fea65ee94f68adc1ee6..1ab47bf9af7a62542599d3002e72eeb978272eba 100755 (executable)
@@ -2011,7 +2011,8 @@ class IndexNode(ExprNode):
 
         # Handle the case where base is a literal char* (and we expect a string, not an int)
         if isinstance(self.base, BytesNode) or is_slice:
-            self.base = self.base.coerce_to_pyobject(env)
+            if not (self.base.type.is_ptr or self.base.type.is_array):
+                self.base = self.base.coerce_to_pyobject(env)
 
         skip_child_analysis = False
         buffer_access = False
@@ -2092,7 +2093,7 @@ class IndexNode(ExprNode):
                     if self.index.type.is_pyobject:
                         self.index = self.index.coerce_to(
                             PyrexTypes.c_py_ssize_t_type, env)
-                    if not self.index.type.is_int:
+                    elif not self.index.type.is_int:
                         error(self.pos,
                             "Invalid index type '%s'" %
                                 self.index.type)
@@ -5995,10 +5996,11 @@ class CmpNode(object):
               (op, operand1.type, operand2.type))
 
     def is_python_comparison(self):
-        return not self.is_c_string_contains() and (
-            self.has_python_operands()
-            or (self.cascade and self.cascade.is_python_comparison())
-            or self.operator in ('in', 'not_in'))
+        return (not self.is_ptr_contains()
+            and not self.is_c_string_contains()
+            and (self.has_python_operands()
+                 or (self.cascade and self.cascade.is_python_comparison())
+                 or self.operator in ('in', 'not_in')))
 
     def coerce_operands_to(self, dst_type, env):
         operand2 = self.operand2
@@ -6010,7 +6012,8 @@ class CmpNode(object):
     def is_python_result(self):
         return ((self.has_python_operands() and
                  self.operator not in ('is', 'is_not', 'in', 'not_in') and
-                 not self.is_c_string_contains())
+                 not self.is_c_string_contains() and
+                 not self.is_ptr_contains())
             or (self.cascade and self.cascade.is_python_result()))
 
     def is_c_string_contains(self):
@@ -6019,6 +6022,16 @@ class CmpNode(object):
                  and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or
                 (self.operand1.type is PyrexTypes.c_py_unicode_type
                  and self.operand2.type is unicode_type))
+    
+    def is_ptr_contains(self):
+        if self.operator in ('in', 'not_in'):
+            iterator = self.operand2
+            if iterator.type.is_ptr or iterator.type.is_array:
+                return iterator.type.base_type is not PyrexTypes.c_char_type
+            if (isinstance(iterator, IndexNode) and
+                isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
+                (iterator.base.type.is_array or iterator.base.type.is_ptr)):
+                return iterator.base.type.base_type is not PyrexTypes.c_char_type
 
     def generate_operation_code(self, code, result_code, 
             operand1, op , operand2):
@@ -6214,6 +6227,12 @@ class PrimaryCmpNode(ExprNode, CmpNode):
                     env.use_utility_code(char_in_bytes_utility_code)
                 self.operand2 = self.operand2.as_none_safe_node(
                     "argument of type 'NoneType' is not iterable")
+            elif self.is_ptr_contains():
+                if self.cascade:
+                    error(self.pos, "Cascading comparison not yet supported for 'val in sliced pointer'.")
+                self.type = PyrexTypes.c_bint_type
+                # Will be transformed by IterationTransform
+                return
             else:
                 common_type = py_object_type
                 self.is_pycmp = True
index ecd4b8f7b9b83ba720d0b4509e25a5a1c8e4cb7e..e04c9f9349eaa5a84c465df2485abafea557bd55 100644 (file)
@@ -4295,6 +4295,9 @@ class ForInStatNode(LoopNode, StatNode):
         self.target.analyse_target_types(env)
         self.iterator.analyse_expressions(env)
         self.item = ExprNodes.NextNode(self.iterator, env)
+        if not self.target.type.assignable_from(self.item.type) and \
+            (self.iterator.sequence.type.is_ptr or self.iterator.sequence.type.is_array):
+            self.item.type = self.iterator.sequence.type.base_type
         self.item = self.item.coerce_to(self.target.type, env)
         self.body.analyse_expressions(env)
         if self.else_clause:
index 79370bda06ed11d760d2291fa3e954c9c8ace0e7..a329b96e5db1766866f05cfc991d042383c5b836 100644 (file)
@@ -77,11 +77,62 @@ class IterationTransform(Visitor.VisitorTransform):
         self.visitchildren(node)
         self.current_scope = oldscope
         return node
+    
+    def visit_PrimaryCmpNode(self, node):
+        if node.is_ptr_contains():
+        
+            # for t in operand2:
+            #     if operand1 == t:
+            #         res = True
+            #         break
+            # else:
+            #     res = False
+            
+            pos = node.pos
+            res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
+            res = res_handle.ref(pos)
+            result_ref = UtilNodes.ResultRefNode(node)
+            if isinstance(node.operand2, ExprNodes.IndexNode):
+                base_type = node.operand2.base.type.base_type
+            else:
+                base_type = node.operand2.type.base_type
+            target_handle = UtilNodes.TempHandle(base_type)
+            target = target_handle.ref(pos)
+            cmp_node = ExprNodes.PrimaryCmpNode(
+                pos, operator=u'==', operand1=node.operand1, operand2=target)
+            if_body = Nodes.StatListNode(
+                pos, 
+                stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
+                         Nodes.BreakStatNode(pos)])
+            if_node = Nodes.IfStatNode(
+                pos,
+                if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
+                else_clause=None)
+            for_loop = UtilNodes.TempsBlockNode(
+                pos,
+                temps = [target_handle],
+                body = Nodes.ForInStatNode(
+                    pos,
+                    target=target,
+                    iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
+                    body=if_node,
+                    else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
+            for_loop.analyse_expressions(self.current_scope)
+            for_loop = self(for_loop)
+            new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
+            
+            if node.operator == 'not_in':
+                new_node = ExprNodes.NotNode(pos, operand=new_node)
+            return new_node
+
+        else:
+            self.visitchildren(node)
+            return node
 
     def visit_ForInStatNode(self, node):
         self.visitchildren(node)
         return self._optimise_for_loop(node)
-
+    
     def _optimise_for_loop(self, node):
         iterator = node.iterator.sequence
         if iterator.type is Builtin.dict_type: