Fix optional cdef arguments for c++, possible optimization when not all args are...
authorRobert Bradshaw <robertwb@math.washington.edu>
Sun, 3 Aug 2008 11:02:45 +0000 (04:02 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Sun, 3 Aug 2008 11:02:45 +0000 (04:02 -0700)
Cython/Compiler/ExprNodes.py
Demos/Setup.py
tests/run/cdefoptargs.pyx

index 75d3a93107dc08197f90c1f70a5867a53fc1e417..a37f9e751c852fc84202e03ec05e84384a9f6b9a 100644 (file)
@@ -1680,6 +1680,7 @@ class SimpleCallNode(CallNode):
     #  self           ExprNode or None     used internally
     #  coerced_self   ExprNode or None     used internally
     #  wrapper_call   bool                 used internally
+    #  has_optional_args   bool            used internally
     
     subexprs = ['self', 'coerced_self', 'function', 'args', 'arg_tuple']
     
@@ -1687,6 +1688,7 @@ class SimpleCallNode(CallNode):
     coerced_self = None
     arg_tuple = None
     wrapper_call = False
+    has_optional_args = False
     
     def compile_time_value(self, denv):
         function = self.function.compile_time_value(denv)
@@ -1773,6 +1775,11 @@ class SimpleCallNode(CallNode):
                 self.type = PyrexTypes.error_type
                 self.result_code = "<error>"
                 return
+        if func_type.optional_arg_count and expected_nargs != actual_nargs:
+            self.has_optional_args = 1
+            self.is_temp = 1
+            self.opt_arg_struct = env.allocate_temp(func_type.op_arg_struct.base_type)
+            env.release_temp(self.opt_arg_struct)
         # Coerce arguments
         for i in range(min(max_nargs, actual_nargs)):
             formal_type = func_type.args[i].type
@@ -1818,15 +1825,7 @@ class SimpleCallNode(CallNode):
             if expected_nargs == actual_nargs:
                 optional_args = 'NULL'
             else:
-                optional_arg_code = [str(actual_nargs - expected_nargs)]
-                for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
-                    arg_code = actual_arg.result_as(formal_arg.type)
-                    optional_arg_code.append(arg_code)
-#                for formal_arg in formal_args[actual_nargs:max_nargs]:
-#                    optional_arg_code.append(formal_arg.type.cast_code('0'))
-                optional_arg_struct = '{%s}' % ','.join(optional_arg_code)
-                optional_args = PyrexTypes.c_void_ptr_type.cast_code(
-                    '&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct))
+                optional_args = "&%s" % self.opt_arg_struct
             arg_list_code.append(optional_args)
             
         for actual_arg in self.args[len(formal_args):]:
@@ -1849,6 +1848,19 @@ class SimpleCallNode(CallNode):
                     arg_code,
                     code.error_goto_if_null(self.result_code, self.pos)))
         elif func_type.is_cfunction:
+            if self.has_optional_args:
+                actual_nargs = len(self.args)
+                expected_nargs = len(func_type.args) - func_type.optional_arg_count
+                code.putln("%s.%s = %s;" % (
+                        self.opt_arg_struct,
+                        Naming.pyrex_prefix + "n",
+                        len(self.args) - expected_nargs))
+                args = zip(func_type.args, self.args)
+                for formal_arg, actual_arg in args[expected_nargs:actual_nargs]:
+                    code.putln("%s.%s = %s;" % (
+                            self.opt_arg_struct,
+                            formal_arg.name,
+                            actual_arg.result_as(formal_arg.type)))
             exc_checks = []
             if self.type.is_pyobject:
                 exc_checks.append("!%s" % self.result_code)
@@ -1883,12 +1895,12 @@ class SimpleCallNode(CallNode):
                         rhs,
                         raise_py_exception,
                         code.error_goto(self.pos)))
-                    return
-                code.putln(
-                    "%s%s; %s" % (
-                        lhs,
-                        rhs,
-                        code.error_goto_if(" && ".join(exc_checks), self.pos)))    
+                else:
+                    if exc_checks:
+                        goto_error = code.error_goto_if(" && ".join(exc_checks), self.pos)
+                    else:
+                        goto_error = ""
+                    code.putln("%s%s; %s" % (lhs, rhs, goto_error))
 
 class GeneralCallNode(CallNode):
     #  General Python function call, including keyword,
index 8980cf0670ddf30c198483f297bd074026433c40..75b05af0fce37948d547cb1fb88ae7303ba7712c 100644 (file)
@@ -7,7 +7,6 @@ from Cython.Distutils import build_ext
 ext_modules=[ 
     Extension("primes",       ["primes.pyx"]),
     Extension("spam",         ["spam.pyx"]),
-#    Extension("optargs",      ["optargs.pyx"], language = "c++"),
 ]
 
 for file in glob.glob("*.pyx"):
index 366a18d38807eb2c1f3eae62a773b5853d7d9358..325da795252c8b89b0738d74b743486495fddf85 100644 (file)
@@ -2,6 +2,11 @@ __doc__ = u"""
     >>> call2()
     >>> call3()
     >>> call4()
+    >>> test_foo()
+    2
+    3
+    7
+    26
 """
 
 # the calls:
@@ -19,3 +24,13 @@ def call4():
 
 cdef b(a, b, c=1, d=2):
     pass
+
+
+cdef int foo(int a, int b=1, int c=1):
+    return a+b*c
+    
+def test_foo():
+    print foo(1)
+    print foo(1, 2)
+    print foo(1, 2, 3)
+    print foo(1, foo(2, 3), foo(4))