From 36f3f6d0d353e26f5906613ecf87a52a946a51b0 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 7 Feb 2008 15:59:58 -0800 Subject: [PATCH] Add optional args to any cdef overridden function --- Cython/Compiler/ExprNodes.py | 14 ++++------ Cython/Compiler/Naming.py | 1 + Cython/Compiler/Nodes.py | 49 ++++++++++++++++++++++++++--------- Cython/Compiler/PyrexTypes.py | 13 +++------- Cython/Compiler/Symtab.py | 20 +++++++++++--- 5 files changed, 62 insertions(+), 35 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 83f57e47..4219f400 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1567,13 +1567,10 @@ class SimpleCallNode(ExprNode): for formal_arg, actual_arg in args[:expected_nargs]: arg_code = actual_arg.result_as(formal_arg.type) arg_list_code.append(arg_code) + if func_type.optional_arg_count: if expected_nargs == actual_nargs: - if func_type.old_signature: - struct_type = func_type.old_signature.op_args - else: - struct_type = func_type.op_args - optional_args = struct_type.cast_code('NULL') + optional_args = 'NULL' else: optional_arg_code = [str(actual_nargs - expected_nargs)] for formal_arg, actual_arg in args[expected_nargs:actual_nargs]: @@ -1582,11 +1579,10 @@ class SimpleCallNode(ExprNode): # for formal_arg in formal_args[actual_nargs:max_nargs]: # optional_arg_code.append(formal_arg.type.cast_code('0')) optional_arg_struct = '{%s}' % ','.join(optional_arg_code) - optional_args = '&' + func_type.op_args.base_type.cast_code(optional_arg_struct) - if func_type.old_signature and \ - func_type.old_signature.op_args != func_type.op_args: - optional_args = func_type.old_signature.op_args.cast_code(optional_args) + optional_args = PyrexTypes.c_void_ptr_type.cast_code( + '&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct)) arg_list_code.append(optional_args) + for actual_arg in self.args[len(formal_args):]: arg_list_code.append(actual_arg.result_code) result = "%s(%s)" % (self.function.result_code, diff --git a/Cython/Compiler/Naming.py b/Cython/Compiler/Naming.py index 64665684..165cb297 100644 --- a/Cython/Compiler/Naming.py +++ b/Cython/Compiler/Naming.py @@ -62,6 +62,7 @@ skip_dispatch_cname = pyrex_prefix + "skip_dispatch" empty_tuple = pyrex_prefix + "empty_tuple" cleanup_cname = pyrex_prefix + "module_cleanup" optional_args_cname = pyrex_prefix + "optional_args" +no_opt_args = pyrex_prefix + "no_opt_args" extern_c_macro = pyrex_prefix.upper() + "EXTERN_C" diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index fcc6acf2..9181a23d 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -384,7 +384,7 @@ class CFuncDeclaratorNode(CDeclaratorNode): calling_convention = self.base.calling_convention, nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable) if self.optional_arg_count: - func_type.op_args = PyrexTypes.c_ptr_type(self.op_args_struct.type) + func_type.op_arg_struct = PyrexTypes.c_ptr_type(self.op_args_struct.type) return self.base.analyse(func_type, env) @@ -763,6 +763,7 @@ class FuncDefNode(StatNode, BlockNode): # ----- Python version if self.py_func: self.py_func.generate_function_definitions(env, code) + self.generate_optarg_wrapper_function(env, code) def put_stararg_decrefs(self, code): pass @@ -782,6 +783,9 @@ class FuncDefNode(StatNode, BlockNode): for entry in env.arg_entries: code.put_var_incref(entry) + def generate_optarg_wrapper_function(self, env, code): + pass + def generate_execution_code(self, code): # Evaluate and store argument default values for arg in self.args: @@ -845,11 +849,7 @@ class CFuncDefNode(FuncDefNode): if self.overridable: import ExprNodes - arg_names = [arg.name for arg in self.type.args] - self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0]) - cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name) - c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True) - py_func_body = ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call) + py_func_body = self.call_self_node() self.py_func = DefNode(pos = self.pos, name = self.declarator.base.name, args = self.declarator.args, @@ -864,7 +864,17 @@ class CFuncDefNode(FuncDefNode): self.py_func.interned_attr_cname = env.intern(self.py_func.entry.name) self.override = OverrideCheckNode(self.pos, py_func = self.py_func) self.body = StatListNode(self.pos, stats=[self.override, self.body]) - + + def call_self_node(self, omit_optional_args=0): + import ExprNodes + args = self.type.args + if omit_optional_args: + args = args[:len(args) - self.type.optional_arg_count] + arg_names = [arg.name for arg in args] + self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0]) + cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name) + c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True) + return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call) def declare_arguments(self, env): for arg in self.type.args: @@ -886,20 +896,22 @@ class CFuncDefNode(FuncDefNode): if self.overridable: self.py_func.analyse_expressions(env) - def generate_function_header(self, code, with_pymethdef): + def generate_function_header(self, code, with_pymethdef, with_opt_args = 1): arg_decls = [] type = self.type visibility = self.entry.visibility for arg in type.args[:len(type.args)-type.optional_arg_count]: arg_decls.append(arg.declaration_code()) - if type.optional_arg_count: - arg_decls.append(type.op_args.declaration_code(Naming.optional_args_cname)) + if type.optional_arg_count and with_opt_args: + arg_decls.append(type.op_arg_struct.declaration_code(Naming.optional_args_cname)) if type.has_varargs: arg_decls.append("...") if not arg_decls: arg_decls = ["void"] - entity = type.function_header_code(self.entry.func_cname, - string.join(arg_decls, ", ")) + cname = self.entry.func_cname + if not with_opt_args: + cname += Naming.no_opt_args + entity = type.function_header_code(cname, string.join(arg_decls, ", ")) if visibility == 'public': dll_linkage = "DL_EXPORT" else: @@ -973,6 +985,19 @@ class CFuncDefNode(FuncDefNode): def caller_will_check_exceptions(self): return self.entry.type.exception_check + def generate_optarg_wrapper_function(self, env, code): + if self.type.optional_arg_count and \ + self.type.original_sig and not self.type.original_sig.optional_arg_count: + code.putln() + self.generate_function_header(code, 0, with_opt_args = 0) + if not self.return_type.is_void: + code.put('return ') + args = self.type.args + arglist = [arg.cname for arg in args[:len(args)-self.type.optional_arg_count]] + arglist.append('NULL') + code.putln('%s(%s);' % (self.entry.func_cname, ', '.join(arglist))) + code.putln('}') + class PyArgDeclNode(Node): # Argument which must be a Python object (used diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index ff3148cd..7ee44433 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -584,7 +584,7 @@ class CFuncType(CType): # with_gil boolean Acquire gil around function body is_cfunction = 1 - old_signature = None + original_sig = None def __init__(self, return_type, args, has_varargs = 0, exception_value = None, exception_check = 0, calling_convention = "", @@ -680,15 +680,9 @@ class CFuncType(CType): return 0 if not self.same_calling_convention_as(other_type): return 0 - self.old_signature = other_type + self.original_sig = other_type.original_sig or other_type if as_cmethod: self.args[0] = other_type.args[0] - if self.optional_arg_count and \ - self.optional_arg_count == other_type.optional_arg_count: - self.op_args = other_type.op_args - print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count - elif self.optional_arg_count: - print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count return 1 @@ -741,8 +735,7 @@ class CFuncType(CType): arg_decl_list.append( arg.type.declaration_code("", for_display, pyrex = pyrex)) if self.optional_arg_count: - arg_decl_list.append(self.op_args.declaration_code(Naming.optional_args_cname)) -# arg_decl_list.append(c_void_ptr_type.declaration_code(Naming.optional_args_cname)) + arg_decl_list.append(self.op_arg_struct.declaration_code(Naming.optional_args_cname)) if self.has_varargs: arg_decl_list.append("...") arg_decl_code = string.join(arg_decl_list, ", ") diff --git a/Cython/Compiler/Symtab.py b/Cython/Compiler/Symtab.py index 6b13e540..af770fc7 100644 --- a/Cython/Compiler/Symtab.py +++ b/Cython/Compiler/Symtab.py @@ -1273,14 +1273,26 @@ class CClassScope(ClassScope): if defining and entry.func_cname: error(pos, "'%s' already defined" % name) #print "CClassScope.declare_cfunction: checking signature" ### - if type.compatible_signature_with(entry.type, as_cmethod = 1): + if type.same_c_signature_as(entry.type, as_cmethod = 1): + pass + elif type.compatible_signature_with(entry.type, as_cmethod = 1): + if type.optional_arg_count and not type.original_sig.optional_arg_count: + # Need to put a wrapper taking no optional arguments + # into the method table. + wrapper_func_cname = self.mangle(Naming.func_prefix, name) + Naming.no_opt_args + wrapper_func_name = name + Naming.no_opt_args + if entry.type.optional_arg_count: + old_entry = self.lookup_here(wrapper_func_name) + old_entry.func_cname = wrapper_func_cname + else: + entry.func_cname = wrapper_func_cname + entry.name = wrapper_func_name + entry = self.add_cfunction(name, type, pos, cname or name, visibility) + defining = 1 entry.type = type - elif type.same_c_signature_as(entry.type, as_cmethod = 1): - print "not compatible", name # if type.narrower_c_signature_than(entry.type, as_cmethod = 1): # entry.type = type else: - print "here" error(pos, "Signature not compatible with previous declaration") else: if self.defined: -- 2.26.2