From: Stefan Behnel Date: Tue, 23 Feb 2010 14:26:00 +0000 (+0100) Subject: optimise dict.get() in Py3 (and in Py2 when applicable) X-Git-Tag: 0.13.beta0~319^2~22^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=ff147aad6dc5cd0c5900dd3ee6e1ec44981439f7;p=cython.git optimise dict.get() in Py3 (and in Py2 when applicable) --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 3a01316d..db3397ca 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1350,6 +1350,26 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): node, "PyList_Reverse", self.single_param_func_type, 'reverse', is_unbound_method, args) + Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType( + PyrexTypes.py_object_type, [ + PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None), + ], + exception_value = "NULL") + + def _handle_simple_method_dict_get(self, node, args, is_unbound_method): + if len(args) == 2: + args.append(ExprNodes.NoneNode(node.pos)) + elif len(args) != 3: + self._error_wrong_arg_count('dict.get', node, args, "2 or 3") + return node + + return self._substitute_method_call( + node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type, + 'get', is_unbound_method, args, + utility_code = dict_getitem_default_utility_code) + PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType( Builtin.bytes_type, [ PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None), @@ -1575,7 +1595,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): return (encoding, encoding_node, error_handling, error_handling_node) def _substitute_method_call(self, node, name, func_type, - attr_name, is_unbound_method, args=()): + attr_name, is_unbound_method, args=(), + utility_code=None): args = list(args) if args: self_arg = args[0] @@ -1592,10 +1613,49 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): return ExprNodes.PythonCapiCallNode( node.pos, name, func_type, args = args, - is_temp = node.is_temp + is_temp = node.is_temp, + utility_code = utility_code ) +dict_getitem_default_utility_code = UtilityCode( +proto = ''' +static CYTHON_INLINE PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) { + PyObject* value; +#if PY_MAJOR_VERSION >= 3 + value = PyDict_GetItemWithError(d, key); + if (unlikely(!value)) { + if (unlikely(PyErr_Occurred())) + return NULL; + value = default_value; + } + Py_INCREF(value); +#else + if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) { + /* these presumably have safe hash functions */ + value = PyDict_GetItem(d, key); + if (unlikely(!value)) { + value = default_value; + } + Py_INCREF(value); + } else { + PyObject *m; + m = __Pyx_GetAttrString(d, "get"); + if (!m) return NULL; + if (default_value == Py_None) { + value = PyObject_CallFunctionObjArgs(m, key, default_value, NULL); + } else { + value = PyObject_CallFunctionObjArgs(m, key, NULL); + } + Py_DECREF(m); + } +#endif + return value; +} +''', +impl = "" +) + append_utility_code = UtilityCode( proto = """ static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) { diff --git a/tests/run/dict_get.pyx b/tests/run/dict_get.pyx new file mode 100644 index 00000000..3aa1b2af --- /dev/null +++ b/tests/run/dict_get.pyx @@ -0,0 +1,55 @@ +def get(dict d, key): + """ + >>> d = { 1: 10 } + >>> d.get(1) + 10 + >>> get(d, 1) + 10 + + >>> d.get(2) is None + True + >>> get(d, 2) is None + True + + >>> d.get((1,2)) is None + True + >>> get(d, (1,2)) is None + True + + >>> class Unhashable: + ... def __hash__(self): + ... raise ValueError + + >>> d.get(Unhashable()) + Traceback (most recent call last): + ValueError + >>> get(d, Unhashable()) + Traceback (most recent call last): + ValueError + + >>> None.get(1) + Traceback (most recent call last): + ... + AttributeError: 'NoneType' object has no attribute 'get' + >>> get(None, 1) + Traceback (most recent call last): + ... + AttributeError: 'NoneType' object has no attribute 'get' + """ + return d.get(key) + +def get_default(dict d, key, default): + """ + >>> d = { 1: 10 } + + >>> d.get(1, 2) + 10 + >>> get_default(d, 1, 2) + 10 + + >>> d.get(2, 2) + 2 + >>> get_default(d, 2, 2) + 2 + """ + return d.get(key, default)