From 87c3455ce3aec3160a10762f347ca52fb0eeb4b8 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Wed, 6 Feb 2008 15:11:28 -0800 Subject: [PATCH] Optional arguments in cpdef functions --- Cython/Compiler/ExprNodes.py | 25 ++++++--- Cython/Compiler/Naming.py | 1 + Cython/Compiler/Nodes.py | 97 ++++++++++++++++++++--------------- Cython/Compiler/Parsing.py | 4 +- Cython/Compiler/PyrexTypes.py | 9 +++- 5 files changed, 87 insertions(+), 49 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index d8d5768c..15aedfdf 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1514,13 +1514,19 @@ class SimpleCallNode(ExprNode): self.result_code = "" return # Check no. of args - expected_nargs = len(func_type.args) + max_nargs = len(func_type.args) + expected_nargs = max_nargs - func_type.optional_arg_count actual_nargs = len(self.args) if actual_nargs < expected_nargs \ - or (not func_type.has_varargs and actual_nargs > expected_nargs): + or (not func_type.has_varargs and actual_nargs > max_nargs): expected_str = str(expected_nargs) if func_type.has_varargs: expected_str = "at least " + expected_str + elif func_type.optional_arg_count: + if actual_nargs > max_nargs: + expected_str = "at least " + expected_str + else: + expected_str = "at most " + str(max_nargs) error(self.pos, "Call with wrong number of arguments (expected %s, got %s)" % (expected_str, actual_nargs)) @@ -1529,10 +1535,10 @@ class SimpleCallNode(ExprNode): self.result_code = "" return # Coerce arguments - for i in range(expected_nargs): + for i in range(min(max_nargs, actual_nargs)): formal_type = func_type.args[i].type self.args[i] = self.args[i].coerce_to(formal_type, env) - for i in range(expected_nargs, actual_nargs): + for i in range(max_nargs, actual_nargs): if self.args[i].type.is_pyobject: error(self.args[i].pos, "Python object cannot be passed as a varargs parameter") @@ -1558,10 +1564,14 @@ class SimpleCallNode(ExprNode): zip(formal_args, self.args): arg_code = actual_arg.result_as(formal_arg.type) arg_list_code.append(arg_code) + if func_type.optional_arg_count: + for formal_arg in formal_args[len(self.args):]: + arg_list_code.append(formal_arg.type.cast_code('0')) + arg_list_code.append(str(max(0, len(formal_args) - len(self.args)))) for actual_arg in self.args[len(formal_args):]: arg_list_code.append(actual_arg.result_code) result = "%s(%s)" % (self.function.result_code, - join(arg_list_code, ",")) + join(arg_list_code, ", ")) if self.wrapper_call or \ self.function.entry.is_unbound_cmethod and self.function.entry.type.is_overridable: result = "(%s = 1, %s)" % (Naming.skip_dispatch_cname, result) @@ -3539,6 +3549,9 @@ class PyTypeTestNode(CoercionNode): self.result_ctype = arg.ctype() env.use_utility_code(type_test_utility_code) + def analyse_types(self, env): + pass + def result_in_temp(self): return self.arg.result_in_temp() @@ -3552,7 +3565,7 @@ class PyTypeTestNode(CoercionNode): if self.type.typeobj_is_available(): code.putln( "if (!__Pyx_TypeTest(%s, %s)) %s" % ( - self.arg.py_result(), + self.arg.py_result(), self.type.typeptr_cname, code.error_goto(self.pos))) else: diff --git a/Cython/Compiler/Naming.py b/Cython/Compiler/Naming.py index c273d73f..59ab7f1c 100644 --- a/Cython/Compiler/Naming.py +++ b/Cython/Compiler/Naming.py @@ -60,6 +60,7 @@ gilstate_cname = pyrex_prefix + "state" skip_dispatch_cname = pyrex_prefix + "skip_dispatch" empty_tuple = pyrex_prefix + "empty_tuple" cleanup_cname = pyrex_prefix + "module_cleanup" +optional_count_cname = pyrex_prefix + "optional_arg_count" extern_c_macro = pyrex_prefix.upper() + "EXTERN_C" diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index f13df2ab..ae7b8994 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -316,6 +316,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): # with_gil boolean Acquire gil around function body overridable = 0 + optional_arg_count = 0 def analyse(self, return_type, env): func_type_args = [] @@ -337,7 +338,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): func_type_args.append( PyrexTypes.CFuncTypeArg(name, type, arg_node.pos)) if arg_node.default: - error(arg_node.pos, "C function argument cannot have default value") + self.optional_arg_count += 1 exc_val = None exc_check = 0 if return_type.is_pyobject \ @@ -363,6 +364,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): "Function cannot return a function") func_type = PyrexTypes.CFuncType( return_type, func_type_args, self.has_varargs, + optional_arg_count = self.optional_arg_count, exception_value = exc_val, exception_check = exc_check, calling_convention = self.base.calling_convention, nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) @@ -609,10 +611,24 @@ class FuncDefNode(StatNode, BlockNode): # entry Symtab.Entry py_func = None + assmt = None + + def analyse_default_values(self, env): + genv = env.global_scope() + for arg in self.args: + if arg.default: + if arg.is_generic: + if not hasattr(arg, 'default_entry'): + arg.default.analyse_types(genv) + arg.default = arg.default.coerce_to(arg.type, genv) + arg.default.allocate_temps(genv) + arg.default_entry = genv.add_default_value(arg.type) + arg.default_entry.used = 1 + else: + error(arg.pos, + "This argument cannot have a default value") + arg.default = None - def analyse_expressions(self, env): - pass - def need_gil_acquisition(self, lenv): return 0 @@ -749,7 +765,24 @@ class FuncDefNode(StatNode, BlockNode): code.put_var_incref(entry) def generate_execution_code(self, code): - pass + # Evaluate and store argument default values + for arg in self.args: + default = arg.default + if default: + default.generate_evaluation_code(code) + default.make_owned_reference(code) + code.putln( + "%s = %s;" % ( + arg.default_entry.cname, + default.result_as(arg.default_entry.type))) + if default.is_temp and default.type.is_pyobject: + code.putln( + "%s = 0;" % + default.result_code) + # For Python class methods, create and store function object + if self.assmt: + self.assmt.generate_execution_code(code) + class CFuncDefNode(FuncDefNode): @@ -826,12 +859,20 @@ class CFuncDefNode(FuncDefNode): error(self.pos, "Function declared nogil has Python locals or temporaries") return with_gil + def analyse_expressions(self, env): + self.args = self.declarator.args + self.analyse_default_values(env) + if self.overridable: + self.py_func.analyse_expressions(env) + def generate_function_header(self, code, with_pymethdef): arg_decls = [] type = self.type visibility = self.entry.visibility for arg in type.args: arg_decls.append(arg.declaration_code()) + if type.optional_arg_count: + arg_decls.append("int %s" % Naming.optional_count_cname) if type.has_varargs: arg_decls.append("...") if not arg_decls: @@ -861,7 +902,17 @@ class CFuncDefNode(FuncDefNode): pass def generate_argument_parsing_code(self, code): - pass + rev_args = zip(self.declarator.args, self.type.args) + rev_args.reverse() + i = 0 + for darg, targ in rev_args: + if darg.default: + code.putln('if (%s > %s) {' % (Naming.optional_count_cname, i)) + code.putln('%s = %s;' % (targ.cname, darg.default_entry.cname)) + i += 1 + for _ in range(i): + code.putln('}') + code.putln('/* defaults */') def generate_argument_conversion_code(self, code): pass @@ -1102,21 +1153,6 @@ class DefNode(FuncDefNode): if env.is_py_class_scope: self.synthesize_assignment_node(env) - def analyse_default_values(self, env): - genv = env.global_scope() - for arg in self.args: - if arg.default: - if arg.is_generic: - arg.default.analyse_types(genv) - arg.default = arg.default.coerce_to(arg.type, genv) - arg.default.allocate_temps(genv) - arg.default_entry = genv.add_default_value(arg.type) - arg.default_entry.used = 1 - else: - error(arg.pos, - "This argument cannot have a default value") - arg.default = None - def synthesize_assignment_node(self, env): import ExprNodes self.assmt = SingleAssignmentNode(self.pos, @@ -1485,25 +1521,6 @@ class DefNode(FuncDefNode): error(arg.pos, "Cannot test type of extern C class " "without type object name specification") - def generate_execution_code(self, code): - # Evaluate and store argument default values - for arg in self.args: - default = arg.default - if default: - default.generate_evaluation_code(code) - default.make_owned_reference(code) - code.putln( - "%s = %s;" % ( - arg.default_entry.cname, - default.result_as(arg.default_entry.type))) - if default.is_temp and default.type.is_pyobject: - code.putln( - "%s = 0;" % - default.result_code) - # For Python class methods, create and store function object - if self.assmt: - self.assmt.generate_execution_code(code) - def error_value(self): return self.entry.signature.error_value diff --git a/Cython/Compiler/Parsing.py b/Cython/Compiler/Parsing.py index 52729e1a..a3d1c6e2 100644 --- a/Cython/Compiler/Parsing.py +++ b/Cython/Compiler/Parsing.py @@ -1710,8 +1710,8 @@ def p_api(s): def p_cdef_statement(s, level, visibility = 'private', api = 0, overridable = False): pos = s.position() - if overridable and level not in ('c_class', 'c_class_pxd'): - error(pos, "Overridable cdef function not allowed here") +# if overridable and level not in ('c_class', 'c_class_pxd'): +# error(pos, "Overridable cdef function not allowed here") visibility = p_visibility(s, visibility) api = api or p_api(s) if api: diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 476a2336..444e89df 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -587,10 +587,11 @@ class CFuncType(CType): def __init__(self, return_type, args, has_varargs = 0, exception_value = None, exception_check = 0, calling_convention = "", - nogil = 0, with_gil = 0, is_overridable = 0): + nogil = 0, with_gil = 0, is_overridable = 0, optional_arg_count = 0): self.return_type = return_type self.args = args self.has_varargs = has_varargs + self.optional_arg_count = optional_arg_count self.exception_value = exception_value self.exception_check = exception_check self.calling_convention = calling_convention @@ -639,6 +640,8 @@ class CFuncType(CType): return 0 if self.has_varargs <> other_type.has_varargs: return 0 + if self.optional_arg_count <> other_type.optional_arg_count: + return 0 if not self.return_type.same_as(other_type.return_type): return 0 if not self.same_calling_convention_as(other_type): @@ -664,6 +667,8 @@ class CFuncType(CType): or not self.args[i].type.same_as(other_type.args[i].type) if self.has_varargs <> other_type.has_varargs: return 0 + if self.optional_arg_count <> other_type.optional_arg_count: + return 0 if not self.return_type.subtype_of_resolved_type(other_type.return_type): return 0 return 1 @@ -691,6 +696,8 @@ class CFuncType(CType): for arg in self.args: arg_decl_list.append( arg.type.declaration_code("", for_display, pyrex = pyrex)) + if self.optional_arg_count: + arg_decl_list.append("int %s" % Naming.optional_count_cname) if self.has_varargs: arg_decl_list.append("...") arg_decl_code = string.join(arg_decl_list, ",") -- 2.26.2