Fix for in/not in cascading.
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 11 Nov 2009 20:45:58 +0000 (12:45 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 11 Nov 2009 20:45:58 +0000 (12:45 -0800)
Cython/Compiler/ExprNodes.py
tests/run/contains_T455.pyx

index 353a78e09f05b9c84129d4bc3dc7560bc9ca5ac0..1b6c2e82849b62b1b0eea6e8b7d8df70678f3f94 100644 (file)
@@ -5231,9 +5231,11 @@ class CmpNode(object):
         else: 
             negation = ""
         if op == 'in' or op == 'not_in':
-            assert not coerce_result
+            code.globalstate.use_utility_code(contians_utility_code)
+            if self.type is PyrexTypes.py_object_type:
+                coerce_result = "__Pyx_PyBoolOrNull_FromLong"
             if op == 'not_in':
-                negation = "if (likely(%s != -1)) %s = !%s; " % ((result_code,)*3)
+                negation = "__Pyx_NegateNonNeg"
             if operand2.type is dict_type:
                 code.globalstate.use_utility_code(
                     raise_none_iter_error_utility_code)
@@ -5241,22 +5243,27 @@ class CmpNode(object):
                 code.putln("__Pyx_RaiseNoneNotIterableError(); %s" %
                            code.error_goto(self.pos))
                 code.putln("} else {")
-                code.putln(
-                    "%s = PyDict_Contains(%s, %s); %s%s" % (
-                        result_code, 
-                        operand2.py_result(), 
-                        operand1.py_result(), 
-                        negation,
-                        code.error_goto_if_neg(result_code, self.pos)))
-                code.putln("}")
+                method = "PyDict_Contains"
             else:
-                code.putln(
-                    "%s = PySequence_Contains(%s, %s); %s%s" % (
-                        result_code, 
-                        operand2.py_result(), 
-                        operand1.py_result(), 
-                        negation,
-                        code.error_goto_if_neg(result_code, self.pos)))
+                method = "PySequence_Contains"
+            if self.type is PyrexTypes.py_object_type:
+                error_clause = code.error_goto_if_null
+                got_ref = "__Pyx_XGOTREF(%s); " % result_code
+            else:
+                error_clause = code.error_goto_if_neg
+                got_ref = ""
+            code.putln(
+                "%s = %s(%s(%s(%s, %s))); %s%s" % (
+                    result_code,
+                    coerce_result,
+                    negation,
+                    method,
+                    operand2.py_result(), 
+                    operand1.py_result(), 
+                    got_ref,
+                    error_clause(result_code, self.pos)))
+            if operand2.type is dict_type:
+                code.putln("}")
                     
         elif (operand1.type.is_pyobject
             and op not in ('is', 'is_not')):
@@ -5306,6 +5313,14 @@ class CmpNode(object):
         else:
             return op
     
+contians_utility_code = UtilityCode(
+proto="""
+static INLINE long __Pyx_NegateNonNeg(long b) { return unlikely(b < 0) ? b : !b; }
+static INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) {
+    return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b);
+}
+""")
+
 
 class PrimaryCmpNode(ExprNode, CmpNode):
     #  Non-cascaded comparison or first comparison of
index 669f89c1612dd486c1791a05898b39c26ba45573..1fa018212f003eca3dd0383eb9cad6c209de633d 100644 (file)
@@ -80,3 +80,18 @@ def not_in_dict(k, dict dct):
     TypeError: 'NoneType' object is not iterable
     """
     return k not in dct
+
+def cascaded(a, b, c):
+    """
+    >>> cascaded(1, 2, 3)
+    Traceback (most recent call last):
+    ...
+    TypeError: argument of type 'int' is not iterable
+    >>> cascaded(-1, (1,2), (1,3))
+    True
+    >>> cascaded(1, (1,2), (1,3))
+    False
+    >>> cascaded(-1, (1,2), (1,0))
+    False
+    """
+    return a not in b < c