From 5668a4fb73f61c6acd26c004e89ad979c61965ac Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Thu, 11 Nov 2010 23:58:11 -0800 Subject: [PATCH] Handle inplace arithmatic via parse tree transform. Excludes buffers and C++, which have their own code. This is in preparation for #591 (inline vs. cdivision) and support for inline complex arithamtic. --- Cython/Compiler/ExprNodes.py | 4 ++ Cython/Compiler/Main.py | 2 + Cython/Compiler/Nodes.py | 34 +++++++++++-- Cython/Compiler/ParseTreeTransforms.py | 67 ++++++++++++++++++++++++++ Cython/Compiler/UtilNodes.py | 3 ++ 5 files changed, 106 insertions(+), 4 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 94dca417..23303655 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -2746,6 +2746,7 @@ class SimpleCallNode(CallNode): wrapper_call = False has_optional_args = False nogil = False + analysed = False def compile_time_value(self, denv): function = self.function.compile_time_value(denv) @@ -2799,6 +2800,9 @@ class SimpleCallNode(CallNode): def analyse_types(self, env): if self.analyse_as_type_constructor(env): return + if self.analysed: + return + self.analysed = True function = self.function function.is_called = 1 self.function.analyse_types(env) diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 700f43f2..2e283d8d 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -98,6 +98,7 @@ class Context(object): from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform from ParseTreeTransforms import InterpretCompilerDirectives, TransformBuiltinMethods + from ParseTreeTransforms import ExpandInplaceOperators from TypeInference import MarkAssignments, MarkOverflowingArithmetic from ParseTreeTransforms import AlignFunctionDefinitions, GilCheck from AnalysedTreeTransforms import AutoTestDictTransform @@ -143,6 +144,7 @@ class Context(object): IntroduceBufferAuxiliaryVars(self), _check_c_declarations, AnalyseExpressionsTransform(self), + ExpandInplaceOperators(self), OptimizeBuiltinCalls(self), ## Necessary? IterationTransform(), SwitchTransform(), diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 1b6ab907..a6953e2e 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -3520,15 +3520,15 @@ class InPlaceAssignmentNode(AssignmentNode): # (it must be a NameNode, AttributeNode, or IndexNode). child_attrs = ["lhs", "rhs"] - dup = None def analyse_declarations(self, env): self.lhs.analyse_target_declaration(env) def analyse_types(self, env): - self.dup = self.create_dup_node(env) # re-assigns lhs to a shallow copy self.rhs.analyse_types(env) self.lhs.analyse_target_types(env) + return + import ExprNodes if self.lhs.type.is_pyobject: self.rhs = self.rhs.coerce_to_pyobject(env) @@ -3539,6 +3539,28 @@ class InPlaceAssignmentNode(AssignmentNode): self.result_value = self.result_value_temp.coerce_to(self.lhs.type, env) def generate_execution_code(self, code): + import ExprNodes + self.rhs.generate_evaluation_code(code) + self.lhs.generate_subexpr_evaluation_code(code) + c_op = self.operator + if c_op == "//": + c_op = "/" + elif c_op == "**": + error(self.pos, "No C inplace power operator") + if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access: + if self.lhs.type.is_pyobject: + error(self.pos, "In-place operators not allowed on object buffers in this release.") + self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op) + else: + # C++ + # TODO: make sure overload is declared + code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result())) + self.lhs.generate_subexpr_disposal_code(code) + self.lhs.free_subexpr_temps(code) + self.rhs.generate_disposal_code(code) + self.rhs.free_temps(code) + + return import ExprNodes self.rhs.generate_evaluation_code(code) self.dup.generate_subexpr_evaluation_code(code) @@ -3581,10 +3603,15 @@ class InPlaceAssignmentNode(AssignmentNode): # have to do assignment directly to avoid side-effects if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access: + if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']: + error(self.pos, "Inplace non-c division not implemented for buffer types. (Use cdivision=False for now.)") self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op) else: self.dup.generate_result_code(code) - code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) ) + if self.lhs.type.is_int and c_op == "/" and not code.globalstate.directives['cdivision']: + error(self.pos, "bad") + else: + code.putln("%s %s= %s;" % (self.lhs.result(), c_op, self.rhs.result()) ) self.rhs.generate_disposal_code(code) self.rhs.free_temps(code) if self.dup.is_temp: @@ -3645,7 +3672,6 @@ class InPlaceAssignmentNode(AssignmentNode): def annotate(self, code): self.lhs.annotate(code) self.rhs.annotate(code) - self.dup.annotate(code) def create_binop_node(self): import ExprNodes diff --git a/Cython/Compiler/ParseTreeTransforms.py b/Cython/Compiler/ParseTreeTransforms.py index 2be6f7a4..6692ca08 100644 --- a/Cython/Compiler/ParseTreeTransforms.py +++ b/Cython/Compiler/ParseTreeTransforms.py @@ -1194,7 +1194,74 @@ class AnalyseExpressionsTransform(CythonTransform): node.analyse_scoped_expressions(node.expr_scope) self.visitchildren(node) return node + +class ExpandInplaceOperators(CythonTransform): + + def __call__(self, root): + self.env_stack = [root.scope] + return super(ExpandInplaceOperators, self).__call__(root) + + def visit_FuncDefNode(self, node): + self.env_stack.append(node.local_scope) + self.visitchildren(node) + self.env_stack.pop() + return node + + + def visit_InPlaceAssignmentNode(self, node): + lhs = node.lhs + rhs = node.rhs + if lhs.type.is_cpp_class: + # No getting around this exact operator here. + return node + if isinstance(lhs, IndexNode) and lhs.is_buffer_access: + # There is code to handle this case. + return node + + def side_effect_free_reference(node, setting=False): + if node.type.is_pyobject and not setting: + node = LetRefNode(node) + return node, [node] + elif isinstance(node, IndexNode): + if node.is_buffer_access: + raise ValueError, "Buffer access" + base, temps = side_effect_free_reference(node.base) + index = LetRefNode(node.index) + return IndexNode(node.pos, base=base, index=index), temps + [index] + elif isinstance(node, AttributeNode): + obj, temps = side_effect_free_reference(node.obj) + return AttributeNode(node.pos, obj=obj, attribute=node.attribute), temps + elif isinstance(node, NameNode): + return node, [] + else: + node = LetRefNode(node) + return node, [node] + try: + lhs, let_ref_nodes = side_effect_free_reference(lhs, setting=True) + except ValueError: + return node + dup = lhs.__class__(**lhs.__dict__) + binop = binop_node(node.pos, + operator = node.operator, + operand1 = dup, + operand2 = rhs) + node = SingleAssignmentNode(node.pos, lhs=lhs, rhs=binop) #, inplace=True) + # Use LetRefNode to avoid side effects. + let_ref_nodes.reverse() + for t in let_ref_nodes: + node = LetNode(t, node) + node.analyse_expressions(self.env_stack[-1]) + return node + + def visit_ExprNode(self, node): + # In-place assignments can't happen within an expression. + return node + + def visit_Node(self, node): + self.visitchildren(node) + return node + class AlignFunctionDefinitions(CythonTransform): """ This class takes the signatures from a .pxd file and applies them to diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index 3ac02cfa..821f2143 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -8,6 +8,7 @@ import Nodes import ExprNodes from Nodes import Node from ExprNodes import AtomicExprNode +from PyrexTypes import c_ptr_type class TempHandle(object): # THIS IS DEPRECATED, USE LetRefNode instead @@ -196,6 +197,8 @@ class LetNodeMixin: def setup_temp_expr(self, code): self.temp_expression.generate_evaluation_code(code) self.temp_type = self.temp_expression.type + if self.temp_type.is_array: + self.temp_type = c_ptr_type(self.temp_type.base_type) self._result_in_temp = self.temp_expression.result_in_temp() if self._result_in_temp: self.temp = self.temp_expression.result() -- 2.26.2