Moved buffer transform to Buffer.py
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 8 Jul 2008 11:06:48 +0000 (13:06 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 8 Jul 2008 11:06:48 +0000 (13:06 +0200)
Cython/Compiler/Buffer.py [new file with mode: 0644]
Cython/Compiler/Main.py
Cython/Compiler/ParseTreeTransforms.py

diff --git a/Cython/Compiler/Buffer.py b/Cython/Compiler/Buffer.py
new file mode 100644 (file)
index 0000000..7f67206
--- /dev/null
@@ -0,0 +1,188 @@
+from Cython.Compiler.Visitor import VisitorTransform, temp_name_handle, CythonTransform
+from Cython.Compiler.ModuleNode import ModuleNode
+from Cython.Compiler.Nodes import *
+from Cython.Compiler.ExprNodes import *
+from Cython.Compiler.TreeFragment import TreeFragment
+from Cython.Utils import EncodedString
+from Cython.Compiler.Errors import CompileError
+from sets import Set as set
+
+class BufferTransform(CythonTransform):
+    """
+    Run after type analysis. Takes care of the buffer functionality.
+    """
+    scope = None
+
+    def __call__(self, node):
+        cymod = self.context.modules[u'__cython__']
+        self.buffer_type = cymod.entries[u'Py_buffer'].type
+        return super(BufferTransform, self).__call__(node)
+
+    def handle_scope(self, node, scope):
+        # For all buffers, insert extra variables in the scope.
+        # The variables are also accessible from the buffer_info
+        # on the buffer entry
+        bufvars = [(name, entry) for name, entry
+                   in scope.entries.iteritems()
+                   if entry.type.buffer_options is not None]
+                   
+        for name, entry in bufvars:
+            # Variable has buffer opts, declare auxiliary vars
+            bufopts = entry.type.buffer_options
+
+            bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
+                                        self.buffer_type, node.pos)
+
+            temp_var =  scope.declare_var(temp_name_handle(u"%s_tmp" % name),
+                                        entry.type, node.pos)
+            
+            
+            stridevars = []
+            shapevars = []
+            for idx in range(bufopts.ndim):
+                # stride
+                varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
+                var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
+                stridevars.append(var)
+                # shape
+                varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
+                var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
+                shapevars.append(var)
+            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, 
+                                                shapevars)
+            entry.buffer_aux.temp_var = temp_var
+        self.scope = scope
+
+            
+    def visit_ModuleNode(self, node):
+        self.handle_scope(node, node.scope)
+        self.visitchildren(node)
+        return node
+
+    def visit_FuncDefNode(self, node):
+        self.handle_scope(node, node.local_scope)
+        self.visitchildren(node)
+        return node
+
+    acquire_buffer_fragment = TreeFragment(u"""
+        TMP = LHS
+        if TMP is not None:
+            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
+        TMP = RHS
+        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
+        ASSIGN_AUX
+        LHS = TMP
+    """)
+
+    fetch_strides = TreeFragment(u"""
+        TARGET = BUFINFO.strides[IDX]
+    """)
+
+    fetch_shape = TreeFragment(u"""
+        TARGET = BUFINFO.shape[IDX]
+    """)
+
+    def visit_SingleAssignmentNode(self, node):
+        # On assignments, two buffer-related things can happen:
+        # a) A buffer variable is assigned to (reacquisition)
+        # b) Buffer access assignment: arr[...] = ...
+        # Since we don't allow nested buffers, these don't overlap.
+        
+        self.visitchildren(node)
+        # Only acquire buffers on vars (not attributes) for now.
+        if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
+            # Is buffer variable
+            return self.reacquire_buffer(node)
+        elif (isinstance(node.lhs, IndexNode) and
+              isinstance(node.lhs.base, NameNode) and
+              node.lhs.base.entry.buffer_aux is not None):
+            return self.assign_into_buffer(node)
+        
+    def reacquire_buffer(self, node):
+        bufaux = node.lhs.entry.buffer_aux
+        auxass = []
+        for idx, entry in enumerate(bufaux.stridevars):
+            entry.used = True
+            ass = self.fetch_strides.substitute({
+                u"TARGET": NameNode(node.pos, name=entry.name),
+                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                u"IDX": IntNode(node.pos, value=EncodedString(idx))
+            })
+            auxass.append(ass)
+
+        for idx, entry in enumerate(bufaux.shapevars):
+            entry.used = True
+            ass = self.fetch_shape.substitute({
+                u"TARGET": NameNode(node.pos, name=entry.name),
+                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                u"IDX": IntNode(node.pos, value=EncodedString(idx))
+            })
+            auxass.append(ass)
+
+        bufaux.buffer_info_var.used = True
+        acq = self.acquire_buffer_fragment.substitute({
+            u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
+            u"LHS" : node.lhs,
+            u"RHS": node.rhs,
+            u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
+            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+        }, pos=node.pos)
+        # Note: The below should probably be refactored into something
+        # like fragment.substitute(..., context=self.context), with
+        # TreeFragment getting context.pipeline_until_now() and
+        # applying it on the fragment.
+        acq.analyse_declarations(self.scope)
+        acq.analyse_expressions(self.scope)
+        stats = acq.stats
+        return stats
+
+    def assign_into_buffer(self, node):
+        result = SingleAssignmentNode(node.pos,
+                                      rhs=self.visit(node.rhs),
+                                      lhs=self.buffer_index(node.lhs))
+        result.analyse_expressions(self.scope)
+        return result
+        
+
+    def buffer_index(self, node):
+        bufaux = node.base.entry.buffer_aux
+        assert bufaux is not None
+        # indices * strides...
+        to_sum = [ IntBinopNode(node.pos, operator='*',
+                                operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
+                                operand2=NameNode(node.pos, name=stride.name))
+            for index, stride in zip(node.indices, bufaux.stridevars)]
+
+        # then sum them 
+        expr = to_sum[0]
+        for next in to_sum[1:]:
+            expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
+
+        tmp= self.buffer_access.substitute({
+            'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
+            'OFFSET': expr
+            }, pos=node.pos)
+
+        return tmp.stats[0].expr
+
+    buffer_access = TreeFragment(u"""
+        (<unsigned char*>(BUF.buf + OFFSET))[0]
+    """)
+    def visit_IndexNode(self, node):
+        # Only occurs when the IndexNode is an rvalue
+        if node.is_buffer_access:
+            assert node.index is None
+            assert node.indices is not None
+            result = self.buffer_index(node)
+            result.analyse_expressions(self.scope)
+            return result
+        else:
+            return node
+
+    def visit_CallNode(self, node):
+###        print node.dump()
+        return node
+    
+#    def visit_FuncDefNode(self, node):
+#        print node.dump()
+    
index 8fa6d3bed676a74801f6c08588dd56af9ad560e9..dd2243bacd3a6cf05e5b8f1810124c2d1df37bad 100644 (file)
@@ -354,9 +354,10 @@ def create_generate_code(context, options, result):
     return generate_code
 
 def create_default_pipeline(context, options, result):
