# Handle the case where base is a literal char* (and we expect a string, not an int)
if isinstance(self.base, BytesNode) or is_slice:
- self.base = self.base.coerce_to_pyobject(env)
+ if not (self.base.type.is_ptr or self.base.type.is_array):
+ self.base = self.base.coerce_to_pyobject(env)
skip_child_analysis = False
buffer_access = False
if self.index.type.is_pyobject:
self.index = self.index.coerce_to(
PyrexTypes.c_py_ssize_t_type, env)
- if not self.index.type.is_int:
+ elif not self.index.type.is_int:
error(self.pos,
"Invalid index type '%s'" %
self.index.type)
(op, operand1.type, operand2.type))
def is_python_comparison(self):
- return not self.is_c_string_contains() and (
- self.has_python_operands()
- or (self.cascade and self.cascade.is_python_comparison())
- or self.operator in ('in', 'not_in'))
+ return (not self.is_ptr_contains()
+ and not self.is_c_string_contains()
+ and (self.has_python_operands()
+ or (self.cascade and self.cascade.is_python_comparison())
+ or self.operator in ('in', 'not_in')))
def coerce_operands_to(self, dst_type, env):
operand2 = self.operand2
def is_python_result(self):
return ((self.has_python_operands() and
self.operator not in ('is', 'is_not', 'in', 'not_in') and
- not self.is_c_string_contains())
+ not self.is_c_string_contains() and
+ not self.is_ptr_contains())
or (self.cascade and self.cascade.is_python_result()))
def is_c_string_contains(self):
and (self.operand2.type.is_string or self.operand2.type is bytes_type)) or
(self.operand1.type is PyrexTypes.c_py_unicode_type
and self.operand2.type is unicode_type))
+
+ def is_ptr_contains(self):
+ if self.operator in ('in', 'not_in'):
+ iterator = self.operand2
+ if iterator.type.is_ptr or iterator.type.is_array:
+ return iterator.type.base_type is not PyrexTypes.c_char_type
+ if (isinstance(iterator, IndexNode) and
+ isinstance(iterator.index, (SliceNode, CoerceFromPyTypeNode)) and
+ (iterator.base.type.is_array or iterator.base.type.is_ptr)):
+ return iterator.base.type.base_type is not PyrexTypes.c_char_type
def generate_operation_code(self, code, result_code,
operand1, op , operand2):
env.use_utility_code(char_in_bytes_utility_code)
self.operand2 = self.operand2.as_none_safe_node(
"argument of type 'NoneType' is not iterable")
+ elif self.is_ptr_contains():
+ if self.cascade:
+ error(self.pos, "Cascading comparison not yet supported for 'val in sliced pointer'.")
+ self.type = PyrexTypes.c_bint_type
+ # Will be transformed by IterationTransform
+ return
else:
common_type = py_object_type
self.is_pycmp = True
self.visitchildren(node)
self.current_scope = oldscope
return node
+
+ def visit_PrimaryCmpNode(self, node):
+ if node.is_ptr_contains():
+
+ # for t in operand2:
+ # if operand1 == t:
+ # res = True
+ # break
+ # else:
+ # res = False
+
+ pos = node.pos
+ res_handle = UtilNodes.TempHandle(PyrexTypes.c_bint_type)
+ res = res_handle.ref(pos)
+ result_ref = UtilNodes.ResultRefNode(node)
+ if isinstance(node.operand2, ExprNodes.IndexNode):
+ base_type = node.operand2.base.type.base_type
+ else:
+ base_type = node.operand2.type.base_type
+ target_handle = UtilNodes.TempHandle(base_type)
+ target = target_handle.ref(pos)
+ cmp_node = ExprNodes.PrimaryCmpNode(
+ pos, operator=u'==', operand1=node.operand1, operand2=target)
+ if_body = Nodes.StatListNode(
+ pos,
+ stats = [Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=1)),
+ Nodes.BreakStatNode(pos)])
+ if_node = Nodes.IfStatNode(
+ pos,
+ if_clauses=[Nodes.IfClauseNode(pos, condition=cmp_node, body=if_body)],
+ else_clause=None)
+ for_loop = UtilNodes.TempsBlockNode(
+ pos,
+ temps = [target_handle],
+ body = Nodes.ForInStatNode(
+ pos,
+ target=target,
+ iterator=ExprNodes.IteratorNode(node.operand2.pos, sequence=node.operand2),
+ body=if_node,
+ else_clause=Nodes.SingleAssignmentNode(pos, lhs=result_ref, rhs=ExprNodes.BoolNode(pos, value=0))))
+ for_loop.analyse_expressions(self.current_scope)
+ for_loop = self(for_loop)
+ new_node = UtilNodes.TempResultFromStatNode(result_ref, for_loop)
+
+ if node.operator == 'not_in':
+ new_node = ExprNodes.NotNode(pos, operand=new_node)
+ return new_node
+
+ else:
+ self.visitchildren(node)
+ return node
def visit_ForInStatNode(self, node):
self.visitchildren(node)
return self._optimise_for_loop(node)
-
+
def _optimise_for_loop(self, node):
iterator = node.iterator.sequence
if iterator.type is Builtin.dict_type: