rewrite of the argument unpacking code, now using fast C switch statements
authorStefan Behnel <scoder@users.berlios.de>
Sat, 23 Aug 2008 15:32:50 +0000 (17:32 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 23 Aug 2008 15:32:50 +0000 (17:32 +0200)
Cython/Compiler/Nodes.py

index dc67330fb369f426ca786e02a57c08dfed4f13a0..6fa53589753dec2d427c17954b66481e5d7b5a13 100644 (file)
@@ -1547,28 +1547,13 @@ class DefNode(FuncDefNode):
     def generate_keyword_list(self, code):
         if self.signature_has_generic_args() and \
                 self.signature_has_nongeneric_args():
-            reqd_kw_flags = []
-            has_reqd_kwds = False
             code.put(
                 "static PyObject **%s[] = {" %
                     Naming.pykwdlist_cname)
             for arg in self.args:
                 if arg.is_generic:
                     code.put('&%s,' % arg.name_entry.pystring_cname)
-                    if arg.kw_only and not arg.default:
-                        has_reqd_kwds = 1
-                        flag = "1"
-                    else:
-                        flag = "0"
-                    reqd_kw_flags.append(flag)
             code.putln("0};")
-            if has_reqd_kwds:
-                flags_name = Naming.reqd_kwds_cname
-                self.reqd_kw_flags_cname = flags_name
-                code.putln(
-                    "static char %s[] = {%s};" % (
-                        flags_name,
-                        ",".join(reqd_kw_flags)))
 
     def generate_argument_parsing_code(self, env, code):
         # Generate PyArg_ParseTuple call for generic
@@ -1755,18 +1740,21 @@ class DefNode(FuncDefNode):
 
     def generate_tuple_and_keyword_parsing_code(self, positional_args,
                                                 kw_only_args, code):
+        all_args = tuple(positional_args) + tuple(kw_only_args)
+
         min_positional_args = self.num_required_args - self.num_required_kw_args
         if len(self.args) > 0 and self.args[0].is_self_arg:
             min_positional_args -= 1
         max_positional_args = len(positional_args)
-        max_args = max_positional_args + len(kw_only_args)
-
-        if not self.star_arg:
-            has_fixed_positional_count = min_positional_args == max_positional_args
-            self.generate_positional_args_check(
-                max_positional_args, has_fixed_positional_count, code)
+        max_args = len(all_args)
+        has_fixed_positional_count = not self.star_arg and \
+            min_positional_args == max_positional_args
 
         if self.star_arg or self.starstar_arg:
+            if not self.star_arg:
+                # need to check the tuple before checking the keywords...
+                self.generate_positional_args_check(
+                    max_positional_args, has_fixed_positional_count, code)
             self.generate_stararg_getting_code(max_positional_args, code)
 
         code.globalstate.use_utility_code(raise_double_keywords_utility_code)
@@ -1775,49 +1763,58 @@ class DefNode(FuncDefNode):
         # --- optimised code when we receive keyword arguments
         code.putln("if (unlikely(%s) && (PyDict_Size(%s) > 0)) {" % (
                 Naming.kwds_cname, Naming.kwds_cname))
-        code.putln("PyObject* values[%d];" % max_args)
-        code.putln("Py_ssize_t arg, kw_args = PyDict_Size(%s);" %
+        code.putln("PyObject* values[%d] = {%s};" % (
+                max_args, ('0,'*max_args)[:-1]))
+        code.putln("Py_ssize_t kw_args = PyDict_Size(%s);" %
                    Naming.kwds_cname)
 
-        # parse arg tuple and check that positional args are not also
-        # passed as kw args
-        code.putln("for (arg=0; arg < PyTuple_GET_SIZE(%s); arg++) {" %
-                   Naming.args_cname)
-        code.putln("values[arg] = PyTuple_GET_ITEM(%s, arg);" %
-                   Naming.args_cname)
-        code.putln("if (unlikely(PyDict_GetItem(%s, *%s[arg]))) {" % (
-                Naming.kwds_cname, Naming.pykwdlist_cname))
-        code.put('__Pyx_RaiseDoubleKeywordsError("%s", *%s[arg]); ' % (
-                self.name.utf8encode(), Naming.pykwdlist_cname))
-        code.putln(code.error_goto(self.pos))
-        code.putln('}')
+        # parse the tuple first, then start parsing the arg tuple and
+        # check that positional args are not also passed as kw args
+        code.putln('switch (PyTuple_GET_SIZE(%s)) {' % Naming.args_cname)
+        for i in range(max_positional_args-1, -1, -1):
+            code.putln('case %d:' % (i+1))
+            code.putln("values[%d] = PyTuple_GET_ITEM(%s, %d);" % (
+                    i, Naming.args_cname, i))
+        if not self.star_arg and not self.starstar_arg:
+            code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
+            code.putln('case 0:')
+            code.putln('break;')
+            code.putln('default:') # more arguments than allowed
+            code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
+                    self.name.utf8encode(), has_fixed_positional_count,
+                    min_positional_args, max_positional_args,
+                    Naming.args_cname))
+            code.putln(code.error_goto(self.pos))
         code.putln('}')
 
