Add optional args to any cdef overridden function
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 7 Feb 2008 23:59:58 +0000 (15:59 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 7 Feb 2008 23:59:58 +0000 (15:59 -0800)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index 83f57e4747170badda1e7bc8ada8565c9f39e872..4219f400cbf647b741e532c8bf6c498c13888e8e 100644 (file)
@@ -1567,13 +1567,10 @@ class SimpleCallNode(ExprNode):
         for formal_arg, actual_arg in args[:expected_nargs]:
                 arg_code = actual_arg.result_as(formal_arg.type)
                 arg_list_code.append(arg_code)
+                
         if func_type.optional_arg_count:
             if expected_nargs == actual_nargs:
-                if func_type.old_signature:
-                    struct_type = func_type.old_signature.op_args
-                else:
-                    struct_type = func_type.op_args
-                optional_args = struct_type.cast_code('NULL')
+                optional_args = 'NULL'
             else:
                 optional_arg_code = [str(actual_nargs - expected_nargs)]
                 for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
@@ -1582,11 +1579,10 @@ class SimpleCallNode(ExprNode):
 #                for formal_arg in formal_args[actual_nargs:max_nargs]:
 #                    optional_arg_code.append(formal_arg.type.cast_code('0'))
                 optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
-                optional_args = '&' + func_type.op_args.base_type.cast_code(optional_arg_struct)
-                if func_type.old_signature and \
-                        func_type.old_signature.op_args != func_type.op_args:
-                        optional_args = func_type.old_signature.op_args.cast_code(optional_args)
+                optional_args = PyrexTypes.c_void_ptr_type.cast_code(
+                    '&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct))
             arg_list_code.append(optional_args)
+            
         for actual_arg in self.args[len(formal_args):]:
             arg_list_code.append(actual_arg.result_code)
         result = "%s(%s)" % (self.function.result_code,
index 64665684d4972127d5589d598539a26e85bc7a58..165cb297d2d2eae546729a73245ee5004a0dc6b7 100644 (file)
@@ -62,6 +62,7 @@ skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
 empty_tuple      = pyrex_prefix + "empty_tuple"
 cleanup_cname    = pyrex_prefix + "module_cleanup"
 optional_args_cname = pyrex_prefix + "optional_args"
+no_opt_args      = pyrex_prefix + "no_opt_args"
 
 
 extern_c_macro  = pyrex_prefix.upper() + "EXTERN_C"
index fcc6acf2a680607d3a33a72ad039fb77ae559b5d..9181a23d6f765398a1f2258f9b1ced9805cafbad 100644 (file)
@@ -384,7 +384,7 @@ class CFuncDeclaratorNode(CDeclaratorNode):
             calling_convention = self.base.calling_convention,
             nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
         if self.optional_arg_count:
-            func_type.op_args = PyrexTypes.c_ptr_type(self.op_args_struct.type)
+            func_type.op_arg_struct = PyrexTypes.c_ptr_type(self.op_args_struct.type)
         return self.base.analyse(func_type, env)
 
 
@@ -763,6 +763,7 @@ class FuncDefNode(StatNode, BlockNode):
         # ----- Python version
         if self.py_func:
             self.py_func.generate_function_definitions(env, code)
+        self.generate_optarg_wrapper_function(env, code)
 
     def put_stararg_decrefs(self, code):
         pass
@@ -782,6 +783,9 @@ class FuncDefNode(StatNode, BlockNode):
         for entry in env.arg_entries:
             code.put_var_incref(entry)
 
+    def generate_optarg_wrapper_function(self, env, code):
+        pass
+
     def generate_execution_code(self, code):
         # Evaluate and store argument default values
         for arg in self.args:
@@ -845,11 +849,7 @@ class CFuncDefNode(FuncDefNode):
         
         if self.overridable:
             import ExprNodes
-            arg_names = [arg.name for arg in self.type.args]
-            self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
-            cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
-            c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
-            py_func_body = ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
+            py_func_body = self.call_self_node()
             self.py_func = DefNode(pos = self.pos, 
                                    name = self.declarator.base.name,
                                    args = self.declarator.args,
@@ -864,7 +864,17 @@ class CFuncDefNode(FuncDefNode):
                 self.py_func.interned_attr_cname = env.intern(self.py_func.entry.name)
             self.override = OverrideCheckNode(self.pos, py_func = self.py_func)
             self.body = StatListNode(self.pos, stats=[self.override, self.body])
-            
+    
+    def call_self_node(self, omit_optional_args=0):
+        import ExprNodes
+        args = self.type.args
+        if omit_optional_args:
+            args = args[:len(args) - self.type.optional_arg_count]
+        arg_names = [arg.name for arg in args]
+        self_arg = ExprNodes.NameNode(self.pos, name=arg_names[0])
+        cfunc = ExprNodes.AttributeNode(self.pos, obj=self_arg, attribute=self.declarator.base.name)
+        c_call = ExprNodes.SimpleCallNode(self.pos, function=cfunc, args=[ExprNodes.NameNode(self.pos, name=n) for n in arg_names[1:]], wrapper_call=True)
+        return ReturnStatNode(pos=self.pos, return_type=PyrexTypes.py_object_type, value=c_call)
                     
     def declare_arguments(self, env):
         for arg in self.type.args:
@@ -886,20 +896,22 @@ class CFuncDefNode(FuncDefNode):
         if self.overridable:
             self.py_func.analyse_expressions(env)
 
-    def generate_function_header(self, code, with_pymethdef):
+    def generate_function_header(self, code, with_pymethdef, with_opt_args = 1):
         arg_decls = []
         type = self.type
         visibility = self.entry.visibility
         for arg in type.args[:len(type.args)-type.optional_arg_count]:
             arg_decls.append(arg.declaration_code())
-        if type.optional_arg_count:
-            arg_decls.append(type.op_args.declaration_code(Naming.optional_args_cname))
+        if type.optional_arg_count and with_opt_args:
+            arg_decls.append(type.op_arg_struct.declaration_code(Naming.optional_args_cname))
         if type.has_varargs:
             arg_decls.append("...")
         if not arg_decls:
             arg_decls = ["void"]
-        entity = type.function_header_code(self.entry.func_cname,
-            string.join(arg_decls, ", "))
+        cname = self.entry.func_cname
+        if not with_opt_args:
+            cname += Naming.no_opt_args
+        entity = type.function_header_code(cname, string.join(arg_decls, ", "))
         if visibility == 'public':
             dll_linkage = "DL_EXPORT"
         else:
@@ -973,6 +985,19 @@ class CFuncDefNode(FuncDefNode):
     def caller_will_check_exceptions(self):
         return self.entry.type.exception_check
                     
+    def generate_optarg_wrapper_function(self, env, code):
+        if self.type.optional_arg_count and \
+                self.type.original_sig and not self.type.original_sig.optional_arg_count:
+            code.putln()
+            self.generate_function_header(code, 0, with_opt_args = 0)
+            if not self.return_type.is_void:
+                code.put('return ')
+            args = self.type.args
+            arglist = [arg.cname for arg in args[:len(args)-self.type.optional_arg_count]]
+            arglist.append('NULL')
+            code.putln('%s(%s);' % (self.entry.func_cname, ', '.join(arglist)))
+            code.putln('}')
+
 
 class PyArgDeclNode(Node):
     # Argument which must be a Python object (used
index ff3148cd27548c42ae97654cc71f9358e0cc570f..7ee44433fd300122f877a97480f740090db88eb0 100644 (file)
@@ -584,7 +584,7 @@ class CFuncType(CType):
     #  with_gil         boolean    Acquire gil around function body
     
     is_cfunction = 1
-    old_signature = None
+    original_sig = None
     
     def __init__(self, return_type, args, has_varargs = 0,
             exception_value = None, exception_check = 0, calling_convention = "",
@@ -680,15 +680,9 @@ class CFuncType(CType):
             return 0
         if not self.same_calling_convention_as(other_type):
             return 0
-        self.old_signature = other_type
+        self.original_sig = other_type.original_sig or other_type
         if as_cmethod:
             self.args[0] = other_type.args[0]
-        if self.optional_arg_count and \
-                self.optional_arg_count == other_type.optional_arg_count:
-            self.op_args = other_type.op_args
-            print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count
-        elif self.optional_arg_count:
-            print self.op_args, other_type.op_args, self.optional_arg_count, other_type.optional_arg_count
         return 1
         
         
@@ -741,8 +735,7 @@ class CFuncType(CType):
             arg_decl_list.append(
                 arg.type.declaration_code("", for_display, pyrex = pyrex))
         if self.optional_arg_count:
-            arg_decl_list.append(self.op_args.declaration_code(Naming.optional_args_cname))
-#            arg_decl_list.append(c_void_ptr_type.declaration_code(Naming.optional_args_cname))
+            arg_decl_list.append(self.op_arg_struct.declaration_code(Naming.optional_args_cname))
         if self.has_varargs:
             arg_decl_list.append("...")
         arg_decl_code = string.join(arg_decl_list, ", ")
index 6b13e5407fe73501500a2618d25e4f70e0e51ff8..af770fc77b4927649581b08325573180cbc0e790 100644 (file)
@@ -1273,14 +1273,26 @@ class CClassScope(ClassScope):
                 if defining and entry.func_cname:
                     error(pos, "'%s' already defined" % name)
                 #print "CClassScope.declare_cfunction: checking signature" ###
-                if type.compatible_signature_with(entry.type, as_cmethod = 1):
+                if type.same_c_signature_as(entry.type, as_cmethod = 1):
+                    pass
+                elif type.compatible_signature_with(entry.type, as_cmethod = 1):
+                    if type.optional_arg_count and not type.original_sig.optional_arg_count:
+                        # Need to put a wrapper taking no optional arguments 
+                        # into the method table.
+                        wrapper_func_cname = self.mangle(Naming.func_prefix, name) + Naming.no_opt_args
+                        wrapper_func_name = name + Naming.no_opt_args
+                        if entry.type.optional_arg_count:
+                            old_entry = self.lookup_here(wrapper_func_name)
+                            old_entry.func_cname = wrapper_func_cname
+                        else:
+                            entry.func_cname = wrapper_func_cname
+                            entry.name = wrapper_func_name
+                            entry = self.add_cfunction(name, type, pos, cname or name, visibility)
+                            defining = 1
                     entry.type = type
-                elif type.same_c_signature_as(entry.type, as_cmethod = 1):
-                    print "not compatible", name
 #                if type.narrower_c_signature_than(entry.type, as_cmethod = 1):
 #                    entry.type = type
                 else:
-                    print "here"
                     error(pos, "Signature not compatible with previous declaration")
         else:
             if self.defined: