Optional arguments in cpdef functions
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 6 Feb 2008 23:11:28 +0000 (15:11 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 6 Feb 2008 23:11:28 +0000 (15:11 -0800)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py
Cython/Compiler/PyrexTypes.py

index d8d5768cfaa4661375a4ce0fc5129472ea491a9e..15aedfdf5f8003da1b83fb1c77e9511db9ff59e7 100644 (file)
@@ -1514,13 +1514,19 @@ class SimpleCallNode(ExprNode):
             self.result_code = "<error>"
             return
         # Check no. of args
-        expected_nargs = len(func_type.args)
+        max_nargs = len(func_type.args)
+        expected_nargs = max_nargs - func_type.optional_arg_count
         actual_nargs = len(self.args)
         if actual_nargs < expected_nargs \
-            or (not func_type.has_varargs and actual_nargs > expected_nargs):
+            or (not func_type.has_varargs and actual_nargs > max_nargs):
                 expected_str = str(expected_nargs)
                 if func_type.has_varargs:
                     expected_str = "at least " + expected_str
+                elif func_type.optional_arg_count:
+                    if actual_nargs > max_nargs:
+                        expected_str = "at least " + expected_str
+                    else:
+                        expected_str = "at most " + str(max_nargs)
                 error(self.pos, 
                     "Call with wrong number of arguments (expected %s, got %s)"
                         % (expected_str, actual_nargs))
@@ -1529,10 +1535,10 @@ class SimpleCallNode(ExprNode):
                 self.result_code = "<error>"
                 return
         # Coerce arguments
-        for i in range(expected_nargs):
+        for i in range(min(max_nargs, actual_nargs)):
             formal_type = func_type.args[i].type
             self.args[i] = self.args[i].coerce_to(formal_type, env)
-        for i in range(expected_nargs, actual_nargs):
+        for i in range(max_nargs, actual_nargs):
             if self.args[i].type.is_pyobject:
                 error(self.args[i].pos, 
                     "Python object cannot be passed as a varargs parameter")
@@ -1558,10 +1564,14 @@ class SimpleCallNode(ExprNode):
             zip(formal_args, self.args):
                 arg_code = actual_arg.result_as(formal_arg.type)
                 arg_list_code.append(arg_code)
+        if func_type.optional_arg_count:
+            for formal_arg in formal_args[len(self.args):]:
+                arg_list_code.append(formal_arg.type.cast_code('0'))
+            arg_list_code.append(str(max(0, len(formal_args) - len(self.args))))
         for actual_arg in self.args[len(formal_args):]:
             arg_list_code.append(actual_arg.result_code)
         result = "%s(%s)" % (self.function.result_code,
-            join(arg_list_code, ","))
+            join(arg_list_code, ", "))
         if self.wrapper_call or \
                 self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable:
             result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result)
@@ -3539,6 +3549,9 @@ class PyTypeTestNode(CoercionNode):
         self.result_ctype = arg.ctype()
         env.use_utility_code(type_test_utility_code)
     
+    def analyse_types(self, env):
+        pass
+    
     def result_in_temp(self):
         return self.arg.result_in_temp()
     
@@ -3552,7 +3565,7 @@ class PyTypeTestNode(CoercionNode):
         if self.type.typeobj_is_available():
             code.putln(
                 "if (!__Pyx_TypeTest(%s, %s)) %s" % (
-                    self.arg.py_result(),
+                    self.arg.py_result(), 
                     self.type.typeptr_cname,
                     code.error_goto(self.pos)))
         else:
index c273d73f08a07bff10a85db3f74fd6f005ff20dc..59ab7f1c9040587ed35a86a796583f6a1f52a839 100644 (file)
@@ -60,6 +60,7 @@ gilstate_cname   = pyrex_prefix + "state"
 skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
 empty_tuple      = pyrex_prefix + "empty_tuple"
 cleanup_cname    = pyrex_prefix + "module_cleanup"
+optional_count_cname = pyrex_prefix + "optional_arg_count"
 
 
 extern_c_macro  = pyrex_prefix.upper() + "EXTERN_C"
