From: Robert Bradshaw Date: Wed, 11 Nov 2009 20:45:58 +0000 (-0800) Subject: Fix for in/not in cascading. X-Git-Tag: 0.12.rc0~2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=350f769244d64627f9202e254c75aef9b35e8c35;p=cython.git Fix for in/not in cascading. --- diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 353a78e0..1b6c2e82 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -5231,9 +5231,11 @@ class CmpNode(object): else: negation = "" if op == 'in' or op == 'not_in': - assert not coerce_result + code.globalstate.use_utility_code(contians_utility_code) + if self.type is PyrexTypes.py_object_type: + coerce_result = "__Pyx_PyBoolOrNull_FromLong" if op == 'not_in': - negation = "if (likely(%s != -1)) %s = !%s; " % ((result_code,)*3) + negation = "__Pyx_NegateNonNeg" if operand2.type is dict_type: code.globalstate.use_utility_code( raise_none_iter_error_utility_code) @@ -5241,22 +5243,27 @@ class CmpNode(object): code.putln("__Pyx_RaiseNoneNotIterableError(); %s" % code.error_goto(self.pos)) code.putln("} else {") - code.putln( - "%s = PyDict_Contains(%s, %s); %s%s" % ( - result_code, - operand2.py_result(), - operand1.py_result(), - negation, - code.error_goto_if_neg(result_code, self.pos))) - code.putln("}") + method = "PyDict_Contains" else: - code.putln( - "%s = PySequence_Contains(%s, %s); %s%s" % ( - result_code, - operand2.py_result(), - operand1.py_result(), - negation, - code.error_goto_if_neg(result_code, self.pos))) + method = "PySequence_Contains" + if self.type is PyrexTypes.py_object_type: + error_clause = code.error_goto_if_null + got_ref = "__Pyx_XGOTREF(%s); " % result_code + else: + error_clause = code.error_goto_if_neg + got_ref = "" + code.putln( + "%s = %s(%s(%s(%s, %s))); %s%s" % ( + result_code, + coerce_result, + negation, + method, + operand2.py_result(), + operand1.py_result(), + got_ref, + error_clause(result_code, self.pos))) + if operand2.type is dict_type: + code.putln("}") elif (operand1.type.is_pyobject and op not in ('is', 'is_not')): @@ -5306,6 +5313,14 @@ class CmpNode(object): else: return op +contians_utility_code = UtilityCode( +proto=""" +static INLINE long __Pyx_NegateNonNeg(long b) { return unlikely(b < 0) ? b : !b; } +static INLINE PyObject* __Pyx_PyBoolOrNull_FromLong(long b) { + return unlikely(b < 0) ? NULL : __Pyx_PyBool_FromLong(b); +} +""") + class PrimaryCmpNode(ExprNode, CmpNode): # Non-cascaded comparison or first comparison of diff --git a/tests/run/contains_T455.pyx b/tests/run/contains_T455.pyx index 669f89c1..1fa01821 100644 --- a/tests/run/contains_T455.pyx +++ b/tests/run/contains_T455.pyx @@ -80,3 +80,18 @@ def not_in_dict(k, dict dct): TypeError: 'NoneType' object is not iterable """ return k not in dct + +def cascaded(a, b, c): + """ + >>> cascaded(1, 2, 3) + Traceback (most recent call last): + ... + TypeError: argument of type 'int' is not iterable + >>> cascaded(-1, (1,2), (1,3)) + True + >>> cascaded(1, (1,2), (1,3)) + False + >>> cascaded(-1, (1,2), (1,0)) + False + """ + return a not in b < c