Support for parse tree transformations.
author"Dag Sverre Seljebotn" <dagss@student.matnat.uio.no>
Sun, 9 Mar 2008 08:14:12 +0000 (00:14 -0800)
committer"Dag Sverre Seljebotn" <dagss@student.matnat.uio.no>
Sun, 9 Mar 2008 08:14:12 +0000 (00:14 -0800)
Cython/Compiler/CmdLine.py
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Nodes.py
Cython/Compiler/Transform.py [new file with mode: 0644]

index ce268606df79ff42d60c7bb32998349613042947..6ba73f0b629b1e542c35d3df5295b984c5d1f45a 100644 (file)
@@ -4,6 +4,7 @@
 
 import sys
 import Options
+import Transform
 
 usage = """\
 Cython (http://cython.org) is a compiler for code written in the
@@ -36,11 +37,36 @@ Options:
 #  -+, --cplus      Use C++ compiler for compiling and linking
 #  Additional .o files to link may be supplied when using -X."""
 
+#The following options are very experimental and is used for plugging in code
+#into different transform stages.
+#  -T phase:factory At the phase given, hand off the tree to the transform returned
+#                   when calling factory without arguments. Factory should be fully
+#                   specified (ie Module.SubModule.factory) and the containing module
+#                   will be imported. This option can be repeated to add more transforms,
+#                   transforms for the same phase will be used in the order they are given.
+
 def bad_usage():
     print >>sys.stderr, usage
     sys.exit(1)
 
 def parse_command_line(args):
+
+    def parse_add_transform(transforms, param):
+        def import_symbol(fqn):
+            modsplitpt = fqn.rfind(".")
+            if modsplitpt == -1: bad_usage()
+            modulename = fqn[:modsplitpt]
+            symbolname = fqn[modsplitpt+1:]
+            module = __import__(modulename, fromlist=[symbolname], level=0)
+            return getattr(module, symbolname)
+    
+        stagename, factoryname = param.split(":")
+        if not stagename in Transform.PHASES:
+            bad_usage()
+        factory = import_symbol(factoryname)
+        transform = factory()
+        transforms[stagename].append(transform)
+    
     from Cython.Compiler.Main import \
         CompilationOptions, default_options
 
@@ -93,6 +119,9 @@ def parse_command_line(args):
                 Options.annotate = True
             elif option == "--convert-range":
                 Options.convert_range = True
+            elif option.startswith("-T"):
+                parse_add_transform(options.transforms, get_param(option))
+                # Note: this can occur multiple times, each time appends
             else:
                 bad_usage()
         else:
index efe3baf3f750db879e771d501a0e3d2a7d4ee2ed..f681d87bb2d64839889357f94ecebcda8fe5c7f7 100644 (file)
@@ -31,6 +31,7 @@ class ExprNode(Node):
     #                            Cached result of subexpr_nodes()
     
     result_ctype = None
+    type = None
 
     #  The Analyse Expressions phase for expressions is split
     #  into two sub-phases:
@@ -165,6 +166,14 @@ class ExprNode(Node):
     saved_subexpr_nodes = None
     is_temp = 0
 
+    def get_child_attrs(self):
+        """Automatically provide the contents of subexprs as children, unless child_attr
+        has been declared. See Nodes.Node.get_child_accessors."""
+        if self.child_attrs != None:
+            return self.child_attr
+        elif self.subexprs != None:
+            return self.subexprs
+        
     def not_implemented(self, method_name):
         print_call_chain(method_name, "not implemented") ###
         raise InternalError(
index 9e157332a2e1573ff3e6bc845c4e4c28d12fa030..c50b1a98f4b4a9aba3d4748967040967d06fed8f 100644 (file)
@@ -18,6 +18,7 @@ from Symtab import BuiltinScope, ModuleScope
 import Code
 from Cython.Utils import replace_suffix
 from Cython import Utils
+import Transform
 
 verbose = 0
 
@@ -236,6 +237,7 @@ class CompilationOptions:
     include_path      [string]  Directories to search for include files
     output_file       string    Name of generated .c file
     generate_pxi      boolean   Generate .pxi file for public declarations
+    transforms        Transform.TransformSet Transforms to use on the parse tree
     
     Following options are experimental and only used on MacOSX:
     
@@ -342,7 +344,8 @@ default_options = dict(
     obj_only = 1,
     cplus = 0,
     output_file = None,
-    generate_pxi = 0)
+    generate_pxi = 0,
+    transforms = Transform.TransformSet())
     
 if sys.platform == "mac":
     from Cython.Mac.MacSystem import c_compile, c_link, CCompilerError
