Handle inplace arithmatic via parse tree transform.
authorRobert Bradshaw <robertwb@math.washington.edu>
Fri, 12 Nov 2010 07:58:11 +0000 (23:58 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Fri, 12 Nov 2010 07:58:11 +0000 (23:58 -0800)
Excludes buffers and C++, which have their own code.
This is in preparation for #591 (inline vs. cdivision) and
support for inline complex arithamtic.

Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Nodes.py
Cython/Compiler/ParseTreeTransforms.py
Cython/Compiler/UtilNodes.py

index 94dca417bc0c04e6a04efe6333c6a2729d9164f5..233036553d5a18abb233dcd12a7c5bce1b9e8115 100755 (executable)
@@ -2746,6 +2746,7 @@ class SimpleCallNode(CallNode):
     wrapper_call = False
     has_optional_args = False
     nogil = False
+    analysed = False
     
     def compile_time_value(self, denv):
         function = self.function.compile_time_value(denv)
@@ -2799,6 +2800,9 @@ class SimpleCallNode(CallNode):
     def analyse_types(self, env):
         if self.analyse_as_type_constructor(env):
             return
+        if self.analysed:
+            return
+        self.analysed = True
         function = self.function
         function.is_called = 1
         self.function.analyse_types(env)
index 700f43f20c7eec58ff833626f8d0725a02875531..2e283d8d92a0d0ac36bcdeecd2eed1c6eabf2919 100644 (file)
@@ -98,6 +98,7 @@ class Context(object):
         from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform
         from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform
         from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods
+        from ParseTreeTransforms import ExpandInplaceOperators
         from TypeInference import MarkAssignments, MarkOverflowingArithmetic
         from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck
         from AnalysedTreeTransforms import AutoTestDictTransform
@@ -143,6 +144,7 @@ class Context(object):
             IntroduceBufferAuxiliaryVars(self),
             _check_c_declarations,
             AnalyseExpressionsTransform(self),
+            ExpandInplaceOperators(self),
             OptimizeBuiltinCalls(self),  ## Necessary?
             IterationTransform(),
             SwitchTransform(),
index 1b6ab907afb452d09cac324107f7bf8b803e3eea..a6953e2e58030381fa02e9ffd38fcdebdb164204 100644 (file)
@@ -3520,15 +3520,15 @@ class InPlaceAssignmentNode(AssignmentNode):
     #  (it must be a NameNode, AttributeNode, or IndexNode).     
     
     child_attrs = ["lhs", "rhs"]
-    dup = None
 
     def analyse_declarations(self, env):
         self.lhs.analyse_target_declaration(env)
         
     def analyse_types(self, env):
-        self.dup = self.create_dup_node(env) # re-assigns lhs to a shallow copy
         self.rhs.analyse_types(env)
         self.lhs.analyse_target_types(env)
+        return
+        
         import ExprNodes
         if self.lhs.type.is_pyobject:
             self.rhs = self.rhs.coerce_to_pyobject(env)
@@ -3539,6 +3539,28 @@ class InPlaceAssignmentNode(AssignmentNode):
             self.result_value = self.result_value_temp.coerce_to(self.lhs.type, env)
         
     def generate_execution_code(self, code):
+        import ExprNodes
+        self.rhs.generate_evaluation_code(code)
+        self.lhs.generate_subexpr_evaluation_code(code)
+        c_op = self.operator
+        if c_op == "//":
+            c_op = "/"
+        elif c_op == "**":
+            error(self.pos, "No C inplace power operator")
+        if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
+            if self.lhs.type.is_pyobject:
+                error(self.pos, "In-place operators not allowed on object buffers in this release.")
+            self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
+        else:
+            # C++
+            # TODO: make sure overload is declared
+            code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()))
+        self.lhs.generate_subexpr_disposal_code(code)
+        self.lhs.free_subexpr_temps(code)
+        self.rhs.generate_disposal_code(code)
+        self.rhs.free_temps(code)
+
+        return
         import ExprNodes
         self.rhs.generate_evaluation_code(code)
         self.dup.generate_subexpr_evaluation_code(code)
@@ -3581,10 +3603,15 @@ class InPlaceAssignmentNode(AssignmentNode):
                 
             # have to do assignment directly to avoid side-effects
             if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
