Use PyObject_RichCompare rather than PyObject_Cmp
authorRobert Bradshaw <robertwb@math.washington.edu>
Wed, 10 Oct 2007 09:15:39 +0000 (02:15 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Wed, 10 Oct 2007 09:15:39 +0000 (02:15 -0700)
This is what the interpreter does, and allows one to
get at the actual object (rather than just its truth
value).

Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py

index 7d9d2d1a2956ff8e08faca7424f91c98e655d89c..4d93dfa0966533d4740ca92e54b32881116091b2 100644 (file)
@@ -2836,6 +2836,16 @@ class CondExprNode(ExprNode):
         code.putln("}")
         self.test.generate_disposal_code(code)
 
+richcmp_constants = {
+    "<" : "Py_LT",
+    "<=": "Py_LE",
+    "==": "Py_EQ",
+    "!=": "Py_NE",
+    "<>": "Py_NE",
+    ">" : "Py_GT",
+    ">=": "Py_GE",
+}
+
 class CmpNode:
     #  Mixin class containing code common to PrimaryCmpNodes
     #  and CascadedCmpNodes.
@@ -2845,6 +2855,10 @@ class CmpNode:
             or (self.cascade and self.cascade.is_python_comparison())
             or self.operator in ('in', 'not_in'))
 
+    def is_python_result(self):
+        return ((self.has_python_operands() and self.operator not in ('is', 'is_not', 'in', 'not_in'))
+            or (self.cascade and self.cascade.is_python_result()))
+
     def check_types(self, env, operand1, op, operand2):
         if not self.types_okay(operand1, op, operand2):
             error(self.pos, "Invalid types for '%s' (%s, %s)" %
@@ -2871,30 +2885,33 @@ class CmpNode:
 
     def generate_operation_code(self, code, result_code, 
             operand1, op , operand2):
+        if self.type is PyrexTypes.py_object_type:
+            coerce_result = "__Pyx_PyBool_FromLong"
+        else:
+            coerce_result = ""
+        if 'not' in op: negation = "!"
+        else: negation = ""
         if op == 'in' or op == 'not_in':
             code.putln(
-                "%s = PySequence_Contains(%s, %s); %s" % (
+                "%s = %s(%sPySequence_Contains(%s, %s)); %s" % (
                     result_code, 
+                    coerce_result, 
+                    negation,
                     operand2.py_result(), 
                     operand1.py_result(), 
                     code.error_goto_if_neg(result_code, self.pos)))
-            if op == 'not_in':
-                code.putln(
-                    "%s = !%s;" % (
-                        result_code, result_code))
         elif (operand1.type.is_pyobject
             and op not in ('is', 'is_not')):
-                code.put_error_if_neg(self.pos, 
-                    "PyObject_Cmp(%s, %s, &%s)" % (
+                code.putln("%s = PyObject_RichCompare(%s, %s, %s); %s" % (
+                        result_code, 
                         operand1.py_result(), 
                         operand2.py_result(), 
-                        result_code))
-                code.putln(
-                    "%s = %s %s 0;" % (
-                        result_code, result_code, op))
+                        richcmp_constants[op],
+                        code.error_goto_if_null(result_code, self.pos)))
         else:
-            code.putln("%s = %s %s %s;" % (
+            code.putln("%s = %s(%s %s %s);" % (
                 result_code, 
+                coerce_result, 
                 operand1.result_code, 
                 self.c_operator(op), 
                 operand2.result_code))
@@ -2937,7 +2954,14 @@ class PrimaryCmpNode(ExprNode, CmpNode):
             self.operand2 = self.operand2.coerce_to_simple(env)
             self.cascade.coerce_cascaded_operands_to_temp(env)
         self.check_operand_types(env)
-        self.type = PyrexTypes.c_bint_type
+        if self.is_python_result():
+            self.type = PyrexTypes.py_object_type
+        else:
+            self.type = PyrexTypes.c_bint_type
+        cdr = self.cascade
+        while cdr:
+            cdr.type = self.type
+            cdr = cdr.cascade
         if self.is_pycmp or self.cascade:
             self.is_temp = 1
     
@@ -3048,7 +3072,10 @@ class CascadedCmpNode(Node, CmpNode):
             self.cascade.release_subexpr_temps(env)
     
     def generate_evaluation_code(self, code, result, operand1):
-        code.putln("if (%s) {" % result)
+        if self.type.is_pyobject:
+            code.putln("if (__Pyx_PyObject_IsTrue(%s)) {" % result)
+        else:
+            code.putln("if (%s) {" % result)
         self.operand2.generate_evaluation_code(code)
         self.generate_operation_code(code, result, 
             operand1, self.operator, self.operand2)
@@ -3242,7 +3269,7 @@ class CoerceToBooleanNode(CoercionNode):
     def generate_result_code(self, code):
         if self.arg.type.is_pyobject:
             code.putln(
-                "%s = PyObject_IsTrue(%s); %s" % (
+                "%s = __Pyx_PyObject_IsTrue(%s); %s" % (
                     self.result_code, 
                     self.arg.py_result(), 
                     code.error_goto_if_neg(self.result_code, self.pos)))
index 6853f67f0ba1e85829aa33d2510d47b2aba8bf2f..c98fd8011bc6fb0ff0684e104330c247b2ee07dc 100644 (file)
@@ -346,8 +346,6 @@ class CIntType(CNumericType):
 
 class CBIntType(CIntType):
 
-    # TODO: this should be a macro "(__ ? Py_True : Py_False)"
-    #       and no error checking should be needed (just an incref). 
     to_py_function = "__Pyx_PyBool_FromLong"
     from_py_function = "__Pyx_PyObject_IsTrue"
     exception_check = 0