speed up tree visitor somewhat by moving code out of the critical methods
authorStefan Behnel <scoder@users.berlios.de>
Wed, 9 Dec 2009 08:31:18 +0000 (09:31 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Wed, 9 Dec 2009 08:31:18 +0000 (09:31 +0100)
Cython/Compiler/Visitor.pxd
Cython/Compiler/Visitor.py

index 5a0edb39cd246d2343b6ba6058823a22cbf7db0e..160d2595763c4fb056065e57ed0d99468112ccd6 100644 (file)
@@ -1,10 +1,15 @@
+cimport cython
+
 cdef class BasicVisitor:
     cdef dict dispatch_table
     cpdef visit(self, obj)
+    cpdef find_handler(self, obj)
 
 cdef class TreeVisitor(BasicVisitor):
     cdef public list access_path
     cpdef visitchild(self, child, parent, attrname, idx)
+    @cython.locals(idx=int)
+    cpdef dict _visitchildren(self, parent, attrs)
 #    cpdef visitchildren(self, parent, attrs=*)
 
 cdef class VisitorTransform(TreeVisitor):
index d06149bd5984ca5450de9d415081b7240ca9e5ac..ee18a52ac5c7ac308b264bf2899fc44546933300 100644 (file)
@@ -1,3 +1,5 @@
+# cython: infer_types=True
+
 #
 #   Tree visitor and transform framework
 #
@@ -19,32 +21,36 @@ class BasicVisitor(object):
         self.dispatch_table = {}
 
     def visit(self, obj):
-        cls = type(obj)
         try:
-            handler_method = self.dispatch_table[cls]
+            handler_method = self.dispatch_table[type(obj)]
         except KeyError:
-            #print "Cache miss for class %s in visitor %s" % (
-            #    cls.__name__, type(self).__name__)
-            # Must resolve, try entire hierarchy
-            pattern = "visit_%s"
-            mro = inspect.getmro(cls)
-            handler_method = None
-            for mro_cls in mro:
-                if hasattr(self, pattern % mro_cls.__name__):
-                    handler_method = getattr(self, pattern % mro_cls.__name__)
-                    break
-            if handler_method is None:
-                print type(self), type(obj)
-                if hasattr(self, 'access_path') and self.access_path:
-                    print self.access_path
-                    if self.access_path:
-                        print self.access_path[-1][0].pos
-                        print self.access_path[-1][0].__dict__
-                raise RuntimeError("Visitor does not accept object: %s" % obj)
-            #print "Caching " + cls.__name__
-            self.dispatch_table[cls] = handler_method
+            handler_method = self.find_handler(obj)
+            self.dispatch_table[type(obj)] = handler_method
         return handler_method(obj)
 
+    def find_handler(self, obj):
+        cls = type(obj)
+        #print "Cache miss for class %s in visitor %s" % (
+        #    cls.__name__, type(self).__name__)
+        # Must resolve, try entire hierarchy
+        pattern = "visit_%s"
+        mro = inspect.getmro(cls)
+        handler_method = None
+        for mro_cls in mro:
+            if hasattr(self, pattern % mro_cls.__name__):
+                handler_method = getattr(self, pattern % mro_cls.__name__)
+                break
+        if handler_method is None:
+            print type(self), cls
+            if hasattr(self, 'access_path') and self.access_path:
+                print self.access_path
+                if self.access_path:
+                    print self.access_path[-1][0].pos
+                    print self.access_path[-1][0].__dict__
+            raise RuntimeError("Visitor does not accept object: %s" % obj)
+        #print "Caching " + cls.__name__
+        return handler_method
+
 class TreeVisitor(BasicVisitor):
     """
     Base class for writing visitors for a Cython tree, contains utilities for
@@ -144,6 +150,29 @@ class TreeVisitor(BasicVisitor):
             stacktrace = stacktrace.tb_next
         return (last_traceback, nodes)
 
+    def _raise_compiler_error(self, child, e):
+        import sys
+        trace = ['']
+        for parent, attribute, index in self.access_path:
+            node = getattr(parent, attribute)
+            if index is None:
+                index = ''
+            else:
+                node = node[index]
+                index = u'[%d]' % index
+            trace.append(u'%s.%s%s = %s' % (
+                parent.__class__.__name__, attribute, index,
+                self.dump_node(node)))
+        stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
+        last_node = child
+        for node, method_name, pos in called_nodes:
+            last_node = node
+            trace.append(u"File '%s', line %d, in %s: %s" % (
+                pos[0], pos[1], method_name, self.dump_node(node)))
+        raise Errors.CompilerCrash(
+            last_node.pos, self.__class__.__name__,
+            u'\n'.join(trace), e, stacktrace)
+
     def visitchild(self, child, parent, attrname, idx):
         self.access_path.append((parent, attrname, idx))
         try:
@@ -151,33 +180,16 @@ class TreeVisitor(BasicVisitor):
         except Errors.CompileError:
             raise
         except Exception, e:
-            import sys
             if DebugFlags.debug_no_exception_intercept:
                 raise
-            trace = ['']
-            for parent, attribute, index in self.access_path:
-                node = getattr(parent, attribute)
-                if index is None:
-                    index = ''
-                else:
-                    node = node[index]
-                    index = u'[%d]' % index
-                trace.append(u'%s.%s%s = %s' % (
-                    parent.__class__.__name__, attribute, index,
-                    self.dump_node(node)))
-            stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
-            last_node = child
-            for node, method_name, pos in called_nodes:
-                last_node = node
-                trace.append(u"File '%s', line %d, in %s: %s" % (
-                    pos[0], pos[1], method_name, self.dump_node(node)))
-            raise Errors.CompilerCrash(
-                last_node.pos, self.__class__.__name__,
-                u'\n'.join(trace), e, stacktrace)
+            self._raise_compiler_error(child, e)
         self.access_path.pop()
         return result
 
     def visitchildren(self, parent, attrs=None):
+        return self._visitchildren(parent, attrs)
+
+    def _visitchildren(self, parent, attrs):
         """
         Visits the children of the given parent. If parent is None, returns
         immediately (returning None).
@@ -223,8 +235,7 @@ class VisitorTransform(TreeVisitor):
     are within a StatListNode or similar before doing this.)
     """
     def visitchildren(self, parent, attrs=None):
-        result = cython.declare(dict)
-        result = TreeVisitor.visitchildren(self, parent, attrs)
+        result = self._visitchildren(parent, attrs)
         for attr, newnode in result.iteritems():
             if not type(newnode) is list:
                 setattr(parent, attr, newnode)