avoid unpacking for functions that only have star args
authorStefan Behnel <scoder@users.berlios.de>
Wed, 6 Feb 2008 19:59:38 +0000 (20:59 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 6 Feb 2008 19:59:38 +0000 (20:59 +0100)
Cython/Compiler/Nodes.py

index aaceed901fde9e3f4bc17d5aa00298f69a4e7fd2..7d62b61be9354205394a448801ab837f1bb85ab0 100644 (file)
@@ -663,7 +663,7 @@ class FuncDefNode(StatNode, BlockNode):
         if acquire_gil:
             code.putln("PyGILState_STATE _save = PyGILState_Ensure();")
         # ----- Fetch arguments
-        self.generate_argument_parsing_code(code)
+        self.generate_argument_parsing_code(env, code)
         self.generate_argument_increfs(lenv, code)
         # ----- Initialise local variables
         for entry in lenv.var_entries:
@@ -860,7 +860,7 @@ class CFuncDefNode(FuncDefNode):
     def generate_keyword_list(self, code):
         pass
         
-    def generate_argument_parsing_code(self, code):
+    def generate_argument_parsing_code(self, env, code):
         pass
     
     def generate_argument_conversion_code(self, code):
@@ -1044,6 +1044,19 @@ class DefNode(FuncDefNode):
             "%s %s has wrong number of arguments "
             "(%d declared, %s expected)" % (
                 desc, self.name, len(self.args), expected_str))
+
+    def signature_has_nongeneric_args(self):
+        has_generic_args = self.entry.signature.has_generic_args
+        argcount = len(self.args)
+        if argcount == 0:
+            return 0
+        elif argcount == 1:
+            if self.args[0].is_self_arg:
+                return 0
+        return 1
+
+    def signature_has_generic_args(self):
+        return self.entry.signature.has_generic_args
     
     def declare_pyfunction(self, env):
         #print "DefNode.declare_pyfunction:", self.name, "in", env ###
@@ -1174,7 +1187,8 @@ class DefNode(FuncDefNode):
                     code.put_var_declaration(arg.entry)
     
     def generate_keyword_list(self, code):
-        if self.entry.signature.has_generic_args:
+        if self.signature_has_generic_args() and \
+                self.signature_has_nongeneric_args():
             reqd_kw_flags = []
             has_reqd_kwds = False
             code.put(
@@ -1201,16 +1215,19 @@ class DefNode(FuncDefNode):
                         flags_name,
                         ",".join(reqd_kw_flags)))
 
-    def generate_argument_parsing_code(self, code):
+    def generate_argument_parsing_code(self, env, code):
         # Generate PyArg_ParseTuple call for generic
         # arguments, if any.
         has_kwonly_args = self.num_kwonly_args > 0
         has_star_or_kw_args = self.star_arg is not None \
             or self.starstar_arg is not None or has_kwonly_args
-        if not self.entry.signature.has_generic_args:
+        if not self.signature_has_generic_args():
             if has_star_or_kw_args:
                 error(self.pos, "This method cannot have * or keyword arguments")
             self.generate_argument_conversion_code(code)
+        elif not self.signature_has_nongeneric_args():
+            # func(*args) or func(**kw) or func(*args, **kw)
+            self.generate_stararg_copy_code(env, code)
         else:
             arg_addrs = []
             arg_formats = []
@@ -1353,6 +1370,38 @@ class DefNode(FuncDefNode):
         else:
             return 0
 
