From: Stefan Behnel Date: Fri, 18 Sep 2009 06:02:46 +0000 (+0200) Subject: TreePath implementation for selecting nodes from the code tree X-Git-Tag: 0.12.alpha0~194 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=2787ac2a48e4639efe7f027b89b96925a95b1349;p=cython.git TreePath implementation for selecting nodes from the code tree --- diff --git a/Cython/Compiler/Builtin.py b/Cython/Compiler/Builtin.py index 53ce01b1..923cfc4a 100644 --- a/Cython/Compiler/Builtin.py +++ b/Cython/Compiler/Builtin.py @@ -21,7 +21,7 @@ builtin_function_table = [ #('eval', "", "", ""), #('execfile', "", "", ""), #('filter', "", "", ""), - #('getattr', "OO", "O", "PyObject_GetAttr"), # optimised later on + #('getattr', "OO", "O", "PyObject_GetAttr"), # optimised later on ('getattr3', "OOO", "O", "__Pyx_GetAttr3", "getattr"), ('hasattr', "OO", "b", "PyObject_HasAttr"), ('hash', "O", "l", "PyObject_Hash"), @@ -29,7 +29,7 @@ builtin_function_table = [ #('id', "", "", ""), #('input', "", "", ""), ('intern', "s", "O", "__Pyx_InternFromString"), - ('isinstance', "OO", "b", "PyObject_IsInstance"), + #('isinstance', "OO", "b", "PyObject_IsInstance"), # optimised later on ('issubclass', "OO", "b", "PyObject_IsSubclass"), ('iter', "O", "O", "PyObject_GetIter"), ('len', "O", "Z", "PyObject_Length"), diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 425c7f6e..fe0a7329 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -712,6 +712,41 @@ class OptimizeBuiltinCalls(Visitor.VisitorTransform): "expected 2 or 3, found %d" % len(args)) return node + PyObject_TypeCheck_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, [ + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("type", PyrexTypes.c_py_type_object_ptr_type, None), + ]) + + PyObject_IsInstance_func_type = PyrexTypes.CFuncType( + PyrexTypes.c_bint_type, [ + PyrexTypes.CFuncTypeArg("obj", PyrexTypes.py_object_type, None), + PyrexTypes.CFuncTypeArg("type", PyrexTypes.py_object_type, None), + ]) + + def _handle_simple_function_isinstance(self, node, pos_args): + """Replace generic calls to isinstance(x, type) by a more + efficient type check. + """ + args = pos_args.args + if len(args) != 2: + error(node.pos, "isinstance(x, type) called with wrong number of args, found %d" % + len(args)) + return node + + type_arg = args[1] + if type_arg.type is Builtin.type_type: + function_name = "PyObject_TypeCheck" + function_type = self.PyObject_TypeCheck_func_type + args[1] = ExprNodes.CastNode(type_arg, PyrexTypes.c_py_type_object_ptr_type) + else: + function_name = "PyObject_IsInstance" + function_type = self.PyObject_IsInstance_func_type + + return ExprNodes.PythonCapiCallNode( + node.pos, function_name, function_type, + args = args, is_temp = node.is_temp) + Pyx_Type_func_type = PyrexTypes.CFuncType( Builtin.type_type, [ PyrexTypes.CFuncTypeArg("object", PyrexTypes.py_object_type, None) @@ -1058,8 +1093,8 @@ class FinalOptimizePhase(Visitor.CythonTransform): just before the C code generation phase. The optimizations currently implemented in this class are: - - Eliminate None assignment and refcounting for first assignment. - - isinstance -> typecheck for cdef types + - Eliminate None assignment and refcounting for first assignment. + - Eliminate dead coercion nodes. """ def visit_SingleAssignmentNode(self, node): """Avoid redundant initialisation of local variables before their @@ -1075,18 +1110,23 @@ class FinalOptimizePhase(Visitor.CythonTransform): lhs.entry.init = 0 return node - def visit_SimpleCallNode(self, node): - """Replace generic calls to isinstance(x, type) by a more efficient - type check. + def visit_NoneCheckNode(self, node): + """Remove NoneCheckNode nodes wrapping nodes that cannot + possibly be None. + + FIXME: the list below might be better maintained as a node + class attribute... """ - self.visitchildren(node) - if node.function.type.is_cfunction and isinstance(node.function, ExprNodes.NameNode): - if node.function.name == 'isinstance': - type_arg = node.args[1] - if type_arg.type.is_builtin_type and type_arg.type.name == 'type': - from CythonScope import utility_scope - node.function.entry = utility_scope.lookup('PyObject_TypeCheck') - node.function.type = node.function.entry.type - PyTypeObjectPtr = PyrexTypes.CPtrType(utility_scope.lookup('PyTypeObject').type) - node.args[1] = ExprNodes.CastNode(node.args[1], PyTypeObjectPtr) + target = node.arg + if isinstance(target, ExprNodes.NoneNode): + return node + if not target.type.is_pyobject: + return target + if isinstance(target, (ExprNodes.ConstNode, + ExprNodes.NumBinopNode)): + return target + if isinstance(target, (ExprNodes.SequenceNode, + ExprNodes.ComprehensionNode, + ExprNodes.SetNode, ExprNodes.DictNode)): + return target return node diff --git a/Cython/Compiler/PyrexTypes.py b/Cython/Compiler/PyrexTypes.py index 17d886c5..46f21ab6 100644 --- a/Cython/Compiler/PyrexTypes.py +++ b/Cython/Compiler/PyrexTypes.py @@ -1691,6 +1691,9 @@ c_anon_enum_type = CAnonEnumType(-1, 1) c_py_buffer_type = CStructOrUnionType("Py_buffer", "struct", None, 1, "Py_buffer") c_py_buffer_ptr_type = CPtrType(c_py_buffer_type) +c_py_type_object_type = CStructOrUnionType("PyTypeObject", "struct", None, 1, "PyTypeObject") +c_py_type_object_ptr_type = CPtrType(c_py_type_object_type) + error_type = ErrorType() unspecified_type = UnspecifiedType() diff --git a/Cython/Compiler/Tests/TestTreePath.py b/Cython/Compiler/Tests/TestTreePath.py new file mode 100644 index 00000000..7014a9ad --- /dev/null +++ b/Cython/Compiler/Tests/TestTreePath.py @@ -0,0 +1,57 @@ +import unittest +from Cython.Compiler.Visitor import PrintTree +from Cython.TestUtils import TransformTest +from Cython.Compiler.TreePath import find_first, find_all + +class TestTreePath(TransformTest): + + def test_node_path(self): + t = self.run_pipeline([], u""" + def decorator(fun): # DefNode + return fun # ReturnStatNode, NameNode + @decorator # NameNode + def decorated(): # DefNode + pass + """) + + self.assertEquals(2, len(find_all(t, "//DefNode"))) + self.assertEquals(2, len(find_all(t, "//NameNode"))) + self.assertEquals(1, len(find_all(t, "//ReturnStatNode"))) + self.assertEquals(1, len(find_all(t, "//DefNode//ReturnStatNode"))) + + def test_node_path_child(self): + t = self.run_pipeline([], u""" + def decorator(fun): # DefNode + return fun # ReturnStatNode, NameNode + @decorator # NameNode + def decorated(): # DefNode + pass + """) + + self.assertEquals(1, len(find_all(t, "//DefNode/ReturnStatNode/NameNode"))) + self.assertEquals(1, len(find_all(t, "//ReturnStatNode/NameNode"))) + + def test_node_path_attribute_exists(self): + t = self.run_pipeline([], u""" + def decorator(fun): + return fun + @decorator + def decorated(): + pass + """) + + self.assertEquals(2, len(find_all(t, "//NameNode[@name]"))) + + def test_node_path_attribute_string_predicate(self): + t = self.run_pipeline([], u""" + def decorator(fun): + return fun + @decorator + def decorated(): + pass + """) + + self.assertEquals(1, len(find_all(t, "//NameNode[@name = 'decorator']"))) + +if __name__ == '__main__': + unittest.main() diff --git a/Cython/Compiler/TreePath.py b/Cython/Compiler/TreePath.py new file mode 100644 index 00000000..74aaeba7 --- /dev/null +++ b/Cython/Compiler/TreePath.py @@ -0,0 +1,250 @@ +""" +A simple XPath-like language for tree traversal. + +This works by creating a filter chain of generator functions. Each +function selects a part of the expression, e.g. a child node, a +specific descendant or a node that holds an attribute. +""" + +import re + +path_tokenizer = re.compile( + "(" + "'[^']*'|\"[^\"]*\"|" + "//?|" + "\(\)|" + "==?|" + "[/.*\[\]\(\)@])|" + "([^/\[\]\(\)@=\s]+)|" + "\s+" + ).findall + +def iterchildren(node, attr_name): + # returns an iterable of all child nodes of that name + child = getattr(node, attr_name) + if child is not None: + if type(child) is list: + return child + else: + return [child] + else: + return () + +def _get_first_or_none(it): + try: + try: + _next = it.next + except AttributeError: + return next(it) + else: + return _next() + except StopIteration: + return None + +def type_name(node): + return node.__class__.__name__.split('.')[-1] + +def parse_func(next, token): + name = token[1] + token = next() + if token[0] != '(': + raise ValueError("Expected '(' after function name '%s'" % name) + predicate = handle_predicate(next, token, end_marker=')') + return name, predicate + +def handle_func_not(next, token): + """ + func(...) + """ + name, predicate = parse_func(next, token) + + def select(result): + for node in result: + if _get_first_or_none(predicate(node)) is not None: + yield node + return select + +def handle_name(next, token): + """ + /NodeName/ + or + func(...) + """ + name = token[1] + if name in functions: + return functions[name](next, token) + def select(result): + for node in result: + for attr_name in node.child_attrs: + for child in iterchildren(node, attr_name): + if type_name(child) == name: + yield child + return select + +def handle_star(next, token): + """ + /*/ + """ + def select(result): + for node in result: + for name in node.child_attrs: + for child in iterchildren(node, name): + yield child + return select + +def handle_dot(next, token): + """ + /./ + """ + def select(result): + return result + return select + +def handle_descendants(next, token): + """ + //... + """ + token = next() + if token[0] == "*": + def iter_recursive(node): + for name in node.child_attrs: + for child in iterchildren(node, name): + yield child + for c in iter_recursive(child): + yield c + elif not token[0]: + node_name = token[1] + def iter_recursive(node): + for name in node.child_attrs: + for child in iterchildren(node, name): + if type_name(child) == node_name: + yield child + for c in iter_recursive(child): + yield c + else: + raise ValueError("Expected node name after '//'") + + def select(result): + for node in result: + for child in iter_recursive(node): + yield child + + return select + +def handle_attribute(next, token): + token = next() + if token[0]: + raise ValueError("Expected attribute name") + name = token[1] + token = next() + value = None + if token[0] == '=': + value = parse_path_value(next) + if value is None: + def select(result): + for node in result: + try: + attr_value = getattr(node, name) + except AttributeError: + continue + if attr_value is not None: + yield attr_value + else: + def select(result): + for node in result: + try: + attr_value = getattr(node, name) + except AttributeError: + continue + if attr_value == value: + yield value + return select + +def parse_path_value(next): + token = next() + value = token[0] + if value[:1] == "'" or value[:1] == '"': + value = value[1:-1] + else: + try: + value = int(value) + except ValueError: + raise ValueError("Invalid attribute predicate: '%s'" % value) + return value + +def handle_predicate(next, token, end_marker=']'): + token = next() + selector = [] + while token[0] != end_marker: + selector.append( operations[token[0]](next, token) ) + try: + token = next() + except StopIteration: + break + else: + if token[0] == "/": + token = next() + + def select(result): + for node in result: + subresult = iter((node,)) + for select in selector: + subresult = select(subresult) + predicate_result = _get_first_or_none(subresult) + if predicate_result is not None: + yield predicate_result + return select + +operations = { + "@": handle_attribute, + "": handle_name, + "*": handle_star, + ".": handle_dot, + "//": handle_descendants, + "[": handle_predicate, + } + +functions = { + 'not' : handle_func_not + } + +def _build_path_iterator(path): + # parse pattern + stream = iter([ (special,text) + for (special,text) in path_tokenizer(path) + if special or text ]) + try: + _next = stream.next + except AttributeError: + # Python 3 + def _next(): + return next(stream) + token = _next() + selector = [] + while 1: + try: + selector.append(operations[token[0]](_next, token)) + except StopIteration: + raise ValueError("invalid path") + try: + token = _next() + if token[0] == "/": + token = _next() + except StopIteration: + break + return selector + +# main module API + +def iterfind(node, path): + selector_chain = _build_path_iterator(path) + result = iter((node,)) + for select in selector_chain: + result = select(result) + return result + +def find_first(node, path): + return _get_first_or_none(iterfind(node, path)) + +def find_all(node, path): + return list(iterfind(node, path)) diff --git a/tests/bugs.txt b/tests/bugs.txt index 3c6ad215..9de9fdf2 100644 --- a/tests/bugs.txt +++ b/tests/bugs.txt @@ -8,6 +8,7 @@ unsignedbehaviour_T184 funcexc_iter_T228 bad_c_struct_T252 missing_baseclass_in_predecl_T262 +compile_time_unraisable_T370 # Not yet enabled profile_test