GetStarArgs: separate out keyword handling, generate specific code for functions...
authorStefan Behnel <scoder@users.berlios.de>
Tue, 15 Jan 2008 22:20:02 +0000 (23:20 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 15 Jan 2008 22:20:02 +0000 (23:20 +0100)
Cython/Compiler/Nodes.py

index 27a07fd509f4bbda0593be7c7aee908f10ccf28a..02fefce450de101425dd743a0604c88ad60c4825 100644 (file)
@@ -930,15 +930,19 @@ class DefNode(FuncDefNode):
     
     assmt = None
     num_kwonly_args = 0
+    num_required_kw_args = 0
     reqd_kw_flags_cname = "0"
     
     def __init__(self, pos, **kwds):
         FuncDefNode.__init__(self, pos, **kwds)
-        n = 0
+        n = r = 0
         for arg in self.args:
             if arg.kw_only:
                 n += 1
+                if not arg.default:
+                    r += 1
         self.num_kwonly_args = n
+        self.num_required_kw_args = r
     
     def analyse_declarations(self, env):
         for arg in self.args:
@@ -963,7 +967,8 @@ class DefNode(FuncDefNode):
         if self.star_arg or self.starstar_arg or self.num_kwonly_args > 0:
             env.use_utility_code(get_stararg_utility_code)
             env.use_utility_code(get_splitkeywords_utility_code)
-    
+            env.use_utility_code(get_checkkeywords_utility_code)
+
     def analyse_signature(self, env):
         any_type_tests_needed = 0
         # Use the simpler calling signature for zero- and one-argument functions.
@@ -1246,9 +1251,8 @@ class DefNode(FuncDefNode):
                     pt_argstring)
             if has_star_or_kw_args:
                 code.putln("{")
-                code.put_xdecref(Naming.args_cname, py_object_type)
-                code.put_xdecref(Naming.kwds_cname, py_object_type)
-                self.generate_arg_xdecref(self.star_arg, code)
+                self.put_stararg_decrefs(code)
+                self.generate_arg_decref(self.star_arg, code)
                 self.generate_arg_xdecref(self.starstar_arg, code)
                 code.putln(error_return_code)
                 code.putln("}")
@@ -1256,14 +1260,19 @@ class DefNode(FuncDefNode):
                 code.putln(error_return_code)
 
     def put_stararg_decrefs(self, code):
-        if self.star_arg or self.starstar_arg or self.num_kwonly_args > 0:
-            code.put_xdecref(Naming.args_cname, py_object_type)
+        if self.star_arg:
+            code.put_decref(Naming.args_cname, py_object_type)
+        if self.starstar_arg:
             code.put_xdecref(Naming.kwds_cname, py_object_type)
     
     def generate_arg_xdecref(self, arg, code):
         if arg:
             code.put_var_xdecref(arg.entry)
     
+    def generate_arg_decref(self, arg, code):
+        if arg:
+            code.put_var_decref(arg.entry)
+    
     def arg_address(self, arg):
         if arg:
             return "&%s" % arg.entry.cname
@@ -1271,9 +1280,11 @@ class DefNode(FuncDefNode):
             return 0
 
     def generate_stararg_getting_code(self, code):
+        error_return = "return %s;" % self.error_value()
         num_kwonly = self.num_kwonly_args
         fixed_args = self.entry.signature.num_fixed_args()
         nargs = len(self.args) - num_kwonly - fixed_args
