NoneCheckNode to enforce runtime None checks for object references
authorStefan Behnel <scoder@users.berlios.de>
Sun, 29 Mar 2009 18:55:51 +0000 (20:55 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 29 Mar 2009 18:55:51 +0000 (20:55 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Optimize.py

index f464e9feee39978e10c8a4244b92b63bc21f7cdb..ead7f838ff3f30e0cc6e3cfc2342744588f1e565 100644 (file)
@@ -5091,7 +5091,44 @@ class PyTypeTestNode(CoercionNode):
 
     def free_temps(self, code):
         self.arg.free_temps(code)
-        
+
+
+class NoneCheckNode(CoercionNode):
+    # This node is used to check that a Python object is not None and
+    # raises an appropriate exception (as specified by the creating
+    # transform).
+
+    def __init__(self, arg, exception_type_cname, exception_message):
+        CoercionNode.__init__(self, arg)
+        self.type = arg.type
+        self.result_ctype = arg.ctype()
+        self.exception_type_cname = exception_type_cname
+        self.exception_message = exception_message
+
+    def analyse_types(self, env):
+        pass
+
+    def result_in_temp(self):
+        return self.arg.result_in_temp()
+
+    def calculate_result_code(self):
+        return self.arg.result()
+    
+    def generate_result_code(self, code):
+        code.putln(
+            "if (unlikely(%s == Py_None)) {" % self.arg.result())
+        code.putln('PyErr_SetString(%s, "%s"); %s ' % (
+            self.exception_type_cname,
+            StringEncoding.escape_byte_string(self.exception_message),
+            code.error_goto(self.pos)))
+        code.putln("}")
+
+    def generate_post_assignment_code(self, code):
+        self.arg.generate_post_assignment_code(code)
+
+    def free_temps(self, code):
+        self.arg.free_temps(code)
+
 
 class CoerceToPyTypeNode(CoercionNode):
     #  This node is used to convert a C data type
index 7f4972a5f1abcaa65729dcee2a173ef89b14e91f..ebbf497b70fc765bd1c53e611a756a4349aeb0b8 100644 (file)
@@ -476,6 +476,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             arg_list = arg_tuple.args
             self_arg = function.obj
             obj_type = self_arg.type
+            is_unbound_method = False
             if obj_type.is_builtin_type:
                 if obj_type is Builtin.type_type and arg_list and \
                          arg_list[0].type.is_pyobject:
@@ -483,6 +484,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
                     # (ignoring 'type.mro()' here ...)
                     type_name = function.obj.name
                     self_arg = None
+                    is_unbound_method = True
                 else:
                     type_name = obj_type.name
             else:
@@ -494,9 +496,9 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             if self_arg is not None:
                 arg_list = [self_arg] + list(arg_list)
             if kwargs:
-                return method_handler(node, arg_list, kwargs)
+                return method_handler(node, arg_list, kwargs, is_unbound_method)
             else:
-                return method_handler(node, arg_list)
+                return method_handler(node, arg_list, is_unbound_method)
         else:
             return node
 
@@ -625,7 +627,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             PyrexTypes.CFuncTypeArg("item", PyrexTypes.py_object_type, None),
             ])
 
-    def _handle_simple_method_object_append(self, node, args):
+    def _handle_simple_method_object_append(self, node, args, is_unbound_method):
         # X.append() is almost always referring to a list
         if len(args) != 2:
             return node
@@ -644,13 +646,14 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             ],
         exception_value = "-1")
 
-    def _handle_simple_method_list_append(self, node, args):
+    def _handle_simple_method_list_append(self, node, args, is_unbound_method):
         if len(args) != 2:
             error(node.pos, "list.append(x) called with wrong number of args, found %d" %
                   len(args))
             return node
         return self._substitute_method_call(
-            node, "PyList_Append", self.PyList_Append_func_type, args)
+            node, "PyList_Append", self.PyList_Append_func_type,
+            'append', is_unbound_method, args)
 
     single_param_func_type = PyrexTypes.CFuncType(
         PyrexTypes.c_int_type, [
@@ -658,21 +661,37 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
             ],
         exception_value = "-1")
 
-    def _handle_simple_method_list_sort(self, node, args):
+    def _handle_simple_method_list_sort(self, node, args, is_unbound_method):
         if len(args) != 1:
             return node
         return self._substitute_method_call(
-            node, "PyList_Sort", self.single_param_func_type, args)
+            node, "PyList_Sort", self.single_param_func_type,
+            'sort', is_unbound_method, args)
 
-    def _handle_simple_method_list_reverse(self, node, args):
+    def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
         if len(args) != 1:
             error(node.pos, "list.reverse(x) called with wrong number of args, found %d" %
                   len(args))
+            return node
         return self._substitute_method_call(
-            node, "PyList_Reverse", self.single_param_func_type, args)
+            node, "PyList_Reverse", self.single_param_func_type,
+            'reverse', is_unbound_method, args)
 
-    def _substitute_method_call(self, node, name, func_type, args=()):
+    def _substitute_method_call(self, node, name, func_type,
+                                attr_name, is_unbound_method, args=()):
         args = list(args)
+        if args:
+            self_arg = args[0]
+            if is_unbound_method:
+                self_arg = ExprNodes.NoneCheckNode(
+                    self_arg, "PyExc_TypeError",
+                    "descriptor '%s' requires a '%s' object but received a 'NoneType'" % (
+                    attr_name, node.function.obj.name))
+            else:
+                self_arg = ExprNodes.NoneCheckNode(
+                    self_arg, "PyExc_AttributeError",
+                    "'NoneType' object has no attribute '%s'" % attr_name)
+            args[0] = self_arg
         # FIXME: args[0] may need a runtime None check (ticket #166)
         return ExprNodes.PythonCapiCallNode(
             node.pos, name, func_type,