index f05562448f40741db5b4a642332d18aeae8c559c..7c8b4aae38b1970c45b9e299341a54917dd705a7 100644 (file)
@@ -26,6 +26,8 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
     #  referenced_modules   [ModuleScope]
     #  module_temp_cname    string
     #  full_module_name     string
+
+    children_attrs = ["body"]
     
     def analyse_declarations(self, env):
         if Options.embed_pos_in_docstring:
@@ -46,7 +48,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         if self.has_imported_c_functions():
             self.module_temp_cname = env.allocate_temp_pyobject()
             env.release_temp(self.module_temp_cname)
-        self.generate_c_code(env, result)
+        self.generate_c_code(env, options, result)
         self.generate_h_code(env, options, result)
         self.generate_api_code(env, result)
     
@@ -199,7 +201,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
             i_code.putln("pass")
         i_code.dedent()
     
-    def generate_c_code(self, env, result):
+    def generate_c_code(self, env, options, result):
         modules = self.referenced_modules
         if Options.annotate:
             code = Annotate.AnnotationCCodeWriter(StringIO())
@@ -216,7 +218,7 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         self.generate_interned_name_decls(env, code)
         self.generate_py_string_decls(env, code)
         self.generate_cached_builtins_decls(env, code)
-        self.body.generate_function_definitions(env, code)
+        self.body.generate_function_definitions(env, code, options.transforms)
         code.mark_pos(None)
         self.generate_interned_name_table(env, code)
         self.generate_py_string_table(env, code)
index ccb3ee1fff7fdbaa645395fe3d843966ff07d594..4146688287f6b948579d354528dac38df2e0492f 100644 (file)
@@ -38,6 +38,21 @@ def relative_position(pos):
     return (pos[0][absolute_path_length+1:], pos[1])
         
 
+class AttributeAccessor:
+    """Used as the result of the Node.get_children_accessors() generator"""
+    def __init__(self, obj, attrname):
+        self.obj = obj
+        self.attrname = attrname
+    def get(self):
+        try:
+            return getattr(self.obj, self.attrname)
+        except AttributeError:
+            return None
+    def set(self, value):
+        setattr(self.obj, self.attrname, value)
+    def name(self):
+        return self.attrname
+
 class Node:
     #  pos         (string, int, int)   Source file position
     #  is_name     boolean              Is a NameNode
@@ -45,11 +60,64 @@ class Node:
     
     is_name = 0
     is_literal = 0
+
+    # All descandants should set child_attrs (see get_child_accessors)    
+    child_attrs = None
     
     def __init__(self, pos, **kw):
         self.pos = pos
         self.__dict__.update(kw)
     
