From 6d2ccb85d1fcde98944b9ef8b95dd063bcb7ed0d Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Tue, 27 Apr 2010 17:25:42 +0200 Subject: [PATCH] fix constant folding in PrimaryCmpNode/CascadedCmpNode --- Cython/Compiler/ExprNodes.py | 21 ++++++---- tests/run/consts.pyx | 75 ++++++++++++++++++++++++++++++++++++ 2 files changed, 88 insertions(+), 8 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index 9631ad6a..778ef394 100755 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -5475,11 +5475,13 @@ class CmpNode(object): func = compile_time_binary_operators[self.operator] operand2_result = self.operand2.constant_result result = func(operand1_result, operand2_result) - if result and self.cascade: - result = result and \ - self.cascade.cascaded_compile_time_value(operand2_result) - self.constant_result = result - + if self.cascade: + self.cascade.calculate_cascaded_constant_result(operand2_result) + if self.cascade.constant_result: + self.constant_result = result and self.cascade.constant_result + else: + self.constant_result = result + def cascaded_compile_time_value(self, operand1, denv): func = get_compile_time_binop(self) operand2 = self.operand2.compile_time_value(denv) @@ -5492,7 +5494,7 @@ class CmpNode(object): cascade = self.cascade if cascade: # FIXME: I bet this must call cascaded_compile_time_value() - result = result and cascade.compile_time_value(operand2, denv) + result = result and cascade.cascaded_compile_time_value(operand2, denv) return result def is_cpp_comparison(self): @@ -5787,8 +5789,7 @@ class PrimaryCmpNode(ExprNode, CmpNode): return () def calculate_constant_result(self): - self.constant_result = self.calculate_cascaded_constant_result( - self.operand1.constant_result) + self.calculate_cascaded_constant_result(self.operand1.constant_result) def compile_time_value(self, denv): operand1 = self.operand1.compile_time_value(denv) @@ -5966,6 +5967,10 @@ class CascadedCmpNode(Node, CmpNode): def type_dependencies(self, env): return () + def has_constant_result(self): + return self.constant_result is not constant_value_not_set and \ + self.constant_result is not not_a_constant + def analyse_types(self, env): self.operand2.analyse_types(env) if self.cascade: diff --git a/tests/run/consts.pyx b/tests/run/consts.pyx index 2511cec8..c1a461c2 100644 --- a/tests/run/consts.pyx +++ b/tests/run/consts.pyx @@ -1,6 +1,10 @@ import sys IS_PY3 = sys.version_info[0] >= 3 +cimport cython + +DEF INT_VAL = 1 + def _func(a,b,c): return a+b+c @@ -76,3 +80,74 @@ def lists(): True """ return [1,2,3] + [4,5,6] + +def int_bool_result(): + """ + >>> int_bool_result() + True + """ + if 5: + return True + else: + return False + +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def if_compare_true(): + """ + >>> if_compare_true() + True + """ + if 0 == 0: + return True + else: + return False + +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def if_compare_false(): + """ + >>> if_compare_false() + False + """ + if 0 == 1 or 1 == 0: + return True + else: + return False + +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def if_compare_cascaded(): + """ + >>> if_compare_cascaded() + True + """ + if 0 < 1 < 2 < 3: + return True + else: + return False + +def list_bool_result(): + """ + >>> list_bool_result() + True + """ + if [1,2,3]: + return True + else: + return False + +def compile_time_DEF(): + """ + >>> compile_time_DEF() + (1, False, True, True, False) + """ + return INT_VAL, INT_VAL == 0, INT_VAL != 0, INT_VAL == 1, INT_VAL != 1 + +@cython.test_fail_if_path_exists("//PrimaryCmpNode") +def compile_time_DEF_if(): + """ + >>> compile_time_DEF_if() + True + """ + if INT_VAL != 0: + return True + else: + return False -- 2.26.2