From: Stefan Behnel Date: Sat, 27 Mar 2010 09:15:37 +0000 (+0100) Subject: extend switch statement transformation to arbitrary 'in' tests: shorter, more readabl... X-Git-Tag: 0.13.beta0~251^2 X-Git-Url: http://git.tremily.us/?a=commitdiff_plain;h=ab7bd1948073c6296a81463afc6d906e95b691ed;p=cython.git extend switch statement transformation to arbitrary 'in' tests: shorter, more readable C code --- diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index c6e65bfa..83616355 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -540,34 +540,101 @@ class SwitchTransform(Visitor.VisitorTransform): if is_common_value(t1, t2): return t1, c1+c2 return None, None - + + def extract_common_conditions(self, common_var, condition): + var, conditions = self.extract_conditions(condition) + if var is None: + return None, None + elif common_var is not None and not is_common_value(var, common_var): + return None, None + elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]): + return None, None + return var, conditions + def visit_IfStatNode(self, node): - self.visitchildren(node) common_var = None - case_count = 0 cases = [] for if_clause in node.if_clauses: - var, conditions = self.extract_conditions(if_clause.condition) - if var is None: - return node - elif common_var is not None and not is_common_value(var, common_var): + common_var, conditions = self.extract_common_conditions( + common_var, if_clause.condition) + if common_var is None: return node - elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]): - return node - else: - common_var = var - case_count += len(conditions) - cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, - conditions = conditions, - body = if_clause.body)) - if case_count < 2: - return node - + cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, + conditions = conditions, + body = if_clause.body)) + + if sum([ len(case.conditions) for case in cases ]) < 2: + return node + + common_var = unwrap_node(common_var) + switch_node = Nodes.SwitchStatNode(pos = node.pos, + test = common_var, + cases = cases, + else_clause = node.else_clause) + self.visitchildren(switch_node) + return switch_node + + def visit_CondExprNode(self, node): + common_var, conditions = self.extract_common_conditions(None, node.test) + if common_var is None: + return node + if len(conditions) < 2: + return node + + result_ref = UtilNodes.ResultRefNode(node) + true_body = Nodes.SingleAssignmentNode( + node.pos, + lhs = result_ref, + rhs = node.true_val, + first = True) + false_body = Nodes.SingleAssignmentNode( + node.pos, + lhs = result_ref, + rhs = node.false_val, + first = True) + + cases = [Nodes.SwitchCaseNode(pos = node.pos, + conditions = conditions, + body = true_body)] + common_var = unwrap_node(common_var) - return Nodes.SwitchStatNode(pos = node.pos, - test = common_var, - cases = cases, - else_clause = node.else_clause) + switch_node = Nodes.SwitchStatNode(pos = node.pos, + test = common_var, + cases = cases, + else_clause = false_body) + self.visitchildren(switch_node) + return UtilNodes.TempResultFromStatNode(result_ref, switch_node) + + def visit_BoolBinopNode(self, node): + common_var, conditions = self.extract_common_conditions(None, node) + if common_var is None: + return node + if len(conditions) < 2: + return node + + result_ref = UtilNodes.ResultRefNode(node) + true_body = Nodes.SingleAssignmentNode( + node.pos, + lhs = result_ref, + rhs = ExprNodes.BoolNode(node.pos, value=True), + first = True) + false_body = Nodes.SingleAssignmentNode( + node.pos, + lhs = result_ref, + rhs = ExprNodes.BoolNode(node.pos, value=False), + first = True) + + cases = [Nodes.SwitchCaseNode(pos = node.pos, + conditions = conditions, + body = true_body)] + + common_var = unwrap_node(common_var) + switch_node = Nodes.SwitchStatNode(pos = node.pos, + test = common_var, + cases = cases, + else_clause = false_body) + self.visitchildren(switch_node) + return UtilNodes.TempResultFromStatNode(result_ref, switch_node) visit_Node = Visitor.VisitorTransform.recurse_to_children @@ -941,6 +1008,15 @@ class OptimizeBuiltinCalls(Visitor.EnvTransform): return node return node.arg + def visit_TypecastNode(self, node): + """ + Drop redundant type casts. + """ + self.visitchildren(node) + if node.type == node.operand.type: + return node.operand + return node + def visit_CoerceToBooleanNode(self, node): """Drop redundant conversion nodes after tree changes. """ diff --git a/Cython/Compiler/UtilNodes.py b/Cython/Compiler/UtilNodes.py index 27adab2e..7d1a963b 100644 --- a/Cython/Compiler/UtilNodes.py +++ b/Cython/Compiler/UtilNodes.py @@ -148,7 +148,8 @@ class ResultRefNode(AtomicExprNode): def generate_assignment_code(self, rhs, code): if self.type.is_pyobject: rhs.make_owned_reference(code) - code.put_decref(self.result(), self.ctype()) + if not self.lhs_of_first_assignment: + code.put_decref(self.result(), self.ctype()) code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype()))) rhs.generate_post_assignment_code(code) rhs.free_temps(code) @@ -250,3 +251,26 @@ class LetNode(Nodes.StatNode, LetNodeMixin): self.setup_temp_expr(code) self.body.generate_execution_code(code) self.teardown_temp_expr(code) + +class TempResultFromStatNode(ExprNodes.ExprNode): + # An ExprNode wrapper around a StatNode that executes the StatNode + # body. Requires a ResultRefNode that it sets up to refer to its + # own temp result. The StatNode must assign a value to the result + # node, which then becomes the result of this node. + # + # This can only be used in/after type analysis. + # + + subexprs = [] + child_attrs = ['body'] + + def __init__(self, result_ref, body): + self.result_ref = result_ref + self.pos = body.pos + self.body = body + self.type = result_ref.type + self.is_temp = 1 + + def generate_result_code(self, code): + self.result_ref.result_code = self.result() + self.body.generate_execution_code(code) diff --git a/tests/run/inop.pyx b/tests/run/inop.pyx index e836f4cc..0719516f 100644 --- a/tests/run/inop.pyx +++ b/tests/run/inop.pyx @@ -1,3 +1,6 @@ + +cimport cython + def f(a,b): """ >>> f(1,[1,2,3]) @@ -42,6 +45,7 @@ def j(b): cdef int result = 2 in b return result +@cython.test_fail_if_path_exists("//SwitchStatNode") def k(a): """ >>> k(1) @@ -52,6 +56,8 @@ def k(a): cdef int result = a in [1,2,3,4] return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") def m_list(int a): """ >>> m_list(2) @@ -62,6 +68,8 @@ def m_list(int a): cdef int result = a in [1,2,3,4] return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") def m_tuple(int a): """ >>> m_tuple(2) @@ -72,6 +80,8 @@ def m_tuple(int a): cdef int result = a in (1,2,3,4) return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") def m_set(int a): """ >>> m_set(2) @@ -82,6 +92,44 @@ def m_set(int a): cdef int result = a in {1,2,3,4} return result +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def conditional_int(int a): + """ + >>> conditional_int(1) + 1 + >>> conditional_int(0) + 2 + >>> conditional_int(5) + 2 + """ + return 1 if a in (1,2,3,4) else 2 + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def conditional_object(int a): + """ + >>> conditional_object(1) + 1 + >>> conditional_object(0) + '2' + >>> conditional_object(5) + '2' + """ + return 1 if a in (1,2,3,4) else '2' + +@cython.test_assert_path_exists("//SwitchStatNode") +@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode") +def conditional_none(int a): + """ + >>> conditional_none(1) + >>> conditional_none(0) + 1 + >>> conditional_none(5) + 1 + """ + return None if a in {1,2,3,4} else 1 + def n(a): """ >>> n('d *')