+    def get_child_accessors(self):
+        """Returns an iterator over the children of the Node. Each member in the
+        iterated list is an object with get(), set(value), and name() methods,
+        which can be used to fetch and replace the child and query the name
+        the relation this node has with the child. For instance, for an
+        assignment node, this code:
+        
+        for child in assignment_node.get_child_accessors():
+            print(child.name())
+            child.set(i_node)
+        
+        will print "lhs", "rhs", and change the assignment statement to "i = i"
+        (assuming that i_node is a node able to represent the variable i in the
+        tree).
+        
+        Any kind of objects can in principle be returned, but the typical
+        candidates are either Node instances or lists of node instances.
+        
+        The object returned in each iteration stage can only be used until the
+        iterator is advanced to the next child attribute. (However, the objects
+        returned by the get() function can be kept).
+        
+        Typically, a Node instance will have other interesting and potentially
+        hierarchical attributes as well. These must be explicitly accessed -- this
+        method only provides access to attributes that are deemed to naturally
+        belong in the parse tree.
+        
+        Descandant classes can either specify child_attrs, override get_child_attrs,
+        or override this method directly in order to provide access to their
+        children. All descendants of Node *must* declare their children -- leaf nodes
+        should simply declare "child_attrs = []".
+        """
+        attrnames = self.get_child_attrs()
+        if attrnames is None:
+            raise InternalError("Children access not implemented for %s" % \
+                self.__class__.__name__)
+        for name in attrnames:
+            a = AttributeAccessor(self, name)
+            yield a
+            # Not really needed, but one wants to enforce clients not to
+            # hold on to iterated objects, in case more advanced iteration
+            # is needed later
+            a.attrname = None
+    
+    def get_child_attrs(self):
+        """Utility method for more easily implementing get_child_accessors.
+        If you override get_child_accessors then this method is not used."""
+        return self.child_attrs
+    
+    
     #
     #  There are 3 phases of parse tree processing, applied in order to
     #  all the statements in a given scope-block:
@@ -145,6 +213,8 @@ class BlockNode:
 class StatListNode(Node):
     # stats     a list of StatNode
     
+    child_attrs = ["stats"]
+    
     def analyse_declarations(self, env):
         #print "StatListNode.analyse_declarations" ###
         for stat in self.stats:
@@ -155,10 +225,10 @@ class StatListNode(Node):
         for stat in self.stats:
             stat.analyse_expressions(env)
     
-    def generate_function_definitions(self, env, code):
+    def generate_function_definitions(self, env, code, transforms):
         #print "StatListNode.generate_function_definitions" ###
         for stat in self.stats:
-            stat.generate_function_definitions(env, code)
+            stat.generate_function_definitions(env, code, transforms)
             
     def generate_execution_code(self, code):
         #print "StatListNode.generate_execution_code" ###
@@ -184,7 +254,7 @@ class StatNode(Node):
     #        Emit C code for executable statements.
     #
     
-    def generate_function_definitions(self, env, code):
+    def generate_function_definitions(self, env, code, transforms):
         pass
     
     def generate_execution_code(self, code):
@@ -196,6 +266,8 @@ class CDefExternNode(StatNode):
     #  include_file   string or None
     #  body           StatNode
     
+    child_attrs = ["body"]
+    
     def analyse_declarations(self, env):
         if self.include_file:
             env.add_include_file(self.include_file)
@@ -227,6 +299,8 @@ class CDeclaratorNode(Node):
     #  calling_convention  string   Calling convention of CFuncDeclaratorNode
     #                               for which this is a base 
 
+    child_attrs = []
+
     calling_convention = ""
 
     def analyse_expressions(self, env):
@@ -240,6 +314,8 @@ class CNameDeclaratorNode(CDeclaratorNode):
     #  name   string           The Pyrex name being declared
     #  cname  string or None   C name, if specified
     
+    child_attrs = []
+
     def analyse(self, base_type, env):
         self.type = base_type
         return self, base_type
@@ -268,6 +344,8 @@ class CNameDeclaratorNode(CDeclaratorNode):
 class CPtrDeclaratorNode(CDeclaratorNode):
     # base     CDeclaratorNode
     
