big C++ mergeback
[cython.git] / Cython / Compiler / Visitor.py
index c6add7cbdc0db9e9f21ee2d9689934d5df51b8f1..441716c2617d856433bbfd1f082d2af1a397ec61 100644 (file)
@@ -1,3 +1,5 @@
+# cython: infer_types=True
+
 #
 #   Tree visitor and transform framework
 #
@@ -7,7 +9,7 @@ import Nodes
 import ExprNodes
 import Naming
 import Errors
-from StringEncoding import EncodedString
+import DebugFlags
 
 class BasicVisitor(object):
     """A generic visitor base class which can be used for visiting any kind of object."""
@@ -18,32 +20,36 @@ class BasicVisitor(object):
         self.dispatch_table = {}
 
     def visit(self, obj):
-        cls = type(obj)
         try:
-            handler_method = self.dispatch_table[cls]
+            handler_method = self.dispatch_table[type(obj)]
         except KeyError:
-            #print "Cache miss for class %s in visitor %s" % (
-            #    cls.__name__, type(self).__name__)
-            # Must resolve, try entire hierarchy
-            pattern = "visit_%s"
-            mro = inspect.getmro(cls)
-            handler_method = None
-            for mro_cls in mro:
-                if hasattr(self, pattern % mro_cls.__name__):
-                    handler_method = getattr(self, pattern % mro_cls.__name__)
-                    break
-            if handler_method is None:
-                print type(self), type(obj)
-                if hasattr(self, 'access_path') and self.access_path:
-                    print self.access_path
-                    if self.access_path:
-                        print self.access_path[-1][0].pos
-                        print self.access_path[-1][0].__dict__
-                raise RuntimeError("Visitor does not accept object: %s" % obj)
-            #print "Caching " + cls.__name__
-            self.dispatch_table[cls] = handler_method
+            handler_method = self.find_handler(obj)
+            self.dispatch_table[type(obj)] = handler_method
         return handler_method(obj)
 
+    def find_handler(self, obj):
+        cls = type(obj)
+        #print "Cache miss for class %s in visitor %s" % (
+        #    cls.__name__, type(self).__name__)
+        # Must resolve, try entire hierarchy
+        pattern = "visit_%s"
+        mro = inspect.getmro(cls)
+        handler_method = None
+        for mro_cls in mro:
+            if hasattr(self, pattern % mro_cls.__name__):
+                handler_method = getattr(self, pattern % mro_cls.__name__)
+                break
+        if handler_method is None:
+            print type(self), cls
+            if hasattr(self, 'access_path') and self.access_path:
+                print self.access_path
+                if self.access_path:
+                    print self.access_path[-1][0].pos
+                    print self.access_path[-1][0].__dict__
+            raise RuntimeError("Visitor does not accept object: %s" % obj)
+        #print "Caching " + cls.__name__
+        return handler_method
+
 class TreeVisitor(BasicVisitor):
     """
     Base class for writing visitors for a Cython tree, contains utilities for
@@ -144,6 +150,29 @@ class TreeVisitor(BasicVisitor):
             stacktrace = stacktrace.tb_next
         return (last_traceback, nodes)
 
+    def _raise_compiler_error(self, child, e):
+        import sys
+        trace = ['']
+        for parent, attribute, index in self.access_path:
+            node = getattr(parent, attribute)
+            if index is None:
+                index = ''
+            else:
+                node = node[index]
+                index = u'[%d]' % index
+            trace.append(u'%s.%s%s = %s' % (
+                parent.__class__.__name__, attribute, index,
+                self.dump_node(node)))
+        stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
+        last_node = child
+        for node, method_name, pos in called_nodes:
+            last_node = node
+            trace.append(u"File '%s', line %d, in %s: %s" % (
+                pos[0], pos[1], method_name, self.dump_node(node)))
+        raise Errors.CompilerCrash(
+            last_node.pos, self.__class__.__name__,
+            u'\n'.join(trace), e, stacktrace)
+
     def visitchild(self, child, parent, attrname, idx):
         self.access_path.append((parent, attrname, idx))
         try:
@@ -151,31 +180,16 @@ class TreeVisitor(BasicVisitor):
         except Errors.CompileError:
             raise
         except Exception, e:
-            import sys
-            trace = ['']
-            for parent, attribute, index in self.access_path:
-                node = getattr(parent, attribute)
-                if index is None:
-                    index = ''
-                else:
-                    node = node[index]
-                    index = u'[%d]' % index
-                trace.append(u'%s.%s%s = %s' % (
-                    parent.__class__.__name__, attribute, index,
-                    self.dump_node(node)))
-            stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
-            last_node = child
-            for node, method_name, pos in called_nodes:
-                last_node = node
-                trace.append(u"File '%s', line %d, in %s: %s" % (
-                    pos[0], pos[1], method_name, self.dump_node(node)))
-            raise Errors.CompilerCrash(
-                last_node.pos, self.__class__.__name__,
-                u'\n'.join(trace), e, stacktrace)
+            if DebugFlags.debug_no_exception_intercept:
+                raise
+            self._raise_compiler_error(child, e)
         self.access_path.pop()
         return result
 
     def visitchildren(self, parent, attrs=None):
+        return self._visitchildren(parent, attrs)
+
+    def _visitchildren(self, parent, attrs):
         """
         Visits the children of the given parent. If parent is None, returns
         immediately (returning None).
