From cab342a247c6bcf1effddf703ca8447386c6f1be Mon Sep 17 00:00:00 2001 From: Lisandro Dalcin Date: Thu, 22 Oct 2009 18:42:30 -0200 Subject: [PATCH] Extension type cast should reject None (ticket #417) --HG-- extra : rebase_source : 37bb9de5574e1f7b4f288192eaa3c70a2ae350ca --- Cython/Compiler/ExprNodes.py | 19 ++++++------ Cython/Compiler/PyrexTypes.py | 32 +++++++++++++------- tests/run/typetest_T417.pyx | 55 +++++++++++++++++++++++++++++++++++ 3 files changed, 87 insertions(+), 19 deletions(-) create mode 100644 tests/run/typetest_T417.pyx diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 4c11d666..8130a56e 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -4245,7 +4245,7 @@ class TypecastNode(ExprNode): warning(self.pos, "No conversion from %s to %s, python object pointer used." % (self.type, self.operand.type)) elif from_py and to_py: if self.typecheck and self.type.is_extension_type: - self.operand = PyTypeTestNode(self.operand, self.type, env) + self.operand = PyTypeTestNode(self.operand, self.type, env, notnone=True) def nogil_check(self, env): if self.type and self.type.is_pyobject and self.is_temp: @@ -5563,13 +5563,14 @@ class PyTypeTestNode(CoercionNode): # object is an instance of a particular extension type. # This node borrows the result of its argument node. - def __init__(self, arg, dst_type, env): + def __init__(self, arg, dst_type, env, notnone=False): # The arg is know to be a Python object, and # the dst_type is known to be an extension type. assert dst_type.is_extension_type or dst_type.is_builtin_type, "PyTypeTest on non extension type" CoercionNode.__init__(self, arg) self.type = dst_type self.result_ctype = arg.ctype() + self.notnone = notnone nogil_check = Node.gil_error gil_message = "Python type test" @@ -5596,7 +5597,7 @@ class PyTypeTestNode(CoercionNode): code.globalstate.use_utility_code(type_test_utility_code) code.putln( "if (!(%s)) %s" % ( - self.type.type_test_code(self.arg.py_result()), + self.type.type_test_code(self.arg.py_result(), self.notnone), code.error_goto(self.pos))) else: error(self.pos, "Cannot test type of extern C class " @@ -6008,18 +6009,18 @@ bad: type_test_utility_code = UtilityCode( proto = """ -static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/ +static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/ """, impl = """ -static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { - if (!type) { +static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) { + if (unlikely(!type)) { PyErr_Format(PyExc_SystemError, "Missing type object"); return 0; } - if (obj == Py_None || PyObject_TypeCheck(obj, type)) + if (likely(PyObject_TypeCheck(obj, type))) return 1; - PyErr_Format(PyExc_TypeError, "Cannot convert %s to %s", - Py_TYPE(obj)->tp_name, type->tp_name); + PyErr_Format(PyExc_TypeError, "Cannot convert %.200s to %.200s", + Py_TYPE(obj)->tp_name, type->tp_name); return 0; } """) diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 574096e1..74cff0aa 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -408,19 +408,24 @@ class BuiltinObjectType(PyObjectType): def subtype_of(self, type): return type.is_pyobject and self.assignable_from(type) - def type_test_code(self, arg): + def type_test_code(self, arg, notnone=False): type_name = self.name if type_name == 'str': - check = 'PyString_CheckExact' + type_check = 'PyString_CheckExact' elif type_name == 'set': - check = 'PyAnySet_CheckExact' + type_check = 'PyAnySet_CheckExact' elif type_name == 'frozenset': - check = 'PyFrozenSet_CheckExact' + type_check = 'PyFrozenSet_CheckExact' elif type_name == 'bool': - check = 'PyBool_Check' + type_check = 'PyBool_Check' else: - check = 'Py%s_CheckExact' % type_name.capitalize() - return 'likely(%s(%s)) || (%s) == Py_None || (PyErr_Format(PyExc_TypeError, "Expected %s, got %%s", Py_TYPE(%s)->tp_name), 0)' % (check, arg, arg, self.name, arg) + type_check = 'Py%s_CheckExact' % type_name.capitalize() + + check = 'likely(%s(%s))' % (type_check, arg) + if not notnone: + check = check + ('||((%s) == Py_None)' % arg) + error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg) + return check + '||' + error def declaration_code(self, entity_code, for_display = 0, dll_linkage = None, pyrex = 0): @@ -504,9 +509,16 @@ class PyExtensionType(PyObjectType): else: return "%s *%s" % (base, entity_code) - def type_test_code(self, py_arg): - return "__Pyx_TypeTest(%s, %s)" % (py_arg, self.typeptr_cname) - + def type_test_code(self, py_arg, notnone=False): + + none_check = "((%s) == Py_None)" % py_arg + type_check = "likely(__Pyx_TypeTest(%s, %s))" % ( + py_arg, self.typeptr_cname) + if notnone: + return type_check + else: + return "likely(%s || %s)" % (none_check, type_check) + def attributes_known(self): return self.scope is not None diff --git a/tests/run/typetest_T417.pyx b/tests/run/typetest_T417.pyx new file mode 100644 index 00000000..fd844ed7 --- /dev/null +++ b/tests/run/typetest_T417.pyx @@ -0,0 +1,55 @@ +#cython: autotestdict=True + +cdef class Foo: + pass + +cdef class SubFoo(Foo): + pass + +cdef class Bar: + pass + +def foo1(arg): + """ + >>> foo1(Foo()) + >>> foo1(SubFoo()) + >>> foo1(None) + >>> foo1(123) + >>> foo1(Bar()) + """ + cdef Foo val = arg + +def foo2(arg): + """ + >>> foo2(Foo()) + >>> foo2(SubFoo()) + >>> foo2(None) + >>> foo2(123) + Traceback (most recent call last): + ... + TypeError: Cannot convert int to typetest_T417.Foo + >>> foo2(Bar()) + Traceback (most recent call last): + ... + TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo + """ + cdef Foo val = arg + +def foo3(arg): + """ + >>> foo3(Foo()) + >>> foo3(SubFoo()) + >>> foo3(None) + Traceback (most recent call last): + ... + TypeError: Cannot convert NoneType to typetest_T417.Foo + >>> foo3(123) + Traceback (most recent call last): + ... + TypeError: Cannot convert int to typetest_T417.Foo + >>> foo2(Bar()) + Traceback (most recent call last): + ... + TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo + """ + cdef val = arg -- 2.26.2