+    child_attrs = ["base"]
+
     def analyse(self, base_type, env):
         if base_type.is_pyobject:
             error(self.pos,
@@ -284,6 +362,8 @@ class CPtrDeclaratorNode(CDeclaratorNode):
 class CArrayDeclaratorNode(CDeclaratorNode):
     # base        CDeclaratorNode
     # dimension   ExprNode
+
+    child_attrs = ["base", "dimension"]
     
     def analyse(self, base_type, env):
         if self.dimension:
@@ -315,6 +395,8 @@ class CFuncDeclaratorNode(CDeclaratorNode):
     # nogil            boolean    Can be called without gil
     # with_gil         boolean    Acquire gil around function body
     
+    child_attrs = ["base", "args", "exception_value"]
+
     overridable = 0
     optional_arg_count = 0
 
@@ -403,6 +485,8 @@ class CArgDeclNode(Node):
     # is_self_arg    boolean            Is the "self" arg of an extension type method
     # is_kw_only     boolean            Is a keyword-only argument
 
+    child_attrs = ["base_type", "declarator", "default"]
+
     is_self_arg = 0
     is_generic = 1
 
@@ -435,6 +519,8 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
     # longness         integer
     # is_self_arg      boolean      Is self argument of C method
 
+    child_attrs = []
+    
     def analyse(self, env):
         # Return type descriptor.
         #print "CSimpleBaseTypeNode.analyse: is_self_arg =", self.is_self_arg ###
@@ -480,6 +566,8 @@ class CComplexBaseTypeNode(CBaseTypeNode):
     # base_type   CBaseTypeNode
     # declarator  CDeclaratorNode
     
+    child_attrs = ["base_type", "declarator"]
+
     def analyse(self, env):
         base = self.base_type.analyse(env)
         _, type = self.declarator.analyse(base, env)
@@ -494,6 +582,8 @@ class CVarDefNode(StatNode):
     #  declarators   [CDeclaratorNode]
     #  in_pxd        boolean
     #  api           boolean
+
+    child_attrs = ["base_type", "declarators"]
     
     def analyse_declarations(self, env, dest_scope = None):
         if not dest_scope:
@@ -543,6 +633,8 @@ class CStructOrUnionDefNode(StatNode):
     #  attributes    [CVarDefNode] or None
     #  entry         Entry
     
+    child_attrs = ["attributes"]
+
     def analyse_declarations(self, env):
         scope = None
         if self.attributes is not None:
@@ -572,6 +664,8 @@ class CEnumDefNode(StatNode):
     #  in_pxd         boolean
     #  entry          Entry
     
+    child_attrs = ["items"]
+    
     def analyse_declarations(self, env):
         self.entry = env.declare_enum(self.name, self.pos,
             cname = self.cname, typedef_flag = self.typedef_flag,
@@ -594,6 +688,8 @@ class CEnumDefItemNode(StatNode):
     #  cname    string or None
     #  value    ExprNode or None
     
+    child_attrs = ["value"]
+
     def analyse_declarations(self, env, enum_entry):
         if self.value:
             self.value.analyse_const_expression(env)
@@ -610,6 +706,8 @@ class CTypeDefNode(StatNode):
     #  declarator   CDeclaratorNode
     #  visibility   "public" or "private"
     #  in_pxd       boolean
+
+    child_attrs = ["base_type", "declarator"]
     
     def analyse_declarations(self, env):
         base = self.base_type.analyse(env)
@@ -663,7 +761,7 @@ class FuncDefNode(StatNode, BlockNode):
     def need_gil_acquisition(self, lenv):
         return 0
                 
-    def generate_function_definitions(self, env, code):
+    def generate_function_definitions(self, env, code, transforms):
         code.mark_pos(self.pos)
         # Generate C code for header and body of function
         genv = env.global_scope()
@@ -671,8 +769,10 @@ class FuncDefNode(StatNode, BlockNode):
         lenv.return_type = self.return_type
         code.init_labels()
         self.declare_arguments(lenv)
+        transforms.run('before_analyse_function', self, env=env, lenv=lenv, genv=genv)
         self.body.analyse_declarations(lenv)
         self.body.analyse_expressions(lenv)
+        transforms.run('after_analyse_function', self, env=env, lenv=lenv, genv=genv)
         # Code for nested function definitions would go here
         # if we supported them, which we probably won't.
         # ----- Top-level constants used by this function
@@ -775,7 +875,7 @@ class FuncDefNode(StatNode, BlockNode):
         code.putln("}")
         # ----- Python version
         if self.py_func:
-            self.py_func.generate_function_definitions(env, code)
+            self.py_func.generate_function_definitions(env, code, transforms)
         self.generate_optarg_wrapper_function(env, code)
 
     def put_stararg_decrefs(self, code):
@@ -834,6 +934,8 @@ class CFuncDefNode(FuncDefNode):
     #  with_gil      boolean    Acquire GIL around body
     #  type          CFuncType
     
+    child_attrs = ["base_type", "declarator", "body"]
+
     def unqualified_name(self):
         return self.entry.name
         
@@ -1032,6 +1134,7 @@ class PyArgDeclNode(Node):
     #
     # name   string
     # entry  Symtab.Entry
+    child_attrs = []
     
     pass
     
@@ -1051,6 +1154,8 @@ class DefNode(FuncDefNode):
     #
     #  assmt   AssignmentNode   Function construction/assignment
     
+    child_attrs = ["args", "star_arg", "starstar_arg", "body"]
+
     assmt = None
     num_kwonly_args = 0
     num_required_kw_args = 0
@@ -1739,6 +1844,8 @@ class PyClassDefNode(StatNode, BlockNode):
     #  dict     DictNode   Class dictionary
     #  classobj ClassNode  Class object
     #  target   NameNode   Variable to assign class object to
+
+    child_attrs = ["body", "dict", "classobj", "target"]
     
     def __init__(self, pos, name, bases, doc, body):
         StatNode.__init__(self, pos)
@@ -1777,10 +1884,10 @@ class PyClassDefNode(StatNode, BlockNode):
         #self.classobj.release_temp(env)
         #self.target.release_target_temp(env)
     
-    def generate_function_definitions(self, env, code):
+    def generate_function_definitions(self, env, code, transforms):
         self.generate_py_string_decls(self.scope, code)
         self.body.generate_function_definitions(
-            self.scope, code)
+            self.scope, code, transforms)
     
     def generate_execution_code(self, code):
         self.dict.generate_evaluation_code(code)
@@ -1809,6 +1916,8 @@ class CClassDefNode(StatNode):
     #  entry              Symtab.Entry
     #  base_type          PyExtensionType or None
     
+    child_attrs = ["body"]
+
     def analyse_declarations(self, env):
         #print "CClassDefNode.analyse_declarations:", self.class_name
         #print "...visibility =", self.visibility
@@ -1880,10 +1989,10 @@ class CClassDefNode(StatNode):
             scope = self.entry.type.scope
             self.body.analyse_expressions(scope)
     
-    def generate_function_definitions(self, env, code):
+    def generate_function_definitions(self, env, code, transforms):
         if self.body:
             self.body.generate_function_definitions(
-                self.entry.type.scope, code)
+                self.entry.type.scope, code, transforms)
     
     def generate_execution_code(self, code):
         # This is needed to generate evaluation code for
