initial constant folding transform: calculate constant values in node.constant_result
authorStefan Behnel <scoder@users.berlios.de>
Sat, 13 Dec 2008 21:23:00 +0000 (22:23 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Sat, 13 Dec 2008 21:23:00 +0000 (22:23 +0100)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Main.py
Cython/Compiler/Optimize.py

index b84b355609c9335015f942fd5772451c9ceaa899..85506f50f0d77a839a2653f94f87da2321c51130 100644 (file)
@@ -22,6 +22,13 @@ from Cython.Debugging import print_call_chain
 from DebugFlags import debug_disposal_code, debug_temp_alloc, \
     debug_coercion
 
+try:
+    set
+except NameError:
+    from sets import Set as set
+
+not_a_constant = object()
+constant_value_not_set = object()
 
 class ExprNode(Node):
     #  subexprs     [string]     Class var holding names of subexpr node attrs
@@ -172,6 +179,8 @@ class ExprNode(Node):
     is_temp = 0
     is_target = 0
 
+    constant_result = constant_value_not_set
+
     def get_child_attrs(self):
         return self.subexprs
     child_attrs = property(fget=get_child_attrs)
@@ -224,7 +233,17 @@ class ExprNode(Node):
         #  Return the native C type of the result (i.e. the
         #  C type of the result_code expression).
         return self.result_ctype or self.type
-    
+
+    def calculate_constant_result(self):
+        # Calculate the constant result of this expression and store
+        # it in ``self.constant_result``.  Does nothing by default,
+        # thus leaving ``self.constant_result`` unknown.
+        #
+        # This must only be called when it is assured that all
+        # sub-expressions have a valid constant_result value.  The
+        # ConstantFolding transform will do this.
+        pass
+
     def compile_time_value(self, denv):
         #  Return value of compile-time expression, or report error.
         error(self.pos, "Invalid compile-time expression")
@@ -736,7 +755,9 @@ class NoneNode(PyConstNode):
     #  The constant value None
     
     value = "Py_None"
-    
+
+    constant_result = None
+
     def compile_time_value(self, denv):
         return None
     
@@ -745,6 +766,8 @@ class EllipsisNode(PyConstNode):
     
     value = "Py_Ellipsis"
 
+    constant_result = Ellipsis
+
     def compile_time_value(self, denv):
         return Ellipsis
 
@@ -775,7 +798,10 @@ class ConstNode(AtomicNewTempExprNode):
 class BoolNode(ConstNode):
     type = PyrexTypes.c_bint_type
     #  The constant value True or False
-    
+
+    def calculate_constant_result(self):
+        self.constant_result = self.value
+
     def compile_time_value(self, denv):
         return self.value
     
@@ -785,10 +811,14 @@ class BoolNode(ConstNode):
 class NullNode(ConstNode):
     type = PyrexTypes.c_null_ptr_type
     value = "NULL"
+    constant_result = 0
 
 
 class CharNode(ConstNode):
     type = PyrexTypes.c_char_type
+
+    def calculate_constant_result(self):
+        self.constant_result = ord(self.value)
     
     def compile_time_value(self, denv):
         return ord(self.value)
@@ -830,6 +860,9 @@ class IntNode(ConstNode):
         else:
             return str(self.value) + self.unsigned + self.longness
 
+    def calculate_constant_result(self):
+        self.constant_result = int(self.value, 0)
+
     def compile_time_value(self, denv):
         return int(self.value, 0)
 
@@ -953,6 +986,9 @@ class LongNode(AtomicNewTempExprNode):
     #  Python long integer literal
     #
     #  value   string
+
+    def calculate_constant_result(self):
+        self.constant_result = long(self.value)
     
     def compile_time_value(self, denv):
         return long(self.value)
@@ -978,6 +1014,9 @@ class ImagNode(AtomicNewTempExprNode):
     #  Imaginary number literal
     #
     #  value   float    imaginary part
+
+    def calculate_constant_result(self):
+        self.constant_result = complex(0.0, self.value)
     
     def compile_time_value(self, denv):
         return complex(0.0, self.value)
@@ -1350,6 +1389,9 @@ class BackquoteNode(ExprNode):
 
     gil_message = "Backquote expression"
 
+    def calculate_constant_result(self):
+        self.constant_result = repr(self.arg.constant_result)
+
     def generate_result_code(self, code):
         code.putln(
             "%s = PyObject_Repr(%s); %s" % (
@@ -1582,7 +1624,11 @@ class IndexNode(ExprNode):
     def __init__(self, pos, index, *args, **kw):
         ExprNode.__init__(self, pos, index=index, *args, **kw)
         self._index = index
-    
+
+    def calculate_constant_result(self):
+        self.constant_result = \
+            self.base.constant_result[self.index.constant_result]
+
     def compile_time_value(self, denv):
         base = self.base.compile_time_value(denv)
         index = self.index.compile_time_value(denv)
@@ -1881,7 +1927,11 @@ class SliceIndexNode(ExprNode):
     #  stop      ExprNode or None
     
     subexprs = ['base', 'start', 'stop']
-    
+
+    def calculate_constant_result(self):
+        self.constant_result = self.base.constant_result[
+            self.start.constant_result : self.stop.constant_result]
+
     def compile_time_value(self, denv):
         base = self.base.compile_time_value(denv)
         if self.start is None:
@@ -2055,7 +2105,13 @@ class SliceNode(ExprNode):
     #  start     ExprNode
     #  stop      ExprNode
     #  step      ExprNode
-    
+
+    def calculate_constant_result(self):
+        self.constant_result = self.base.constant_result[
+            self.start.constant_result : \
+                self.stop.constant_result : \
+                self.step.constant_result]
+
     def compile_time_value(self, denv):
         start = self.start.compile_time_value(denv)
         if self.stop is None:
@@ -2452,6 +2508,9 @@ class AsTupleNode(ExprNode):
     #  arg    ExprNode
     
     subexprs = ['arg']
+
+    def calculate_constant_result(self):
+        self.constant_result = tuple(self.base.constant_result)
     
     def compile_time_value(self, denv):
         arg = self.arg.compile_time_value(denv)
@@ -2517,7 +2576,13 @@ class AttributeNode(ExprNode):
                 self.analyse_as_python_attribute(env) 
                 return self
         return ExprNode.coerce_to(self, dst_type, env)
-    
+
+    def calculate_constant_result(self):
+        attr = self.attribute
+        if attr.beginswith("__") and attr.endswith("__"):
+            return
+        self.constant_result = getattr(self.obj.constant_result, attr)
+
     def compile_time_value(self, denv):
         attr = self.attribute
         if attr.beginswith("__") and attr.endswith("__"):
@@ -2963,6 +3028,10 @@ class TupleNode(SequenceNode):
         else:
             return Naming.empty_tuple
 
+    def calculate_constant_result(self):
+        self.constant_result = tuple([
+                arg.constant_result for arg in self.args])
+
     def compile_time_value(self, denv):
         values = self.compile_time_value_list(denv)
         try:
@@ -3058,6 +3127,10 @@ class ListNode(SequenceNode):
         else:
             SequenceNode.release_temp(self, env)
 
+    def calculate_constant_result(self):
+        self.constant_result = [
+            arg.constant_result for arg in self.args]
+
     def compile_time_value(self, denv):
         return self.compile_time_value_list(denv)
 
@@ -3228,12 +3301,12 @@ class SetNode(NewTempExprNode):
         self.gil_check(env)
         self.is_temp = 1
 
+    def calculate_constant_result(self):
+        self.constant_result = set([
+                arg.constant_result for arg in self.args])
+
     def compile_time_value(self, denv):
         values = [arg.compile_time_value(denv) for arg in self.args]
-        try:
-            set
-        except NameError:
-            from sets import Set as set
         try:
             return set(values)
         except Exception, e:
@@ -3264,6 +3337,10 @@ class DictNode(ExprNode):
     # obj_conversion_errors    [PyrexError]   used internally
     
     subexprs = ['key_value_pairs']
+
+    def calculate_constant_result(self):
+        self.constant_result = dict([
+                item.constant_result for item in self.key_value_pairs])
     
     def compile_time_value(self, denv):
         pairs = [(item.key.compile_time_value(denv), item.value.compile_time_value(denv))
@@ -3366,6 +3443,10 @@ class DictItemNode(ExprNode):
     # key          ExprNode
     # value        ExprNode
     subexprs = ['key', 'value']
+
+    def calculate_constant_result(self):
+        self.constant_result = (
+            self.key.constant_result, self.value.constant_result)
             
     def analyse_types(self, env):
         self.key.analyse_types(env)
@@ -3507,6 +3588,10 @@ class UnopNode(ExprNode):
     #      - Allocate temporary for result if needed.
     
     subexprs = ['operand']
+
+    def calculate_constant_result(self):
+        func = compile_time_unary_operators[self.operator]
+        self.constant_result = func(self.operand.constant_result)
     
     def compile_time_value(self, denv):
         func = compile_time_unary_operators.get(self.operator)
@@ -3566,7 +3651,10 @@ class NotNode(ExprNode):
     #  'not' operator
     #
     #  operand   ExprNode
-    
+
+    def calculate_constant_result(self):
+        self.constant_result = not self.operand.constant_result
+
     def compile_time_value(self, denv):
         operand = self.operand.compile_time_value(denv)
         try:
@@ -3897,7 +3985,13 @@ class BinopNode(NewTempExprNode):
     #      - Allocate temporary for result if needed.
     
     subexprs = ['operand1', 'operand2']
-    
+
+    def calculate_constant_result(self):
+        func = compile_time_binary_operators[self.operator]
+        self.constant_result = func(
+            self.operand1.constant_result,
+            self.operand2.constant_result)
+
     def compile_time_value(self, denv):
         func = get_compile_time_binop(self)
         operand1 = self.operand1.compile_time_value(denv)
@@ -4137,6 +4231,16 @@ class BoolBinopNode(NewTempExprNode):
     #  operand2     ExprNode
     
     subexprs = ['operand1', 'operand2']
+
+    def calculate_constant_result(self):
+        if self.operator == 'and':
+            self.constant_result = \
+                self.operand1.constant_result and \
+                self.operand2.constant_result
+        else:
+            self.constant_result = \
+                self.operand1.constant_result or \
+                self.operand2.constant_result
     
     def compile_time_value(self, denv):
         if self.operator == 'and':
@@ -4261,7 +4365,13 @@ class CondExprNode(ExprNode):
     false_val = None
     
     subexprs = ['test', 'true_val', 'false_val']
-    
+
+    def calculate_constant_result(self):
+        if self.test.constant_result:
+            self.constant_result = self.true_val.constant_result
+        else:
+            self.constant_result = self.false_val.constant_result
+
     def analyse_types(self, env):
         self.test.analyse_types(env)
         self.test = self.test.coerce_to_boolean(env)
@@ -4350,6 +4460,15 @@ richcmp_constants = {
 class CmpNode:
     #  Mixin class containing code common to PrimaryCmpNodes
     #  and CascadedCmpNodes.
+
+    def calculate_cascaded_constant_result(self, operand1_result):
+        func = compile_time_binary_operators[self.operator]
+        operand2_result = self.operand2.constant_result
+        result = func(operand1_result, operand2_result)
+        if result and self.cascade:
+            result = result and \
+                self.cascade.cascaded_compile_time_value(operand2_result)
+        self.constant_result = result
     
     def cascaded_compile_time_value(self, operand1, denv):
         func = get_compile_time_binop(self)
@@ -4362,6 +4481,7 @@ class CmpNode:
         if result:
             cascade = self.cascade
             if cascade:
+                # FIXME: I bet this must call cascaded_compile_time_value()
                 result = result and cascade.compile_time_value(operand2, denv)
         return result
 
@@ -4468,6 +4588,10 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
     child_attrs = ['operand1', 'operand2', 'cascade']
     
     cascade = None
+
+    def calculate_constant_result(self):
+        self.constant_result = self.calculate_cascaded_constant_result(
+            self.operand1.constant_result)
     
     def compile_time_value(self, denv):
         operand1 = self.operand1.compile_time_value(denv)
@@ -4598,7 +4722,8 @@ class CascadedCmpNode(Node, CmpNode):
     child_attrs = ['operand2', 'cascade']
 
     cascade = None
-    
+    constant_result = constant_value_not_set # FIXME: where to calculate this?
+
     def analyse_types(self, env, operand1):
         self.operand2.analyse_types(env)
         if self.cascade:
index 14cc6a63fc3ee72305af852cabf15f1b90c978c3..1df959b5a56dd0b9bed04661610865880e8ed1bf 100644 (file)
@@ -83,7 +83,7 @@ class Context:
         from ParseTreeTransforms import AlignFunctionDefinitions
         from AutoDocTransforms import EmbedSignature
         from Optimize import FlattenInListTransform, SwitchTransform, DictIterTransform
-        from Optimize import FlattenBuiltinTypeCreation, FinalOptimizePhase
+        from Optimize import FlattenBuiltinTypeCreation, ConstantFolding, FinalOptimizePhase
         from Buffer import IntroduceBufferAuxiliaryVars
         from ModuleNode import check_c_declarations
 
@@ -123,6 +123,7 @@ class Context:
             IntroduceBufferAuxiliaryVars(self),
             _check_c_declarations,
             AnalyseExpressionsTransform(self),
+            ConstantFolding(),
             FlattenBuiltinTypeCreation(),
             DictIterTransform(),
             SwitchTransform(),
index db9520cb06485974a6c653fbb7d86b24c1b7900c..703f1074e7f8cdd63d422ba6dbf5c54f8141a155 100644 (file)
@@ -387,6 +387,54 @@ class FlattenBuiltinTypeCreation(Visitor.VisitorTransform):
         return node
 
 
+class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
+    """Calculate the result of constant expressions to store it in
+    ``expr_node.constant_result``, and replace trivial cases by their
+    constant result.
+    """
+    def _calculate_const(self, node):
+        if node.constant_result is not ExprNodes.constant_value_not_set:
+            return
+
+        # make sure we always set the value
+        not_a_constant = ExprNodes.not_a_constant
+        node.constant_result = not_a_constant
+
+        # check if all children are constant
+        children = self.visitchildren(node)
+        for child_result in children.itervalues():
+            if type(child_result) is list:
+                for child in child_result:
+                    if child.constant_result is not_a_constant:
+                        return
+            elif child_result.constant_result is not_a_constant:
+                return
+
+        # now try to calculate the real constant value
+        try:
+            node.calculate_constant_result()
+#            if node.constant_result is not ExprNodes.not_a_constant:
+#                print node.__class__.__name__, node.constant_result
+        except (ValueError, TypeError, IndexError, AttributeError):
+            # ignore all 'normal' errors here => no constant result
+            pass
+        except Exception:
+            # this looks like a real error
+            import traceback, sys
+            traceback.print_exc(file=sys.stdout)
+
+    def visit_ExprNode(self, node):
+        self._calculate_const(node)
+        return node
+
+    # in the future, other nodes can have their own handler method here
+    # that can replace them with a constant result node
+
+    def visit_Node(self, node):
+        self.visitchildren(node)
+        return node
+
+
 class FinalOptimizePhase(Visitor.CythonTransform):
     """
     This visitor handles several commuting optimizations, and is run