From ad1bed63e67ba3e6309e7aee092cf6849dd05207 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 6 Jun 2008 17:26:24 -0700 Subject: [PATCH] Switch statement omptimization --- Cython/Compiler/Nodes.py | 45 +++++++++++++++++++++- Cython/Compiler/Optimize.py | 74 +++++++++++++++++++++++++++++++++++++ 2 files changed, 118 insertions(+), 1 deletion(-) create mode 100644 Cython/Compiler/Optimize.py diff --git a/Cython/Compiler/Nodes.py b/Cython/Compiler/Nodes.py index 058fb888..c280caa5 100644 --- a/Cython/Compiler/Nodes.py +++ b/Cython/Compiler/Nodes.py @@ -149,7 +149,7 @@ class Node(object): except AttributeError: flat = [] for attr in self.child_attrs: - child = getattr(parent, attr) + child = getattr(self, attr) # Sometimes lists, sometimes nodes if child is None: pass @@ -2850,7 +2850,50 @@ class IfClauseNode(Node): self.condition.annotate(code) self.body.annotate(code) + +class SwitchCaseNode(StatNode): + # Generated in the optimization of an if-elif-else node + # + # conditions [ExprNode] + # body StatNode + + child_attrs = ['conditions', 'body'] + + def generate_execution_code(self, code): + for cond in self.conditions: + code.putln("case %s:" % cond.calculate_result_code()) + self.body.generate_execution_code(code) + code.putln("break;") + def annotate(self, code): + for cond in self.conditions: + cond.annotate(code) + body.annotate(code) + +class SwitchStatNode(StatNode): + # Generated in the optimization of an if-elif-else node + # + # test ExprNode + # cases [SwitchCaseNode] + # else_clause StatNode or None + + child_attrs = ['test', 'cases', 'else_clause'] + + def generate_execution_code(self, code): + code.putln("switch (%s) {" % self.test.calculate_result_code()) + for case in self.cases: + case.generate_execution_code(code) + if self.else_clause is not None: + code.putln("default:") + self.else_clause.generate_execution_code(code) + code.putln("}") + + def annotate(self, code): + self.test.annotate(code) + for case in self.cases: + case.annotate(code) + self.else_clause.annotate(code) + class LoopNode: def analyse_control_flow(self, env): diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py new file mode 100644 index 00000000..7b8a73ad --- /dev/null +++ b/Cython/Compiler/Optimize.py @@ -0,0 +1,74 @@ +import Nodes +import ExprNodes +import Visitor + + +def is_common_value(a, b): + if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): + return a.name == b.name + if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): + return not a.is_py_attr and is_common_value(a.obj, b.obj) + return False + + +class SwitchTransformVisitor(Visitor.VisitorTransform): + + def extract_conditions(self, cond): + + if isinstance(cond, ExprNodes.CoerceToTempNode): + cond = cond.arg + + if (isinstance(cond, ExprNodes.PrimaryCmpNode) + and cond.cascade is None + and cond.operator == '==' + and not cond.is_python_comparison()): + if is_common_value(cond.operand1, cond.operand1): + if isinstance(cond.operand2, ExprNodes.ConstNode): + return cond.operand1, [cond.operand2] + elif hasattr(cond.operand2, 'entry') and cond.operand2.entry.is_const: + return cond.operand1, [cond.operand2] + if is_common_value(cond.operand2, cond.operand2): + if isinstance(cond.operand1, ExprNodes.ConstNode): + return cond.operand2, [cond.operand1] + elif hasattr(cond.operand1, 'entry') and cond.operand1.entry.is_const: + return cond.operand2, [cond.operand1] + elif (isinstance(cond, ExprNodes.BoolBinopNode) + and cond.operator == 'or'): + t1, c1 = self.extract_conditions(cond.operand1) + t2, c2 = self.extract_conditions(cond.operand2) + if is_common_value(t1, t2): + return t1, c1+c2 + return None, None + + def is_common_value(self, a, b): + if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): + return a.name == b.name + if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): + return not a.is_py_attr and is_common_value(a.obj, b.obj) + return False + + def visit_IfStatNode(self, node): + if len(node.if_clauses) < 3: + return node + common_var = None + cases = [] + for if_clause in node.if_clauses: + var, conditions = self.extract_conditions(if_clause.condition) + if var is None: + return node + elif common_var is not None and not self.is_common_value(var, common_var): + return node + else: + common_var = var + cases.append(Nodes.SwitchCaseNode(pos = if_clause.pos, + conditions = conditions, + body = if_clause.body)) + return Nodes.SwitchStatNode(pos = node.pos, + test = common_var, + cases = cases, + else_clause = node.else_clause) + + def visit_Node(self, node): + self.visitchildren(node) + return node + -- 2.26.2