@@ -185,14 +199,13 @@ class TreeVisitor(BasicVisitor):
         or a list of return values (in the case of multiple children
         in an attribute)).
         """
-
         if parent is None: return None
         result = {}
         for attr in parent.child_attrs:
             if attrs is not None and attr not in attrs: continue
             child = getattr(parent, attr)
             if child is not None:
-                if isinstance(child, list):
+                if type(child) is list:
                     childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
                 else:
                     childretval = self.visitchild(child, parent, attr, None)
@@ -221,22 +234,17 @@ class VisitorTransform(TreeVisitor):
     was not, an exception will be raised. (Typically you want to ensure that you
     are within a StatListNode or similar before doing this.)
     """
-    def __init__(self):
-        super(VisitorTransform, self).__init__()
-        self._super_visitchildren = super(VisitorTransform, self).visitchildren
-
     def visitchildren(self, parent, attrs=None):
-        result = cython.declare(dict)
-        result = self._super_visitchildren(parent, attrs)
+        result = self._visitchildren(parent, attrs)
         for attr, newnode in result.iteritems():
-            if not isinstance(newnode, list):
+            if not type(newnode) is list:
                 setattr(parent, attr, newnode)
             else:
                 # Flatten the list one level and remove any None
                 newlist = []
                 for x in newnode:
                     if x is not None:
-                        if isinstance(x, list):
+                        if type(x) is list:
                             newlist += x
                         else:
                             newlist.append(x)
@@ -253,6 +261,9 @@ class VisitorTransform(TreeVisitor):
 class CythonTransform(VisitorTransform):
     """
     Certain common conventions and utilitues for Cython transforms.
+
+     - Sets up the context of the pipeline in self.context
+     - Tracks directives in effect in self.current_directives
     """
     def __init__(self, context):
         super(CythonTransform, self).__init__()
@@ -275,6 +286,69 @@ class CythonTransform(VisitorTransform):
         self.visitchildren(node)
         return node
 
+class ScopeTrackingTransform(CythonTransform):
+    # Keeps track of type of scopes
+    scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass'
+    scope_node = None
+    
+    def visit_ModuleNode(self, node):
+        self.scope_type = 'module'
+        self.scope_node = node
+        self.visitchildren(node)
+        return node
+
+    def visit_scope(self, node, scope_type):
+        prev = self.scope_type, self.scope_node
+        self.scope_type = scope_type
+        self.scope_node = node
+        self.visitchildren(node)
+        self.scope_type, self.scope_node = prev
+        return node
+    
+    def visit_CClassDefNode(self, node):
+        return self.visit_scope(node, 'cclass')
+
+    def visit_PyClassDefNode(self, node):
+        return self.visit_scope(node, 'pyclass')
+
+    def visit_FuncDefNode(self, node):
+        return self.visit_scope(node, 'function')
+
+    def visit_CStructOrUnionDefNode(self, node):
+        return self.visit_scope(node, 'struct')
+
+
+class EnvTransform(CythonTransform):
+    """
+    This transformation keeps a stack of the environments. 
+    """
+    def __call__(self, root):
+        self.env_stack = [root.scope]
+        return super(EnvTransform, self).__call__(root)        
+    
+    def visit_FuncDefNode(self, node):
+        self.env_stack.append(node.local_scope)
+        self.visitchildren(node)
+        self.env_stack.pop()
+        return node
+
+
+class RecursiveNodeReplacer(VisitorTransform):
+    """
+    Recursively replace all occurrences of a node in a subtree by
+    another node.
+    """
+    def __init__(self, orig_node, new_node):
+        super(RecursiveNodeReplacer, self).__init__()
+        self.orig_node, self.new_node = orig_node, new_node
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        if node is self.orig_node:
+            return self.new_node
+        else:
+            return node
+