optimise dict.get() in Py3 (and in Py2 when applicable)
authorStefan Behnel <scoder@users.berlios.de>
Tue, 23 Feb 2010 14:26:00 +0000 (15:26 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 23 Feb 2010 14:26:00 +0000 (15:26 +0100)
Cython/Compiler/Optimize.py
tests/run/dict_get.pyx [new file with mode: 0644]

index 3a01316d3aea8dba7cfd1d7ed8c4b6185d0a6075..db3397ca3b329c4d84eb349868fc4f723da3a04d 100644 (file)
@@ -1350,6 +1350,26 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
             node, "PyList_Reverse", self.single_param_func_type,
             'reverse', is_unbound_method, args)
 
+    Pyx_PyDict_GetItem_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.py_object_type, [
+            PyrexTypes.CFuncTypeArg("dict", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("key", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("default", PyrexTypes.py_object_type, None),
+            ],
+        exception_value = "NULL")
+
+    def _handle_simple_method_dict_get(self, node, args, is_unbound_method):
+        if len(args) == 2:
+            args.append(ExprNodes.NoneNode(node.pos))
+        elif len(args) != 3:
+            self._error_wrong_arg_count('dict.get', node, args, "2 or 3")
+            return node
+
+        return self._substitute_method_call(
+            node, "__Pyx_PyDict_GetItemDefault", self.Pyx_PyDict_GetItem_func_type,
+            'get', is_unbound_method, args,
+            utility_code = dict_getitem_default_utility_code)
+
     PyUnicode_AsEncodedString_func_type = PyrexTypes.CFuncType(
         Builtin.bytes_type, [
             PyrexTypes.CFuncTypeArg("obj", Builtin.unicode_type, None),
@@ -1575,7 +1595,8 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         return (encoding, encoding_node, error_handling, error_handling_node)
 
     def _substitute_method_call(self, node, name, func_type,
-                                attr_name, is_unbound_method, args=()):
+                                attr_name, is_unbound_method, args=(),
+                                utility_code=None):
         args = list(args)
         if args:
             self_arg = args[0]
@@ -1592,10 +1613,49 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform):
         return ExprNodes.PythonCapiCallNode(
             node.pos, name, func_type,
             args = args,
-            is_temp = node.is_temp
+            is_temp = node.is_temp,
+            utility_code = utility_code
             )
 
 
+dict_getitem_default_utility_code = UtilityCode(
+proto = '''
+static CYTHON_INLINE PyObject* __Pyx_PyDict_GetItemDefault(PyObject* d, PyObject* key, PyObject* default_value) {
+    PyObject* value;
+#if PY_MAJOR_VERSION >= 3
+    value = PyDict_GetItemWithError(d, key);
+    if (unlikely(!value)) {
+        if (unlikely(PyErr_Occurred()))
+            return NULL;
+        value = default_value;
+    }
+    Py_INCREF(value);
+#else
+    if (PyString_CheckExact(key) || PyUnicode_CheckExact(key) || PyInt_CheckExact(key)) {
+        /* these presumably have safe hash functions */
+        value = PyDict_GetItem(d, key);
+        if (unlikely(!value)) {
+            value = default_value;
+        }
+        Py_INCREF(value);
+    } else {
+        PyObject *m;
+        m = __Pyx_GetAttrString(d, "get");
+        if (!m) return NULL;
+        if (default_value == Py_None) {
+            value = PyObject_CallFunctionObjArgs(m, key, default_value, NULL);
+        } else {
+            value = PyObject_CallFunctionObjArgs(m, key, NULL);
+        }
+        Py_DECREF(m);
+    }
+#endif
+    return value;
+}
+''',
+impl = ""
+)
+
 append_utility_code = UtilityCode(
 proto = """
 static CYTHON_INLINE PyObject* __Pyx_PyObject_Append(PyObject* L, PyObject* x) {
diff --git a/tests/run/dict_get.pyx b/tests/run/dict_get.pyx
new file mode 100644 (file)
index 0000000..3aa1b2a
--- /dev/null
@@ -0,0 +1,55 @@
+def get(dict d, key):
+    """
+    >>> d = { 1: 10 }
+    >>> d.get(1)
+    10
+    >>> get(d, 1)
+    10
+
+    >>> d.get(2) is None
+    True
+    >>> get(d, 2) is None
+    True
+
+    >>> d.get((1,2)) is None
+    True
+    >>> get(d, (1,2)) is None
+    True
+
+    >>> class Unhashable:
+    ...    def __hash__(self):
+    ...        raise ValueError
+
+    >>> d.get(Unhashable())
+    Traceback (most recent call last):
+    ValueError
+    >>> get(d, Unhashable())
+    Traceback (most recent call last):
+    ValueError
+
+    >>> None.get(1)
+    Traceback (most recent call last):
+    ...
+    AttributeError: 'NoneType' object has no attribute 'get'
+    >>> get(None, 1)
+    Traceback (most recent call last):
+    ...
+    AttributeError: 'NoneType' object has no attribute 'get'
+    """
+    return d.get(key)
+
+def get_default(dict d, key, default):
+    """
+    >>> d = { 1: 10 }
+
+    >>> d.get(1, 2)
+    10
+    >>> get_default(d, 1, 2)
+    10
+
+    >>> d.get(2, 2)
+    2
+    >>> get_default(d, 2, 2)
+    2
+    """
+    return d.get(key, default)