Extension type cast should reject None (ticket #417)
authorLisandro Dalcin <dalcinl@gmail.com>
Thu, 22 Oct 2009 20:42:30 +0000 (18:42 -0200)
committerLisandro Dalcin <dalcinl@gmail.com>
Thu, 22 Oct 2009 20:42:30 +0000 (18:42 -0200)
--HG--
extra : rebase_source : 37bb9de5574e1f7b4f288192eaa3c70a2ae350ca

Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
tests/run/typetest_T417.pyx [new file with mode: 0644]

index 4c11d6664f68ad9c34ab13b9a019075bbb203378..8130a56eebdfbb4ec5642ebdb37e15ce5a0924c9 100644 (file)
@@ -4245,7 +4245,7 @@ class TypecastNode(ExprNode):
                 warning(self.pos, "No conversion from %s to %s, python object pointer used." % (self.type, self.operand.type))
         elif from_py and to_py:
             if self.typecheck and self.type.is_extension_type:
-                self.operand = PyTypeTestNode(self.operand, self.type, env)
+                self.operand = PyTypeTestNode(self.operand, self.type, env, notnone=True)
 
     def nogil_check(self, env):
         if self.type and self.type.is_pyobject and self.is_temp:
@@ -5563,13 +5563,14 @@ class PyTypeTestNode(CoercionNode):
     #  object is an instance of a particular extension type.
     #  This node borrows the result of its argument node.
 
-    def __init__(self, arg, dst_type, env):
+    def __init__(self, arg, dst_type, env, notnone=False):
         #  The arg is know to be a Python object, and
         #  the dst_type is known to be an extension type.
         assert dst_type.is_extension_type or dst_type.is_builtin_type, "PyTypeTest on non extension type"
         CoercionNode.__init__(self, arg)
         self.type = dst_type
         self.result_ctype = arg.ctype()
+        self.notnone = notnone
 
     nogil_check = Node.gil_error
     gil_message = "Python type test"
@@ -5596,7 +5597,7 @@ class PyTypeTestNode(CoercionNode):
                 code.globalstate.use_utility_code(type_test_utility_code)
             code.putln(
                 "if (!(%s)) %s" % (
-                    self.type.type_test_code(self.arg.py_result()),
+                    self.type.type_test_code(self.arg.py_result(), self.notnone),
                     code.error_goto(self.pos)))
         else:
             error(self.pos, "Cannot test type of extern C class "
@@ -6008,18 +6009,18 @@ bad:
 
 type_test_utility_code = UtilityCode(
 proto = """
-static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/
+static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type); /*proto*/
 """,
 impl = """
-static int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
-    if (!type) {
+static INLINE int __Pyx_TypeTest(PyObject *obj, PyTypeObject *type) {
+    if (unlikely(!type)) {
         PyErr_Format(PyExc_SystemError, "Missing type object");
         return 0;
     }
-    if (obj == Py_None || PyObject_TypeCheck(obj, type))
+    if (likely(PyObject_TypeCheck(obj, type)))
         return 1;
-    PyErr_Format(PyExc_TypeError, "Cannot convert %s to %s",
-        Py_TYPE(obj)->tp_name, type->tp_name);
+    PyErr_Format(PyExc_TypeError, "Cannot convert %.200s to %.200s",
+                 Py_TYPE(obj)->tp_name, type->tp_name);
     return 0;
 }
 """)
index 574096e11ded8ad046bb9cc92d28cbadf1e28cb2..74cff0aacdebee975b15ff2ad84deddc9faf6aea 100644 (file)
@@ -408,19 +408,24 @@ class BuiltinObjectType(PyObjectType):
     def subtype_of(self, type):
         return type.is_pyobject and self.assignable_from(type)
         
-    def type_test_code(self, arg):
+    def type_test_code(self, arg, notnone=False):
         type_name = self.name
         if type_name == 'str':
-            check = 'PyString_CheckExact'
+            type_check = 'PyString_CheckExact'
         elif type_name == 'set':
-            check = 'PyAnySet_CheckExact'
+            type_check = 'PyAnySet_CheckExact'
         elif type_name == 'frozenset':
-            check = 'PyFrozenSet_CheckExact'
+            type_check = 'PyFrozenSet_CheckExact'
         elif type_name == 'bool':
-            check = 'PyBool_Check'
+            type_check = 'PyBool_Check'
         else:
-            check = 'Py%s_CheckExact' % type_name.capitalize()
-        return 'likely(%s(%s)) || (%s) == Py_None || (PyErr_Format(PyExc_TypeError, "Expected %s, got %%s", Py_TYPE(%s)->tp_name), 0)' % (check, arg, arg, self.name, arg)
+            type_check = 'Py%s_CheckExact' % type_name.capitalize()
+
+        check = 'likely(%s(%s))' % (type_check, arg)
+        if not notnone:
+            check = check + ('||((%s) == Py_None)' % arg)
+        error = '(PyErr_Format(PyExc_TypeError, "Expected %s, got %%.200s", Py_TYPE(%s)->tp_name), 0)' % (self.name, arg)
+        return check + '||' + error
 
     def declaration_code(self, entity_code, 
             for_display = 0, dll_linkage = None, pyrex = 0):
@@ -504,9 +509,16 @@ class PyExtensionType(PyObjectType):
             else:
                 return "%s *%s" % (base,  entity_code)
 
-    def type_test_code(self, py_arg):
-        return "__Pyx_TypeTest(%s, %s)" % (py_arg, self.typeptr_cname)
-    
+    def type_test_code(self, py_arg, notnone=False):
+
+        none_check = "((%s) == Py_None)" % py_arg
+        type_check = "likely(__Pyx_TypeTest(%s, %s))" % (
+            py_arg, self.typeptr_cname)
+        if notnone:
+            return type_check
+        else:
+            return "likely(%s || %s)" % (none_check, type_check)
+
     def attributes_known(self):
         return self.scope is not None
     
diff --git a/tests/run/typetest_T417.pyx b/tests/run/typetest_T417.pyx
new file mode 100644 (file)
index 0000000..fd844ed
--- /dev/null
@@ -0,0 +1,55 @@
+#cython: autotestdict=True
+
+cdef class Foo:
+    pass
+
+cdef class SubFoo(Foo):
+    pass
+
+cdef class Bar:
+    pass
+
+def foo1(arg):
+    """
+    >>> foo1(Foo())
+    >>> foo1(SubFoo())
+    >>> foo1(None)
+    >>> foo1(123)
+    >>> foo1(Bar())
+    """
+    cdef Foo val = <Foo>arg
+
+def foo2(arg):
+    """
+    >>> foo2(Foo())
+    >>> foo2(SubFoo())
+    >>> foo2(None)
+    >>> foo2(123)
+    Traceback (most recent call last):
+       ...
+    TypeError: Cannot convert int to typetest_T417.Foo
+    >>> foo2(Bar())
+    Traceback (most recent call last):
+       ...
+    TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo
+    """
+    cdef Foo val = arg
+
+def foo3(arg):
+    """
+    >>> foo3(Foo())
+    >>> foo3(SubFoo())
+    >>> foo3(None)
+    Traceback (most recent call last):
+       ...
+    TypeError: Cannot convert NoneType to typetest_T417.Foo
+    >>> foo3(123)
+    Traceback (most recent call last):
+       ...
+    TypeError: Cannot convert int to typetest_T417.Foo
+    >>> foo2(Bar())
+    Traceback (most recent call last):
+       ...
+    TypeError: Cannot convert typetest_T417.Bar to typetest_T417.Foo
+    """
+    cdef val = <Foo?>arg