+    def generate_stararg_copy_code(self, env, code):
+        if not self.starstar_arg:
+            env.use_utility_code(get_keyword_error_utility_code)
+            code.putln("if (unlikely(%s) && unlikely(PyDict_Size(%s))) {" % (
+                    Naming.kwds_cname, Naming.kwds_cname))
+            code.putln("__Pyx_RaiseKeywordError(%s);" % Naming.kwds_cname)
+            code.putln("return %s;" % self.error_value())
+            code.putln("}")
+        if self.star_arg:
+            code.put_incref(Naming.args_cname, py_object_type)
+            code.putln("%s = %s; %s = 0;" % (
+                    self.star_arg.entry.cname,
+                    Naming.args_cname,
+                    Naming.args_cname))
+            self.star_arg.entry.xdecref_cleanup = 0
+            self.star_arg = None
+        else:
+            self.generate_positional_args_check(code, 0)
+        if self.starstar_arg:
+            code.putln("if (%s) {" % Naming.kwds_cname)
+            code.put_incref(Naming.kwds_cname, py_object_type)
+            code.putln("%s = %s; %s = 0;" % (
+                    self.starstar_arg.entry.cname,
+                    Naming.kwds_cname,
+                    Naming.kwds_cname))
+            code.putln("}")
+            code.putln("else {")
+            code.putln("%s = PyDict_New();" % self.starstar_arg.entry.cname)
+            code.putln("}")
+            self.starstar_arg.entry.xdecref_cleanup = 0
+            self.starstar_arg = None
+
     def generate_stararg_getting_code(self, code):
         num_kwonly = self.num_kwonly_args
         fixed_args = self.entry.signature.num_fixed_args()
@@ -1376,17 +1425,11 @@ class DefNode(FuncDefNode):
                     self.error_value()))
             code.putln("}")
             self.star_arg.entry.xdecref_cleanup = 0
-        elif self.entry.signature.has_generic_args:
+        elif self.signature_has_generic_args():
             # make sure supernumerous positional arguments do not run
             # into keyword-only arguments and provide a more helpful
             # message than PyArg_ParseTupelAndKeywords()
-            code.putln("if (unlikely(PyTuple_GET_SIZE(%s) > %d)) {" % (
-                    Naming.args_cname, nargs))
-            error_message = "function takes at most %d positional arguments (%d given)"
-            code.putln("PyErr_Format(PyExc_TypeError, \"%s\", %d, PyTuple_GET_SIZE(%s));" % (
-                    error_message, nargs, Naming.args_cname))
-            code.putln(error_return)
-            code.putln("}")
+            self.generate_positional_args_check(code, nargs)
 
         handle_error = 0
         if self.starstar_arg:
@@ -1415,6 +1458,15 @@ class DefNode(FuncDefNode):
             else:
                 code.putln(error_return)
 
+    def generate_positional_args_check(self, code, nargs):
+        code.putln("if (unlikely(PyTuple_GET_SIZE(%s) > %d)) {" % (
+                Naming.args_cname, nargs))
+        error_message = "function takes at most %d positional arguments (%d given)"
+        code.putln("PyErr_Format(PyExc_TypeError, \"%s\", %d, PyTuple_GET_SIZE(%s));" % (
+                error_message, nargs, Naming.args_cname))
+        code.putln("return %s;" % self.error_value())
+        code.putln("}")
+
     def generate_argument_conversion_code(self, code):
         # Generate code to convert arguments from
         # signature type to declared type, if needed.
@@ -3484,6 +3536,30 @@ static INLINE int __Pyx_SplitStarArg(
 }
 """]
 
+#------------------------------------------------------------------------------------
+#
+#  __Pyx_RaiseKeywordError raises an error that keywords were passed
+#  to a function that does not accept them.
+
+get_keyword_error_utility_code = [
+"""
+static void __Pyx_RaiseKeywordError(PyObject *kwdict); /*proto*/
+""","""
+static void __Pyx_RaiseKeywordError(PyObject *kwdict) {
+    PyObject* key = 0;
+    Py_ssize_t pos = 0;
+    PyDict_Next(kwdict, &pos, &key, 0);
+    if (!PyString_Check(key)) {
+        PyErr_SetString(PyExc_TypeError, "keywords must be strings");
+    }
+    else {
+        PyErr_Format(PyExc_TypeError,
+                     "'%s' is an invalid keyword argument for this function",
+                     PyString_AsString(key));
+    }
+}
+"""]
+
 #------------------------------------------------------------------------------------
 #
 #  __Pyx_SplitKeywords splits the kwds dict into two parts one part