inlined *args splitting code
authorStefan Behnel <scoder@users.berlios.de>
Mon, 25 Aug 2008 06:16:32 +0000 (08:16 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Mon, 25 Aug 2008 06:16:32 +0000 (08:16 +0200)
removed some redundancy from arg parsing helper functions (and some of the helper functions)
general cleanup and performance improvements (now a lot faster for common kw passing casese)

Cython/Compiler/Nodes.py

index 3a2d0b643125ef0e8c221fb08ce0d2871ab407ec..5e128d4fb0a533aec1d3cb31cf4cdedbf890313a 100644 (file)
@@ -1339,9 +1339,6 @@ class DefNode(FuncDefNode):
         self.declare_pyfunction(env)
         self.analyse_signature(env)
         self.return_type = self.entry.signature.return_type()
-        env.use_utility_code(raise_keyword_required_utility_code)
-        if self.num_required_kw_args:
-            env.use_utility_code(check_required_keywords_utility_code)
 
     def analyse_signature(self, env):
         any_type_tests_needed = 0
@@ -1649,10 +1646,7 @@ class DefNode(FuncDefNode):
                 error(arg.pos, "Cannot convert Python object argument to type '%s'" % arg.type)
 
     def put_stararg_decrefs(self, code):
-        if self.star_arg:
-            code.put_decref(Naming.args_cname, py_object_type)
-        if self.starstar_arg:
-            code.put_xdecref(Naming.kwds_cname, py_object_type)
+        pass
     
     def generate_arg_xdecref(self, arg, code):
         if arg:
@@ -1687,7 +1681,6 @@ class DefNode(FuncDefNode):
             code.putln("if (unlikely(!%s)) return %s;" % (
                     self.starstar_arg.entry.cname, self.error_value()))
             self.starstar_arg.entry.xdecref_cleanup = 0
-            self.starstar_arg = None
 
         if self.star_arg:
             code.put_incref(Naming.args_cname, py_object_type)
@@ -1695,48 +1688,6 @@ class DefNode(FuncDefNode):
                     self.star_arg.entry.cname,
                     Naming.args_cname))
             self.star_arg.entry.xdecref_cleanup = 0
-            self.star_arg = None
-
-    def generate_stararg_getting_code(self, max_positional_args, code):
-        if self.star_arg:
-            code.globalstate.use_utility_code(split_stararg_utility_code)
-            star_arg_cname = self.star_arg.entry.cname
-            code.putln("if (likely(PyTuple_GET_SIZE(%s) <= %d)) {" % (
-                    Naming.args_cname, max_positional_args))
-            code.put_incref(Naming.args_cname, py_object_type)
-            code.put("%s = %s; " % (star_arg_cname, Naming.empty_tuple))
-            code.put_incref(Naming.empty_tuple, py_object_type)
-            code.putln("} else {")
-            code.putln(
-                "if (unlikely(__Pyx_SplitStarArg(&%s, %d, &%s) < 0)) return %s;" % (
-                    Naming.args_cname,
-                    max_positional_args,
-                    star_arg_cname,
-                    self.error_value()))
-            code.putln("}")
-            self.star_arg.entry.xdecref_cleanup = 0
-
-        if self.starstar_arg:
-            code.globalstate.use_utility_code(split_keywords_utility_code)
-            handle_error = 1
-            code.put(
-                'if (unlikely(__Pyx_SplitKeywords(&%s, %s, &%s, %s, PyTuple_GET_SIZE(%s), "%s") < 0)) ' % (
-                    Naming.kwds_cname,
-                    Naming.pykwdlist_cname,
-                    self.starstar_arg.entry.cname,
-                    self.reqd_kw_flags_cname,
-                    Naming.args_cname,
-                    self.name.utf8encode()))
-            self.starstar_arg.entry.xdecref_cleanup = 0
-            error_return = "return %s;" % self.error_value()
-            if self.star_arg:
-                code.putln("{")
-                code.put_decref(Naming.args_cname, py_object_type)
-                code.put_decref(self.star_arg.entry.cname, py_object_type)
-                code.putln(error_return)
-                code.putln("}")
-            else:
-                code.putln(error_return)
 
     def generate_tuple_and_keyword_parsing_code(self, positional_args,
                                                 kw_only_args, code):
@@ -1750,35 +1701,65 @@ class DefNode(FuncDefNode):
         has_fixed_positional_count = not self.star_arg and \
             min_positional_args == max_positional_args
 
