Implemented inplace arithmetic
authorRobert Bradshaw <robertwb@math.washington.edu>
Tue, 16 Jan 2007 01:39:47 +0000 (17:39 -0800)
committerRobert Bradshaw <robertwb@math.washington.edu>
Tue, 16 Jan 2007 01:39:47 +0000 (17:39 -0800)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Lexicon.py
Cython/Compiler/Nodes.py
Cython/Compiler/Parsing.py

index 1569f5e0356a6f9eac10d21b22fd716296d21793..66c00c82cee7437de3010527d1f86d70f96915d1 100644 (file)
@@ -3041,8 +3041,11 @@ class CloneNode(CoercionNode):
     
     def __init__(self, arg):
         CoercionNode.__init__(self, arg)
-        self.type = arg.type
-        self.result_ctype = arg.result_ctype
+        if hasattr(arg, 'type'):
+            self.type = arg.type
+            self.result_ctype = arg.result_ctype
+        if hasattr(arg, 'entry'):
+            self.entry = arg.entry
     
     def calculate_result_code(self):
         return self.arg.result_code
@@ -3051,6 +3054,8 @@ class CloneNode(CoercionNode):
         self.type = self.arg.type
         self.result_ctype = self.arg.result_ctype
         self.is_temp = 1
+        if hasattr(self.arg, 'entry'):
+            self.entry = self.arg.entry
     
     #def result_as_extension_type(self):
     #  return self.arg.result_as_extension_type()
@@ -3060,6 +3065,15 @@ class CloneNode(CoercionNode):
 
     def generate_result_code(self, code):
         pass
+        
+    def generate_disposal_code(self, code):
+        code.putln("// ---- CloneNode.generate_disposal_code() for %s"%self.arg.result_code)
+        
+    def allocate_temps(self, env):
+        self.result_code = self.calculate_result_code()
+        
+    def release_temp(self, env):
+        pass
     
 #------------------------------------------------------------------------------------
 #
index 07ea3af1bd0b903d48d84f5707a34677f8d776bb..479b6cdfb0388f461367ec41e2b27024c7a1b768 100644 (file)
@@ -66,7 +66,7 @@ def make_lexicon():
     bra = Any("([{")
     ket = Any(")]}")
     punct = Any(":,;+-*/|&<>=.%`~^?")
-    diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**")
+    diphthong = Str("==", "<>", "!=", "<=", ">=", "<<", ">>", "**", "+=", "-=", "*=", "/=", "%=", "|=", "^=", "&=")
     spaces = Rep1(Any(" \t\f"))
     comment = Str("#") + Rep(AnyBut("\n"))
     escaped_newline = Str("\\\n")
index b45909f9b5943770b0b5b20b75a7705e405a23c2..3532f0f45e571b5dee6f1d0aae1956da3bd8c24d 100644 (file)
@@ -2708,6 +2708,104 @@ class ParallelAssignmentNode(AssignmentNode):
         for stat in self.stats:
             stat.generate_assignment_code(code)
 