@@ -1903,6 +2012,8 @@ class PropertyNode(StatNode):
     #  doc    string or None    Doc string
     #  body   StatListNode
     
+    child_attrs = ["body"]
+
     def analyse_declarations(self, env):
         entry = env.declare_property(self.name, self.doc, self.pos)
         if entry:
@@ -1914,8 +2025,8 @@ class PropertyNode(StatNode):
     def analyse_expressions(self, env):
         self.body.analyse_expressions(env)
     
-    def generate_function_definitions(self, env, code):
-        self.body.generate_function_definitions(env, code)
+    def generate_function_definitions(self, env, code, transforms):
+        self.body.generate_function_definitions(env, code, transforms)
 
     def generate_execution_code(self, code):
         pass
@@ -1929,6 +2040,8 @@ class GlobalNode(StatNode):
     #
     # names    [string]
     
+    child_attrs = []
+
     def analyse_declarations(self, env):
         for name in self.names:
             env.declare_global(name, self.pos)
@@ -1944,6 +2057,8 @@ class ExprStatNode(StatNode):
     #  Expression used as a statement.
     #
     #  expr   ExprNode
+
+    child_attrs = ["expr"]
     
     def analyse_expressions(self, env):
         self.expr.analyse_expressions(env)
@@ -1989,6 +2104,8 @@ class SingleAssignmentNode(AssignmentNode):
     #
     #  lhs      ExprNode      Left hand side
     #  rhs      ExprNode      Right hand side
