extended TreePath test case, fix predicate evaluation
authorStefan Behnel <scoder@users.berlios.de>
Sun, 4 Oct 2009 16:56:10 +0000 (18:56 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 4 Oct 2009 16:56:10 +0000 (18:56 +0200)
Cython/Compiler/Tests/TestTreePath.py
Cython/Compiler/TreePath.py

index 8c6e7fe5072d9c19ab0728a3d5430150e5adab0b..1ae0d374a89e2e9461f677fe8e18ab063344a79f 100644 (file)
@@ -2,6 +2,7 @@ import unittest
 from Cython.Compiler.Visitor import PrintTree
 from Cython.TestUtils import TransformTest
 from Cython.Compiler.TreePath import find_first, find_all
+from Cython.Compiler import Nodes, ExprNodes
 
 class TestTreePath(TransformTest):
     _tree = None
@@ -24,6 +25,12 @@ class TestTreePath(TransformTest):
         self.assertEquals(1, len(find_all(t, "//ReturnStatNode")))
         self.assertEquals(1, len(find_all(t, "//DefNode//ReturnStatNode")))
 
+    def test_node_path_star(self):
+        t = self._build_tree()
+        self.assertEquals(10, len(find_all(t, "//*")))
+        self.assertEquals(8, len(find_all(t, "//DefNode//*")))
+        self.assertEquals(0, len(find_all(t, "//NameNode//*")))
+
     def test_node_path_attribute(self):
         t = self._build_tree()
         self.assertEquals(2, len(find_all(t, "//NameNode/@name")))
@@ -34,9 +41,27 @@ class TestTreePath(TransformTest):
         self.assertEquals(1, len(find_all(t, "//DefNode/ReturnStatNode/NameNode")))
         self.assertEquals(1, len(find_all(t, "//ReturnStatNode/NameNode")))
 
+    def test_node_path_node_predicate(self):
+        t = self._build_tree()
+        self.assertEquals(0, len(find_all(t, "//DefNode[.//ForInStatNode]")))
+        self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode]")))
+        self.assertEquals(1, len(find_all(t, "//ReturnStatNode[./NameNode]")))
+        self.assertEquals(Nodes.ReturnStatNode,
+                          type(find_first(t, "//ReturnStatNode[./NameNode]")))
+
+    def test_node_path_node_predicate_step(self):
+        t = self._build_tree()
+        self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode]")))
+        self.assertEquals(8, len(find_all(t, "//DefNode[.//NameNode]//*")))
+        self.assertEquals(1, len(find_all(t, "//DefNode[.//NameNode]//ReturnStatNode")))
+        self.assertEquals(Nodes.ReturnStatNode,
+                          type(find_first(t, "//DefNode[.//NameNode]//ReturnStatNode")))
+
     def test_node_path_attribute_exists(self):
         t = self._build_tree()
         self.assertEquals(2, len(find_all(t, "//NameNode[@name]")))
+        self.assertEquals(ExprNodes.NameNode,
+                          type(find_first(t, "//NameNode[@name]")))
 
     def test_node_path_attribute_exists_not(self):
         t = self._build_tree()
@@ -47,5 +72,11 @@ class TestTreePath(TransformTest):
         t = self._build_tree()
         self.assertEquals(1, len(find_all(t, "//NameNode[@name = 'decorator']")))
 
+    def test_node_path_recursive_predicate(self):
+        t = self._build_tree()
+        self.assertEquals(2, len(find_all(t, "//DefNode[.//NameNode[@name]]")))
+        self.assertEquals(1, len(find_all(t, "//DefNode[.//NameNode[@name = 'decorator']]")))
+        self.assertEquals(1, len(find_all(t, "//DefNode[.//ReturnStatNode[./NameNode[@name = 'fun']]/NameNode]")))
+
 if __name__ == '__main__':
     unittest.main()
index 30612a50605b39d8bcaf07718502c62847392a7d..432c8652ea1b3da357852075e2093dee35a5b4ad 100644 (file)
@@ -54,7 +54,7 @@ def parse_func(next, token):
 
 def handle_func_not(next, token):
     """
-    func(...)
+    not(...)
     """
     name, predicate = parse_func(next, token)
 
@@ -196,7 +196,7 @@ def handle_predicate(next, token):
                 subresult = select(subresult)
             predicate_result = _get_first_or_none(subresult)
             if predicate_result is not None:
-                yield predicate_result
+                yield node
     return select
 
 operations = {