+class InPlaceAssignmentNode(AssignmentNode):
+    #  An in place arithmatic operand:
+    #
+    #    a += b
+    #    a -= b
+    #    ...
+    #
+    #  lhs      ExprNode      Left hand side
+    #  rhs      ExprNode      Right hand side
+    #  op       char          one of "+-*/%^&|"
+    #  dup     (ExprNode)     copy of lhs used for operation (auto-generated)
+    #
+    #  This code is a bit tricky because in order to obey Python 
+    #  semantics the sub-expressions (e.g. indices) of the lhs must 
+    #  not be evaluated twice. So we must re-use the values calculated 
+    #  in evaluation phase for the assignment phase as well. 
+    #  Fortunately, the type of the lhs node is fairly constrained 
+    #  (it must be a NameNode, AttributeNode, or IndexNode).     
+    
+    def analyse_declarations(self, env):
+        self.lhs.analyse_target_declaration(env)
+    
+    def analyse_expressions_1(self, env, use_temp = 0):
+        import ExprNodes
+        self.create_dup_node(env) # re-assigns lhs to a shallow copy
+        self.rhs.analyse_types(env)
+        self.lhs.analyse_target_types(env)
+        if self.lhs.type.is_pyobject or self.rhs.type.is_pyobject:
+            self.rhs = self.rhs.coerce_to(self.lhs.type, env)
+        if self.lhs.type.is_pyobject:
+             self.result = ExprNodes.PyTempNode(self.pos, env)
+             self.result.allocate_temps(env)
+        if use_temp:
+            self.rhs = self.rhs.coerce_to_temp(env)
+        self.dup.allocate_subexpr_temps(env)
+        self.dup.allocate_temp(env)
+        self.rhs.allocate_temps(env)
+    
+    def analyse_expressions_2(self, env):
+        self.lhs.allocate_target_temps(env)
+        self.lhs.release_target_temp(env)
+        self.dup.release_temp(env)
+        if self.dup.is_temp:
+            self.dup.release_subexpr_temps(env)
+        self.rhs.release_temp(env)
+        if self.lhs.type.is_pyobject:
+            self.result.release_temp(env)
+
+    def generate_execution_code(self, code):
+        self.rhs.generate_evaluation_code(code)
+        self.dup.generate_subexpr_evaluation_code(code)
+        self.dup.generate_result_code(code)
+        if self.lhs.type.is_pyobject:
+            code.putln("//---- iadd code");
+            code.putln(
+                "%s = %s(%s, %s); if (!%s) %s" % (
+                    self.result.result_code, 
+                    self.py_operation_function(), 
+                    self.dup.py_result(),
+                    self.rhs.py_result(),
+                    self.result.py_result(),
+                    code.error_goto(self.pos)))
+            self.rhs.generate_disposal_code(code)
+            self.dup.generate_disposal_code(code)
+            self.lhs.generate_assignment_code(self.result, code)
+        else: 
+            # have to do assignment directly to avoid side-effects
+            code.putln("%s %s= %s;" % (self.lhs.result_code, self.operator, self.rhs.result_code) )
+            self.rhs.generate_disposal_code(code)
+        if self.dup.is_temp:
+            self.dup.generate_subexpr_disposal_code(code)
+            
+    def create_dup_node(self, env): 
+        import ExprNodes
+        self.dup = self.lhs
+        self.dup.analyse_types(env)
+        if isinstance(self.lhs, ExprNodes.NameNode):
+            target_lhs = ExprNodes.NameNode(self.dup.pos, name = self.dup.name, is_temp = self.dup.is_temp, entry = self.dup.entry)
+        elif isinstance(self.lhs, ExprNodes.AttributeNode):
+            target_lhs = ExprNodes.AttributeNode(self.dup.pos, obj = ExprNodes.CloneNode(self.lhs.obj), attribute = self.dup.attribute, is_temp = self.dup.is_temp)
+        elif isinstance(self.lhs, ExprNodes.IndexNode):
+            target_lhs = ExprNodes.IndexNode(self.dup.pos, base = ExprNodes.CloneNode(self.dup.base), index = ExprNodes.CloneNode(self.lhs.index), is_temp = self.dup.is_temp)
+        self.lhs = target_lhs
+    
+    def py_operation_function(self):
+        return self.py_functions[self.operator]
+
+    py_functions = {
+        "|":           "PyNumber_InPlaceOr",
+        "^":           "PyNumber_InPlaceXor",
+        "&":           "PyNumber_InPlaceAnd",
+        "+":           "PyNumber_InPlaceAdd",
+        "-":           "PyNumber_InPlaceSubtract",
+        "*":           "PyNumber_InPlaceMultiply",
+        "/":           "PyNumber_InPlaceDivide",
+        "%":           "PyNumber_InPlaceRemainder",
+    }
+
 
 class PrintStatNode(StatNode):
     #  print statement
index 3bbac250bf14f1f861f10fc3f8fd8dbfeeaf65fa..ab1ee1785a82ab1e609e568efa7517cb56d7ceaf 100644 (file)
@@ -700,8 +700,17 @@ def p_expression_or_assignment(s):
         s.next()
         expr_list.append(p_expr(s))
     if len(expr_list) == 1:
-        expr = expr_list[0]
-        return Nodes.ExprStatNode(expr.pos, expr = expr)
+        if re.match("[+*/\%^\&|-]=", s.sy):
+            lhs = expr_list[0]
+            if not isinstance(lhs, (ExprNodes.AttributeNode, ExprNodes.IndexNode, ExprNodes.NameNode) ):
+                error(lhs.pos, "Illegal operand for inplace operation.")
+            operator = s.sy[0]
+            s.next()
+            rhs = p_expr(s)
+            return Nodes.InPlaceAssignmentNode(lhs.pos, operator = operator, lhs = lhs, rhs = rhs)
+        else:
+            expr = expr_list[0]
+            return Nodes.ExprStatNode(expr.pos, expr = expr)
     else:
         expr_list_list = []
         flatten_parallel_assignments(expr_list, expr_list_list)
@@ -1835,6 +1844,7 @@ def p_module(s, pxd, full_module_name):
 #----------------------------------------------
 
 def print_parse_tree(f, node, level, key = None):      
+    from Nodes import Node
     ind = "  " * level
     if node:
         f.write(ind)