+                if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']:
+                    error(self.pos, "Inplace non-c division not implemented for buffer types. (Use cdivision=False for now.)")
                 self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
             else:
                 self.dup.generate_result_code(code)
-                code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) )
+                if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']:
+                    error(self.pos, "bad")
+                else:
+                    code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) )
             self.rhs.generate_disposal_code(code)
             self.rhs.free_temps(code)
         if self.dup.is_temp:
@@ -3645,7 +3672,6 @@ class InPlaceAssignmentNode(AssignmentNode):
     def annotate(self, code):
         self.lhs.annotate(code)
         self.rhs.annotate(code)
-        self.dup.annotate(code)
     
     def create_binop_node(self):
         import ExprNodes
index 2be6f7a42164cb1448e7e1ad164508bb69d8434a..6692ca08fee2f68914ef36a7d4651c8d8021a5f8 100644 (file)
@@ -1194,7 +1194,74 @@ class AnalyseExpressionsTransform(CythonTransform):
             node.analyse_scoped_expressions(node.expr_scope)
         self.visitchildren(node)
         return node
+
+class ExpandInplaceOperators(CythonTransform):
+
+    def __call__(self, root):
+        self.env_stack = [root.scope]
+        return super(ExpandInplaceOperators, self).__call__(root)
+
+    def visit_FuncDefNode(self, node):
+        self.env_stack.append(node.local_scope)
+        self.visitchildren(node)
+        self.env_stack.pop()
+        return node
+    
         
+
+    def visit_InPlaceAssignmentNode(self, node):
+        lhs = node.lhs
+        rhs = node.rhs
+        if lhs.type.is_cpp_class:
+            # No getting around this exact operator here.
+            return node
+        if isinstance(lhs, IndexNode) and lhs.is_buffer_access:
+            # There is code to handle this case.
+            return node
+
+        def side_effect_free_reference(node, setting=False):
+            if node.type.is_pyobject and not setting:
+                node = LetRefNode(node)
+                return node, [node]
+            elif isinstance(node, IndexNode):
+                if node.is_buffer_access:
+                    raise ValueError, "Buffer access"
+                base, temps = side_effect_free_reference(node.base)
+                index = LetRefNode(node.index)
+                return IndexNode(node.pos, base=base, index=index), temps + [index]
+            elif isinstance(node, AttributeNode):
+                obj, temps = side_effect_free_reference(node.obj)
+                return AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps
+            elif isinstance(node, NameNode):
+                return node, []
+            else:
+                node = LetRefNode(node)
+                return node, [node]
+        try:
+            lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True)
+        except ValueError:
+            return node
+        dup = lhs.__class__(**lhs.__dict__)
+        binop = binop_node(node.pos, 
+                           operator = node.operator,
+                           operand1 = dup,
+                           operand2 = rhs)
+        node = SingleAssignmentNode(node.pos, lhs=lhs, rhs=binop) #, inplace=True)
+        # Use LetRefNode to avoid side effects.
+        let_ref_nodes.reverse()
+        for t in let_ref_nodes:
+            node = LetNode(t, node)
+        node.analyse_expressions(self.env_stack[-1])
+        return node
+
+    def visit_ExprNode(self, node):
+        # In-place assignments can't happen within an expression.
+        return node
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
 class AlignFunctionDefinitions(CythonTransform):
     """
     This class takes the signatures from a .pxd file and applies them to 
index 3ac02cfac37c09f31ad1b30f366169cfdd4e0708..821f2143b8ec2ec01706fb9d6b1adac355c562a1 100644 (file)
@@ -8,6 +8,7 @@ import Nodes
 import ExprNodes
 from Nodes import Node
 from ExprNodes import AtomicExprNode
+from PyrexTypes import c_ptr_type
 
 class TempHandle(object):
     # THIS IS DEPRECATED, USE LetRefNode instead
@@ -196,6 +197,8 @@ class LetNodeMixin:
     def setup_temp_expr(self, code):
         self.temp_expression.generate_evaluation_code(code)
         self.temp_type = self.temp_expression.type
+        if self.temp_type.is_array:
+            self.temp_type = c_ptr_type(self.temp_type.base_type)
         self._result_in_temp = self.temp_expression.result_in_temp()
         if self._result_in_temp:
             self.temp = self.temp_expression.result()