From b8782c99db6e659e5c1d8f726efad0973c5a7591 Mon Sep 17 00:00:00 2001 From: Stefan Behnel Date: Sat, 13 Dec 2008 22:23:00 +0100 Subject: [PATCH] initial constant folding transform: calculate constant values in node.constant_result --- Cython/Compiler/ExprNodes.py | 155 +++++++++++++++++++++++++++++++---- Cython/Compiler/Main.py | 3 +- Cython/Compiler/Optimize.py | 48 +++++++++++ 3 files changed, 190 insertions(+), 16 deletions(-) diff --git a/Cython/Compiler/ExprNodes.py b/Cython/Compiler/ExprNodes.py index b84b3556..85506f50 100644 --- a/Cython/Compiler/ExprNodes.py +++ b/Cython/Compiler/ExprNodes.py @@ -22,6 +22,13 @@ from Cython.Debugging import print_call_chain from DebugFlags import debug_disposal_code, debug_temp_alloc, \ debug_coercion +try: + set +except NameError: + from sets import Set as set + +not_a_constant = object() +constant_value_not_set = object() class ExprNode(Node): # subexprs [string] Class var holding names of subexpr node attrs @@ -172,6 +179,8 @@ class ExprNode(Node): is_temp = 0 is_target = 0 + constant_result = constant_value_not_set + def get_child_attrs(self): return self.subexprs child_attrs = property(fget=get_child_attrs) @@ -224,7 +233,17 @@ class ExprNode(Node): # Return the native C type of the result (i.e. the # C type of the result_code expression). return self.result_ctype or self.type - + + def calculate_constant_result(self): + # Calculate the constant result of this expression and store + # it in ``self.constant_result``. Does nothing by default, + # thus leaving ``self.constant_result`` unknown. + # + # This must only be called when it is assured that all + # sub-expressions have a valid constant_result value. The + # ConstantFolding transform will do this. + pass + def compile_time_value(self, denv): # Return value of compile-time expression, or report error. error(self.pos, "Invalid compile-time expression") @@ -736,7 +755,9 @@ class NoneNode(PyConstNode): # The constant value None value = "Py_None" - + + constant_result = None + def compile_time_value(self, denv): return None @@ -745,6 +766,8 @@ class EllipsisNode(PyConstNode): value = "Py_Ellipsis" + constant_result = Ellipsis + def compile_time_value(self, denv): return Ellipsis @@ -775,7 +798,10 @@ class ConstNode(AtomicNewTempExprNode): class BoolNode(ConstNode): type = PyrexTypes.c_bint_type # The constant value True or False - + + def calculate_constant_result(self): + self.constant_result = self.value + def compile_time_value(self, denv): return self.value @@ -785,10 +811,14 @@ class BoolNode(ConstNode): class NullNode(ConstNode): type = PyrexTypes.c_null_ptr_type value = "NULL" + constant_result = 0 class CharNode(ConstNode): type = PyrexTypes.c_char_type + + def calculate_constant_result(self): + self.constant_result = ord(self.value) def compile_time_value(self, denv): return ord(self.value) @@ -830,6 +860,9 @@ class IntNode(ConstNode): else: return str(self.value) + self.unsigned + self.longness + def calculate_constant_result(self): + self.constant_result = int(self.value, 0) + def compile_time_value(self, denv): return int(self.value, 0) @@ -953,6 +986,9 @@ class LongNode(AtomicNewTempExprNode): # Python long integer literal # # value string + + def calculate_constant_result(self): + self.constant_result = long(self.value) def compile_time_value(self, denv): return long(self.value) @@ -978,6 +1014,9 @@ class ImagNode(AtomicNewTempExprNode): # Imaginary number literal # # value float imaginary part + + def calculate_constant_result(self): + self.constant_result = complex(0.0, self.value) def compile_time_value(self, denv): return complex(0.0, self.value) @@ -1350,6 +1389,9 @@ class BackquoteNode(ExprNode): gil_message = "Backquote expression" + def calculate_constant_result(self): + self.constant_result = repr(self.arg.constant_result) + def generate_result_code(self, code): code.putln( "%s = PyObject_Repr(%s); %s" % ( @@ -1582,7 +1624,11 @@ class IndexNode(ExprNode): def __init__(self, pos, index, *args, **kw): ExprNode.__init__(self, pos, index=index, *args, **kw) self._index = index - + + def calculate_constant_result(self): + self.constant_result = \ + self.base.constant_result[self.index.constant_result] + def compile_time_value(self, denv): base = self.base.compile_time_value(denv) index = self.index.compile_time_value(denv) @@ -1881,7 +1927,11 @@ class SliceIndexNode(ExprNode): # stop ExprNode or None subexprs = ['base', 'start', 'stop'] - + + def calculate_constant_result(self): + self.constant_result = self.base.constant_result[ + self.start.constant_result : self.stop.constant_result] + def compile_time_value(self, denv): base = self.base.compile_time_value(denv) if self.start is None: @@ -2055,7 +2105,13 @@ class SliceNode(ExprNode): # start ExprNode # stop ExprNode # step ExprNode - + + def calculate_constant_result(self): + self.constant_result = self.base.constant_result[ + self.start.constant_result : \ + self.stop.constant_result : \ + self.step.constant_result] + def compile_time_value(self, denv): start = self.start.compile_time_value(denv) if self.stop is None: @@ -2452,6 +2508,9 @@ class AsTupleNode(ExprNode): # arg ExprNode subexprs = ['arg'] + + def calculate_constant_result(self): + self.constant_result = tuple(self.base.constant_result) def compile_time_value(self, denv): arg = self.arg.compile_time_value(denv) @@ -2517,7 +2576,13 @@ class AttributeNode(ExprNode): self.analyse_as_python_attribute(env) return self return ExprNode.coerce_to(self, dst_type, env) - + + def calculate_constant_result(self): + attr = self.attribute + if attr.beginswith("__") and attr.endswith("__"): + return + self.constant_result = getattr(self.obj.constant_result, attr) + def compile_time_value(self, denv): attr = self.attribute if attr.beginswith("__") and attr.endswith("__"): @@ -2963,6 +3028,10 @@ class TupleNode(SequenceNode): else: return Naming.empty_tuple + def calculate_constant_result(self): + self.constant_result = tuple([ + arg.constant_result for arg in self.args]) + def compile_time_value(self, denv): values = self.compile_time_value_list(denv) try: @@ -3058,6 +3127,10 @@ class ListNode(SequenceNode): else: SequenceNode.release_temp(self, env) + def calculate_constant_result(self): + self.constant_result = [ + arg.constant_result for arg in self.args] + def compile_time_value(self, denv): return self.compile_time_value_list(denv) @@ -3228,12 +3301,12 @@ class SetNode(NewTempExprNode): self.gil_check(env) self.is_temp = 1 + def calculate_constant_result(self): + self.constant_result = set([ + arg.constant_result for arg in self.args]) + def compile_time_value(self, denv): values = [arg.compile_time_value(denv) for arg in self.args] - try: - set - except NameError: - from sets import Set as set try: return set(values) except Exception, e: @@ -3264,6 +3337,10 @@ class DictNode(ExprNode): # obj_conversion_errors [PyrexError] used internally subexprs = ['key_value_pairs'] + + def calculate_constant_result(self): + self.constant_result = dict([ + item.constant_result for item in self.key_value_pairs]) def compile_time_value(self, denv): pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv)) @@ -3366,6 +3443,10 @@ class DictItemNode(ExprNode): # key ExprNode # value ExprNode subexprs = ['key', 'value'] + + def calculate_constant_result(self): + self.constant_result = ( + self.key.constant_result, self.value.constant_result) def analyse_types(self, env): self.key.analyse_types(env) @@ -3507,6 +3588,10 @@ class UnopNode(ExprNode): # - Allocate temporary for result if needed. subexprs = ['operand'] + + def calculate_constant_result(self): + func = compile_time_unary_operators[self.operator] + self.constant_result = func(self.operand.constant_result) def compile_time_value(self, denv): func = compile_time_unary_operators.get(self.operator) @@ -3566,7 +3651,10 @@ class NotNode(ExprNode): # 'not' operator # # operand ExprNode - + + def calculate_constant_result(self): + self.constant_result = not self.operand.constant_result + def compile_time_value(self, denv): operand = self.operand.compile_time_value(denv) try: @@ -3897,7 +3985,13 @@ class BinopNode(NewTempExprNode): # - Allocate temporary for result if needed. subexprs = ['operand1', 'operand2'] - + + def calculate_constant_result(self): + func = compile_time_binary_operators[self.operator] + self.constant_result = func( + self.operand1.constant_result, + self.operand2.constant_result) + def compile_time_value(self, denv): func = get_compile_time_binop(self) operand1 = self.operand1.compile_time_value(denv) @@ -4137,6 +4231,16 @@ class BoolBinopNode(NewTempExprNode): # operand2 ExprNode subexprs = ['operand1', 'operand2'] + + def calculate_constant_result(self): + if self.operator == 'and': + self.constant_result = \ + self.operand1.constant_result and \ + self.operand2.constant_result + else: + self.constant_result = \ + self.operand1.constant_result or \ + self.operand2.constant_result def compile_time_value(self, denv): if self.operator == 'and': @@ -4261,7 +4365,13 @@ class CondExprNode(ExprNode): false_val = None subexprs = ['test', 'true_val', 'false_val'] - + + def calculate_constant_result(self): + if self.test.constant_result: + self.constant_result = self.true_val.constant_result + else: + self.constant_result = self.false_val.constant_result + def analyse_types(self, env): self.test.analyse_types(env) self.test = self.test.coerce_to_boolean(env) @@ -4350,6 +4460,15 @@ richcmp_constants = { class CmpNode: # Mixin class containing code common to PrimaryCmpNodes # and CascadedCmpNodes. + + def calculate_cascaded_constant_result(self, operand1_result): + func = compile_time_binary_operators[self.operator] + operand2_result = self.operand2.constant_result + result = func(operand1_result, operand2_result) + if result and self.cascade: + result = result and \ + self.cascade.cascaded_compile_time_value(operand2_result) + self.constant_result = result def cascaded_compile_time_value(self, operand1, denv): func = get_compile_time_binop(self) @@ -4362,6 +4481,7 @@ class CmpNode: if result: cascade = self.cascade if cascade: + # FIXME: I bet this must call cascaded_compile_time_value() result = result and cascade.compile_time_value(operand2, denv) return result @@ -4468,6 +4588,10 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode): child_attrs = ['operand1', 'operand2', 'cascade'] cascade = None + + def calculate_constant_result(self): + self.constant_result = self.calculate_cascaded_constant_result( + self.operand1.constant_result) def compile_time_value(self, denv): operand1 = self.operand1.compile_time_value(denv) @@ -4598,7 +4722,8 @@ class CascadedCmpNode(Node, CmpNode): child_attrs = ['operand2', 'cascade'] cascade = None - + constant_result = constant_value_not_set # FIXME: where to calculate this? + def analyse_types(self, env, operand1): self.operand2.analyse_types(env) if self.cascade: diff --git a/Cython/Compiler/Main.py b/Cython/Compiler/Main.py index 14cc6a63..1df959b5 100644 --- a/Cython/Compiler/Main.py +++ b/Cython/Compiler/Main.py @@ -83,7 +83,7 @@ class Context: from ParseTreeTransforms import AlignFunctionDefinitions from AutoDocTransforms import EmbedSignature from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform - from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase + from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase from Buffer import IntroduceBufferAuxiliaryVars from ModuleNode import check_c_declarations @@ -123,6 +123,7 @@ class Context: IntroduceBufferAuxiliaryVars(self), _check_c_declarations, AnalyseExpressionsTransform(self), + ConstantFolding(), FlattenBuiltinTypeCreation(), DictIterTransform(), SwitchTransform(), diff --git a/Cython/Compiler/Optimize.py b/Cython/Compiler/Optimize.py index db9520cb..703f1074 100644 --- a/Cython/Compiler/Optimize.py +++ b/Cython/Compiler/Optimize.py @@ -387,6 +387,54 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform): return node +class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations): + """Calculate the result of constant expressions to store it in + ``expr_node.constant_result``, and replace trivial cases by their + constant result. + """ + def _calculate_const(self, node): + if node.constant_result is not ExprNodes.constant_value_not_set: + return + + # make sure we always set the value + not_a_constant = ExprNodes.not_a_constant + node.constant_result = not_a_constant + + # check if all children are constant + children = self.visitchildren(node) + for child_result in children.itervalues(): + if type(child_result) is list: + for child in child_result: + if child.constant_result is not_a_constant: + return + elif child_result.constant_result is not_a_constant: + return + + # now try to calculate the real constant value + try: + node.calculate_constant_result() +# if node.constant_result is not ExprNodes.not_a_constant: +# print node.__class__.__name__, node.constant_result + except (ValueError, TypeError, IndexError, AttributeError): + # ignore all 'normal' errors here => no constant result + pass + except Exception: + # this looks like a real error + import traceback, sys + traceback.print_exc(file=sys.stdout) + + def visit_ExprNode(self, node): + self._calculate_const(node) + return node + + # in the future, other nodes can have their own handler method here + # that can replace them with a constant result node + + def visit_Node(self, node): + self.visitchildren(node) + return node + + class FinalOptimizePhase(Visitor.CythonTransform): """ This visitor handles several commuting optimizations, and is run -- 2.26.2