moved constant folding before type analysis, disabled for type casts and float expres...
authorStefan Behnel <scoder@users.berlios.de>
Mon, 23 Mar 2009 10:56:04 +0000 (11:56 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Mon, 23 Mar 2009 10:56:04 +0000 (11:56 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py
tests/run/consts.pyx

index 0f2ae9caa9d5ccc599ef84ac17ebfee83cf7fa5d..359e953c6fec8ff86b89975108b58a38efb03fe7 100644 (file)
@@ -901,7 +901,9 @@ class FloatNode(ConstNode):
     type = PyrexTypes.c_double_type
 
     def calculate_constant_result(self):
-        self.constant_result = float(self.value)
+        # calculating float values is usually not a good idea
+        #self.constant_result = float(self.value)
+        pass
 
     def compile_time_value(self, denv):
         return float(self.value)
@@ -3927,7 +3929,9 @@ class TypecastNode(NewTempExprNode):
         self.operand.check_const()
 
     def calculate_constant_result(self):
-        self.constant_result = self.operand.constant_result
+        # we usually do not know the result of a type cast at code
+        # generation time
+        pass
     
     def calculate_result_code(self):
         opnd = self.operand
@@ -4939,7 +4943,8 @@ class CoercionNode(NewTempExprNode):
             print("%s Coercing %s" % (self, self.arg))
 
     def calculate_constant_result(self):
-        self.constant_result = self.arg.constant_result
+        # constant folding can break type coercion, so this is disabled
+        pass
             
     def annotate(self, code):
         self.arg.annotate(code)
@@ -4986,7 +4991,11 @@ class PyTypeTestNode(CoercionNode):
     
     def is_ephemeral(self):
         return self.arg.is_ephemeral()
-    
+
+    def calculate_constant_result(self):
+        # FIXME
+        pass
+
     def calculate_result_code(self):
         return self.arg.result()
     
index 333d4aba8c6992bc4e4b5d99e74431b9712b7901..1dfed5999458c35f09bcd58896c63e7a25aba5ab 100644 (file)
@@ -115,6 +115,7 @@ class Context(object):
             _specific_post_parse,
             InterpretCompilerDirectives(self, self.pragma_overrides),
             _align_function_definitions,
+            ConstantFolding(),
             FlattenInListTransform(),
             WithTransform(self),
             DecoratorTransform(self),
@@ -125,7 +126,6 @@ class Context(object):
             _check_c_declarations,
             AnalyseExpressionsTransform(self),
             FlattenBuiltinTypeCreation(),
-            ConstantFolding(),
 #            ComprehensionTransform(),
             IterationTransform(),
             SwitchTransform(),
index 7aec738ae84b8da6dceb57eb064a13d0da623b14..8dc06e21015cf9a0cad4723283641a8845379cf0 100644 (file)
@@ -566,44 +566,60 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
             import traceback, sys
             traceback.print_exc(file=sys.stdout)
 
+    NODE_TYPE_ORDER = (ExprNodes.CharNode, ExprNodes.IntNode,
+                       ExprNodes.LongNode, ExprNodes.FloatNode)
+
+    def _widest_node_class(self, *nodes):
+        try:
+            return self.NODE_TYPE_ORDER[
+                max(map(self.NODE_TYPE_ORDER.index, map(type, nodes)))]
+        except ValueError:
+            return None
+
     def visit_ExprNode(self, node):
         self._calculate_const(node)
         return node
 
-#    def visit_NumBinopNode(self, node):
     def visit_BinopNode(self, node):
         self._calculate_const(node)
-        if node.type is PyrexTypes.py_object_type:
-            return node
         if node.constant_result is ExprNodes.not_a_constant:
             return node
-#        print node.constant_result, node.operand1, node.operand2, node.pos
+        try:
+            if node.operand1.type is None or node.operand2.type is None:
+                return node
+        except AttributeError:
+            return node
+
+        type1, type2 = node.operand1.type, node.operand2.type
         if isinstance(node.operand1, ExprNodes.ConstNode) and \
-                node.type is node.operand1.type:
-            new_node = node.operand1
-        elif isinstance(node.operand2, ExprNodes.ConstNode) and \
-                node.type is node.operand2.type:
-            new_node = node.operand2
+               isinstance(node.operand1, ExprNodes.ConstNode):
+            if type1 is type2:
+                new_node = node.operand1
+            else:
+                widest_type = PyrexTypes.widest_numeric_type(type1, type2)
+                if type(node.operand1) is type(node.operand2):
+                    new_node = node.operand1
+                    new_node.type = widest_type
+                elif type1 is widest_type:
+                    new_node = node.operand1
+                elif type2 is widest_type:
+                    new_node = node.operand2
+                else:
+                    target_class = self._widest_node_class(
+                        node.operand1, node.operand2)
+                    if target_class is None:
+                        return node
+                    new_node = target_class(type = widest_type)
         else:
             return node
-        new_node.value = new_node.constant_result = node.constant_result
-        new_node = new_node.coerce_to(node.type, self.current_scope)
+
+        new_node.constant_result = node.constant_result
+        new_node.value = str(node.constant_result)
+        #new_node = new_node.coerce_to(node.type, self.current_scope)
         return new_node
 
     # in the future, other nodes can have their own handler method here
     # that can replace them with a constant result node
-    
-    def visit_ModuleNode(self, node):
-        self.current_scope = node.scope
-        self.visitchildren(node)
-        return node
-
-    def visit_FuncDefNode(self, node):
-        old_scope = self.current_scope
-        self.current_scope = node.entry.scope
-        self.visitchildren(node)
-        self.current_scope = old_scope
-        return node
 
     visit_Node = Visitor.VisitorTransform.recurse_to_children
 
index d6c500f4e0d53f5b2cf5c7cd57c4b45dc97dece3..873fc000ec52873527db9bc88c89bcefe13152ee 100644 (file)
@@ -5,6 +5,13 @@ True
 True
 >>> neg() == -1 -2 - (-3+4)
 True
+>>> int_mix() == 1 + (2 * 3) // 2
+True
+>>> if IS_PY3: type(int_mix()) is int
+... else:      type(int_mix()) is long
+True
+>>> int_cast() == 1 + 2 * 6000
+True
 >>> mul() == 1*60*1000
 True
 >>> arithm() == 9*2+3*8/6-10
@@ -15,6 +22,9 @@ True
 True
 """
 
+import sys
+IS_PY3 = sys.version_info[0] >= 3
+
 def _func(a,b,c):
     return a+b+c
 
@@ -27,6 +37,12 @@ def add_var(a):
 def neg():
     return -1 -2 - (-3+4)
 
+def int_mix():
+    return 1L + (2 * 3L) // 2
+
+def int_cast():
+    return <int>(1 + 2 * 6000)
+
 def mul():
     return 1*60*1000