Works with some assignment expressions
authorDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 8 Jul 2008 11:00:55 +0000 (13:00 +0200)
committerDag Sverre Seljebotn <dagss@student.matnat.uio.no>
Tue, 8 Jul 2008 11:00:55 +0000 (13:00 +0200)
Cython/Compiler/ExprNodes.py
Cython/Compiler/ParseTreeTransforms.py

index b469e502d2c096bb312b0a0c526043dfc7022974..d342cc1a72704254e7b009c81dcd1010d32b2596 100644 (file)
@@ -1339,8 +1339,11 @@ class IndexNode(ExprNode):
         return 1
     
     def calculate_result_code(self):
-        return "(%s[%s])" % (
-            self.base.result_code, self.index.result_code)
+        if self.is_buffer_access:
+            return "<not needed>"
+        else:
+            return "(%s[%s])" % (
+                self.base.result_code, self.index.result_code)
             
     def index_unsigned_parameter(self):
         if self.index.type.is_int:
@@ -3842,6 +3845,10 @@ class CoerceToPyTypeNode(CoercionNode):
 
     gil_message = "Converting to Python object"
 
+    def analyse_types(self, env):
+        # The arg is always already analysed
+        pass
+
     def generate_result_code(self, code):
         function = self.arg.type.to_py_function
         code.putln('%s = %s(%s); %s' % (
@@ -3866,6 +3873,10 @@ class CoerceFromPyTypeNode(CoercionNode):
             error(arg.pos,
                 "Obtaining char * from temporary Python value")
     
+    def analyse_types(self, env):
+        # The arg is always already analysed
+        pass
+
     def generate_result_code(self, code):
         function = self.type.from_py_function
         operand = self.arg.py_result()
index dc25c5d8e455536999821baac307169b064d661f..c269a2b2f33685390cde06d019bb58b81662dcc5 100644 (file)
@@ -155,6 +155,7 @@ class BufferTransform(CythonTransform):
     def __call__(self, node):
         cymod = self.context.modules[u'__cython__']
         self.buffer_type = cymod.entries[u'Py_buffer'].type
+        self.lhs = False
         return super(BufferTransform, self).__call__(node)
 
     def handle_scope(self, node, scope):
@@ -229,75 +230,105 @@ class BufferTransform(CythonTransform):
 #                            attribute=EncodedString("strides")),
 #                        index=IntNode(node.pos, value=EncodedString(idx))))
 #                print ass.dump()
+
     def visit_SingleAssignmentNode(self, node):
-        self.visitchildren(node)
+        # On assignments, two buffer-related things can happen:
+        # a) A buffer variable is assigned to (reacquisition)
+        # b) Buffer access assignment: arr[...] = ...
+        # Since we don't allow nested buffers, these don't overlap.
+        
+#        self.lhs = True
+        self.visit(node.rhs)
+        self.visit(node.lhs)
+#        self.lhs = False
+#        self.visitchildren(node)
+        # Only acquire buffers on vars (not attributes) for now.
+        if isinstance(node.lhs, NameNode) and node.lhs.entry.buffer_aux:
+            # Is buffer variable
+            return self.reacquire_buffer(node)
+        elif (isinstance(node.lhs, IndexNode) and
+              isinstance(node.lhs.base, NameNode) and
+              node.lhs.base.entry.buffer_aux is not None):
+            return self.assign_into_buffer(node)
+        
+    def reacquire_buffer(self, node):
         bufaux = node.lhs.entry.buffer_aux
-        if bufaux is not None:
-            auxass = []
-            for idx, entry in enumerate(bufaux.stridevars):
-                entry.used = True
-                ass = self.fetch_strides.substitute({
-                    u"TARGET": NameNode(node.pos, name=entry.name),
-                    u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                    u"IDX": IntNode(node.pos, value=EncodedString(idx))
-                })
-                auxass.append(ass)
-
-            for idx, entry in enumerate(bufaux.shapevars):
-                entry.used = True
-                ass = self.fetch_shape.substitute({
-                    u"TARGET": NameNode(node.pos, name=entry.name),
-                    u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                    u"IDX": IntNode(node.pos, value=EncodedString(idx))
-                })
-                auxass.append(ass)
-                
-            bufaux.buffer_info_var.used = True
-            acq = self.acquire_buffer_fragment.substitute({
-                u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
-                u"LHS" : node.lhs,
-                u"RHS": node.rhs,
-                u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
-                u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+        auxass = []
+        for idx, entry in enumerate(bufaux.stridevars):
+            entry.used = True
+            ass = self.fetch_strides.substitute({
+                u"TARGET": NameNode(node.pos, name=entry.name),
+                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                u"IDX": IntNode(node.pos, value=EncodedString(idx))
+            })
+            auxass.append(ass)
+
+        for idx, entry in enumerate(bufaux.shapevars):
+            entry.used = True
+            ass = self.fetch_shape.substitute({
+                u"TARGET": NameNode(node.pos, name=entry.name),
+                u"BUFINFO": NameNode(node.pos, name=bufaux.buffer_info_var.name),
+                u"IDX": IntNode(node.pos, value=EncodedString(idx))
+            })
+            auxass.append(ass)
+
+        bufaux.buffer_info_var.used = True
+        acq = self.acquire_buffer_fragment.substitute({
+            u"TMP" : NameNode(pos=node.pos, name=bufaux.temp_var.name),
+            u"LHS" : node.lhs,
+            u"RHS": node.rhs,
+            u"ASSIGN_AUX": StatListNode(node.pos, stats=auxass),
+            u"BUFINFO": NameNode(pos=node.pos, name=bufaux.buffer_info_var.name)
+        }, pos=node.pos)
+        # Note: The below should probably be refactored into something
+        # like fragment.substitute(..., context=self.context), with
+        # TreeFragment getting context.pipeline_until_now() and
+        # applying it on the fragment.
+        acq.analyse_declarations(self.scope)
+        acq.analyse_expressions(self.scope)
+        stats = acq.stats
+        return stats
+
+    def assign_into_buffer(self, node):
+        result = SingleAssignmentNode(node.pos,
+                                      rhs=self.visit(node.rhs),
+                                      lhs=self.buffer_index(node.lhs))
+        result.analyse_expressions(self.scope)
+        return result
+        
+
+    def buffer_index(self, node):
+        bufaux = node.base.entry.buffer_aux
+        assert bufaux is not None
+        # indices * strides...
+        to_sum = [ IntBinopNode(node.pos, operator='*',
+                                operand1=index, #PhaseEnvelopeNode(PhaseEnvelopeNode.ANALYSED, index),
+                                operand2=NameNode(node.pos, name=stride.name))
+            for index, stride in zip(node.indices, bufaux.stridevars)]
+
+        # then sum them 
+        expr = to_sum[0]
+        for next in to_sum[1:]:
+            expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
+
+        tmp= self.buffer_access.substitute({
+            'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
+            'OFFSET': expr
             }, pos=node.pos)
-            # Note: The below should probably be refactored into something
-            # like fragment.substitute(..., context=self.context), with
-            # TreeFragment getting context.pipeline_until_now() and
-            # applying it on the fragment.
-            acq.analyse_declarations(self.scope)
-            acq.analyse_expressions(self.scope)
-            stats = acq.stats
-#            stats += [node] # Do assignment after successful buffer acquisition
-         #   print acq.dump()
-            return stats
-        else:
-            return node
+
+        return tmp.stats[0].expr
 
     buffer_access = TreeFragment(u"""
         (<unsigned char*>(BUF.buf + OFFSET))[0]
     """)
     def visit_IndexNode(self, node):
+        # Only occurs when the IndexNode is an rvalue
         if node.is_buffer_access:
             assert node.index is None
             assert node.indices is not None
-            bufaux = node.base.entry.buffer_aux
-            assert bufaux is not None
-            to_sum = [ IntBinopNode(node.pos, operator='*', operand1=index,
-                                    operand2=NameNode(node.pos, name=stride.name))
-                for index, stride in zip(node.indices, bufaux.stridevars)]
-            print to_sum
-
-            indices = node.indices
-            # reduce * on indices
-            expr = to_sum[0]
-            for next in to_sum[1:]:
-                expr = IntBinopNode(node.pos, operator='+', operand1=expr, operand2=next)
-            tmp= self.buffer_access.substitute({
-                'BUF': NameNode(node.pos, name=bufaux.buffer_info_var.name),
-                'OFFSET': expr
-                })
-            tmp.analyse_expressions(self.scope)
-            return tmp.stats[0].expr
+            result = self.buffer_index(node)
+            result.analyse_expressions(self.scope)
+            return result
         else:
             return node
 
@@ -309,6 +340,56 @@ class BufferTransform(CythonTransform):
 #        print node.dump()
     
 
+class PhaseEnvelopeNode(Node):
+    """
+    This node is used if you need to protect a node from reevaluation
+    of a phase. For instance, if you extract...
+
+    Use with care!
+    """
+
+    # Phases
+    PARSED, ANALYSED = range(2)
+
+    def __init__(self, phase, wrapped):
+        self.phase = phase
+        self.wrapped = wrapped
+
+    def get_pos(self): return self.wrapped.pos 
+    def set_pos(self, value): self.wrapped.pos  = value
+    pos = property(get_pos, set_pos)
+
+    def get_subexprs(self): return self.wrapped.subexprs
+    subexprs = property(get_subexprs)
+
+    def analyse_types(self, env):
+        if self.phase < self.ANALYSED:
+            self.wrapped.analyse_types(env)
+
+    def __getattribute__(self, attrname):
+        wrapped = object.__getattribute__(self, "wrapped")
+        phase = object.__getattribute__(self, "phase")
+        if attrname == "wrapped": return wrapped
+        if attrname == "phase": return phase
+        
+        attr = getattr(wrapped, attrname)
+
+        overridden = ("analyse_types",)
+
+
+        
+        print attrname, attr
+        if not isinstance(attr, Node):
+            return attr
+        else:
+            if attr is None: return None
+            else:
+                return PhaseEnvelopeNode(phase, attr)
+
+
+    
+
+
 class WithTransform(CythonTransform):
 
     # EXCINFO is manually set to a variable that contains