+    
+    child_attrs = ["lhs", "rhs"]
 
     def analyse_declarations(self, env):
         self.lhs.analyse_target_declaration(env)
@@ -2044,6 +2161,8 @@ class CascadedAssignmentNode(AssignmentNode):
     #
     #  coerced_rhs_list   [ExprNode]   RHS coerced to type of each LHS
     
+    child_attrs = ["lhs_list", "rhs", "coerced_rhs_list"]
+
     def analyse_declarations(self, env):
         for lhs in self.lhs_list:
             lhs.analyse_target_declaration(env)
@@ -2128,6 +2247,8 @@ class ParallelAssignmentNode(AssignmentNode):
     #
     #  stats     [AssignmentNode]   The constituent assignments
     
+    child_attrs = ["stats"]
+
     def analyse_declarations(self, env):
         for stat in self.stats:
             stat.analyse_declarations(env)
@@ -2175,6 +2296,8 @@ class InPlaceAssignmentNode(AssignmentNode):
     #  Fortunately, the type of the lhs node is fairly constrained 
     #  (it must be a NameNode, AttributeNode, or IndexNode).     
     
+    child_attrs = ["lhs", "rhs", "dup"]
+
     def analyse_declarations(self, env):
         self.lhs.analyse_target_declaration(env)
         
@@ -2272,6 +2395,8 @@ class PrintStatNode(StatNode):
     #  args              [ExprNode]
     #  ends_with_comma   boolean
     
+    child_attrs = ["args"]
+    
     def analyse_expressions(self, env):
         for i in range(len(self.args)):
             arg = self.args[i]
@@ -2306,6 +2431,8 @@ class DelStatNode(StatNode):
     #
     #  args     [ExprNode]
     
+    child_attrs = ["args"]
+
     def analyse_declarations(self, env):
         for arg in self.args:
             arg.analyse_target_declaration(env)
@@ -2330,6 +2457,8 @@ class DelStatNode(StatNode):
 
 class PassStatNode(StatNode):
     #  pass statement
+
+    child_attrs = []
     
     def analyse_expressions(self, env):
         pass
@@ -2340,6 +2469,8 @@ class PassStatNode(StatNode):
 
 class BreakStatNode(StatNode):
 
+    child_attrs = []
+
     def analyse_expressions(self, env):
         pass
     
@@ -2355,6 +2486,8 @@ class BreakStatNode(StatNode):
 
 class ContinueStatNode(StatNode):
 
+    child_attrs = []
+
     def analyse_expressions(self, env):
         pass
     
@@ -2377,6 +2510,8 @@ class ReturnStatNode(StatNode):
     #  return_type   PyrexType
     #  temps_in_use  [Entry]            Temps in use at time of return
     
+    child_attrs = ["value"]
+
     def analyse_expressions(self, env):
         return_type = env.return_type
         self.return_type = return_type
@@ -2439,6 +2574,8 @@ class RaiseStatNode(StatNode):
     #  exc_value   ExprNode or None
     #  exc_tb      ExprNode or None
     
+    child_attrs = ["exc_type", "exc_value", "exc_tb"]
+
     def analyse_expressions(self, env):
         if self.exc_type:
             self.exc_type.analyse_types(env)
@@ -2508,6 +2645,8 @@ class RaiseStatNode(StatNode):
 
 class ReraiseStatNode(StatNode):
 
+    child_attrs = []
+
     def analyse_expressions(self, env):
         env.use_utility_code(raise_utility_code)
 
@@ -2526,6 +2665,8 @@ class AssertStatNode(StatNode):
     #  cond    ExprNode
     #  value   ExprNode or None
     
+    child_attrs = ["cond", "value"]
+
     def analyse_expressions(self, env):
         self.cond = self.cond.analyse_boolean_expression(env)
         if self.value:
@@ -2570,6 +2711,8 @@ class IfStatNode(StatNode):
     #
     #  if_clauses   [IfClauseNode]
     #  else_clause  StatNode or None
+
+    child_attrs = ["if_clauses", "else_clause"]
     
     def analyse_declarations(self, env):
         for if_clause in self.if_clauses:
@@ -2607,6 +2750,8 @@ class IfClauseNode(Node):
     #  condition   ExprNode
     #  body        StatNode
     
+    child_attrs = ["condition", "body"]
+
     def analyse_declarations(self, env):
         self.condition.analyse_declarations(env)
         self.body.analyse_declarations(env)
@@ -2641,6 +2786,8 @@ class WhileStatNode(StatNode):
     #  condition    ExprNode
     #  body         StatNode
     #  else_clause  StatNode
+
+    child_attrs = ["condition", "body", "else_clause"]
     
     def analyse_declarations(self, env):
         self.body.analyse_declarations(env)
@@ -2697,6 +2844,8 @@ class ForInStatNode(StatNode):
     #  else_clause   StatNode
     #  item          NextNode       used internally
     
+    child_attrs = ["target", "iterator", "body", "else_clause", "item"]
+    
     def analyse_declarations(self, env):
         self.target.analyse_target_declaration(env)
         self.body.analyse_declarations(env)
@@ -2811,6 +2960,7 @@ class ForFromStatNode(StatNode):
     #  is_py_target       bool
     #  loopvar_name       string
     #  py_loopvar_node    PyTempNode or None
+    child_attrs = ["target", "bound1", "bound2", "step", "body", "else_clause", "py_loopvar_node"]
     
     def analyse_declarations(self, env):
         self.target.analyse_target_declaration(env)
@@ -2933,6 +3083,8 @@ class TryExceptStatNode(StatNode):
     #  except_clauses   [ExceptClauseNode]
     #  else_clause      StatNode or None
     #  cleanup_list     [Entry]            temps to clean up on error
+
+    child_attrs = ["body", "except_clauses", "else_clause", "cleanup_list"]
     
     def analyse_declarations(self, env):
         self.body.analyse_declarations(env)
@@ -2999,6 +3151,8 @@ class ExceptClauseNode(Node):
     #  function_name  string             qualified name of enclosing function
     #  exc_vars       (string * 3)       local exception variables
     
+    child_attrs = ["pattern", "target", "body", "exc_value"]
+
     def analyse_declarations(self, env):
         if self.target:
             self.target.analyse_target_declaration(env)
@@ -3084,6 +3238,8 @@ class TryFinallyStatNode(StatNode):
     #  In addition, if we're doing an error, we save the
     #  exception on entry to the finally block and restore
     #  it on exit.
+
+    child_attrs = ["body", "finally_clause", "cleanup_list"]
     
     preserve_exception = 1
     
@@ -3243,6 +3399,8 @@ class GILStatNode(TryFinallyStatNode):
     #
     #   state   string   'gil' or 'nogil'
         
+    child_attrs = []
+    
     preserve_exception = 0
 
     def __init__(self, pos, state, body):
@@ -3280,6 +3438,8 @@ class GILExitNode(StatNode):
     #
     #  state   string   'gil' or 'nogil'
 
+    child_attrs = []
+
     def analyse_expressions(self, env):
         pass
 
@@ -3295,6 +3455,8 @@ class CImportStatNode(StatNode):
     #
     #  module_name   string           Qualified name of module being imported
     #  as_name       string or None   Name specified in "as" clause, if any
+
+    child_attrs = []
     
     def analyse_declarations(self, env):
         if not env.is_module_scope:
@@ -3331,6 +3493,8 @@ class FromCImportStatNode(StatNode):
     #  module_name     string                  Qualified name of module
     #  imported_names  [(pos, name, as_name)]  Names to be imported
     
+    child_attrs = []
+
     def analyse_declarations(self, env):
         if not env.is_module_scope:
             error(self.pos, "cimport only allowed at module level")
@@ -3357,6 +3521,8 @@ class FromImportStatNode(StatNode):
     #  items            [(string, NameNode)]
     #  interned_items   [(string, NameNode)]
     #  item             PyTempNode            used internally
+
+    child_attrs = ["module"]
     
     def analyse_declarations(self, env):
         for _, target in self.items:
diff --git a/Cython/Compiler/Transform.py b/Cython/Compiler/Transform.py
new file mode 100644 (file)
index 0000000..8397a47
--- /dev/null
@@ -0,0 +1,115 @@
+#
+#   Tree transform framework
+#
+import Nodes
+import ExprNodes
+
+class Transform(object):
+    #  parent_stack [Node]       A stack providing information about where in the tree
+    #                            we currently are. Nodes here should be considered
+    #                            read-only.
+
+    # Transforms for the parse tree should usually extend this class for convenience.
+    # The caller of a transform will only first call initialize and then process_node on
+    # the root node, the rest are utility functions and conventions.
+    
+    # Transformations usually happens by recursively filtering through the stream.
+    # process_node is always expected to return a new node, however it is ok to simply
+    # return the input node untouched. Returning None will remove the node from the
+    # parent.
+    
+    def __init__(self):
+        self.parent_stack = []
+    
+    def initialize(self, phase, **options):
+        pass
+
+    def process_children(self, node):
+        """For all children of node, either process_list (if isinstance(node, list))
+        or process_node (otherwise) is called."""
+        if node == None: return
+        
+        self.parent_stack.append(node)
+        for childacc in node.get_child_accessors():
+            child = childacc.get()
+            if isinstance(child, list):
+                newchild = self.process_list(child, childacc.name())
+                if not isinstance(newchild, list): raise Exception("Cannot replace list with non-list!")
+            else:
+                newchild = self.process_node(child, childacc.name())
+                if newchild is not None and not isinstance(newchild, Nodes.Node):
+                    raise Exception("Cannot replace Node with non-Node!")
+            childacc.set(newchild)
+        self.parent_stack.pop()
+
+    def process_list(self, l, name):
+        """Calls process_node on all the items in l, using the name one gets when appending
+        [idx] to the name. Each item in l is transformed in-place by the item process_node
+        returns, then l is returned."""
+        # Comment: If moving to a copying strategy, it might makes sense to return a
+        # new list instead.
+        for idx in xrange(len(l)):
+            l[idx] = self.process_node(l[idx], "%s[%d]" % (name, idx))
+        return l
+
+    def process_node(self, node, name):
+        """Override this method to process nodes. name specifies which kind of relation the
+        parent has with child. This method should always return the node which the parent
+        should use for this relation, which can either be the same node, None to remove
+        the node, or a different node."""
+        raise InternalError("Not implemented")
+
+class PrintTree(Transform):
+    """Prints a representation of the tree to standard output.
+    Subclass and override repr_of to provide more information
+    about nodes. """
+    def __init__(self):
+        Transform.__init__(self)
+        self._indent = ""
+
+    def indent(self):
+        self._indent += "  "
+    def unindent(self):
+        self._indent = self._indent[:-2]
+
+    def initialize(self, phase, **options):
+        print("Parse tree dump at phase '%s'" % phase)
+
+    # Don't do anything about process_list, the defaults gives
+    # nice-looking name[idx] nodes which will visually appear
+    # under the parent-node, not displaying the list itself in
+    # the hierarchy.
+    
+    def process_node(self, node, name):
+        print("%s- %s: %s" % (self._indent, name, self.repr_of(node)))
+        self.indent()
+        self.process_children(node)
+        self.unindent()
+        return node
+
+    def repr_of(self, node):
+        if node is None:
+            return "(none)"
+        else:
+            result = node.__class__.__name__
+            if isinstance(node, ExprNodes.ExprNode):
+                t = node.type
+                result += "(type=%s)" % repr(t)
+            return result
+
+
+PHASES = [
+    'before_analyse_function', # run in FuncDefNode.generate_function_definitions
+    'after_analyse_function'   # run in FuncDefNode.generate_function_definitions
+]
+
+class TransformSet(dict):
+    def __init__(self):
+        self.update([(name, []) for name in PHASES])
+    def run(self, name, node, **options):
+        assert name in self
+        for transform in self[name]:
+            transform.initialize(phase=name, **options)
+            transform.process_node(node, "(root)")
+
+