faster way to test certain builtin types for truth: list, tuple, bytes, unicode;...
authorStefan Behnel <scoder@users.berlios.de>
Thu, 15 Apr 2010 12:52:40 +0000 (14:52 +0200)
committerStefan Behnel <scoder@users.berlios.de>
Thu, 15 Apr 2010 12:52:40 +0000 (14:52 +0200)
Cython/Compiler/ExprNodes.py
tests/run/builtins_truth_test.pyx [new file with mode: 0644]

index 54d097addad24a9d988d724d235e1788dfe15221..4c16d1c05a1d9386b281f74a715a3ee3d187538f 100755 (executable)
@@ -272,6 +272,10 @@ class ExprNode(Node):
         # ConstantFolding transform will do this.
         pass
 
+    def has_constant_result(self):
+        return self.constant_result is not constant_value_not_set and \
+               self.constant_result is not not_a_constant
+
     def compile_time_value(self, denv):
         #  Return value of compile-time expression, or report error.
         error(self.pos, "Invalid compile-time expression")
@@ -589,6 +593,13 @@ class ExprNode(Node):
     def coerce_to_boolean(self, env):
         #  Coerce result to something acceptable as
         #  a boolean value.
+
+        # if it's constant, calculate the result now
+        if self.has_constant_result():
+            bool_value = bool(self.constant_result)
+            return BoolNode(self.pos, value=bool_value,
+                            constant_result=bool_value)
+
         type = self.type
         if type.is_pyobject or type.is_ptr or type.is_float:
             return CoerceToBooleanNode(self, env)
@@ -841,6 +852,11 @@ class BytesNode(ConstNode):
     def can_coerce_to_char_literal(self):
         return len(self.value) == 1
 
+    def coerce_to_boolean(self, env):
+        # This is special because we start off as a C char*.  Testing
+        # that for truth directly would yield the wrong result.
+        return BoolNode(self.pos, value=bool(self.value))
+
     def coerce_to(self, dst_type, env):
         if dst_type.is_int:
             if not self.can_coerce_to_char_literal():
@@ -5009,7 +5025,7 @@ class DivNode(NumBinopNode):
         if not self.type.is_pyobject:
             self.zerodivision_check = (
                 self.cdivision is None and not env.directives['cdivision']
-                and (self.operand2.constant_result is not_a_constant or
+                and (not self.operand2.has_constant_result() or
                      self.operand2.constant_result == 0))
             if self.zerodivision_check or env.directives['cdivision_warnings']:
                 # Need to check ahead of time to warn or raise zero division error
@@ -6087,7 +6103,14 @@ class CoerceToBooleanNode(CoercionNode):
     #  in a boolean context.
     
     type = PyrexTypes.c_bint_type
-    
+
+    _special_builtins = {
+        Builtin.list_type    : 'PyList_GET_SIZE',
+        Builtin.tuple_type   : 'PyTuple_GET_SIZE',
+        Builtin.bytes_type   : 'PyBytes_GET_SIZE',
+        Builtin.unicode_type : 'PyUnicode_GET_SIZE',
+        }
+
     def __init__(self, arg, env):
         CoercionNode.__init__(self, arg)
         if arg.type.is_pyobject:
@@ -6109,7 +6132,16 @@ class CoerceToBooleanNode(CoercionNode):
         return "(%s != 0)" % self.arg.result()
 
     def generate_result_code(self, code):
-        if self.arg.type.is_pyobject:
+        if not self.is_temp:
+            return
+        test_func = self._special_builtins.get(self.arg.type)
+        if test_func is not None:
+            code.putln("%s = (%s != Py_None) & (%s(%s) != 0);" % (
+                       self.result(),
+                       self.arg.py_result(),
+                       test_func,
+                       self.arg.py_result()))
+        else:
             code.putln(
                 "%s = __Pyx_PyObject_IsTrue(%s); %s" % (
                     self.result(), 
diff --git a/tests/run/builtins_truth_test.pyx b/tests/run/builtins_truth_test.pyx
new file mode 100644 (file)
index 0000000..6d4b5b1
--- /dev/null
@@ -0,0 +1,134 @@
+
+def if_list(list obj):
+    """
+    >>> if_list( [] )
+    False
+    >>> if_list( [1] )
+    True
+    >>> if_list(None)
+    False
+    """
+    if obj:
+        return True
+    else:
+        return False
+
+def if_list_literal(t):
+    """
+    >>> if_list_literal(True)
+    True
+    >>> if_list_literal(False)
+    False
+    """
+    if t:
+        if [1,2,3]:
+            return True
+        else:
+            return False
+    else:
+        if []:
+            return True
+        else:
+            return False
+
+def if_tuple(tuple obj):
+    """
+    >>> if_tuple( () )
+    False
+    >>> if_tuple( (1,) )
+    True
+    >>> if_tuple(None)
+    False
+    """
+    if obj:
+        return True
+    else:
+        return False
+
+def if_tuple_literal(t):
+    """
+    >>> if_tuple_literal(True)
+    True
+    >>> if_tuple_literal(False)
+    False
+    """
+    if t:
+        if (1,2,3):
+            return True
+        else:
+            return False
+    else:
+        if ():
+            return True
+        else:
+            return False
+
+b0 = b''
+b1 = b'abc'
+
+def if_bytes(bytes obj):
+    """
+    >>> if_bytes(b0)
+    False
+    >>> if_bytes(b1)
+    True
+    >>> if_bytes(None)
+    False
+    """
+    if obj:
+        return True
+    else:
+        return False
+
+def if_bytes_literal(t):
+    """
+    >>> if_bytes_literal(True)
+    True
+    >>> if_bytes_literal(False)
+    False
+    """
+    if t:
+        if b'abc':
+            return True
+        else:
+            return False
+    else:
+        if b'':
+            return True
+        else:
+            return False
+
+u0 = u''
+u1 = u'abc'
+
+def if_unicode(unicode obj):
+    """
+    >>> if_unicode(u0)
+    False
+    >>> if_unicode(u1)
+    True
+    >>> if_unicode(None)
+    False
+    """
+    if obj:
+        return True
+    else:
+        return False
+
+def if_unicode_literal(t):
+    """
+    >>> if_unicode_literal(True)
+    True
+    >>> if_unicode_literal(False)
+    False
+    """
+    if t:
+        if u'abc':
+            return True
+        else:
+            return False
+    else:
+        if u'':
+            return True
+        else:
+            return False