fix constant folding in PrimaryCmpNode/CascadedCmpNode
authorStefan Behnel <scoder@users.berlios.de>
Tue, 27 Apr 2010 15:25:42 +0000 (17:25 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Tue, 27 Apr 2010 15:25:42 +0000 (17:25 +0200)
Cython/Compiler/ExprNodes.py
tests/run/consts.pyx

index 9631ad6ab2367ac131b20ab1c9c3224be0fc4a12..778ef394be7a68fa8ecf2c2d8f53b1e6bd962ccc 100755 (executable)
@@ -5475,11 +5475,13 @@ class CmpNode(object):
         func = compile_time_binary_operators[self.operator]
         operand2_result = self.operand2.constant_result
         result = func(operand1_result, operand2_result)
-        if result and self.cascade:
-            result = result and \
-                self.cascade.cascaded_compile_time_value(operand2_result)
-        self.constant_result = result
-    
+        if self.cascade:
+            self.cascade.calculate_cascaded_constant_result(operand2_result)
+            if self.cascade.constant_result:
+                self.constant_result = result and self.cascade.constant_result
+        else:
+            self.constant_result = result
+
     def cascaded_compile_time_value(self, operand1, denv):
         func = get_compile_time_binop(self)
         operand2 = self.operand2.compile_time_value(denv)
@@ -5492,7 +5494,7 @@ class CmpNode(object):
             cascade = self.cascade
             if cascade:
                 # FIXME: I bet this must call cascaded_compile_time_value()
-                result = result and cascade.compile_time_value(operand2, denv)
+                result = result and cascade.cascaded_compile_time_value(operand2, denv)
         return result
 
     def is_cpp_comparison(self):
@@ -5787,8 +5789,7 @@ class PrimaryCmpNode(ExprNode, CmpNode):
         return ()
 
     def calculate_constant_result(self):
-        self.constant_result = self.calculate_cascaded_constant_result(
-            self.operand1.constant_result)
+        self.calculate_cascaded_constant_result(self.operand1.constant_result)
     
     def compile_time_value(self, denv):
         operand1 = self.operand1.compile_time_value(denv)
@@ -5966,6 +5967,10 @@ class CascadedCmpNode(Node, CmpNode):
     def type_dependencies(self, env):
         return ()
 
+    def has_constant_result(self):
+        return self.constant_result is not constant_value_not_set and \
+               self.constant_result is not not_a_constant
+
     def analyse_types(self, env):
         self.operand2.analyse_types(env)
         if self.cascade:
index 2511cec87346c02fdf1fca3239c989410d0e0678..c1a461c2e6db1dd78a2e58b5ff55d816ad808cb7 100644 (file)
@@ -1,6 +1,10 @@
 import sys
 IS_PY3 = sys.version_info[0] >= 3
 
+cimport cython
+
+DEF INT_VAL = 1
+
 def _func(a,b,c):
     return a+b+c
 
@@ -76,3 +80,74 @@ def lists():
     True
     """
     return [1,2,3] + [4,5,6]
+
+def int_bool_result():
+    """
+    >>> int_bool_result()
+    True
+    """
+    if 5:
+        return True
+    else:
+        return False
+
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def if_compare_true():
+    """
+    >>> if_compare_true()
+    True
+    """
+    if 0 == 0:
+        return True
+    else:
+        return False
+
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def if_compare_false():
+    """
+    >>> if_compare_false()
+    False
+    """
+    if 0 == 1 or 1 == 0:
+        return True
+    else:
+        return False
+
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def if_compare_cascaded():
+    """
+    >>> if_compare_cascaded()
+    True
+    """
+    if 0 < 1 < 2 < 3:
+        return True
+    else:
+        return False
+
+def list_bool_result():
+    """
+    >>> list_bool_result()
+    True
+    """
+    if [1,2,3]:
+        return True
+    else:
+        return False
+
+def compile_time_DEF():
+    """
+    >>> compile_time_DEF()
+    (1, False, True, True, False)
+    """
+    return INT_VAL, INT_VAL == 0, INT_VAL != 0, INT_VAL == 1, INT_VAL != 1
+
+@cython.test_fail_if_path_exists("//PrimaryCmpNode")
+def compile_time_DEF_if():
+    """
+    >>> compile_time_DEF_if()
+    True
+    """
+    if INT_VAL != 0:
+        return True
+    else:
+        return False