X-Git-Url: http://git.tremily.us/?a=blobdiff_plain;f=Cython%2FCompiler%2FVisitor.py;h=441716c2617d856433bbfd1f082d2af1a397ec61;hb=73da9b353950aa68b4e059ca86ab58076af2103d;hp=c6add7cbdc0db9e9f21ee2d9689934d5df51b8f1;hpb=4df01a09a927c422fa102d2462a4a812f617f558;p=cython.git diff --git a/Cython/Compiler/Visitor.py b/Cython/Compiler/Visitor.py index c6add7cb..441716c2 100644 --- a/Cython/Compiler/Visitor.py +++ b/Cython/Compiler/Visitor.py @@ -1,3 +1,5 @@ +# cython: infer_types=True + # # Tree visitor and transform framework # @@ -7,7 +9,7 @@ import Nodes import ExprNodes import Naming import Errors -from StringEncoding import EncodedString +import DebugFlags class BasicVisitor(object): """A generic visitor base class which can be used for visiting any kind of object.""" @@ -18,32 +20,36 @@ class BasicVisitor(object): self.dispatch_table = {} def visit(self, obj): - cls = type(obj) try: - handler_method = self.dispatch_table[cls] + handler_method = self.dispatch_table[type(obj)] except KeyError: - #print "Cache miss for class %s in visitor %s" % ( - # cls.__name__, type(self).__name__) - # Must resolve, try entire hierarchy - pattern = "visit_%s" - mro = inspect.getmro(cls) - handler_method = None - for mro_cls in mro: - if hasattr(self, pattern % mro_cls.__name__): - handler_method = getattr(self, pattern % mro_cls.__name__) - break - if handler_method is None: - print type(self), type(obj) - if hasattr(self, 'access_path') and self.access_path: - print self.access_path - if self.access_path: - print self.access_path[-1][0].pos - print self.access_path[-1][0].__dict__ - raise RuntimeError("Visitor does not accept object: %s" % obj) - #print "Caching " + cls.__name__ - self.dispatch_table[cls] = handler_method + handler_method = self.find_handler(obj) + self.dispatch_table[type(obj)] = handler_method return handler_method(obj) + def find_handler(self, obj): + cls = type(obj) + #print "Cache miss for class %s in visitor %s" % ( + # cls.__name__, type(self).__name__) + # Must resolve, try entire hierarchy + pattern = "visit_%s" + mro = inspect.getmro(cls) + handler_method = None + for mro_cls in mro: + if hasattr(self, pattern % mro_cls.__name__): + handler_method = getattr(self, pattern % mro_cls.__name__) + break + if handler_method is None: + print type(self), cls + if hasattr(self, 'access_path') and self.access_path: + print self.access_path + if self.access_path: + print self.access_path[-1][0].pos + print self.access_path[-1][0].__dict__ + raise RuntimeError("Visitor does not accept object: %s" % obj) + #print "Caching " + cls.__name__ + return handler_method + class TreeVisitor(BasicVisitor): """ Base class for writing visitors for a Cython tree, contains utilities for @@ -144,6 +150,29 @@ class TreeVisitor(BasicVisitor): stacktrace = stacktrace.tb_next return (last_traceback, nodes) + def _raise_compiler_error(self, child, e): + import sys + trace = [''] + for parent, attribute, index in self.access_path: + node = getattr(parent, attribute) + if index is None: + index = '' + else: + node = node[index] + index = u'[%d]' % index + trace.append(u'%s.%s%s = %s' % ( + parent.__class__.__name__, attribute, index, + self.dump_node(node))) + stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2]) + last_node = child + for node, method_name, pos in called_nodes: + last_node = node + trace.append(u"File '%s', line %d, in %s: %s" % ( + pos[0], pos[1], method_name, self.dump_node(node))) + raise Errors.CompilerCrash( + last_node.pos, self.__class__.__name__, + u'\n'.join(trace), e, stacktrace) + def visitchild(self, child, parent, attrname, idx): self.access_path.append((parent, attrname, idx)) try: @@ -151,31 +180,16 @@ class TreeVisitor(BasicVisitor): except Errors.CompileError: raise except Exception, e: - import sys - trace = [''] - for parent, attribute, index in self.access_path: - node = getattr(parent, attribute) - if index is None: - index = '' - else: - node = node[index] - index = u'[%d]' % index - trace.append(u'%s.%s%s = %s' % ( - parent.__class__.__name__, attribute, index, - self.dump_node(node))) - stacktrace, called_nodes = self._find_node_path(sys.exc_info()[2]) - last_node = child - for node, method_name, pos in called_nodes: - last_node = node - trace.append(u"File '%s', line %d, in %s: %s" % ( - pos[0], pos[1], method_name, self.dump_node(node))) - raise Errors.CompilerCrash( - last_node.pos, self.__class__.__name__, - u'\n'.join(trace), e, stacktrace) + if DebugFlags.debug_no_exception_intercept: + raise + self._raise_compiler_error(child, e) self.access_path.pop() return result def visitchildren(self, parent, attrs=None): + return self._visitchildren(parent, attrs) + + def _visitchildren(self, parent, attrs): """ Visits the children of the given parent. If parent is None, returns immediately (returning None). @@ -185,14 +199,13 @@ class TreeVisitor(BasicVisitor): or a list of return values (in the case of multiple children in an attribute)). """ - if parent is None: return None result = {} for attr in parent.child_attrs: if attrs is not None and attr not in attrs: continue child = getattr(parent, attr) if child is not None: - if isinstance(child, list): + if type(child) is list: childretval = [self.visitchild(x, parent, attr, idx) for idx, x in enumerate(child)] else: childretval = self.visitchild(child, parent, attr, None) @@ -221,22 +234,17 @@ class VisitorTransform(TreeVisitor): was not, an exception will be raised. (Typically you want to ensure that you are within a StatListNode or similar before doing this.) """ - def __init__(self): - super(VisitorTransform, self).__init__() - self._super_visitchildren = super(VisitorTransform, self).visitchildren - def visitchildren(self, parent, attrs=None): - result = cython.declare(dict) - result = self._super_visitchildren(parent, attrs) + result = self._visitchildren(parent, attrs) for attr, newnode in result.iteritems(): - if not isinstance(newnode, list): + if not type(newnode) is list: setattr(parent, attr, newnode) else: # Flatten the list one level and remove any None newlist = [] for x in newnode: if x is not None: - if isinstance(x, list): + if type(x) is list: newlist += x else: newlist.append(x) @@ -253,6 +261,9 @@ class VisitorTransform(TreeVisitor): class CythonTransform(VisitorTransform): """ Certain common conventions and utilitues for Cython transforms. + + - Sets up the context of the pipeline in self.context + - Tracks directives in effect in self.current_directives """ def __init__(self, context): super(CythonTransform, self).__init__() @@ -275,6 +286,69 @@ class CythonTransform(VisitorTransform): self.visitchildren(node) return node +class ScopeTrackingTransform(CythonTransform): + # Keeps track of type of scopes + scope_type = None # can be either of 'module', 'function', 'cclass', 'pyclass' + scope_node = None + + def visit_ModuleNode(self, node): + self.scope_type = 'module' + self.scope_node = node + self.visitchildren(node) + return node + + def visit_scope(self, node, scope_type): + prev = self.scope_type, self.scope_node + self.scope_type = scope_type + self.scope_node = node + self.visitchildren(node) + self.scope_type, self.scope_node = prev + return node + + def visit_CClassDefNode(self, node): + return self.visit_scope(node, 'cclass') + + def visit_PyClassDefNode(self, node): + return self.visit_scope(node, 'pyclass') + + def visit_FuncDefNode(self, node): + return self.visit_scope(node, 'function') + + def visit_CStructOrUnionDefNode(self, node): + return self.visit_scope(node, 'struct') + + +class EnvTransform(CythonTransform): + """ + This transformation keeps a stack of the environments. + """ + def __call__(self, root): + self.env_stack = [root.scope] + return super(EnvTransform, self).__call__(root) + + def visit_FuncDefNode(self, node): + self.env_stack.append(node.local_scope) + self.visitchildren(node) + self.env_stack.pop() + return node + + +class RecursiveNodeReplacer(VisitorTransform): + """ + Recursively replace all occurrences of a node in a subtree by + another node. + """ + def __init__(self, orig_node, new_node): + super(RecursiveNodeReplacer, self).__init__() + self.orig_node, self.new_node = orig_node, new_node + + def visit_Node(self, node): + self.visitchildren(node) + if node is self.orig_node: + return self.new_node + else: + return node +