Default cdef args via struct
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 7 Feb 2008 09:46:57 +0000 (01:46 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 7 Feb 2008 09:46:57 +0000 (01:46 -0800)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Naming.py
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Symtab.py

index 15aedfdf5f8003da1b83fb1c77e9511db9ff59e7..74a1cb919f6666c94ba9e67ced9663945a27a0a5 100644 (file)
@@ -1560,14 +1560,25 @@ class SimpleCallNode(ExprNode):
             return "<error>"
         formal_args = func_type.args
         arg_list_code = []
-        for (formal_arg, actual_arg) in \
-            zip(formal_args, self.args):
+        args = zip(formal_args, self.args)
+        max_nargs = len(func_type.args)
+        expected_nargs = max_nargs - func_type.optional_arg_count
+        actual_nargs = len(self.args)
+        for formal_arg, actual_arg in args[:expected_nargs]:
                 arg_code = actual_arg.result_as(formal_arg.type)
                 arg_list_code.append(arg_code)
         if func_type.optional_arg_count:
-            for formal_arg in formal_args[len(self.args):]:
-                arg_list_code.append(formal_arg.type.cast_code('0'))
-            arg_list_code.append(str(max(0, len(formal_args) - len(self.args))))
+            if expected_nargs == actual_nargs:
+                arg_list_code.append(func_type.op_args.cast_code('NULL'))
+            else:
+                optional_arg_code = [str(actual_nargs - expected_nargs)]
+                for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
+                    arg_code = actual_arg.result_as(formal_arg.type)
+                    optional_arg_code.append(arg_code)
+#                for formal_arg in formal_args[actual_nargs:max_nargs]:
+#                    optional_arg_code.append(formal_arg.type.cast_code('0'))
+                optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
+                arg_list_code.append('&' + func_type.op_args.base_type.cast_code(optional_arg_struct))
         for actual_arg in self.args[len(formal_args):]:
             arg_list_code.append(actual_arg.result_code)
         result = "%s(%s)" % (self.function.result_code,
index 59ab7f1c9040587ed35a86a796583f6a1f52a839..64665684d4972127d5589d598539a26e85bc7a58 100644 (file)
@@ -32,6 +32,7 @@ var_prefix        = pyrex_prefix + "v_"
 vtable_prefix     = pyrex_prefix + "vtable_"
 vtabptr_prefix    = pyrex_prefix + "vtabptr_"
 vtabstruct_prefix = pyrex_prefix + "vtabstruct_"
+opt_arg_prefix    = pyrex_prefix + "opt_args_"
 
 args_cname       = pyrex_prefix + "args"
 kwdlist_cname    = pyrex_prefix + "argnames"
@@ -60,7 +61,7 @@ gilstate_cname   = pyrex_prefix + "state"
 skip_dispatch_cname = pyrex_prefix + "skip_dispatch"
 empty_tuple      = pyrex_prefix + "empty_tuple"
 cleanup_cname    = pyrex_prefix + "module_cleanup"
-optional_count_cname = pyrex_prefix + "optional_arg_count"
+optional_args_cname = pyrex_prefix + "optional_args"
 
 
 extern_c_macro  = pyrex_prefix.upper() + "EXTERN_C"
index ae7b89949581ffa2c660eebcde63eb58918520ac..6f6ad2614818f31429e9d215bbcff050061774c0 100644 (file)
@@ -339,6 +339,21 @@ class CFuncDeclaratorNode(CDeclaratorNode):
                 PyrexTypes.CFuncTypeArg(name, type, arg_node.pos))
             if arg_node.default:
                 self.optional_arg_count += 1
+        
+        if self.optional_arg_count:
+            scope = StructOrUnionScope()
+            scope.declare_var('n', PyrexTypes.c_int_type, self.pos)
+            for arg in func_type_args[len(func_type_args)-self.optional_arg_count:]:
+                scope.declare_var(arg.name, arg.type, arg.pos, allow_pyobject = 1)
+            struct_cname = Naming.opt_arg_prefix + self.base.name
+            self.op_args_struct = env.global_scope().declare_struct_or_union(name = struct_cname,
+                                        kind = 'struct',
+                                        scope = scope,
+                                        typedef_flag = 0,
+                                        pos = self.pos,
+                                        cname = struct_cname)
+            self.op_args_struct.used = 1
+        
         exc_val = None
         exc_check = 0
         if return_type.is_pyobject \
@@ -368,6 +383,8 @@ class CFuncDeclaratorNode(CDeclaratorNode):
             exception_value = exc_val, exception_check = exc_check,
             calling_convention = self.base.calling_convention,
             nogil = self.nogil, with_gil = self.with_gil, is_overridable = self.overridable)
+        if self.optional_arg_count:
+            func_type.op_args = PyrexTypes.c_ptr_type(self.op_args_struct.type)
         return self.base.analyse(func_type, env)
 
 
@@ -383,7 +400,8 @@ class CArgDeclNode(Node):
     # is_kw_only     boolean            Is a keyword-only argument
 
     is_self_arg = 0
