code cleanup in ConstantFolding transform to make boolean handling less error prone
authorStefan Behnel <scoder@users.berlios.de>
Thu, 25 Nov 2010 06:49:03 +0000 (07:49 +0100)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 25 Nov 2010 06:49:03 +0000 (07:49 +0100)
Cython/Compiler/Optimize.py

index c39c9e7474ba6a6e09dadf93ea45d6e736a0a9dd..c480278e3a497758df262f7e1f54c8d5fe380667 100644 (file)
@@ -2970,12 +2970,23 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
         self._calculate_const(node)
         return node
 
-    def visit_UnaryMinusNode(self, node):
+    def visit_UnopNode(self, node):
         self._calculate_const(node)
         if node.constant_result is ExprNodes.not_a_constant:
             return node
         if not node.operand.is_literal:
             return node
+        if isinstance(node.operand, ExprNodes.BoolNode):
+            return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
+                                     type = PyrexTypes.c_int_type,
+                                     constant_result = node.constant_result)
+        if node.operator == '+':
+            return self._handle_UnaryPlusNode(node)
+        elif node.operator == '-':
+            return self._handle_UnaryMinusNode(node)
+        return node
+
+    def _handle_UnaryMinusNode(self, node):
         if isinstance(node.operand, ExprNodes.LongNode):
             return ExprNodes.LongNode(node.pos, value = '-' + node.operand.value,
                                       constant_result = node.constant_result)
@@ -2983,11 +2994,6 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
             # this is a safe operation
             return ExprNodes.FloatNode(node.pos, value = '-' + node.operand.value,
                                        constant_result = node.constant_result)
-        if isinstance(node.operand, ExprNodes.BoolNode):
-            # not important at all, but simplifies the code below
-            return ExprNodes.IntNode(node.pos, value = str(node.constant_result),
-                                     type = PyrexTypes.c_int_type,
-                                     constant_result = node.constant_result)
         node_type = node.operand.type
         if node_type.is_int and node_type.signed or \
                isinstance(node.operand, ExprNodes.IntNode) and node_type.is_pyobject:
@@ -2997,10 +3003,7 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
                                      constant_result = node.constant_result)
         return node
 
-    def visit_UnaryPlusNode(self, node):
-        self._calculate_const(node)
-        if node.constant_result is ExprNodes.not_a_constant:
-            return node
+    def _handle_UnaryPlusNode(self, node):
         if node.constant_result == node.operand.constant_result:
             return node.operand
         return node
@@ -3026,12 +3029,13 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
             return node
         if isinstance(node.constant_result, float):
             return node
-        if not node.operand1.is_literal or not node.operand2.is_literal:
+        operand1, operand2 = node.operand1, node.operand2
+        if not operand1.is_literal or not operand2.is_literal:
             return node
 
         # now inject a new constant node with the calculated value
         try:
-            type1, type2 = node.operand1.type, node.operand2.type
+            type1, type2 = operand1.type, operand2.type
             if type1 is None or type2 is None:
                 return node
         except AttributeError:
@@ -3041,14 +3045,14 @@ class ConstantFolding(Visitor.VisitorTransform, SkipDeclarations):
             widest_type = PyrexTypes.widest_numeric_type(type1, type2)
         else:
             widest_type = PyrexTypes.py_object_type
-        target_class = self._widest_node_class(node.operand1, node.operand2)
+        target_class = self._widest_node_class(operand1, operand2)
         if target_class is None:
             return node
         elif target_class is ExprNodes.IntNode:
-            unsigned = getattr(node.operand1, 'unsigned', '') and \
-                       getattr(node.operand2, 'unsigned', '')
-            longness = "LL"[:max(len(getattr(node.operand1, 'longness', '')),
-                                 len(getattr(node.operand2, 'longness', '')))]
+            unsigned = getattr(operand1, 'unsigned', '') and \
+                       getattr(operand2, 'unsigned', '')
+            longness = "LL"[:max(len(getattr(operand1, 'longness', '')),
+                                 len(getattr(operand2, 'longness', '')))]
             new_node = ExprNodes.IntNode(pos=node.pos,
                                          unsigned = unsigned, longness = longness,
                                          value = str(node.constant_result),