index f13df2abdf498b1672dcd47aa450a65278a77ce8..ae7b89949581ffa2c660eebcde63eb58918520ac 100644 (file)
@@ -316,6 +316,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
     # with_gil         boolean    Acquire gil around function body
     
     overridable = 0
+    optional_arg_count = 0
 
     def analyse(self, return_type, env):
         func_type_args = []
@@ -337,7 +338,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
             func_type_args.append(
                 PyrexTypes.CFuncTypeArg(name, type, arg_node.pos))
             if arg_node.default:
-                error(arg_node.pos, "C function argument cannot have default value")
+                self.optional_arg_count += 1
         exc_val = None
         exc_check = 0
         if return_type.is_pyobject \
@@ -363,6 +364,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
                 "Function cannot return a function")
         func_type = PyrexTypes.CFuncType(
             return_type, func_type_args, self.has_varargs, 
+            optional_arg_count = self.optional_arg_count,
             exception_value = exc_val, exception_check = exc_check,
             calling_convention = self.base.calling_convention,
             nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
@@ -609,10 +611,24 @@ class FuncDefNode(StatNode, BlockNode):
     #  entry           Symtab.Entry
     
     py_func = None
+    assmt = None
+    
+    def analyse_default_values(self, env):
+        genv = env.global_scope()
+        for arg in self.args:
+            if arg.default:
+                if arg.is_generic:
+                    if not hasattr(arg, 'default_entry'):
+                        arg.default.analyse_types(genv)
+                        arg.default = arg.default.coerce_to(arg.type, genv)
+                        arg.default.allocate_temps(genv)
+                        arg.default_entry = genv.add_default_value(arg.type)
+                        arg.default_entry.used = 1
+                else:
+                    error(arg.pos,
+                        "This argument cannot have a default value")
+                    arg.default = None
     
-    def analyse_expressions(self, env):
-        pass
-
     def need_gil_acquisition(self, lenv):
         return 0
                 
@@ -749,7 +765,24 @@ class FuncDefNode(StatNode, BlockNode):
             code.put_var_incref(entry)
 
     def generate_execution_code(self, code):
-        pass
+        # Evaluate and store argument default values
+        for arg in self.args:
+            default = arg.default
+            if default:
+                default.generate_evaluation_code(code)
+                default.make_owned_reference(code)
+                code.putln(
+                    "%s = %s;" % (
+                        arg.default_entry.cname,
+                        default.result_as(arg.default_entry.type)))
+                if default.is_temp and default.type.is_pyobject:
+                    code.putln(
+                        "%s = 0;" %
+                            default.result_code)
+        # For Python class methods, create and store function object
+        if self.assmt:
+            self.assmt.generate_execution_code(code)
+    
 
 
 class CFuncDefNode(FuncDefNode):
@@ -826,12 +859,20 @@ class CFuncDefNode(FuncDefNode):
                     error(self.pos, "Function declared nogil has Python locals or temporaries")
         return with_gil
 
+    def analyse_expressions(self, env):
+        self.args = self.declarator.args
+        self.analyse_default_values(env)
+        if self.overridable:
+            self.py_func.analyse_expressions(env)
+
     def generate_function_header(self, code, with_pymethdef):
         arg_decls = []
         type = self.type
         visibility = self.entry.visibility
         for arg in type.args:
             arg_decls.append(arg.declaration_code())
+        if type.optional_arg_count:
+            arg_decls.append("int %s" % Naming.optional_count_cname)
         if type.has_varargs:
             arg_decls.append("...")
         if not arg_decls:
@@ -861,7 +902,17 @@ class CFuncDefNode(FuncDefNode):
         pass
         
     def generate_argument_parsing_code(self, code):
-        pass
+        rev_args = zip(self.declarator.args, self.type.args)
+        rev_args.reverse()
+        i = 0
+        for darg, targ in rev_args:
+            if darg.default:
+                code.putln('if (%s > %s) {' % (Naming.optional_count_cname, i))
+                code.putln('%s = %s;' % (targ.cname, darg.default_entry.cname))
+                i += 1
+        for _ in range(i):
+            code.putln('}')
+        code.putln('/* defaults */')
     
     def generate_argument_conversion_code(self, code):
         pass
@@ -1102,21 +1153,6 @@ class DefNode(FuncDefNode):
         if env.is_py_class_scope:
             self.synthesize_assignment_node(env)
     
-    def analyse_default_values(self, env):
-        genv = env.global_scope()
-        for arg in self.args:
-            if arg.default:
-                if arg.is_generic:
-                    arg.default.analyse_types(genv)
-                    arg.default = arg.default.coerce_to(arg.type, genv)
-                    arg.default.allocate_temps(genv)
-                    arg.default_entry = genv.add_default_value(arg.type)
-                    arg.default_entry.used = 1
-                else:
-                    error(arg.pos,
-                        "This argument cannot have a default value")
-                    arg.default = None
-    
     def synthesize_assignment_node(self, env):
         import ExprNodes
         self.assmt = SingleAssignmentNode(self.pos,
@@ -1485,25 +1521,6 @@ class DefNode(FuncDefNode):
             error(arg.pos, "Cannot test type of extern C class "
                 "without type object name specification")
     
-    def generate_execution_code(self, code):
-        # Evaluate and store argument default values
-        for arg in self.args:
-            default = arg.default
-            if default:
-                default.generate_evaluation_code(code)
-                default.make_owned_reference(code)
-                code.putln(
-                    "%s = %s;" % (
-                        arg.default_entry.cname,
-                        default.result_as(arg.default_entry.type)))
-                if default.is_temp and default.type.is_pyobject:
-                    code.putln(
-                        "%s = 0;" %
-                            default.result_code)
-        # For Python class methods, create and store function object
-        if self.assmt:
-            self.assmt.generate_execution_code(code)
-    
     def error_value(self):
         return self.entry.signature.error_value
     
index 52729e1a2aa5a230f81c1b46a8f54391cf20c7de..a3d1c6e248ec761750434103490a603132185ac7 100644 (file)
@@ -1710,8 +1710,8 @@ def p_api(s):
 def p_cdef_statement(s, level, visibility = 'private', api = 0,
                      overridable = False):
     pos = s.position()
-    if overridable and level not in ('c_class', 'c_class_pxd'):
-            error(pos, "Overridable cdef function not allowed here")
+#    if overridable and level not in ('c_class', 'c_class_pxd'):
+#            error(pos, "Overridable cdef function not allowed here")
     visibility = p_visibility(s, visibility)
     api = api or p_api(s)
     if api:
index 476a2336d186d926e17e7111262c5bb7eb4bffbd..444e89df3fcfca5bd684a8bf1a68fcfc2cdf35a6 100644 (file)
@@ -587,10 +587,11 @@ class CFuncType(CType):
     
     def __init__(self, return_type, args, has_varargs = 0,
             exception_value = None, exception_check = 0, calling_convention = "",
-            nogil = 0, with_gil = 0, is_overridable = 0):
+            nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0):
         self.return_type = return_type
         self.args = args
         self.has_varargs = has_varargs
+        self.optional_arg_count = optional_arg_count
         self.exception_value = exception_value
         self.exception_check = exception_check
         self.calling_convention = calling_convention
@@ -639,6 +640,8 @@ class CFuncType(CType):
                     return 0
         if self.has_varargs <> other_type.has_varargs:
             return 0
+        if self.optional_arg_count <> other_type.optional_arg_count:
+            return 0
         if not self.return_type.same_as(other_type.return_type):
             return 0
         if not self.same_calling_convention_as(other_type):
@@ -664,6 +667,8 @@ class CFuncType(CType):
                         or not self.args[i].type.same_as(other_type.args[i].type)
         if self.has_varargs <> other_type.has_varargs:
             return 0
+        if self.optional_arg_count <> other_type.optional_arg_count:
+            return 0
         if not self.return_type.subtype_of_resolved_type(other_type.return_type):
             return 0
         return 1
@@ -691,6 +696,8 @@ class CFuncType(CType):
         for arg in self.args:
             arg_decl_list.append(
                 arg.type.declaration_code("", for_display, pyrex = pyrex))
+        if self.optional_arg_count:
+            arg_decl_list.append("int %s" % Naming.optional_count_cname)
         if self.has_varargs:
             arg_decl_list.append("...")
         arg_decl_code = string.join(arg_decl_list, ",")