-        # parse remaining positional args from the keyword dictionary
-        code.putln("for (arg=PyTuple_GET_SIZE(%s); arg < %d; arg++) {" % (
-                Naming.args_cname, max_args))
-        code.putln('values[arg] = PyDict_GetItem(%s, *%s[arg]);' % (
-                Naming.kwds_cname, Naming.pykwdlist_cname))
-        code.putln('if (values[arg]) kw_args--;');
-        if self.num_required_kw_args:
-            code.putln('else if (%s[arg]) {' % Naming.reqd_kwds_cname)
-            code.put('__Pyx_RaiseKeywordRequired("%s", *%s[arg]); ' %(
-                    self.name.utf8encode(), Naming.pykwdlist_cname))
-            code.putln(code.error_goto(self.pos))
-            code.putln('}')
+        # now fill up the arguments with values from the kw dict
+        code.putln('switch (PyTuple_GET_SIZE(%s)) {' % Naming.args_cname)
+        for i, arg in enumerate(all_args):
+            if i <= max_positional_args:
+                code.putln('case %d:' % i)
+            code.putln('values[%d] = PyDict_GetItem(%s, *%s[%d]);' % (
+                    i, Naming.kwds_cname, Naming.pykwdlist_cname, i))
+            code.putln('if (values[%d]) kw_args--;' % i);
+            if arg.kw_only and not arg.default:
+                code.putln('else {')
+                code.put('__Pyx_RaiseKeywordRequired("%s", *%s[%d]); ' %(
+                        self.name.utf8encode(), Naming.pykwdlist_cname, i))
+                code.putln(code.error_goto(self.pos))
+                code.putln('}')
         code.putln('}')
 
-        # raise an error if not all keywords were read
         code.putln('if (unlikely(kw_args > 0)) {')
-        # __Pyx_CheckKeywords() this does more than strictly
-        # necessary, but this is not performance critical at all
-        code.put('__Pyx_CheckKeywords(%s, "%s", %s); ' % (
-                Naming.kwds_cname, self.name.utf8encode(), Naming.pykwdlist_cname))
+        # __Pyx_CheckKeywords() does more than strictly necessary, but
+        # since we already know we will raise an exception, this is
+        # not performance critical anymore
+        code.put('__Pyx_CheckKeywords(%s, "%s", %s, PyTuple_GET_SIZE(%s)); ' % (
+                Naming.kwds_cname, self.name.utf8encode(),
+                Naming.pykwdlist_cname, Naming.args_cname))
         code.putln(code.error_goto(self.pos))
         code.putln('}')
 
         # convert arg values to their final type and assign them
-        for i, arg in enumerate(tuple(positional_args) + tuple(kw_only_args)):
+        for i, arg in enumerate(all_args):
             if arg.default:
                 code.putln("if (values[%d]) {" % i)
             self.generate_arg_assignment(arg, "values[%d]" % i, code)
@@ -1826,41 +1823,62 @@ class DefNode(FuncDefNode):
 
         # --- optimised code when we do not receive any keyword arguments
         if self.num_required_kw_args:
+            code.putln('} else {')
+            if not self.star_arg:
+                self.generate_positional_args_check(
+                    max_positional_args, has_fixed_positional_count, code)
             # simple case: keywords required but none passed
             for i, arg in enumerate(kw_only_args):
                 if not arg.default:
-                    required_arg = arg
+                    code.put('__Pyx_RaiseKeywordRequired("%s", *%s[%d]); ' % (
+                            self.name.utf8encode(), Naming.pykwdlist_cname,
+                            len(positional_args) + i))
+                    code.putln(code.error_goto(self.pos))
                     break
-            code.putln('} else {')
-            code.put('__Pyx_RaiseKeywordRequired("%s", *%s[%d]); ' % (
-                    self.name.utf8encode(), Naming.pykwdlist_cname,
-                    len(positional_args) + i))
-            code.putln(code.error_goto(self.pos))
-            code.putln('}')
-        else:
-            # check if we have all required positional arguments
+        elif min_positional_args == max_positional_args:
+            # parse the exact number of positional arguments from the
+            # args tuple
             code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
-            exact_count = not self.star_arg and min_positional_args == max_positional_args
-            code.putln('} else if (unlikely(PyTuple_GET_SIZE(%s) < %d)) {' % (
+            code.putln('} else if (PyTuple_GET_SIZE(%s) != %d) {' % (
                     Naming.args_cname, min_positional_args))
             code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
-                    self.name.utf8encode(), exact_count, min_positional_args,
-                    max_positional_args, Naming.args_cname))
+                    self.name.utf8encode(), has_fixed_positional_count,
+                    min_positional_args, max_positional_args,
+                    Naming.args_cname))
             code.putln(code.error_goto(self.pos))
-
-            # parse all positional arguments from the args tuple
             code.putln('} else {')
-            closing = 0
             for i, arg in enumerate(positional_args):
-                if arg.default:
-                    code.putln('if (PyTuple_GET_SIZE(%s) > %s) {' % (Naming.args_cname, i))
-                    closing += 1
-                item = "PyTuple_GET_ITEM(%s, %s)" % (Naming.args_cname, i)
+                item = "PyTuple_GET_ITEM(%s, %d)" % (Naming.args_cname, i)
                 self.generate_arg_assignment(arg, item, code)
-            for _ in range(closing):
-                code.putln('}')
-
+        else:
+            # parse the positional arguments from the variable length
+            # args tuple
+            code.putln('} else {')
+            code.putln('switch (PyTuple_GET_SIZE(%s)) {' % Naming.args_cname)
+            reversed_args = list(enumerate(positional_args))[::-1]
+            for i, arg in reversed_args:
+                if i >= min_positional_args-1:
+                    code.putln('case %d:' % (i+1))
+                item = "PyTuple_GET_ITEM(%s, %d)" % (Naming.args_cname, i)
+                self.generate_arg_assignment(arg, item, code)
+            if not self.star_arg or min_positional_args > 0:
+                code.globalstate.use_utility_code(raise_argtuple_invalid_utility_code)
+                if min_positional_args == 0:
+                    code.putln('case 0:')
+                code.putln('break;')
+                if self.star_arg:
+                    for i in range(min_positional_args-1,-1,-1):
+                        code.putln('case %d:' % i)
+                else:
+                    code.putln('default:') # more arguments than allowed
+                code.put('__Pyx_RaiseArgtupleInvalid("%s", %d, %d, %d, PyTuple_GET_SIZE(%s)); ' % (
+                        self.name.utf8encode(), has_fixed_positional_count,
+                        min_positional_args, max_positional_args,
+                        Naming.args_cname))
+                code.putln(code.error_goto(self.pos))
             code.putln('}')