-        if self.star_arg or self.starstar_arg:
-            if not self.star_arg:
-                # need to check the tuple before checking the keywords...
-                self.generate_positional_args_check(
-                    max_positional_args, has_fixed_positional_count, code)
-            self.generate_stararg_getting_code(max_positional_args, code)
-
         code.globalstate.use_utility_code(raise_double_keywords_utility_code)
-        code.globalstate.use_utility_code(keyword_check_utility_code)
+        code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
+        if self.num_required_kw_args:
+            code.globalstate.use_utility_code(raise_keyword_required_utility_code)
+
+        if self.starstar_arg:
+            self.starstar_arg.entry.xdecref_cleanup = 0
+            code.putln('%s = PyDict_New(); if (unlikely(!%s)) return %s;' % (
+                    self.starstar_arg.entry.cname,
+                    self.starstar_arg.entry.cname,
+                    self.error_value()))
+        if self.star_arg:
+            self.star_arg.entry.xdecref_cleanup = 0
+            code.putln('if (PyTuple_GET_SIZE(%s) > %d) {' % (
+                    Naming.args_cname,
+                    max_positional_args))
+            code.put('%s = PyTuple_GetSlice(%s, %d, PyTuple_GET_SIZE(%s)); ' % (
+                    self.star_arg.entry.cname, Naming.args_cname,
+                    max_positional_args, Naming.args_cname))
+            if self.starstar_arg:
+                code.putln("")
+                code.putln("if (unlikely(!%s)) {" % self.star_arg.entry.cname)
+                code.put_decref(self.starstar_arg.entry.cname, py_object_type)
+                code.putln('return %s;' % self.error_value())
+                code.putln('}')
+            else:
+                code.putln("if (unlikely(!%s)) return %s;" % (
+                        self.star_arg.entry.cname, self.error_value()))
+            code.putln('} else {')
+            code.put("%s = %s; " % (self.star_arg.entry.cname, Naming.empty_tuple))
+            code.put_incref(Naming.empty_tuple, py_object_type)
+            code.putln('}')
 
         # --- optimised code when we receive keyword arguments
-        code.putln("if (unlikely(%s) && (PyDict_Size(%s) > 0)) {" % (
-                Naming.kwds_cname, Naming.kwds_cname))
+        if self.num_required_kw_args:
+            code.putln("if (likely(%s)) {" % Naming.kwds_cname)
+        else:
+            code.putln("if (unlikely(%s) && (PyDict_Size(%s) > 0)) {" % (
+                    Naming.kwds_cname, Naming.kwds_cname))
         code.putln("PyObject* values[%d] = {%s};" % (
                 max_args, ('0,'*max_args)[:-1]))
         code.putln("Py_ssize_t kw_args = PyDict_Size(%s);" %
                    Naming.kwds_cname)
 
-        # parse the tuple and check that there are not too many
+        # parse the tuple and check that it's not too long
         code.putln('switch (PyTuple_GET_SIZE(%s)) {' % Naming.args_cname)
+        if self.star_arg:
+            code.putln('default:')
         for i in range(max_positional_args-1, -1, -1):
             code.putln('case %d:' % (i+1))
             code.putln("values[%d] = PyTuple_GET_ITEM(%s, %d);" % (
                     i, Naming.args_cname, i))
-        if not self.star_arg and not self.starstar_arg:
-            code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
-            code.putln('case 0:')
+        if self.star_arg:
+            code.putln('case 0: break;')
+        else:
+            if min_positional_args == 0:
+                code.putln('case 0:')
             code.putln('break;')
-            code.putln('default:') # more arguments than allowed
+            code.put('default: ') # more arguments than allowed
             code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
                     self.name.utf8encode(), has_fixed_positional_count,
                     min_positional_args, max_positional_args,
@@ -1790,25 +1771,48 @@ class DefNode(FuncDefNode):
         code.putln('switch (PyTuple_GET_SIZE(%s)) {' % Naming.args_cname)
         for i, arg in enumerate(all_args):
             if i <= max_positional_args:
-                code.putln('case %d:' % i)
+                if self.star_arg and i == max_positional_args:
+                    code.putln('default:')
+                else:
+                    code.putln('case %d:' % i)
             code.putln('values[%d] = PyDict_GetItem(%s, *%s[%d]);' % (
                     i, Naming.kwds_cname, Naming.pykwdlist_cname, i))
-            code.putln('if (values[%d]) kw_args--;' % i);
-            if arg.kw_only and not arg.default:
+            if i < min_positional_args:
+                code.putln('if (likely(values[%d])) kw_args--;' % i);
                 code.putln('else {')
