generic aggregation of a constant BinopNode into a ConstNode (in simple cases)
authorStefan Behnel <scoder@users.berlios.de>
Sun, 14 Dec 2008 14:08:21 +0000 (15:08 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sun, 14 Dec 2008 14:08:21 +0000 (15:08 +0100)
Cython/Compiler/Optimize.py
tests/run/consts.pyx

index a58fb327b314a129c6bb27e76aa658c60db151dd..d985ced0b9476312a6c95915d3579e033743b760 100644 (file)
@@ -427,8 +427,33 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
         self._calculate_const(node)
         return node
 
+#    def visit_NumBinopNode(self, node):
+    def visit_BinopNode(self, node):
+        self._calculate_const(node)
+        if node.type is PyrexTypes.py_object_type:
+            return node
+        if node.constant_result is ExprNodes.not_a_constant:
+            return node
+#        print node.constant_result, node.operand1, node.operand2, node.pos
+        if isinstance(node.operand1, ExprNodes.ConstNode) and \
+                node.type is node.operand1.type:
+            new_node = node.operand1
+        elif isinstance(node.operand2, ExprNodes.ConstNode) and \
+                node.type is node.operand2.type:
+            new_node = node.operand2
+        else:
+            return node
+        new_node.value = new_node.constant_result = node.constant_result
+        new_node = new_node.coerce_to(node.type, self.module_scope)
+        return new_node
+
     # in the future, other nodes can have their own handler method here
     # that can replace them with a constant result node
+    
+    def visit_ModuleNode(self, node):
+        self.module_scope = node.scope
+        self.visitchildren(node)
+        return node
 
     def visit_Node(self, node):
         self.visitchildren(node)
index d9da34370ed743b3308fe789f4fc2c30ec79356f..11b91a606ea2434c0ab996dba073de8eddbae58e 100644 (file)
@@ -1,8 +1,12 @@
 __doc__ = u"""
->>> add()
-10
->>> add_var(10)
-20
+>>> add() == 1+2+3+4
+True
+>>> add_var(10) == 1+2+10+3+4
+True
+>>> mul() == 1*60*1000
+True
+>>> arithm() == 9*2+3*8/6-10
+True
 """
 
 def add():
@@ -10,3 +14,9 @@ def add():
 
 def add_var(a):
     return 1+2 +a+ 3+4
+
+def mul():
+    return 1*60*1000
+
+def arithm():
+    return 9*2+3*8/6-10