+
+        code.putln('}')
         return
 
     def generate_positional_args_check(self, max_positional_args,
@@ -4396,7 +4414,8 @@ static INLINE void __Pyx_RaiseDoubleKeywordsError(
 
 keyword_string_check_utility_code = [
 """
-static int __Pyx_CheckKeywordStrings(PyObject *kwdict, const char* function_name, int kw_allowed); /*proto*/
+static int __Pyx_CheckKeywordStrings(PyObject *kwdict,
+    const char* function_name, int kw_allowed); /*proto*/
 ""","""
 static int __Pyx_CheckKeywordStrings(
     PyObject *kwdict,
@@ -4433,19 +4452,22 @@ static int __Pyx_CheckKeywordStrings(
 
 #------------------------------------------------------------------------------------
 #
-#  __Pyx_CheckKeywords raises an error if non-string keywords were
-#  passed to a function, or if any unsupported keywords were passed to
-#  a function.
+#  __Pyx_CheckKeywords raises an error if any non-string or
+#  unsupported keywords were passed to a function, or if a keyword was
+#  already passed as positional argument.
+#
+#  It generally does the right thing. :)
 
 keyword_check_utility_code = [
 """
 static int __Pyx_CheckKeywords(PyObject *kwdict, const char* function_name,
-    PyObject** argnames[]); /*proto*/
+    PyObject** argnames[], Py_ssize_t num_pos_args); /*proto*/
 ""","""
 static int __Pyx_CheckKeywords(
     PyObject *kwdict,
     const char* function_name,
-    PyObject** argnames[])
+    PyObject** argnames[],
+    Py_ssize_t num_pos_args)
 {
     PyObject* key = 0;
     Py_ssize_t pos = 0;
@@ -4474,9 +4496,13 @@ static int __Pyx_CheckKeywords(
                 if (!*name)
                     goto invalid_keyword;
             }
+            if (*name && ((name-argnames) < num_pos_args)) {
+                __Pyx_RaiseDoubleKeywordsError(function_name, **name);
+               return -1;
+            }
         }
     }
-    return 1;
+    return 0;
 invalid_keyword:
     PyErr_Format(PyExc_TypeError,
     #if PY_MAJOR_VERSION < 3
@@ -4486,7 +4512,7 @@ invalid_keyword:
         "'%U' is an invalid keyword argument for this function",
         key);
     #endif
-    return 0;
+    return -1;
 }
 """]
 
@@ -4516,7 +4542,7 @@ split_keywords_utility_code = [
 """
 static int __Pyx_SplitKeywords(PyObject **kwds, PyObject **kwd_list[], \
     PyObject **kwds2, char rqd_kwds[],
-    Py_ssize_t num_posargs, char* func_name); /*proto*/
+    Py_ssize_t num_posargs, char* function_name); /*proto*/
 ""","""
 static int __Pyx_SplitKeywords(
     PyObject **kwds,
@@ -4524,7 +4550,7 @@ static int __Pyx_SplitKeywords(
     PyObject **kwds2,
     char rqd_kwds[],
     Py_ssize_t num_posargs,
-    char* func_name)
+    char* function_name)
 {
     PyObject *x = 0, *kwds1 = 0;
     int i;
@@ -4565,10 +4591,10 @@ static int __Pyx_SplitKeywords(
     *kwds = kwds1;
     return 0;
 arg_passed_twice:
-    __Pyx_RaiseDoubleKeywordsError(func_name, **p);
+    __Pyx_RaiseDoubleKeywordsError(function_name, **p);
     goto bad;
 missing_kwarg:
-    __Pyx_RaiseKeywordRequired(func_name, **p);
+    __Pyx_RaiseKeywordRequired(function_name, **p);
 bad:
     Py_XDECREF(kwds1);
     Py_XDECREF(*kwds2);
@@ -4579,14 +4605,14 @@ bad:
 check_required_keywords_utility_code = [
 """
 static INLINE int __Pyx_CheckRequiredKeywords(PyObject *kwds, PyObject **kwd_list[],
-    char rqd_kwds[], Py_ssize_t num_posargs, char* func_name); /*proto*/
+    char rqd_kwds[], Py_ssize_t num_posargs, char* function_name); /*proto*/
 ""","""
 static INLINE int __Pyx_CheckRequiredKeywords(
     PyObject *kwds,
     PyObject **kwd_list[],
     char rqd_kwds[],
     Py_ssize_t num_posargs,
-    char* func_name)
+    char* function_name)
 {
     int i;
     PyObject ***p;
@@ -4611,10 +4637,10 @@ static INLINE int __Pyx_CheckRequiredKeywords(
 
     return 0;
 arg_passed_twice:
-    __Pyx_RaiseDoubleKeywordsError(func_name, **p);
+    __Pyx_RaiseDoubleKeywordsError(function_name, **p);
     return -1;
 missing_kwarg:
-    __Pyx_RaiseKeywordRequired(func_name, **p);
+    __Pyx_RaiseKeywordRequired(function_name, **p);
     return -1;
 }
 """]