From 331758a489379cb12bcf067fe236b35c59848998 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Sun, 3 Aug 2008 04:02:45 -0700 Subject: [PATCH] Fix optional cdef arguments for c++, possible optimization when not all args are used. --- Cython/Compiler/ExprNodes.py | 42 +++++++++++++++++++++++------------- Demos/Setup.py | 1 - tests/run/cdefoptargs.pyx | 15 +++++++++++++ 3 files changed, 42 insertions(+), 16 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 75d3a931..a37f9e75 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -1680,6 +1680,7 @@ class SimpleCallNode(CallNode): # self ExprNode or None used internally # coerced_self ExprNode or None used internally # wrapper_call bool used internally + # has_optional_args bool used internally subexprs = ['self', 'coerced_self', 'function', 'args', 'arg_tuple'] @@ -1687,6 +1688,7 @@ class SimpleCallNode(CallNode): coerced_self = None arg_tuple = None wrapper_call = False + has_optional_args = False def compile_time_value(self, denv): function = self.function.compile_time_value(denv) @@ -1773,6 +1775,11 @@ class SimpleCallNode(CallNode): self.type = PyrexTypes.error_type self.result_code = "" return + if func_type.optional_arg_count and expected_nargs != actual_nargs: + self.has_optional_args = 1 + self.is_temp = 1 + self.opt_arg_struct = env.allocate_temp(func_type.op_arg_struct.base_type) + env.release_temp(self.opt_arg_struct) # Coerce arguments for i in range(min(max_nargs, actual_nargs)): formal_type = func_type.args[i].type @@ -1818,15 +1825,7 @@ class SimpleCallNode(CallNode): if expected_nargs == actual_nargs: optional_args = 'NULL' else: - optional_arg_code = [str(actual_nargs - expected_nargs)] - for formal_arg, actual_arg in args[expected_nargs:actual_nargs]: - arg_code = actual_arg.result_as(formal_arg.type) - optional_arg_code.append(arg_code) -# for formal_arg in formal_args[actual_nargs:max_nargs]: -# optional_arg_code.append(formal_arg.type.cast_code('0')) - optional_arg_struct = '{%s}' % ','.join(optional_arg_code) - optional_args = PyrexTypes.c_void_ptr_type.cast_code( - '&' + func_type.op_arg_struct.base_type.cast_code(optional_arg_struct)) + optional_args = "&%s" % self.opt_arg_struct arg_list_code.append(optional_args) for actual_arg in self.args[len(formal_args):]: @@ -1849,6 +1848,19 @@ class SimpleCallNode(CallNode): arg_code, code.error_goto_if_null(self.result_code, self.pos))) elif func_type.is_cfunction: + if self.has_optional_args: + actual_nargs = len(self.args) + expected_nargs = len(func_type.args) - func_type.optional_arg_count + code.putln("%s.%s = %s;" % ( + self.opt_arg_struct, + Naming.pyrex_prefix + "n", + len(self.args) - expected_nargs)) + args = zip(func_type.args, self.args) + for formal_arg, actual_arg in args[expected_nargs:actual_nargs]: + code.putln("%s.%s = %s;" % ( + self.opt_arg_struct, + formal_arg.name, + actual_arg.result_as(formal_arg.type))) exc_checks = [] if self.type.is_pyobject: exc_checks.append("!%s" % self.result_code) @@ -1883,12 +1895,12 @@ class SimpleCallNode(CallNode): rhs, raise_py_exception, code.error_goto(self.pos))) - return - code.putln( - "%s%s; %s" % ( - lhs, - rhs, - code.error_goto_if(" && ".join(exc_checks), self.pos))) + else: + if exc_checks: + goto_error = code.error_goto_if(" && ".join(exc_checks), self.pos) + else: + goto_error = "" + code.putln("%s%s; %s" % (lhs, rhs, goto_error)) class GeneralCallNode(CallNode): # General Python function call, including keyword, diff --git a/Demos/Setup.py b/Demos/Setup.py index 8980cf06..75b05af0 100644 --- a/Demos/Setup.py +++ b/Demos/Setup.py @@ -7,7 +7,6 @@ from Cython.Distutils import build_ext ext_modules=[ Extension("primes", ["primes.pyx"]), Extension("spam", ["spam.pyx"]), -# Extension("optargs", ["optargs.pyx"], language = "c++"), ] for file in glob.glob("*.pyx"): diff --git a/tests/run/cdefoptargs.pyx b/tests/run/cdefoptargs.pyx index 366a18d3..325da795 100644 --- a/tests/run/cdefoptargs.pyx +++ b/tests/run/cdefoptargs.pyx @@ -2,6 +2,11 @@ __doc__ = u""" >>> call2() >>> call3() >>> call4() + >>> test_foo() + 2 + 3 + 7 + 26 """ # the calls: @@ -19,3 +24,13 @@ def call4(): cdef b(a, b, c=1, d=2): pass + + +cdef int foo(int a, int b=1, int c=1): + return a+b*c + +def test_foo(): + print foo(1) + print foo(1, 2) + print foo(1, 2, 3) + print foo(1, foo(2, 3), foo(4)) -- 2.26.2