-                code.put('__Pyx_RaiseKeywordRequired("%s", *%s[%d]); ' %(
-                        self.name.utf8encode(), Naming.pykwdlist_cname, i))
+                code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
+                        self.name.utf8encode(), has_fixed_positional_count,
+                        min_positional_args, max_positional_args,
+                        Naming.args_cname))
                 code.putln(code.error_goto(self.pos))
                 code.putln('}')
+            else:
+                code.putln('if (values[%d]) kw_args--;' % i);
+                if arg.kw_only and not arg.default:
+                    code.putln('else {')
+                    code.put('__Pyx_RaiseKeywordRequired("%s", *%s[%d]); ' %(
+                            self.name.utf8encode(), Naming.pykwdlist_cname, i))
+                    code.putln(code.error_goto(self.pos))
+                    code.putln('}')
         code.putln('}')
 
         code.putln('if (unlikely(kw_args > 0)) {')
-        # __Pyx_CheckKeywords() does more than strictly necessary, but
-        # since we already know we will raise an exception, this is
-        # not performance critical anymore
-        code.put('__Pyx_CheckKeywords(%s, "%s", %s, PyTuple_GET_SIZE(%s)); ' % (
-                Naming.kwds_cname, self.name.utf8encode(),
-                Naming.pykwdlist_cname, Naming.args_cname))
+        # non-positional kw args left in the dict: **kwargs or error
+        if self.star_arg:
+            code.putln("const Py_ssize_t used_pos_args = (PyTuple_GET_SIZE(%s) < %d) ? PyTuple_GET_SIZE(%s) : %d;" % (
+                    Naming.args_cname, max_positional_args,
+                    Naming.args_cname, max_positional_args))
+            pos_arg_count = "used_pos_args"
+        else:
+            pos_arg_count = "PyTuple_GET_SIZE(%s)" % Naming.args_cname
+        code.globalstate.use_utility_code(split_keywords_utility_code)
+        code.put(
+            'if (unlikely(__Pyx_SplitKeywords(%s, %s, %s, %s, "%s") < 0)) ' % (
+                Naming.kwds_cname,
+                Naming.pykwdlist_cname,
+                self.starstar_arg and self.starstar_arg.entry.cname or '0',
+                pos_arg_count,
+                self.name.utf8encode()))
         code.putln(code.error_goto(self.pos))
         code.putln('}')
 
@@ -1821,6 +1825,21 @@ class DefNode(FuncDefNode):
                 code.putln('}')
 
         # --- optimised code when we do not receive any keyword arguments
+        if min_positional_args > 0 or min_positional_args == max_positional_args:
+            # Python raises arg tuple related errors first, so we must
+            # check the length here
+            if min_positional_args == max_positional_args and not self.star_arg:
+                compare = '!='
+            else:
+                compare = '<'
+            code.putln('} else if (PyTuple_GET_SIZE(%s) %s %d) {' % (
+                    Naming.args_cname, compare, min_positional_args))
+            code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
+                    self.name.utf8encode(), has_fixed_positional_count,
+                    min_positional_args, max_positional_args,
+                    Naming.args_cname))
+            code.putln(code.error_goto(self.pos))
+
         if self.num_required_kw_args:
             # pure error case: keywords required but not passed
             code.putln('} else {')
@@ -1839,14 +1858,6 @@ class DefNode(FuncDefNode):
         elif min_positional_args == max_positional_args:
             # parse the exact number of positional arguments from the
             # args tuple
-            code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
-            code.putln('} else if (PyTuple_GET_SIZE(%s) != %d) {' % (
-                    Naming.args_cname, min_positional_args))
-            code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
-                    self.name.utf8encode(), has_fixed_positional_count,
-                    min_positional_args, max_positional_args,
-                    Naming.args_cname))
-            code.putln(code.error_goto(self.pos))
             code.putln('} else {')
             for i, arg in enumerate(positional_args):
                 item = "PyTuple_GET_ITEM(%s, %d)" % (Naming.args_cname, i)
@@ -1862,16 +1873,11 @@ class DefNode(FuncDefNode):
                     code.putln('case %d:' % (i+1))
                 item = "PyTuple_GET_ITEM(%s, %d)" % (Naming.args_cname, i)
                 self.generate_arg_assignment(arg, item, code)
-            if not self.star_arg or min_positional_args > 0:
-                code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
+            if not self.star_arg:
                 if min_positional_args == 0:
                     code.putln('case 0:')
                 code.putln('break;')
