big C++ mergeback
[cython.git] / Cython / Compiler / Visitor.py
index c6add7cbdc0db9e9f21ee2d9689934d5df51b8f1..441716c2617d856433bbfd1f082d2af1a397ec61 100644 (file)
@@ -1,3 +1,5 @@
+# cython: infer_types=True
+
 #
 #   Tree visitor and transform framework
 #
 #
 #   Tree visitor and transform framework
 #
@@ -7,7 +9,7 @@ import Nodes
 import ExprNodes
 import Naming
 import Errors
 import ExprNodes
 import Naming
 import Errors
-from StringEncoding import EncodedString
+import DebugFlags
 
 class BasicVisitor(object):
     """A generic visitor base class which can be used for visiting any kind of object."""
 
 class BasicVisitor(object):
     """A generic visitor base class which can be used for visiting any kind of object."""
@@ -18,32 +20,36 @@ class BasicVisitor(object):
         self.dispatch_table = {}
 
     def visit(self, obj):
         self.dispatch_table = {}
 
     def visit(self, obj):
-        cls = type(obj)
         try:
         try:
-            handler_method = self.dispatch_table[cls]
+            handler_method = self.dispatch_table[type(obj)]
         except KeyError:
         except KeyError:
-            #print "Cache miss for class %s in visitor %s" % (
-            #    cls.__name__, type(self).__name__)
-            # Must resolve, try entire hierarchy
-            pattern = "visit_%s"
-            mro = inspect.getmro(cls)
-            handler_method = None
-            for mro_cls in mro:
-                if hasattr(self, pattern % mro_cls.__name__):
-                    handler_method = getattr(self, pattern % mro_cls.__name__)
-                    break
-            if handler_method is None:
-                print type(self), type(obj)
-                if hasattr(self, 'access_path') and self.access_path:
-                    print self.access_path
-                    if self.access_path:
-                        print self.access_path[-1][0].pos
-                        print self.access_path[-1][0].__dict__
-                raise RuntimeError("Visitor does not accept object: %s" % obj)
-            #print "Caching " + cls.__name__
-            self.dispatch_table[cls] = handler_method
+            handler_method = self.find_handler(obj)
+            self.dispatch_table[type(obj)] = handler_method
         return handler_method(obj)
 
         return handler_method(obj)
 
+    def find_handler(self, obj):
+        cls = type(obj)
+        #print "Cache miss for class %s in visitor %s" % (
+        #    cls.__name__, type(self).__name__)
+        # Must resolve, try entire hierarchy
+        pattern = "visit_%s"
+        mro = inspect.getmro(cls)
+        handler_method = None
+        for mro_cls in mro:
+            if hasattr(self, pattern % mro_cls.__name__):
+                handler_method = getattr(self, pattern % mro_cls.__name__)
+                break
+        if handler_method is None:
+            print type(self), cls
+            if hasattr(self, 'access_path') and self.access_path:
+                print self.access_path
+                if self.access_path:
+                    print self.access_path[-1][0].pos
+                    print self.access_path[-1][0].__dict__
+            raise RuntimeError("Visitor does not accept object: %s" % obj)
+        #print "Caching " + cls.__name__
+        return handler_method
+
 class TreeVisitor(BasicVisitor):
     """
     Base class for writing visitors for a Cython tree, contains utilities for
 class TreeVisitor(BasicVisitor):
     """
     Base class for writing visitors for a Cython tree, contains utilities for
@@ -144,6 +150,29 @@ class TreeVisitor(BasicVisitor):
             stacktrace = stacktrace.tb_next
         return (last_traceback, nodes)
 
             stacktrace = stacktrace.tb_next
         return (last_traceback, nodes)
 
+    def _raise_compiler_error(self, child, e):
+        import sys
+        trace = ['']
+        for parent, attribute, index in self.access_path:
+            node = getattr(parent, attribute)
+            if index is None:
+                index = ''
+            else:
+                node = node[index]
+                index = u'[%d]' % index
+            trace.append(u'%s.%s%s = %s' % (
+                parent.__class__.__name__, attribute, index,
+                self.dump_node(node)))
+        stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
+        last_node = child
+        for node, method_name, pos in called_nodes:
+            last_node = node
+            trace.append(u"File '%s', line %d, in %s: %s" % (
+                pos[0], pos[1], method_name, self.dump_node(node)))
+        raise Errors.CompilerCrash(
+            last_node.pos, self.__class__.__name__,
+            u'\n'.join(trace), e, stacktrace)
+
     def visitchild(self, child, parent, attrname, idx):
         self.access_path.append((parent, attrname, idx))
         try:
     def visitchild(self, child, parent, attrname, idx):
         self.access_path.append((parent, attrname, idx))
         try:
@@ -151,31 +180,16 @@ class TreeVisitor(BasicVisitor):
         except Errors.CompileError:
             raise
         except Exception, e:
         except Errors.CompileError:
             raise
         except Exception, e:
-            import sys
-            trace = ['']
-            for parent, attribute, index in self.access_path:
-                node = getattr(parent, attribute)
-                if index is None:
-                    index = ''
-                else:
-                    node = node[index]
-                    index = u'[%d]' % index
-                trace.append(u'%s.%s%s = %s' % (
-                    parent.__class__.__name__, attribute, index,
-                    self.dump_node(node)))
-            stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
-            last_node = child
-            for node, method_name, pos in called_nodes:
-                last_node = node
-                trace.append(u"File '%s', line %d, in %s: %s" % (
-                    pos[0], pos[1], method_name, self.dump_node(node)))
-            raise Errors.CompilerCrash(
-                last_node.pos, self.__class__.__name__,
-                u'\n'.join(trace), e, stacktrace)
+            if DebugFlags.debug_no_exception_intercept:
+                raise
+            self._raise_compiler_error(child, e)
         self.access_path.pop()
         return result
 
     def visitchildren(self, parent, attrs=None):
         self.access_path.pop()
         return result
 
     def visitchildren(self, parent, attrs=None):