-    from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse, BufferTransform
+    from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse
     from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
     from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor
+    from Buffer import BufferTransform
     from ModuleNode import check_c_classes
     
     return [
index c269a2b2f33685390cde06d019bb58b81662dcc5..dcbe81e94ef14cd38e1dd30c4ab8e45cc8a1ed9b 100644 (file)
@@ -146,250 +146,6 @@ class PostParse(CythonTransform):
         node.keyword_args = None
         return node
 
-class BufferTransform(CythonTransform):
-    """
-    Run after type analysis. Takes care of the buffer functionality.
-    """
-    scope = None
-
-    def __call__(self, node):
-        cymod = self.context.modules[u'__cython__']
-        self.buffer_type = cymod.entries[u'Py_buffer'].type
-        self.lhs = False
-        return super(BufferTransform, self).__call__(node)
-
-    def handle_scope(self, node, scope):
-        # For all buffers, insert extra variables in the scope.
-        # The variables are also accessible from the buffer_info
-        # on the buffer entry
-        bufvars = [(name, entry) for name, entry
-                   in scope.entries.iteritems()
-                   if entry.type.buffer_options is not None]
-                   
-        for name, entry in bufvars:
-            # Variable has buffer opts, declare auxiliary vars
-            bufopts = entry.type.buffer_options
-
-            bufinfo = scope.declare_var(temp_name_handle(u"%s_bufinfo" % name),
-                                        self.buffer_type, node.pos)
-
-            temp_var =  scope.declare_var(temp_name_handle(u"%s_tmp" % name),
-                                        entry.type, node.pos)
-            
-            
-            stridevars = []
-            shapevars = []
-            for idx in range(bufopts.ndim):
-                # stride
-                varname = temp_name_handle(u"%s_%s%d" % (name, "stride", idx))
-                var = scope.declare_var(varname, PyrexTypes.c_int_type, node.pos, is_cdef=True)
-                stridevars.append(var)
-                # shape
-                varname = temp_name_handle(u"%s_%s%d" % (name, "shape", idx))
-                var = scope.declare_var(varname, PyrexTypes.c_uint_type, node.pos, is_cdef=True)
-                shapevars.append(var)
-            entry.buffer_aux = Symtab.BufferAux(bufinfo, stridevars, 
-                                                shapevars)
-            entry.buffer_aux.temp_var = temp_var
-        self.scope = scope
-
-            
-    def visit_ModuleNode(self, node):
-        self.handle_scope(node, node.scope)
-        self.visitchildren(node)
-        return node
-
-    def visit_FuncDefNode(self, node):
-        self.handle_scope(node, node.local_scope)
-        self.visitchildren(node)
-        return node
-
-    acquire_buffer_fragment = TreeFragment(u"""
-        TMP = LHS
-        if TMP is not None:
-            __cython__.PyObject_ReleaseBuffer(<__cython__.PyObject*>TMP, &BUFINFO)
-        TMP = RHS
-        __cython__.PyObject_GetBuffer(<__cython__.PyObject*>TMP, &BUFINFO, 0)
-        ASSIGN_AUX
-        LHS = TMP
-    """)
-
-    fetch_strides = TreeFragment(u"""
-        TARGET = BUFINFO.strides[IDX]
-    """)
-
-    fetch_shape = TreeFragment(u"""
-        TARGET = BUFINFO.shape[IDX]
-    """)
-
-#                ass = SingleAssignmentNode(pos=node.pos,
-#                    lhs=NameNode(node.pos, name=entry.name),
-#                    rhs=IndexNode(node.pos,
-#                        base=AttributeNode(node.pos,
-#                            obj=NameNode(node.pos, name=bufaux.buffer_info_var.name),
-#                            attribute=EncodedString("strides")),
-#                        index=IntNode(node.pos, value=EncodedString(idx))))
-#                print ass.dump()
-
-    def visit_SingleAssignmentNode(self, node):
-        # On assignments, two buffer-related things can happen:
-        # a) A buffer variable is assigned to (reacquisition)
-        # b) Buffer access assignment: arr[...] = ...
-        # Since we don't allow nested buffers, these don't overlap.
-        
-#        self.lhs = True
-        self.visit(node.rhs)
-        self.visit(node.lhs)
-#        self.lhs = False
-#        self.visitchildren(node)
-        # Only acquire buffers on vars (not attributes) for now.
-        if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
-            # Is buffer variable
-            return self.reacquire_buffer(node)
-        elif (isinstance(node.lhs, IndexNode) and
-              isinstance(node.lhs.base, NameNode) and
-              node.lhs.base.entry.buffer_aux is not None):
-            return self.assign_into_buffer(node)
-        
-    def reacquire_buffer(self, node):
-        bufaux = node.lhs.entry.buffer_aux
-        auxass = []
-        for idx, entry in enumerate(bufaux.stridevars):
-            entry.used = True
-            ass = self.fetch_strides.substitute({
-                u"TARGET": NameNode(node.pos, name=entry.name),
-                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                u"IDX": IntNode(node.pos, value=EncodedString(idx))
-            })
-            auxass.append(ass)
-
-        for idx, entry in enumerate(bufaux.shapevars):
-            entry.used = True
-            ass = self.fetch_shape.substitute({
-                u"TARGET": NameNode(node.pos, name=entry.name),
-                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                u"IDX": IntNode(node.pos, value=EncodedString(idx))
-            })
-            auxass.append(ass)
-
-        bufaux.buffer_info_var.used = True
-        acq = self.acquire_buffer_fragment.substitute({
-            u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
-            u"LHS" : node.lhs,
-            u"RHS": node.rhs,
-            u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
-            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
-        }, pos=node.pos)
-        # Note: The below should probably be refactored into something
-        # like fragment.substitute(..., context=self.context), with
-        # TreeFragment getting context.pipeline_until_now() and
-        # applying it on the fragment.
-        acq.analyse_declarations(self.scope)
-        acq.analyse_expressions(self.scope)
-        stats = acq.stats
-        return stats
-
-    def assign_into_buffer(self, node):
-        result = SingleAssignmentNode(node.pos,
-                                      rhs=self.visit(node.rhs),
-                                      lhs=self.buffer_index(node.lhs))
-        result.analyse_expressions(self.scope)
-        return result
-        
-
-    def buffer_index(self, node):
-        bufaux = node.base.entry.buffer_aux
-        assert bufaux is not None
-        # indices * strides...
-        to_sum = [ IntBinopNode(node.pos, operator='*',
-                                operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
-                                operand2=NameNode(node.pos, name=stride.name))
-            for index, stride in zip(node.indices, bufaux.stridevars)]
-
-        # then sum them 
-        expr = to_sum[0]
-        for next in to_sum[1:]:
-            expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
-
-        tmp= self.buffer_access.substitute({
-            'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
-            'OFFSET': expr
-            }, pos=node.pos)
-
-        return tmp.stats[0].expr
-
-    buffer_access = TreeFragment(u"""
-        (<unsigned char*>(BUF.buf + OFFSET))[0]
-    """)
-    def visit_IndexNode(self, node):
-        # Only occurs when the IndexNode is an rvalue
-        if node.is_buffer_access:
-            assert node.index is None
-            assert node.indices is not None
-            result = self.buffer_index(node)
-            result.analyse_expressions(self.scope)
-            return result
-        else:
-            return node
-
-    def visit_CallNode(self, node):
-###        print node.dump()
-        return node
-    
-#    def visit_FuncDefNode(self, node):
-#        print node.dump()
-    
-
-class PhaseEnvelopeNode(Node):
-    """
-    This node is used if you need to protect a node from reevaluation
-    of a phase. For instance, if you extract...
-
-    Use with care!
-    """
-
-    # Phases
-    PARSED, ANALYSED = range(2)
-
-    def __init__(self, phase, wrapped):
-        self.phase = phase
-        self.wrapped = wrapped
-
-    def get_pos(self): return self.wrapped.pos 
-    def set_pos(self, value): self.wrapped.pos  = value
-    pos = property(get_pos, set_pos)
-
-    def get_subexprs(self): return self.wrapped.subexprs
-    subexprs = property(get_subexprs)
-
-    def analyse_types(self, env):
-        if self.phase < self.ANALYSED:
-            self.wrapped.analyse_types(env)
-
-    def __getattribute__(self, attrname):
-        wrapped = object.__getattribute__(self, "wrapped")
-        phase = object.__getattribute__(self, "phase")
-        if attrname == "wrapped": return wrapped
-        if attrname == "phase": return phase
-        
-        attr = getattr(wrapped, attrname)
-
-        overridden = ("analyse_types",)
-
-
-        
-        print attrname, attr
-        if not isinstance(attr, Node):
-            return attr
-        else:
-            if attr is None: return None
-            else:
-                return PhaseEnvelopeNode(phase, attr)
-
-
-    
-
-
 class WithTransform(CythonTransform):
 
     # EXCINFO is manually set to a variable that contains