+
         if self.star_arg:
             star_arg_addr = self.arg_address(self.star_arg)
             code.putln(
@@ -1290,17 +1301,34 @@ class DefNode(FuncDefNode):
             error_message = "function takes at most %d positional arguments (%d given)"
             code.putln("PyErr_Format(PyExc_TypeError, \"%s\", %d, PyTuple_GET_SIZE(%s));" % (
                     error_message, nargs, Naming.args_cname))
-            code.putln("return %s;" % self.error_value())
+            code.putln(error_return)
             code.putln("}")
-        if self.starstar_arg or num_kwonly:
-            starstar_arg_addr = self.arg_address(self.starstar_arg)
-            code.putln(
-                "if (unlikely(__Pyx_SplitKeywords(&%s, %s, %s, %s) < 0)) return %s;" % (
+
+        handle_error = 0
+        if self.starstar_arg:
+            handle_error = 1
+            code.put(
+                "if (unlikely(__Pyx_SplitKeywords(&%s, %s, %s, %s) < 0)) " % (
                     Naming.kwds_cname,
                     Naming.kwdlist_cname,
-                    starstar_arg_addr,
-                    self.reqd_kw_flags_cname,
-                    self.error_value()))
+                    self.arg_address(self.starstar_arg),
+                    self.reqd_kw_flags_cname))
+        elif self.num_required_kw_args:
+            handle_error = 1
+            code.put("if (unlikely(__Pyx_CheckRequiredKeywords(%s, %s, %s) < 0)) " % (
+                    Naming.kwds_cname,
+                    Naming.kwdlist_cname,
+                    self.reqd_kw_flags_cname))
+
+        if handle_error:
+            if self.star_arg:
+                code.putln("{")
+                code.put_decref(Naming.args_cname, py_object_type)
+                code.put_decref(self.star_arg.entry.cname, py_object_type)
+                code.putln(error_return)
+                code.putln("}")
+            else:
+                code.putln(error_return)
 
     def generate_argument_conversion_code(self, code):
         # Generate code to convert arguments from
@@ -3344,17 +3372,12 @@ static int __Pyx_ArgTypeTest(PyObject *obj, PyTypeObject *type, int none_allowed
 #  reference *args with references to a new tuple, and passes back a
 #  new reference in *args2.  Does not touch any of its arguments on
 #  failure.
-#
-#  If rqd_kwds is not 0, it is an array of booleans corresponding to the
-#  names in kwd_list, indicating required keyword arguments. If any of
-#  these are not present in kwds, an exception is raised.
-#
 
 get_stararg_utility_code = [
 """
-static int __Pyx_GetStarArg(PyObject **args, Py_ssize_t nargs, PyObject **args2); /*proto*/
+static INLINE int __Pyx_GetStarArg(PyObject **args, Py_ssize_t nargs, PyObject **args2); /*proto*/
 ""","""
-static int __Pyx_GetStarArg(
+static INLINE int __Pyx_GetStarArg(
     PyObject **args, 
     Py_ssize_t nargs,
     PyObject **args2)
@@ -3364,19 +3387,14 @@ static int __Pyx_GetStarArg(
     *args2 = 0;
     args1 = PyTuple_GetSlice(*args, 0, nargs);
     if (!args1)
-        goto bad;
+        return -1;
     *args2 = PyTuple_GetSlice(*args, nargs, PyTuple_GET_SIZE(*args));
-    if (!*args2)
-        goto bad;
-
+    if (!*args2) {
+        Py_DECREF(args1);
+        return -1;
+    }
     *args = args1;
     return 0;
-bad:
-    Py_XDECREF(args1);
-    if (args2) {
-        Py_XDECREF(*args2);
-    }
-    return -1;
 }
 """]
 
@@ -3413,41 +3431,27 @@ static int __Pyx_SplitKeywords(
     int i;
     char **p;
     
-    if (kwds2)
-        *kwds2 = 0;
-    
     if (*kwds) {
-        if (kwds2) {
-            kwds1 = PyDict_New();
-            if (!kwds1)
-                goto bad;
-            *kwds2 = PyDict_Copy(*kwds);
-            if (!*kwds2)
-                goto bad;
-            for (i = 0, p = kwd_list; *p; i++, p++) {
-                s = PyString_FromString(*p);
-                x = PyDict_GetItem(*kwds, s);
-                if (x) {
-                    if (PyDict_SetItem(kwds1, s, x) < 0)
-                        goto bad;
-                    if (PyDict_DelItem(*kwds2, s) < 0)
-                        goto bad;
-                }
-                else if (rqd_kwds && rqd_kwds[i])
-                    goto missing_kwarg;
-                Py_DECREF(s);
-            }
-            s = 0;
-        }
-        else {
-            kwds1 = *kwds;
-            Py_INCREF(kwds1);
-            if (rqd_kwds) {
-                for (i = 0, p = kwd_list; *p; i++, p++)
-                    if (rqd_kwds[i] && !PyDict_GetItemString(kwds1, *p))
-                        goto missing_kwarg;
+        kwds1 = PyDict_New();
+        if (!kwds1)
+            goto bad;
+        *kwds2 = PyDict_Copy(*kwds);
+        if (!*kwds2)
+            goto bad;
+        for (i = 0, p = kwd_list; *p; i++, p++) {
+            s = PyString_FromString(*p);
+            x = PyDict_GetItem(*kwds, s);
+            if (x) {
+                if (PyDict_SetItem(kwds1, s, x) < 0)
+                    goto bad;
+                if (PyDict_DelItem(*kwds2, s) < 0)
+                    goto bad;
             }
+            else if (rqd_kwds && rqd_kwds[i])
+                goto missing_kwarg;
+            Py_DECREF(s);
         }
+        s = 0;
     }
     else {
         if (rqd_kwds) {
@@ -3455,11 +3459,9 @@ static int __Pyx_SplitKeywords(
                 if (rqd_kwds[i])
                     goto missing_kwarg;
         }
-        if (kwds2) {
-            *kwds2 = PyDict_New();
-            if (!*kwds2)
-                goto bad;
-        }
+        *kwds2 = PyDict_New();
+        if (!*kwds2)
+            goto bad;
     }
 
     *kwds = kwds1;
@@ -3470,9 +3472,39 @@ missing_kwarg:
 bad:
     Py_XDECREF(s);
     Py_XDECREF(kwds1);
-    if (kwds2) {
-        Py_XDECREF(*kwds2);
+    Py_XDECREF(*kwds2);
+    return -1;
+}
+"""]
+
+get_checkkeywords_utility_code = [
+"""
+static INLINE int __Pyx_CheckRequiredKeywords(PyObject *kwds, char *kwd_list[],
+    char rqd_kwds[]); /*proto*/
+""","""
+static INLINE int __Pyx_CheckRequiredKeywords(
+    PyObject *kwds,
+    char *kwd_list[],
+    char rqd_kwds[])
+{
+    int i;
+    char **p;
+
+    if (kwds) {
+        for (i = 0, p = kwd_list; *p; i++, p++)
+            if (rqd_kwds[i] && !PyDict_GetItemString(kwds, *p))
+                goto missing_kwarg;
+    }
+    else {
+        for (i = 0, p = kwd_list; *p; i++, p++)
+            if (rqd_kwds[i])
+                goto missing_kwarg;
     }
+
+    return 0;
+missing_kwarg:
+    PyErr_Format(PyExc_TypeError,
+        "required keyword argument '%s' is missing", *p);
     return -1;
 }
 """]