minor refactoring, use 'notnone=True' in type tests
authorStefan Behnel <scoder@users.berlios.de>
Sat, 24 Oct 2009 08:59:41 +0000 (10:59 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 24 Oct 2009 08:59:41 +0000 (10:59 +0200)
Cython/Compiler/Optimize.py

index 83583fe9bc5f9651438390a3e51418a1c8c7cc61..a08399ff1c05b99e430cf1f704ffe7b8d115916a 100644 (file)
@@ -25,6 +25,10 @@ try:
 except NameError:
     from sets import Set as set
 
+class FakePythonEnv(object):
+    "A fake environment for creating type test nodes etc."
+    nogil = False
+
 def unwrap_node(node):
     while isinstance(node, UtilNodes.ResultRefNode):
         node = node.expression
@@ -297,11 +301,11 @@ class IterationTransform(Visitor.VisitorTransform):
                 tuple_target = node.target
 
         def coerce_object_to(obj_node, dest_type):
-            class FakeEnv(object):
-                nogil = False
             if dest_type.is_pyobject:
-                if dest_type.is_extension_type or dest_type.is_builtin_type:
-                    obj_node = ExprNodes.PyTypeTestNode(obj_node, dest_type, FakeEnv())
+                if dest_type != obj_node.type:
+                    if dest_type.is_extension_type or dest_type.is_builtin_type:
+                        obj_node = ExprNodes.PyTypeTestNode(
+                            obj_node, dest_type, FakePythonEnv(), notnone=True)
                 result = ExprNodes.TypecastNode(
                     obj_node.pos,
                     operand = obj_node,
@@ -316,7 +320,7 @@ class IterationTransform(Visitor.VisitorTransform):
                         return temp_result.result()
                     def generate_execution_code(self, code):
                         self.generate_result_code(code)
-                return (temp_result, CoercedTempNode(dest_type, obj_node, FakeEnv()))
+                return (temp_result, CoercedTempNode(dest_type, obj_node, FakePythonEnv()))
 
         if isinstance(node.body, Nodes.StatListNode):
             body = node.body
@@ -715,6 +719,22 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
         else:
             return node
 
+    def _error_wrong_arg_count(self, function_name, node, args, expected=None):
+        if not expected: # None or 0
+            arg_str = ''
+        elif isinstance(expected, basestring) or expected > 1:
+            arg_str = '...'
+        elif expected == 1:
+            arg_str = 'x'
+        else:
+            arg_str = ''
+        if expected is not None:
+            expected_str = 'expected %s, ' % expected
+        else:
+            expected_str = ''
+        error(node.pos, "%s(%s) called with wrong number of args, %sfound %d" % (
+            function_name, arg_str, expected_str, len(args)))
+
     ### builtin types
 
     def _handle_general_function_dict(self, node, pos_args, kwargs):
@@ -848,8 +868,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
                 is_temp = node.is_temp
                 )
         else:
-            error(node.pos, "getattr() called with wrong number of args, "
-                  "expected 2 or 3, found %d" % len(args))
+            self._error_wrong_arg_count('getattr', node, args, '2 or 3')
         return node
 
     Pyx_Type_func_type = PyrexTypes.CFuncType(
@@ -898,8 +917,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
 
     def _handle_simple_method_list_append(self, node, args, is_unbound_method):
         if len(args) != 2:
-            error(node.pos, "list.append(x) called with wrong number of args, found %d" %
-                  len(args))
+            self._error_wrong_arg_count('list.append', node, args, 2)
             return node
         return self._substitute_method_call(
             node, "PyList_Append", self.PyList_Append_func_type,
@@ -920,8 +938,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
 
     def _handle_simple_method_list_reverse(self, node, args, is_unbound_method):
         if len(args) != 1:
-            error(node.pos, "list.reverse(x) called with wrong number of args, found %d" %
-                  len(args))
+            self._error_wrong_arg_count('list.reverse', node, args, 1)
             return node
         return self._substitute_method_call(
             node, "PyList_Reverse", self.single_param_func_type,
@@ -949,8 +966,7 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
 
     def _handle_simple_method_unicode_encode(self, node, args, is_unbound_method):
         if len(args) < 1 or len(args) > 3:
-            error(node.pos, "unicode.encode(...) called with wrong number of args, found %d" %
-                  len(args))
+            self._error_wrong_arg_count('unicode.encode', node, args, '1-3')
             return node
 
         null_node = ExprNodes.NullNode(node.pos)