-    
+    is_generic = 1
+
     def analyse(self, env):
         #print "CArgDeclNode.analyse: is_self_arg =", self.is_self_arg ###
         base_type = self.base_type.analyse(env)
@@ -813,6 +831,9 @@ class CFuncDefNode(FuncDefNode):
         # from the base type of an extension type.
         self.type = type
         type.is_overridable = self.overridable
+        for formal_arg, type_arg in zip(self.declarator.args, type.args):
+            formal_arg.type = type_arg.type
+            formal_arg.cname = type_arg.cname
         name = name_declarator.name
         cname = name_declarator.cname
         self.entry = env.declare_cfunction(
@@ -821,7 +842,7 @@ class CFuncDefNode(FuncDefNode):
             defining = self.body is not None,
             api = self.api)
         self.return_type = type.return_type
-
+        
         if self.overridable:
             import ExprNodes
             arg_names = [arg.name for arg in self.type.args]
@@ -869,16 +890,16 @@ class CFuncDefNode(FuncDefNode):
         arg_decls = []
         type = self.type
         visibility = self.entry.visibility
-        for arg in type.args:
+        for arg in type.args[:len(type.args)-type.optional_arg_count]:
             arg_decls.append(arg.declaration_code())
         if type.optional_arg_count:
-            arg_decls.append("int %s" % Naming.optional_count_cname)
+            arg_decls.append(type.op_args.declaration_code(Naming.optional_args_cname))
         if type.has_varargs:
             arg_decls.append("...")
         if not arg_decls:
             arg_decls = ["void"]
         entity = type.function_header_code(self.entry.func_cname,
-            string.join(arg_decls, ","))
+            string.join(arg_decls, ", "))
         if visibility == 'public':
             dll_linkage = "DL_EXPORT"
         else:
@@ -895,24 +916,26 @@ class CFuncDefNode(FuncDefNode):
             header))
 
     def generate_argument_declarations(self, env, code):
-        # Arguments already declared in function header
-        pass
-    
+        for arg in self.declarator.args:
+            if arg.default:
+                code.putln('%s = %s;' % (arg.type.declaration_code(arg.cname), arg.default_entry.cname))
+
     def generate_keyword_list(self, code):
         pass
         
     def generate_argument_parsing_code(self, code):
-        rev_args = zip(self.declarator.args, self.type.args)
-        rev_args.reverse()
+        rev_args = self.declarator.args
         i = 0
-        for darg, targ in rev_args:
-            if darg.default:
-                code.putln('if (%s > %s) {' % (Naming.optional_count_cname, i))
-                code.putln('%s = %s;' % (targ.cname, darg.default_entry.cname))
-                i += 1
-        for _ in range(i):
+        if self.type.optional_arg_count:
+            code.putln('if (%s) {' % Naming.optional_args_cname)
+            for arg in rev_args:
+                if arg.default:
+                    code.putln('if (%s->n > %s) {' % (Naming.optional_args_cname, i))
+                    code.putln('%s = %s->%s;' % (arg.cname, Naming.optional_args_cname, arg.declarator.name))
+                    i += 1
+            for _ in range(self.type.optional_arg_count):
+                code.putln('}')
             code.putln('}')
-        code.putln('/* defaults */')
     
     def generate_argument_conversion_code(self, code):
         pass
index 444e89df3fcfca5bd684a8bf1a68fcfc2cdf35a6..9928dad771f42bf60d3680b9220bc8c9e2aa6c7d 100644 (file)
@@ -693,14 +693,14 @@ class CFuncType(CType):
     def declaration_code(self, entity_code, 
             for_display = 0, dll_linkage = None, pyrex = 0):
         arg_decl_list = []
-        for arg in self.args:
+        for arg in self.args[:len(self.args)-self.optional_arg_count]:
             arg_decl_list.append(
                 arg.type.declaration_code("", for_display, pyrex = pyrex))
         if self.optional_arg_count:
-            arg_decl_list.append("int %s" % Naming.optional_count_cname)
+            arg_decl_list.append(self.op_args.declaration_code(Naming.optional_args_cname))
         if self.has_varargs:
             arg_decl_list.append("...")
-        arg_decl_code = string.join(arg_decl_list, ",")
+        arg_decl_code = string.join(arg_decl_list, ", ")
         if not arg_decl_code and not pyrex:
             arg_decl_code = "void"
         exc_clause = ""
index 2324136f5c91715f53e8f0de16c352bfe02bebab..aeeee9772073720e43104e4b47e1add80cf7dd7c 100644 (file)
@@ -1058,14 +1058,14 @@ class StructOrUnionScope(Scope):
         Scope.__init__(self, "?", None, None)
 
     def declare_var(self, name, type, pos, 
-            cname = None, visibility = 'private', is_cdef = 0):
+            cname = None, visibility = 'private', is_cdef = 0, allow_pyobject = 0):
         # Add an entry for an attribute.
         if not cname:
             cname = name
         entry = self.declare(name, cname, type, pos)
         entry.is_variable = 1
         self.var_entries.append(entry)
-        if type.is_pyobject:
+        if type.is_pyobject and not allow_pyobject:
             error(pos,
                 "C struct/union member cannot be a Python object")
         if visibility <> 'private':