-                if self.star_arg:
-                    for i in range(min_positional_args-1,-1,-1):
-                        code.putln('case %d:' % i)
-                else:
-                    code.putln('default:') # more arguments than allowed
+                code.put('default:')
                 code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
                         self.name.utf8encode(), has_fixed_positional_count,
                         min_positional_args, max_positional_args,
@@ -1880,7 +1886,6 @@ class DefNode(FuncDefNode):
             code.putln('}')
 
         code.putln('}')
-        return
 
     def generate_positional_args_check(self, max_positional_args,
                                        has_fixed_pos_count, code):
@@ -4453,27 +4458,37 @@ static int __Pyx_CheckKeywordStrings(
 
 #------------------------------------------------------------------------------------
 #
-#  __Pyx_CheckKeywords raises an error if any non-string or
-#  unsupported keywords were passed to a function, or if a keyword was
-#  already passed as positional argument.
+#  __Pyx_SplitKeywords copies the keyword arguments that are not named
+#  in argnames[] from the kwds dict into kwds2.  If kwds2 is NULL,
+#  these keywords will raise an invalid keyword error.
+#
+#  Three kinds of errors are checked: 1) non-string keywords, 2)
+#  unexpected keywords and 3) overlap with positional arguments.
+#
+#  If num_posargs is greater 0, it denotes the number of positional
+#  arguments that were passed and that must therefore not appear
+#  amongst the keywords as well.
+#
+#  This method does not check for required keyword arguments.
 #
-#  It generally does the right thing. :)
 
-keyword_check_utility_code = [
+split_keywords_utility_code = [
 """
-static int __Pyx_CheckKeywords(PyObject *kwdict, const char* function_name,
-    PyObject** argnames[], Py_ssize_t num_pos_args); /*proto*/
+static int __Pyx_SplitKeywords(PyObject *kwds, PyObject **argnames[], \
+    PyObject *kwds2, Py_ssize_t num_pos_args, char* function_name); /*proto*/
 ""","""
-static int __Pyx_CheckKeywords(
-    PyObject *kwdict,
-    const char* function_name,
-    PyObject** argnames[],
-    Py_ssize_t num_pos_args)
+static int __Pyx_SplitKeywords(
+    PyObject *kwds,
+    PyObject **argnames[],
+    PyObject *kwds2,
+    Py_ssize_t num_pos_args,
+    char* function_name)
 {
-    PyObject* key = 0;
+    PyObject* key = 0, *value = 0;
     Py_ssize_t pos = 0;
     PyObject*** name;
-    while (PyDict_Next(kwdict, &pos, &key, 0)) {
+
+    while (PyDict_Next(kwds, &pos, &key, &value)) {
         #if PY_MAJOR_VERSION < 3
         if (unlikely(!PyString_CheckExact(key)) && unlikely(!PyString_Check(key))) {
         #else
@@ -4482,7 +4497,7 @@ static int __Pyx_CheckKeywords(
             PyErr_Format(PyExc_TypeError,
                          "%s() keywords must be strings", function_name);
             return 0;
-        } else if (argnames) {
+        } else {
             name = argnames;
             while (*name && (**name != key)) name++;
             if (!*name) {
@@ -4494,17 +4509,23 @@ static int __Pyx_CheckKeywords(
                                PyString_AS_STRING(key)) == 0) break;
                     #endif
                 }
-                if (!*name)
-                    goto invalid_keyword;
+                if (!*name) {
+                    if (kwds2) {
+                        if (unlikely(PyDict_SetItem(kwds2, key, value))) goto bad;
+                    } else {
+                        goto split_kw_invalid_keyword;
+                    }
+                }
             }
-            if (*name && ((name-argnames) < num_pos_args)) {
-                __Pyx_RaiseDoubleKeywordsError(function_name, **name);
-               return -1;
+            if (*name && ((name-argnames) < num_pos_args))
+                goto split_kw_arg_passed_twice;
             }
-        }
     }
     return 0;
-invalid_keyword:
+split_kw_arg_passed_twice:
+    __Pyx_RaiseDoubleKeywordsError(function_name, **name);
+    goto bad;
+split_kw_invalid_keyword:
     PyErr_Format(PyExc_TypeError,
     #if PY_MAJOR_VERSION < 3
         "'%s' is an invalid keyword argument for this function",
@@ -4513,135 +4534,7 @@ invalid_keyword:
         "'%U' is an invalid keyword argument for this function",
         key);
     #endif
-    return -1;
-}
-"""]
-
-#------------------------------------------------------------------------------------
-#
-#  __Pyx_SplitKeywords splits the kwds dict into two parts one part
-#  suitable for passing to PyArg_ParseTupleAndKeywords, and the other
-#  containing any extra arguments. On success, replaces the borrowed
-#  reference *kwds with references to a new dict, and passes back a
-#  new reference in *kwds2.  Does not touch any of its arguments on
-#  failure.
-#
-#  Any of *kwds and kwds2 may be 0 (but not kwds). If *kwds == 0, it
-#  is not changed. If kwds2 == 0 and *kwds != 0, a new reference to
-#  the same dictionary is passed back in *kwds.
-#
-#  If rqd_kwds is not 0, it is an array of booleans corresponding to
-#  the names in kwd_list, indicating required keyword arguments. If
-#  any of these are not present in kwds, an exception is raised.
-#
-#  If num_posargs is greater 0, it denotes the number of positional
-#  arguments that were passed and that must therefore not get passed
-#  as keyword arguments as well.
-#
-
-split_keywords_utility_code = [
-"""
-static int __Pyx_SplitKeywords(PyObject **kwds, PyObject **kwd_list[], \
-    PyObject **kwds2, char rqd_kwds[],
-    Py_ssize_t num_posargs, char* function_name); /*proto*/
-""","""
-static int __Pyx_SplitKeywords(
-    PyObject **kwds,
-    PyObject **kwd_list[],
-    PyObject **kwds2,
-    char rqd_kwds[],
-    Py_ssize_t num_posargs,
-    char* function_name)
-{
-    PyObject *x = 0, *kwds1 = 0;
-    int i;
-    PyObject ***p;
-    
-    if (*kwds) {
-        kwds1 = PyDict_New();
-        if (!kwds1)
-            goto bad;
-        *kwds2 = PyDict_Copy(*kwds);
-        if (!*kwds2)
-            goto bad;
-        for (i = 0, p = kwd_list; *p; i++, p++) {
-            x = PyDict_GetItem(*kwds, **p);
-            if (x) {
-                if (i < num_posargs)
-                    goto arg_passed_twice;
-                if (PyDict_SetItem(kwds1, **p, x) < 0)
-                    goto bad;
-                if (PyDict_DelItem(*kwds2, **p) < 0)
-                    goto bad;
-            }
-            else if (rqd_kwds && rqd_kwds[i])
-                goto missing_kwarg;
-        }
-    }
-    else {
-        if (rqd_kwds) {
-            for (i = 0, p = kwd_list; *p; i++, p++)
-                if (rqd_kwds[i])
-                    goto missing_kwarg;
-        }
-        *kwds2 = PyDict_New();
-        if (!*kwds2)
-            goto bad;
-    }
-
-    *kwds = kwds1;
-    return 0;
-arg_passed_twice:
-    __Pyx_RaiseDoubleKeywordsError(function_name, **p);
-    goto bad;
-missing_kwarg:
-    __Pyx_RaiseKeywordRequired(function_name, **p);
 bad:
-    Py_XDECREF(kwds1);
-    Py_XDECREF(*kwds2);
-    return -1;
-}
-"""]
-
-check_required_keywords_utility_code = [
-"""
-static INLINE int __Pyx_CheckRequiredKeywords(PyObject *kwds, PyObject **kwd_list[],
-    char rqd_kwds[], Py_ssize_t num_posargs, char* function_name); /*proto*/
-""","""
-static INLINE int __Pyx_CheckRequiredKeywords(
-    PyObject *kwds,
-    PyObject **kwd_list[],
-    char rqd_kwds[],
-    Py_ssize_t num_posargs,
-    char* function_name)
-{
-    int i;
-    PyObject ***p;
-
-    if (kwds) {
-        p = kwd_list;
-        for (i=0; i < num_posargs && *p; i++, p++) {
-            if (PyDict_GetItem(kwds, **p))
-                goto arg_passed_twice;
-        }
-        while (*p) {
-            if (rqd_kwds[i] && !PyDict_GetItem(kwds, **p))
-                goto missing_kwarg;
-            i++; p++;
-        }
-    }
-    else {
-        for (i = 0, p = kwd_list; *p; i++, p++)
-            if (rqd_kwds[i])
-                goto missing_kwarg;
-    }
-
-    return 0;
-arg_passed_twice:
-    __Pyx_RaiseDoubleKeywordsError(function_name, **p);
-    return -1;
-missing_kwarg:
-    __Pyx_RaiseKeywordRequired(function_name, **p);
     return -1;
 }
 """]