From efa0d3cb335d2dc3915831de339ef703ba197fb9 Mon Sep 17 00:00:00 2001 From: Robert Bradshaw Date: Fri, 11 Jul 2008 04:10:50 -0700 Subject: [PATCH] Case statements and "x in [...]" flattening. --- Cython/Compiler/ExprNodes.py | 50 +++++++++++++++++++++ Cython/Compiler/Main.py | 3 ++ Cython/Compiler/Optimize.py | 84 ++++++++++++++++++++++++++++++------ Demos/Setup.py | 1 + 4 files changed, 126 insertions(+), 12 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index ed770ef7..60367465 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -4018,6 +4018,56 @@ class CloneNode(CoercionNode): def release_temp(self, env): pass + +class PersistentNode(ExprNode): + # A PersistentNode is like a CloneNode except it handles the temporary + # allocation itself by keeping track of the number of times it has been + # used. + + subexprs = ["arg"] + temp_counter = 0 + generate_counter = 0 + result_code = None + + def __init__(self, arg, uses): + self.pos = arg.pos + self.arg = arg + self.uses = uses + + def analyse_types(self, env): + self.arg.analyse_types(env) + self.type = self.arg.type + self.result_ctype = self.arg.result_ctype + self.is_temp = 1 + + def generate_evaluation_code(self, code): + if self.generate_counter == 0: + self.arg.generate_evaluation_code(code) + code.putln("%s = %s;" % ( + self.result_code, self.arg.result_as(self.ctype()))) + if self.type.is_pyobject: + code.put_incref(self.result_code, self.ctype()) + self.arg.generate_disposal_code(code) + self.generate_counter += 1 + + def generate_disposal_code(self, code): + if self.generate_counter == self.uses: + if self.type.is_pyobject: + code.put_decref_clear(self.result_code, self.ctype()) + + def allocate_temps(self, env, result=None): + if self.temp_counter == 0: + self.arg.allocate_temps(env) + if result is None: + self.result_code = env.allocate_temp(self.type) + else: + self.result_code = result + self.arg.release_temp(env) + self.temp_counter += 1 + + def release_temp(self, env): + if self.temp_counter == self.uses: + env.release_temp(self.result_code) #------------------------------------------------------------------------------------ # diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 3cea71dd..c52b46ec 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -357,6 +357,7 @@ def create_default_pipeline(context, options, result): from ParseTreeTransforms import WithTransform, NormalizeTree, PostParse from ParseTreeTransforms import AnalyseDeclarationsTransform, AnalyseExpressionsTransform from ParseTreeTransforms import CreateClosureClasses, MarkClosureVisitor, DecoratorTransform + from Optimize import FlattenInListTransform, SwitchTransform from Buffer import BufferTransform from ModuleNode import check_c_classes @@ -364,12 +365,14 @@ def create_default_pipeline(context, options, result): create_parse(context), NormalizeTree(context), PostParse(context), + FlattenInListTransform(), WithTransform(context), DecoratorTransform(context), AnalyseDeclarationsTransform(context), check_c_classes, AnalyseExpressionsTransform(context), BufferTransform(context), + SwitchTransform(), # CreateClosureClasses(context), create_generate_code(context, options, result) ] diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index 7b8a73ad..a62f1088 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -1,9 +1,16 @@ import Nodes import ExprNodes +import PyrexTypes import Visitor +def unwrap_node(node): + while isinstance(node, ExprNodes.PersistentNode): + node = node.arg + return node def is_common_value(a, b): + a = unwrap_node(a) + b = unwrap_node(b) if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): return a.name == b.name if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): @@ -11,13 +18,20 @@ def is_common_value(a, b): return False -class SwitchTransformVisitor(Visitor.VisitorTransform): - +class SwitchTransform(Visitor.VisitorTransform): + """ + This transformation tries to turn long if statements into C switch statements. + The requirement is that every clause be an (or of) var == value, where the var + is common among all clauses and both var and value are not Python objects. + """ def extract_conditions(self, cond): if isinstance(cond, ExprNodes.CoerceToTempNode): cond = cond.arg - + + if isinstance(cond, ExprNodes.TypecastNode): + cond = cond.operand + if (isinstance(cond, ExprNodes.PrimaryCmpNode) and cond.cascade is None and cond.operator == '==' @@ -40,14 +54,8 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): return t1, c1+c2 return None, None - def is_common_value(self, a, b): - if isinstance(a, ExprNodes.NameNode) and isinstance(b, ExprNodes.NameNode): - return a.name == b.name - if isinstance(a, ExprNodes.AttributeNode) and isinstance(b, ExprNodes.AttributeNode): - return not a.is_py_attr and is_common_value(a.obj, b.obj) - return False - def visit_IfStatNode(self, node): + self.visitchildren(node) if len(node.if_clauses) < 3: return node common_var = None @@ -56,7 +64,7 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): var, conditions = self.extract_conditions(if_clause.condition) if var is None: return node - elif common_var is not None and not self.is_common_value(var, common_var): + elif common_var is not None and not is_common_value(var, common_var): return node else: common_var = var @@ -67,8 +75,60 @@ class SwitchTransformVisitor(Visitor.VisitorTransform): test = common_var, cases = cases, else_clause = node.else_clause) - + + def visit_Node(self, node): self.visitchildren(node) return node + +class FlattenInListTransform(Visitor.VisitorTransform): + """ + This transformation flattens "x in [val1, ..., valn]" into a sequential list + of comparisons. + """ + + def visit_PrimaryCmpNode(self, node): + self.visitchildren(node) + if node.cascade is not None: + return node + elif node.operator == 'in': + conjunction = 'or' + eq_or_neq = '==' + elif node.operator == 'not_in': + conjunction = 'and' + eq_or_neq = '!=' + else: + return node + + args = node.operand2.args + if isinstance(node.operand2, ExprNodes.TupleNode) or isinstance(node.operand2, ExprNodes.ListNode): + if len(args) == 0: + return ExprNodes.BoolNode(pos = node.pos, value = node.operator == 'not_in') + else: + lhs = ExprNodes.PersistentNode(node.operand1, len(args)) + conds = [] + for arg in args: + cond = ExprNodes.PrimaryCmpNode( + pos = node.pos, + operand1 = lhs, + operator = eq_or_neq, + operand2 = arg, + cascade = None) + conds.append(ExprNodes.TypecastNode( + pos = node.pos, + operand = cond, + type = PyrexTypes.c_bint_type)) + def concat(left, right): + return ExprNodes.BoolBinopNode( + pos = node.pos, + operator = conjunction, + operand1 = left, + operand2 = right) + return reduce(concat, conds) + else: + return node + + def visit_Node(self, node): + self.visitchildren(node) + return node diff --git a/Demos/Setup.py b/Demos/Setup.py index 75b05af0..8980cf06 100644 --- a/Demos/Setup.py +++ b/Demos/Setup.py @@ -7,6 +7,7 @@ from Cython.Distutils import build_ext ext_modules=[ Extension("primes", ["primes.pyx"]), Extension("spam", ["spam.pyx"]), +# Extension("optargs", ["optargs.pyx"], language = "c++"), ] for file in glob.glob("*.pyx"): -- 2.26.2