Complex numeber comparison, etc.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 14 May 2009 07:58:13 +0000 (00:58 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 14 May 2009 07:58:13 +0000 (00:58 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/Nodes.py
Cython/Compiler/PyrexTypes.py

index 70fe414fbea12ffe01115e03ded22ca449cea50b..7cd13af7f8b0e8b36a06f14a44d7f92a8d60c489 100644 (file)
@@ -4213,8 +4213,11 @@ class NumBinopNode(BinopNode):
                 self.operator, 
                 self.operand2.result())
         else:
+            func = self.type.binary_op(self.operator)
+            if func is None:
+                error(self.pos, "binary operator %s not supported for %s" % (self.operator, self.type))
             return "%s(%s, %s)" % (
-                self.type.binop(self.operator),
+                func,
                 self.operand1.result(),
                 self.operand2.result())
     
@@ -4318,7 +4321,7 @@ class DivNode(NumBinopNode):
             return "float division"
 
     def generate_evaluation_code(self, code):
-        if not self.type.is_pyobject:
+        if not self.type.is_pyobject and not self.type.is_complex:
             if self.cdivision is None:
                 self.cdivision = (code.globalstate.directives['cdivision'] 
                                     or not self.type.signed
@@ -4331,7 +4334,11 @@ class DivNode(NumBinopNode):
     def generate_div_warning_code(self, code):
         if not self.type.is_pyobject:
             if self.zerodivision_check:
-                code.putln("if (unlikely(%s == 0)) {" % self.operand2.result())
+                if not self.infix:
+                    zero_test = "%s(%s)" % (self.type.unary_op('zero'), self.operand2.result())
+                else:
+                    zero_test = "%s == 0" % self.operand2.result()
+                code.putln("if (unlikely(%s)) {" % zero_test)
                 code.putln('PyErr_Format(PyExc_ZeroDivisionError, "%s");' % self.zero_division_message())
                 code.putln(code.error_goto(self.pos))
                 code.putln("}")
@@ -4344,7 +4351,7 @@ class DivNode(NumBinopNode):
                     code.putln('PyErr_Format(PyExc_OverflowError, "value too large to perform division");')
                     code.putln(code.error_goto(self.pos))
                     code.putln("}")
-            if code.globalstate.directives['cdivision_warnings']:
+            if code.globalstate.directives['cdivision_warnings'] and self.operand != '/':
                 code.globalstate.use_utility_code(cdivision_warning_utility_code)
                 code.putln("if ((%s < 0) ^ (%s < 0)) {" % (
                                 self.operand1.result(),
@@ -4355,7 +4362,9 @@ class DivNode(NumBinopNode):
                 code.putln("}")
     
     def calculate_result_code(self):
-        if self.type.is_float and self.operator == '//':
+        if self.type.is_complex:
+            return NumBinopNode.calculate_result_code(self)
+        elif self.type.is_float and self.operator == '//':
             return "floor(%s / %s)" % (
                 self.operand1.result(), 
                 self.operand2.result())
@@ -4705,7 +4714,13 @@ class CmpNode(object):
             or (self.cascade and self.cascade.is_python_result()))
 
     def check_types(self, env, operand1, op, operand2):
-        if not self.types_okay(operand1, op, operand2):
+        if operand1.type.is_complex or operand2.type.is_complex:
+            if op not in ('==', '!='):
+                error(self.pos, "complex types unordered")
+            common_type = PyrexTypes.widest_numeric_type(operand1.type, operand2.type)
+            self.operand1 = operand1.coerce_to(common_type, env)
+            self.operand2 = operand2.coerce_to(common_type, env)
+        elif not self.types_okay(operand1, op, operand2):
             error(self.pos, "Invalid types for '%s' (%s, %s)" %
                 (self.operator, operand1.type, operand2.type))
     
@@ -4754,6 +4769,16 @@ class CmpNode(object):
                         richcmp_constants[op],
                         code.error_goto_if_null(result_code, self.pos)))
                 code.put_gotref(result_code)
+        elif operand1.type.is_complex and not code.globalstate.directives['c99_complex']:
+            if op == "!=": negation = "!"
+            else: negation = ""
+            code.putln("%s = %s(%s%s(%s, %s));" % (
+                result_code, 
+                coerce_result,
+                negation,
+                operand1.type.unary_op('eq'), 
+                operand1.result(), 
+                operand2.result()))
         else:
             type1 = operand1.type
             type2 = operand2.type
@@ -4881,10 +4906,21 @@ class PrimaryCmpNode(NewTempExprNode, CmpNode):
             self.not_const()
 
     def calculate_result_code(self):
-        return "(%s %s %s)" % (
-            self.operand1.result(),
-            self.c_operator(self.operator),
-            self.operand2.result())
+        if self.operand1.type.is_complex:
+            if self.operator == "!=":
+                negation = "!"
+            else:
+                negation = ""
+            return "(%s%s(%s, %s))" % (
+                negation,
+                self.operand1.type.binary_op('=='), 
+                self.operand1.result(), 
+                self.operand2.result())
+        else:
+            return "(%s %s %s)" % (
+                self.operand1.result(),
+                self.c_operator(self.operator),
+                self.operand2.result())
 
     def generate_evaluation_code(self, code):
         self.operand1.generate_evaluation_code(code)
index 084caf1213558e2d5e92fece2352f30c24a56579..2ff86dba7b90bfb12fbcc4e81cc2e224f66ed916 100644 (file)
@@ -1921,6 +1921,10 @@ class DefNode(FuncDefNode):
         has_star_or_kw_args = self.star_arg is not None \
             or self.starstar_arg is not None or has_kwonly_args
 
+        for arg in self.args:
+            if not arg.type.is_pyobject and arg.type.from_py_function is None:
+                arg.type.create_from_py_utility_code(env)
+
         if not self.signature_has_generic_args():
             if has_star_or_kw_args:
                 error(self.pos, "This method cannot have * or keyword arguments")
@@ -1951,8 +1955,6 @@ class DefNode(FuncDefNode):
                         error(arg.pos, "Non-default argument following default argument")
                     elif not arg.is_self_arg:
                         positional_args.append(arg)
-                if arg.type.from_py_function is None:
-                    arg.type.create_from_py_utility_code(env)
 
             self.generate_tuple_and_keyword_parsing_code(
                 positional_args, kw_only_args, end_label, code)
index 3cf7395cdfde03f9f5765a688b075b9e2b0cad18..54b81ad1c09a8792a49a7852d3cfbf5c3d6a2052 100644 (file)
@@ -753,18 +753,33 @@ class CComplexType(CNumericType):
         self.from_py_function = "__pyx_PyObject_As_" + self.specalization_name()
         return True
     
-    def binop(self, op):
+    def lookup_op(self, nargs, op):
         try:
-            return self.binops[op]
+            return self.binops[nargs, op]
         except KeyError:
-            if op in "+-*/":
-                from ExprNodes import compile_time_binary_operators
-                op_name = compile_time_binary_operators[op].__name__
-                self.binops[op] = func_name = "%s_%s" % (self.specalization_name(), op_name)
-                return func_name
-            else:
-                error("Binary '%s' not supported in for %s" % (op, self))
-                return "<error>"
+            pass
+        try:
+            op_name = complex_ops[nargs, op]
+            self.binops[nargs, op] = func_name = "%s_%s" % (self.specalization_name(), op_name)
+            return func_name
+        except KeyError:
+            return None
+
+    def unary_op(self, op):
+        return self.lookup_op(1, op)
+        
+    def binary_op(self, op):
+        return self.lookup_op(2, op)
+        
+complex_ops = {
+    (1, '-'): 'neg',
+    (1, 'zero'): 'is_zero',
+    (2, '+'): 'add',
+    (2, '-') : 'sub',
+    (2, '*'): 'mul',
+    (2, '/'): 'div',
+    (2, '=='): 'eq',
+}
 
 complex_generic_utility_code = UtilityCode(
 proto="""
@@ -804,6 +819,7 @@ proto="""
     #define %(type_name)s_from_parts(x, y) ((x) + (y)*(%(type)s)_Complex_I)
 
     #define %(type_name)s_is_zero(a) ((a) == 0)
+    #define %(type_name)s_eq(a, b) ((a) == (b))
     #define %(type_name)s_add(a, b) ((a)+(b))
     #define %(type_name)s_sub(a, b) ((a)-(b))
     #define %(type_name)s_mul(a, b) ((a)*(b))
@@ -819,6 +835,10 @@ proto="""
        return (a.real == 0) & (a.imag == 0);
     }
 
+    static INLINE int %(type_name)s_eq(%(type)s a, %(type)s b) {
+       return (a.real == b.real) & (a.imag == b.imag);
+    }
+
     static INLINE %(type)s %(type_name)s_add(%(type)s a, %(type)s b) {
         %(type)s z;
         z.real = a.real + b.real;