if is_common_value(t1, t2):
return t1, c1+c2
return None, None
-
+
+ def extract_common_conditions(self, common_var, condition):
+ var, conditions = self.extract_conditions(condition)
+ if var is None:
+ return None, None
+ elif common_var is not None and not is_common_value(var, common_var):
+ return None, None
+ elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
+ return None, None
+ return var, conditions
+
def visit_IfStatNode(self, node):
- self.visitchildren(node)
common_var = None
- case_count = 0
cases = []
for if_clause in node.if_clauses:
- var, conditions = self.extract_conditions(if_clause.condition)
- if var is None:
- return node
- elif common_var is not None and not is_common_value(var, common_var):
+ common_var, conditions = self.extract_common_conditions(
+ common_var, if_clause.condition)
+ if common_var is None:
return node
- elif not var.type.is_int or sum([not cond.type.is_int for cond in conditions]):
- return node
- else:
- common_var = var
- case_count += len(conditions)
- cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
- conditions = conditions,
- body = if_clause.body))
- if case_count < 2:
- return node
-
+ cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos,
+ conditions = conditions,
+ body = if_clause.body))
+
+ if sum([ len(case.conditions) for case in cases ]) < 2:
+ return node
+
+ common_var = unwrap_node(common_var)
+ switch_node = Nodes.SwitchStatNode(pos = node.pos,
+ test = common_var,
+ cases = cases,
+ else_clause = node.else_clause)
+ self.visitchildren(switch_node)
+ return switch_node
+
+ def visit_CondExprNode(self, node):
+ common_var, conditions = self.extract_common_conditions(None, node.test)
+ if common_var is None:
+ return node
+ if len(conditions) < 2:
+ return node
+
+ result_ref = UtilNodes.ResultRefNode(node)
+ true_body = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs = result_ref,
+ rhs = node.true_val,
+ first = True)
+ false_body = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs = result_ref,
+ rhs = node.false_val,
+ first = True)
+
+ cases = [Nodes.SwitchCaseNode(pos = node.pos,
+ conditions = conditions,
+ body = true_body)]
+
common_var = unwrap_node(common_var)
- return Nodes.SwitchStatNode(pos = node.pos,
- test = common_var,
- cases = cases,
- else_clause = node.else_clause)
+ switch_node = Nodes.SwitchStatNode(pos = node.pos,
+ test = common_var,
+ cases = cases,
+ else_clause = false_body)
+ self.visitchildren(switch_node)
+ return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
+
+ def visit_BoolBinopNode(self, node):
+ common_var, conditions = self.extract_common_conditions(None, node)
+ if common_var is None:
+ return node
+ if len(conditions) < 2:
+ return node
+
+ result_ref = UtilNodes.ResultRefNode(node)
+ true_body = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs = result_ref,
+ rhs = ExprNodes.BoolNode(node.pos, value=True),
+ first = True)
+ false_body = Nodes.SingleAssignmentNode(
+ node.pos,
+ lhs = result_ref,
+ rhs = ExprNodes.BoolNode(node.pos, value=False),
+ first = True)
+
+ cases = [Nodes.SwitchCaseNode(pos = node.pos,
+ conditions = conditions,
+ body = true_body)]
+
+ common_var = unwrap_node(common_var)
+ switch_node = Nodes.SwitchStatNode(pos = node.pos,
+ test = common_var,
+ cases = cases,
+ else_clause = false_body)
+ self.visitchildren(switch_node)
+ return UtilNodes.TempResultFromStatNode(result_ref, switch_node)
visit_Node = Visitor.VisitorTransform.recurse_to_children
return node
return node.arg
+ def visit_TypecastNode(self, node):
+ """
+ Drop redundant type casts.
+ """
+ self.visitchildren(node)
+ if node.type == node.operand.type:
+ return node.operand
+ return node
+
def visit_CoerceToBooleanNode(self, node):
"""Drop redundant conversion nodes after tree changes.
"""
def generate_assignment_code(self, rhs, code):
if self.type.is_pyobject:
rhs.make_owned_reference(code)
- code.put_decref(self.result(), self.ctype())
+ if not self.lhs_of_first_assignment:
+ code.put_decref(self.result(), self.ctype())
code.putln('%s = %s;' % (self.result(), rhs.result_as(self.ctype())))
rhs.generate_post_assignment_code(code)
rhs.free_temps(code)
self.setup_temp_expr(code)
self.body.generate_execution_code(code)
self.teardown_temp_expr(code)
+
+class TempResultFromStatNode(ExprNodes.ExprNode):
+ # An ExprNode wrapper around a StatNode that executes the StatNode
+ # body. Requires a ResultRefNode that it sets up to refer to its
+ # own temp result. The StatNode must assign a value to the result
+ # node, which then becomes the result of this node.
+ #
+ # This can only be used in/after type analysis.
+ #
+
+ subexprs = []
+ child_attrs = ['body']
+
+ def __init__(self, result_ref, body):
+ self.result_ref = result_ref
+ self.pos = body.pos
+ self.body = body
+ self.type = result_ref.type
+ self.is_temp = 1
+
+ def generate_result_code(self, code):
+ self.result_ref.result_code = self.result()
+ self.body.generate_execution_code(code)
+
+cimport cython
+
def f(a,b):
"""
>>> f(1,[1,2,3])
cdef int result = 2 in b
return result
+@cython.test_fail_if_path_exists("//SwitchStatNode")
def k(a):
"""
>>> k(1)
cdef int result = a in [1,2,3,4]
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_list(int a):
"""
>>> m_list(2)
cdef int result = a in [1,2,3,4]
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_tuple(int a):
"""
>>> m_tuple(2)
cdef int result = a in (1,2,3,4)
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
def m_set(int a):
"""
>>> m_set(2)
cdef int result = a in {1,2,3,4}
return result
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_int(int a):
+ """
+ >>> conditional_int(1)
+ 1
+ >>> conditional_int(0)
+ 2
+ >>> conditional_int(5)
+ 2
+ """
+ return 1 if a in (1,2,3,4) else 2
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_object(int a):
+ """
+ >>> conditional_object(1)
+ 1
+ >>> conditional_object(0)
+ '2'
+ >>> conditional_object(5)
+ '2'
+ """
+ return 1 if a in (1,2,3,4) else '2'
+
+@cython.test_assert_path_exists("//SwitchStatNode")
+@cython.test_fail_if_path_exists("//BoolBinopNode", "//PrimaryCmpNode")
+def conditional_none(int a):
+ """
+ >>> conditional_none(1)
+ >>> conditional_none(0)
+ 1
+ >>> conditional_none(5)
+ 1
+ """
+ return None if a in {1,2,3,4} else 1
+
def n(a):
"""
>>> n('d *')