+        return self._visitchildren(parent, attrs)
+
+    def _visitchildren(self, parent, attrs):
         """
         Visits the children of the given parent. If parent is None, returns
         immediately (returning None).
         """
         Visits the children of the given parent. If parent is None, returns
         immediately (returning None).
@@ -185,14 +199,13 @@ class TreeVisitor(BasicVisitor):
         or a list of return values (in the case of multiple children
         in an attribute)).
         """
         or a list of return values (in the case of multiple children
         in an attribute)).
         """
-
         if parent is None: return None
         result = {}
         for attr in parent.child_attrs:
             if attrs is not None and attr not in attrs: continue
             child = getattr(parent, attr)
             if child is not None:
         if parent is None: return None
         result = {}
         for attr in parent.child_attrs:
             if attrs is not None and attr not in attrs: continue
             child = getattr(parent, attr)
             if child is not None:
-                if isinstance(child, list):
+                if type(child) is list:
                     childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
                 else:
                     childretval = self.visitchild(child, parent, attr, None)
                     childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
                 else:
                     childretval = self.visitchild(child, parent, attr, None)
@@ -221,22 +234,17 @@ class VisitorTransform(TreeVisitor):
     was not, an exception will be raised. (Typically you want to ensure that you
     are within a StatListNode or similar before doing this.)
     """
     was not, an exception will be raised. (Typically you want to ensure that you
     are within a StatListNode or similar before doing this.)
     """
-    def __init__(self):
-        super(VisitorTransform, self).__init__()
-        self._super_visitchildren = super(VisitorTransform, self).visitchildren
-
     def visitchildren(self, parent, attrs=None):
     def visitchildren(self, parent, attrs=None):
-        result = cython.declare(dict)
-        result = self._super_visitchildren(parent, attrs)
+        result = self._visitchildren(parent, attrs)
         for attr, newnode in result.iteritems():
         for attr, newnode in result.iteritems():
-            if not isinstance(newnode, list):
+            if not type(newnode) is list:
                 setattr(parent, attr, newnode)
             else:
                 # Flatten the list one level and remove any None
                 newlist = []
                 for x in newnode:
                     if x is not None:
                 setattr(parent, attr, newnode)
             else:
                 # Flatten the list one level and remove any None
                 newlist = []
                 for x in newnode:
                     if x is not None:
-                        if isinstance(x, list):
+                        if type(x) is list:
                             newlist += x
                         else:
                             newlist.append(x)
                             newlist += x
                         else:
                             newlist.append(x)
@@ -253,6 +261,9 @@ class VisitorTransform(TreeVisitor):
 class CythonTransform(VisitorTransform):
     """
     Certain common conventions and utilitues for Cython transforms.
 class CythonTransform(VisitorTransform):
     """
     Certain common conventions and utilitues for Cython transforms.
+
+     - Sets up the context of the pipeline in self.context
+     - Tracks directives in effect in self.current_directives
     """
     def __init__(self, context):
         super(CythonTransform, self).__init__()
     """
     def __init__(self, context):
         super(CythonTransform, self).__init__()
@@ -275,6 +286,69 @@ class CythonTransform(VisitorTransform):
         self.visitchildren(node)
         return node
 
         self.visitchildren(node)
         return node
 
+class ScopeTrackingTransform(CythonTransform):
+    # Keeps track of type of scopes
+    scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass'
+    scope_node = None
+    
+    def visit_ModuleNode(self, node):
+        self.scope_type = 'module'
+        self.scope_node = node
+        self.visitchildren(node)
+        return node
+
+    def visit_scope(self, node, scope_type):
+        prev = self.scope_type, self.scope_node
+        self.scope_type = scope_type
+        self.scope_node = node
+        self.visitchildren(node)
+        self.scope_type, self.scope_node = prev
+        return node
+    
+    def visit_CClassDefNode(self, node):
+        return self.visit_scope(node, 'cclass')
+
+    def visit_PyClassDefNode(self, node):
+        return self.visit_scope(node, 'pyclass')
+
+    def visit_FuncDefNode(self, node):
+        return self.visit_scope(node, 'function')
+
+    def visit_CStructOrUnionDefNode(self, node):
+        return self.visit_scope(node, 'struct')
+
+
+class EnvTransform(CythonTransform):
+    """
+    This transformation keeps a stack of the environments. 
+    """
+    def __call__(self, root):
+        self.env_stack = [root.scope]
+        return super(EnvTransform, self).__call__(root)        
+    
+    def visit_FuncDefNode(self, node):
+        self.env_stack.append(node.local_scope)
+        self.visitchildren(node)
+        self.env_stack.pop()
+        return node
+
+
+class RecursiveNodeReplacer(VisitorTransform):
+    """
+    Recursively replace all occurrences of a node in a subtree by
+    another node.
+    """
+    def __init__(self, orig_node, new_node):
+        super(RecursiveNodeReplacer, self).__init__()
+        self.orig_node, self.new_node = orig_node, new_node
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        if node is self.orig_node:
+            return self.new_node
+        else:
+            return node
+