TreePath implementation for selecting nodes from the code tree
authorStefan Behnel <scoder@users.berlios.de>
Fri, 18 Sep 2009 06:02:46 +0000 (08:02 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Fri, 18 Sep 2009 06:02:46 +0000 (08:02 +0200)
Cython/Compiler/Builtin.py
Cython/Compiler/Optimize.py
Cython/Compiler/PyrexTypes.py
Cython/Compiler/Tests/TestTreePath.py [new file with mode: 0644]
Cython/Compiler/TreePath.py [new file with mode: 0644]
tests/bugs.txt

index 53ce01b13382c282bb68d26e9cf253e40460a12b..923cfc4a3e4285ff148cea4ca855f2662cb9aeaf 100644 (file)
@@ -21,7 +21,7 @@ builtin_function_table = [
     #('eval',      "",     "",      ""),
     #('execfile',  "",     "",      ""),
     #('filter',    "",     "",      ""),
-    #('getattr',    "OO",   "O",     "PyObject_GetAttr"),   # optimised later on
+    #('getattr',    "OO",   "O",     "PyObject_GetAttr"),   #  optimised later on
     ('getattr3',   "OOO",  "O",     "__Pyx_GetAttr3",       "getattr"),
     ('hasattr',    "OO",   "b",     "PyObject_HasAttr"),
     ('hash',       "O",    "l",     "PyObject_Hash"),
@@ -29,7 +29,7 @@ builtin_function_table = [
     #('id',        "",     "",      ""),
     #('input',     "",     "",      ""),
     ('intern',     "s",    "O",     "__Pyx_InternFromString"),
-    ('isinstance', "OO",   "b",     "PyObject_IsInstance"),
+    #('isinstance', "OO",   "b",     "PyObject_IsInstance"),    # optimised later on
     ('issubclass', "OO",   "b",     "PyObject_IsSubclass"),
     ('iter',       "O",    "O",     "PyObject_GetIter"),
     ('len',        "O",    "Z",     "PyObject_Length"),
index 425c7f6eff136fa95fe1bc7fe3bfdfdb0894dd9f..fe0a73299ddcd8f508547a0e7121ed8147cf189b 100644 (file)
@@ -712,6 +712,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform):
                   "expected 2 or 3, found %d" % len(args))
         return node
 
+    PyObject_TypeCheck_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_bint_type, [
+            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("type", PyrexTypes.c_py_type_object_ptr_type, None),
+            ])
+
+    PyObject_IsInstance_func_type = PyrexTypes.CFuncType(
+        PyrexTypes.c_bint_type, [
+            PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None),
+            PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None),
+            ])
+
+    def _handle_simple_function_isinstance(self, node, pos_args):
+        """Replace generic calls to isinstance(x, type) by a more
+        efficient type check.
+        """
+        args = pos_args.args
+        if len(args) != 2:
+            error(node.pos, "isinstance(x, type) called with wrong number of args, found %d" %
+                  len(args))
+            return node
+
+        type_arg = args[1]
+        if type_arg.type is Builtin.type_type:
+            function_name = "PyObject_TypeCheck"
+            function_type = self.PyObject_TypeCheck_func_type
+            args[1] = ExprNodes.CastNode(type_arg, PyrexTypes.c_py_type_object_ptr_type)
+        else:
+            function_name = "PyObject_IsInstance"
+            function_type = self.PyObject_IsInstance_func_type
+
+        return ExprNodes.PythonCapiCallNode(
+            node.pos, function_name, function_type,
+            args = args, is_temp = node.is_temp)
+
     Pyx_Type_func_type = PyrexTypes.CFuncType(
         Builtin.type_type, [
             PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None)
@@ -1058,8 +1093,8 @@ class FinalOptimizePhase(Visitor.CythonTransform):
     just before the C code generation phase. 
     
     The optimizations currently implemented in this class are: 
-        - Eliminate None assignment and refcounting for first assignment. 
-        - isinstance -> typecheck for cdef types
+        - Eliminate None assignment and refcounting for first assignment.
+        - Eliminate dead coercion nodes.
     """
     def visit_SingleAssignmentNode(self, node):
         """Avoid redundant initialisation of local variables before their
@@ -1075,18 +1110,23 @@ class FinalOptimizePhase(Visitor.CythonTransform):
                 lhs.entry.init = 0
         return node
 
-    def visit_SimpleCallNode(self, node):
-        """Replace generic calls to isinstance(x, type) by a more efficient
-        type check.
+    def visit_NoneCheckNode(self, node):
+        """Remove NoneCheckNode nodes wrapping nodes that cannot
+        possibly be None.
+
+        FIXME: the list below might be better maintained as a node
+        class attribute...
         """
-        self.visitchildren(node)
-        if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode):
-            if node.function.name == 'isinstance':
-                type_arg = node.args[1]
-                if type_arg.type.is_builtin_type and type_arg.type.name == 'type':
-                    from CythonScope import utility_scope
-                    node.function.entry = utility_scope.lookup('PyObject_TypeCheck')
-                    node.function.type = node.function.entry.type
-                    PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type)
-                    node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr)
+        target = node.arg
+        if isinstance(target, ExprNodes.NoneNode):
+            return node
+        if not target.type.is_pyobject:
+            return target
+        if isinstance(target, (ExprNodes.ConstNode,
+                               ExprNodes.NumBinopNode)):
+            return target
+        if isinstance(target, (ExprNodes.SequenceNode,
+                               ExprNodes.ComprehensionNode,
+                               ExprNodes.SetNode, ExprNodes.DictNode)):
+            return target
         return node
index 17d886c53589a40209e6eb2f27fd0e35914e1cee..46f21ab626337e4123216a31c0af0e6d0c46dda5 100644 (file)
@@ -1691,6 +1691,9 @@ c_anon_enum_type =    CAnonEnumType(-1, 1)
 c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer")
 c_py_buffer_ptr_type = CPtrType(c_py_buffer_type)
 
+c_py_type_object_type = CStructOrUnionType("PyTypeObject", "struct", None, 1, "PyTypeObject")
+c_py_type_object_ptr_type = CPtrType(c_py_type_object_type)
+
 error_type =    ErrorType()
 unspecified_type = UnspecifiedType()
 
diff --git a/Cython/Compiler/Tests/TestTreePath.py b/Cython/Compiler/Tests/TestTreePath.py
new file mode 100644 (file)
index 0000000..7014a9a
--- /dev/null
@@ -0,0 +1,57 @@
+import unittest
+from Cython.Compiler.Visitor import PrintTree
+from Cython.TestUtils import TransformTest
+from Cython.Compiler.TreePath import find_first, find_all
+
+class TestTreePath(TransformTest):
+
+    def test_node_path(self):
+        t = self.run_pipeline([], u"""
+        def decorator(fun):  # DefNode
+            return fun       # ReturnStatNode, NameNode
+        @decorator           # NameNode
+        def decorated():     # DefNode
+            pass
+        """)
+
+        self.assertEquals(2, len(find_all(t, "//DefNode")))
+        self.assertEquals(2, len(find_all(t, "//NameNode")))
+        self.assertEquals(1, len(find_all(t, "//ReturnStatNode")))
+        self.assertEquals(1, len(find_all(t, "//DefNode//ReturnStatNode")))
+
+    def test_node_path_child(self):
+        t = self.run_pipeline([], u"""
+        def decorator(fun):  # DefNode
+            return fun       # ReturnStatNode, NameNode
+        @decorator           # NameNode
+        def decorated():     # DefNode
+            pass
+        """)
+
+        self.assertEquals(1, len(find_all(t, "//DefNode/ReturnStatNode/NameNode")))
+        self.assertEquals(1, len(find_all(t, "//ReturnStatNode/NameNode")))
+
+    def test_node_path_attribute_exists(self):
+        t = self.run_pipeline([], u"""
+        def decorator(fun):
+            return fun
+        @decorator
+        def decorated():
+            pass
+        """)
+
+        self.assertEquals(2, len(find_all(t, "//NameNode[@name]")))
+
+    def test_node_path_attribute_string_predicate(self):
+        t = self.run_pipeline([], u"""
+        def decorator(fun):
+            return fun
+        @decorator
+        def decorated():
+            pass
+        """)
+
+        self.assertEquals(1, len(find_all(t, "//NameNode[@name = 'decorator']")))
+
+if __name__ == '__main__':
+    unittest.main()
diff --git a/Cython/Compiler/TreePath.py b/Cython/Compiler/TreePath.py
new file mode 100644 (file)
index 0000000..74aaeba
--- /dev/null
@@ -0,0 +1,250 @@
+"""
+A simple XPath-like language for tree traversal.
+
+This works by creating a filter chain of generator functions.  Each
+function selects a part of the expression, e.g. a child node, a
+specific descendant or a node that holds an attribute.
+"""
+
+import re
+
+path_tokenizer = re.compile(
+    "("
+    "'[^']*'|\"[^\"]*\"|"
+    "//?|"
+    "\(\)|"
+    "==?|"
+    "[/.*\[\]\(\)@])|"
+    "([^/\[\]\(\)@=\s]+)|"
+    "\s+"
+    ).findall
+
+def iterchildren(node, attr_name):
+    # returns an iterable of all child nodes of that name
+    child = getattr(node, attr_name)
+    if child is not None:
+        if type(child) is list:
+            return child
+        else:
+            return [child]
+    else:
+        return ()
+
+def _get_first_or_none(it):
+    try:
+        try:
+            _next = it.next
+        except AttributeError:
+            return next(it)
+        else:
+            return _next()
+    except StopIteration:
+        return None
+
+def type_name(node):
+    return node.__class__.__name__.split('.')[-1]
+
+def parse_func(next, token):
+    name = token[1]
+    token = next()
+    if token[0] != '(':
+        raise ValueError("Expected '(' after function name '%s'" % name)
+    predicate = handle_predicate(next, token, end_marker=')')
+    return name, predicate
+
+def handle_func_not(next, token):
+    """
+    func(...)
+    """
+    name, predicate = parse_func(next, token)
+
+    def select(result):
+        for node in result:
+            if _get_first_or_none(predicate(node)) is not None:
+                yield node
+    return select
+
+def handle_name(next, token):
+    """
+    /NodeName/
+    or
+    func(...)
+    """
+    name = token[1]
+    if name in functions:
+        return functions[name](next, token)
+    def select(result):
+        for node in result:
+            for attr_name in node.child_attrs:
+                for child in iterchildren(node, attr_name):
+                    if type_name(child) == name:
+                        yield child
+    return select
+
+def handle_star(next, token):
+    """
+    /*/
+    """
+    def select(result):
+        for node in result:
+            for name in node.child_attrs:
+                for child in iterchildren(node, name):
+                    yield child
+    return select
+
+def handle_dot(next, token):
+    """
+    /./
+    """
+    def select(result):
+        return result
+    return select
+
+def handle_descendants(next, token):
+    """
+    //...
+    """
+    token = next()
+    if token[0] == "*":
+        def iter_recursive(node):
+            for name in node.child_attrs:
+                for child in iterchildren(node, name):
+                    yield child
+                    for c in iter_recursive(child):
+                        yield c
+    elif not token[0]:
+        node_name = token[1]
+        def iter_recursive(node):
+            for name in node.child_attrs:
+                for child in iterchildren(node, name):
+                    if type_name(child) == node_name:
+                        yield child
+                    for c in iter_recursive(child):
+                        yield c
+    else:
+        raise ValueError("Expected node name after '//'")
+
+    def select(result):
+        for node in result:
+            for child in iter_recursive(node):
+                yield child
+                    
+    return select
+
+def handle_attribute(next, token):
+    token = next()
+    if token[0]:
+        raise ValueError("Expected attribute name")
+    name = token[1]
+    token = next()
+    value = None
+    if token[0] == '=':
+        value = parse_path_value(next)
+    if value is None:
+        def select(result):
+            for node in result:
+                try:
+                    attr_value = getattr(node, name)
+                except AttributeError:
+                    continue
+                if attr_value is not None:
+                    yield attr_value
+    else:
+        def select(result):
+            for node in result:
+                try:
+                    attr_value = getattr(node, name)
+                except AttributeError:
+                    continue
+                if attr_value == value:
+                    yield value
+    return select
+
+def parse_path_value(next):
+    token = next()
+    value = token[0]
+    if value[:1] == "'" or value[:1] == '"':
+        value = value[1:-1]
+    else:
+        try:
+            value = int(value)
+        except ValueError:
+            raise ValueError("Invalid attribute predicate: '%s'" % value)
+    return value
+
+def handle_predicate(next, token, end_marker=']'):
+    token = next()
+    selector = []
+    while token[0] != end_marker:
+        selector.append( operations[token[0]](next, token) )
+        try:
+            token = next()
+        except StopIteration:
+            break
+        else:
+            if token[0] == "/":
+                token = next()
+
+    def select(result):
+        for node in result:
+            subresult = iter((node,))
+            for select in selector:
+                subresult = select(subresult)
+            predicate_result = _get_first_or_none(subresult)
+            if predicate_result is not None:
+                yield predicate_result
+    return select
+
+operations = {
+    "@":  handle_attribute,
+    "":   handle_name,
+    "*":  handle_star,
+    ".":  handle_dot,
+    "//": handle_descendants,
+    "[":  handle_predicate,
+    }
+
+functions = {
+    'not' : handle_func_not
+    }
+
+def _build_path_iterator(path):
+    # parse pattern
+    stream = iter([ (special,text)
+                    for (special,text) in path_tokenizer(path)
+                    if special or text ])
+    try:
+        _next = stream.next
+    except AttributeError:
+        # Python 3
+        def _next():
+            return next(stream)
+    token = _next()
+    selector = []
+    while 1:
+        try:
+            selector.append(operations[token[0]](_next, token))
+        except StopIteration:
+            raise ValueError("invalid path")
+        try:
+            token = _next()
+            if token[0] == "/":
+                token = _next()
+        except StopIteration:
+            break
+    return selector
+
+# main module API
+
+def iterfind(node, path):
+    selector_chain = _build_path_iterator(path)
+    result = iter((node,))
+    for select in selector_chain:
+        result = select(result)
+    return result
+
+def find_first(node, path):
+    return _get_first_or_none(iterfind(node, path))
+
+def find_all(node, path):
+    return list(iterfind(node, path))
index 3c6ad215bfde7d2c233745a6d0802ad6bc221bdf..9de9fdf2ee37abd3b7dfd90b59313d5359dc59c8 100644 (file)
@@ -8,6 +8,7 @@ unsignedbehaviour_T184
 funcexc_iter_T228
 bad_c_struct_T252
 missing_baseclass_in_predecl_T262
+compile_time_unraisable_T370
 
 # Not yet enabled
 profile_test