+# cython: infer_types=True
+
#
# Tree visitor and transform framework
#
import ExprNodes
import Naming
import Errors
-from StringEncoding import EncodedString
+import DebugFlags
class BasicVisitor(object):
"""A generic visitor base class which can be used for visiting any kind of object."""
self.dispatch_table = {}
def visit(self, obj):
- cls = type(obj)
try:
- handler_method = self.dispatch_table[cls]
+ handler_method = self.dispatch_table[type(obj)]
except KeyError:
- #print "Cache miss for class %s in visitor %s" % (
- # cls.__name__, type(self).__name__)
- # Must resolve, try entire hierarchy
- pattern = "visit_%s"
- mro = inspect.getmro(cls)
- handler_method = None
- for mro_cls in mro:
- if hasattr(self, pattern % mro_cls.__name__):
- handler_method = getattr(self, pattern % mro_cls.__name__)
- break
- if handler_method is None:
- print type(self), type(obj)
- if hasattr(self, 'access_path') and self.access_path:
- print self.access_path
- if self.access_path:
- print self.access_path[-1][0].pos
- print self.access_path[-1][0].__dict__
- raise RuntimeError("Visitor does not accept object: %s" % obj)
- #print "Caching " + cls.__name__
- self.dispatch_table[cls] = handler_method
+ handler_method = self.find_handler(obj)
+ self.dispatch_table[type(obj)] = handler_method
return handler_method(obj)
+ def find_handler(self, obj):
+ cls = type(obj)
+ #print "Cache miss for class %s in visitor %s" % (
+ # cls.__name__, type(self).__name__)
+ # Must resolve, try entire hierarchy
+ pattern = "visit_%s"
+ mro = inspect.getmro(cls)
+ handler_method = None
+ for mro_cls in mro:
+ if hasattr(self, pattern % mro_cls.__name__):
+ handler_method = getattr(self, pattern % mro_cls.__name__)
+ break
+ if handler_method is None:
+ print type(self), cls
+ if hasattr(self, 'access_path') and self.access_path:
+ print self.access_path
+ if self.access_path:
+ print self.access_path[-1][0].pos
+ print self.access_path[-1][0].__dict__
+ raise RuntimeError("Visitor does not accept object: %s" % obj)
+ #print "Caching " + cls.__name__
+ return handler_method
+
class TreeVisitor(BasicVisitor):
"""
Base class for writing visitors for a Cython tree, contains utilities for
stacktrace = stacktrace.tb_next
return (last_traceback, nodes)
+ def _raise_compiler_error(self, child, e):
+ import sys
+ trace = ['']
+ for parent, attribute, index in self.access_path:
+ node = getattr(parent, attribute)
+ if index is None:
+ index = ''
+ else:
+ node = node[index]
+ index = u'[%d]' % index
+ trace.append(u'%s.%s%s = %s' % (
+ parent.__class__.__name__, attribute, index,
+ self.dump_node(node)))
+ stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
+ last_node = child
+ for node, method_name, pos in called_nodes:
+ last_node = node
+ trace.append(u"File '%s', line %d, in %s: %s" % (
+ pos[0], pos[1], method_name, self.dump_node(node)))
+ raise Errors.CompilerCrash(
+ last_node.pos, self.__class__.__name__,
+ u'\n'.join(trace), e, stacktrace)
+
def visitchild(self, child, parent, attrname, idx):
self.access_path.append((parent, attrname, idx))
try:
except Errors.CompileError:
raise
except Exception, e:
- import sys
- trace = ['']
- for parent, attribute, index in self.access_path:
- node = getattr(parent, attribute)
- if index is None:
- index = ''
- else:
- node = node[index]
- index = u'[%d]' % index
- trace.append(u'%s.%s%s = %s' % (
- parent.__class__.__name__, attribute, index,
- self.dump_node(node)))
- stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2])
- last_node = child
- for node, method_name, pos in called_nodes:
- last_node = node
- trace.append(u"File '%s', line %d, in %s: %s" % (
- pos[0], pos[1], method_name, self.dump_node(node)))
- raise Errors.CompilerCrash(
- last_node.pos, self.__class__.__name__,
- u'\n'.join(trace), e, stacktrace)
+ if DebugFlags.debug_no_exception_intercept:
+ raise
+ self._raise_compiler_error(child, e)
self.access_path.pop()
return result
def visitchildren(self, parent, attrs=None):
+ return self._visitchildren(parent, attrs)
+
+ def _visitchildren(self, parent, attrs):
"""
Visits the children of the given parent. If parent is None, returns
immediately (returning None).
or a list of return values (in the case of multiple children
in an attribute)).
"""
-
if parent is None: return None
result = {}
for attr in parent.child_attrs:
if attrs is not None and attr not in attrs: continue
child = getattr(parent, attr)
if child is not None:
- if isinstance(child, list):
+ if type(child) is list:
childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)]
else:
childretval = self.visitchild(child, parent, attr, None)
was not, an exception will be raised. (Typically you want to ensure that you
are within a StatListNode or similar before doing this.)
"""
- def __init__(self):
- super(VisitorTransform, self).__init__()
- self._super_visitchildren = super(VisitorTransform, self).visitchildren
-
def visitchildren(self, parent, attrs=None):
- result = cython.declare(dict)
- result = self._super_visitchildren(parent, attrs)
+ result = self._visitchildren(parent, attrs)
for attr, newnode in result.iteritems():
- if not isinstance(newnode, list):
+ if not type(newnode) is list:
setattr(parent, attr, newnode)
else:
# Flatten the list one level and remove any None
newlist = []
for x in newnode:
if x is not None:
- if isinstance(x, list):
+ if type(x) is list:
newlist += x
else:
newlist.append(x)
class CythonTransform(VisitorTransform):
"""
Certain common conventions and utilitues for Cython transforms.
+
+ - Sets up the context of the pipeline in self.context
+ - Tracks directives in effect in self.current_directives
"""
def __init__(self, context):
super(CythonTransform, self).__init__()
self.visitchildren(node)
return node
+class ScopeTrackingTransform(CythonTransform):
+ # Keeps track of type of scopes
+ scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass'
+ scope_node = None
+
+ def visit_ModuleNode(self, node):
+ self.scope_type = 'module'
+ self.scope_node = node
+ self.visitchildren(node)
+ return node
+
+ def visit_scope(self, node, scope_type):
+ prev = self.scope_type, self.scope_node
+ self.scope_type = scope_type
+ self.scope_node = node
+ self.visitchildren(node)
+ self.scope_type, self.scope_node = prev
+ return node
+
+ def visit_CClassDefNode(self, node):
+ return self.visit_scope(node, 'cclass')
+
+ def visit_PyClassDefNode(self, node):
+ return self.visit_scope(node, 'pyclass')
+
+ def visit_FuncDefNode(self, node):
+ return self.visit_scope(node, 'function')
+
+ def visit_CStructOrUnionDefNode(self, node):
+ return self.visit_scope(node, 'struct')
+
+
+class EnvTransform(CythonTransform):
+ """
+ This transformation keeps a stack of the environments.
+ """
+ def __call__(self, root):
+ self.env_stack = [root.scope]
+ return super(EnvTransform, self).__call__(root)
+
+ def visit_FuncDefNode(self, node):
+ self.env_stack.append(node.local_scope)
+ self.visitchildren(node)
+ self.env_stack.pop()
+ return node
+
+
+class RecursiveNodeReplacer(VisitorTransform):
+ """
+ Recursively replace all occurrences of a node in a subtree by
+ another node.
+ """
+ def __init__(self, orig_node, new_node):
+ super(RecursiveNodeReplacer, self).__init__()
+ self.orig_node, self.new_node = orig_node, new_node
+
+ def visit_Node(self, node):
+ self.visitchildren(node)
+ if node is self.orig_node:
+ return self.new_node
+ else:
+ return node
+