Complex number support without c99
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 14 May 2009 07:58:12 +0000 (00:58 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 14 May 2009 07:58:12 +0000 (00:58 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/ModuleNode.py
Cython/Compiler/Nodes.py
Cython/Compiler/Options.py
Cython/Compiler/PyrexTypes.py

index 99c447fbf645356c50b8cad376a9447adc09f821..70fe414fbea12ffe01115e03ded22ca449cea50b 100644 (file)
@@ -549,10 +549,15 @@ class ExprNode(Node):
                     src = PyTypeTestNode(src, dst_type, env)
         elif src.type.is_pyobject:
             src = CoerceFromPyTypeNode(dst_type, src, env)
+        elif (dst_type.is_complex 
+                and src_type != dst_type
+                and dst_type.assignable_from(src_type) 
+                and not env.directives['c99_complex']):
+            src = CoerceToComplexNode(src, dst_type, env)
         else: # neither src nor dst are py types
             # Added the string comparison, since for c types that
             # is enough, but Cython gets confused when the types are
-            # in different files.
+            # in different pxi files.
             if not (str(src.type) == str(dst_type) or dst_type.assignable_from(src_type)):
                 error(self.pos, "Cannot assign type '%s' to '%s'" %
                     (src.type, dst_type))
@@ -843,7 +848,7 @@ class IntNode(ConstNode):
     type = PyrexTypes.c_long_type
 
     def coerce_to(self, dst_type, env):
-        if dst_type.is_numeric:
+        if dst_type.is_numeric and not dst_type.is_complex:
             self.type = PyrexTypes.c_long_type
             return self
         # Arrange for a Python version of the number to be pre-allocated
@@ -1026,6 +1031,8 @@ class ImagNode(AtomicNewTempExprNode):
     #  Imaginary number literal
     #
     #  value   float    imaginary part
+    
+    type = PyrexTypes.c_double_complex_type
 
     def calculate_constant_result(self):
         self.constant_result = complex(0.0, self.value)
@@ -1034,19 +1041,40 @@ class ImagNode(AtomicNewTempExprNode):
         return complex(0.0, self.value)
     
     def analyse_types(self, env):
-        self.type = py_object_type
-        self.gil_check(env)
-        self.is_temp = 1
+        self.type.create_declaration_utility_code(env)
+
+    def coerce_to(self, dst_type, env):
+        # Arrange for a Python version of the number to be pre-allocated
+        # when coercing to a Python type.
+        if dst_type.is_pyobject:
+            self.is_temp = 1
+            self.gil_check(env)
+            self.type = PyrexTypes.py_object_type
+        # We still need to perform normal coerce_to processing on the
+        # result, because we might be coercing to an extension type,
+        # in which case a type test node will be needed.
+        return AtomicNewTempExprNode.coerce_to(self, dst_type, env)
 
     gil_message = "Constructing complex number"
 
+    def calculate_result_code(self):
+        if self.type.is_pyobject:
+            return self.result()
+        elif self.c99_complex:
+            return "%rj" % float(self.value)
+        else:
+            return "%s(0, %r)" % (self.type.from_parts, float(self.value))
+
     def generate_result_code(self, code):
-        code.putln(
-            "%s = PyComplex_FromDoubles(0.0, %r); %s" % (
-                self.result(),
-                float(self.value),
-                code.error_goto_if_null(self.result(), self.pos)))
-        code.put_gotref(self.py_result())
+        if self.type.is_pyobject:
+            code.putln(
+                "%s = PyComplex_FromDoubles(0.0, %r); %s" % (
+                    self.result(),
+                    float(self.value),
+                    code.error_goto_if_null(self.result(), self.pos)))
+            code.put_gotref(self.py_result())
+        else:
+            self.c99_complex = code.globalstate.directives['c99_complex']
         
 
 
@@ -3895,7 +3923,7 @@ class TypecastNode(NewTempExprNode):
             error(self.pos, "Casting temporary Python object to non-numeric non-Python type")
         if to_py and not from_py:
             if (self.operand.type.to_py_function and
-                    self.operand.type.create_convert_utility_code(env)):
+                    self.operand.type.create_to_py_utility_code(env)):
                 self.result_ctype = py_object_type
                 self.operand = self.operand.coerce_to_pyobject(env)
             else:
@@ -4161,6 +4189,11 @@ class NumBinopNode(BinopNode):
         self.type = self.compute_c_result_type(type1, type2)
         if not self.type:
             self.type_error()
+            return
+        self.infix = not self.type.is_complex or env.directives['c99_complex']
+        if not self.infix:
+            self.operand1 = self.operand1.coerce_to(self.type, env)
+            self.operand2 = self.operand2.coerce_to(self.type, env)
     
     def compute_c_result_type(self, type1, type2):
         if self.c_types_okay(type1, type2):
@@ -4174,10 +4207,16 @@ class NumBinopNode(BinopNode):
             and (type2.is_numeric  or type2.is_enum)
 
     def calculate_result_code(self):
-        return "(%s %s %s)" % (
-            self.operand1.result(), 
-            self.operator, 
-            self.operand2.result())
+        if self.infix:
+            return "(%s %s %s)" % (
+                self.operand1.result(), 
+                self.operator, 
+                self.operand2.result())
+        else:
+            return "%s(%s, %s)" % (
+                self.type.binop(self.operator),
+                self.operand1.result(),
+                self.operand2.result())
     
     def py_operation_function(self):
         return self.py_functions[self.operator]
@@ -4380,7 +4419,10 @@ class PowNode(NumBinopNode):
     
     def analyse_c_operation(self, env):
         NumBinopNode.analyse_c_operation(self, env)
-        if self.operand1.type.is_float or self.operand2.type.is_float:
+        if self.type.is_complex:
+            error(self.pos, "complex powers not yet supported")
+            self.pow_func = "<error>"
+        elif self.type.is_float:
             self.pow_func = "pow"
         else:
             self.pow_func = "__Pyx_pow_%s" % self.type.declaration_code('').replace(' ', '_')
@@ -5088,7 +5130,7 @@ class CoerceToPyTypeNode(CoercionNode):
         self.type = py_object_type
         self.gil_check(env)
         self.is_temp = 1
-        if not arg.type.to_py_function or not arg.type.create_convert_utility_code(env):
+        if not arg.type.to_py_function or not arg.type.create_to_py_utility_code(env):
             error(arg.pos,
                 "Cannot convert '%s' to Python object" % arg.type)
         
@@ -5126,7 +5168,7 @@ class CoerceFromPyTypeNode(CoercionNode):
         CoercionNode.__init__(self, arg)
         self.type = result_type
         self.is_temp = 1
-        if not result_type.from_py_function:
+        if not result_type.from_py_function and not result_type.create_from_py_utility_code(env):
             error(arg.pos,
                 "Cannot convert Python object to '%s'" % result_type)
         if self.type.is_string and self.arg.is_ephemeral():
@@ -5181,6 +5223,29 @@ class CoerceToBooleanNode(CoercionNode):
                     self.arg.py_result(), 
                     code.error_goto_if_neg(self.result(), self.pos)))
 
+class CoerceToComplexNode(CoercionNode):
+
+    def __init__(self, arg, dst_type, env):
+        if arg.type.is_complex:
+            arg = arg.coerce_to_simple(env)
+        self.type = dst_type
+        CoercionNode.__init__(self, arg)
+        dst_type.create_declaration_utility_code(env)
+
+    def calculate_result_code(self):
+        if self.arg.type.is_complex:
+            real_part = "__Pyx_REAL_PART(%s)" % self.arg.result()
+            imag_part = "__Pyx_IMAG_PART(%s)" % self.arg.result()
+        else:
+            real_part = self.arg.result()
+            imag_part = "0"
+        return "%s(%s, %s)" % (
+                self.type.from_parts,
+                real_part,
+                imag_part)
+    
+    def generate_result_code(self, code):
+        pass
 
 class CoerceToTempNode(CoercionNode):
     #  This node is used to force the result of another node
index af52468d6d5ca866f9e1ac9b9ab7053edd6bad74..f8012865836682f41ab27444a491d8edb622d83f 100644 (file)
@@ -556,6 +556,11 @@ class ModuleNode(Nodes.Node, Nodes.BlockNode):
         code.putln("#include <math.h>")
         code.putln("#define %s" % Naming.api_guard_prefix + self.api_name(env))
         self.generate_includes(env, cimported_modules, code)
+        if env.directives['c99_complex']:
+            code.putln("#ifndef _Complex_I")
+            code.putln("#include <complex.h>")
+            code.putln("#endif")
+        code.putln("#define __PYX_USE_C99_COMPLEX defined(_Complex_I)")
         code.putln('')
         code.put(Nodes.utility_function_predeclarations)
         code.put(PyrexTypes.type_conversion_predeclarations)
index 362947a574f76101ad267bebc735f613c7a502c5..084caf1213558e2d5e92fece2352f30c24a56579 100644 (file)
@@ -719,6 +719,7 @@ class CSimpleBaseTypeNode(CBaseTypeNode):
             if not type.is_numeric or type.is_complex:
                 error(self.pos, "can only complexify c numeric types")
             type = PyrexTypes.CComplexType(type)
+            type.create_declaration_utility_code(env)
         if type:
             return type
         else:
@@ -1950,6 +1951,8 @@ class DefNode(FuncDefNode):
                         error(arg.pos, "Non-default argument following default argument")
                     elif not arg.is_self_arg:
                         positional_args.append(arg)
+                if arg.type.from_py_function is None:
+                    arg.type.create_from_py_utility_code(env)
 
             self.generate_tuple_and_keyword_parsing_code(
                 positional_args, kw_only_args, end_label, code)
@@ -3163,10 +3166,10 @@ class InPlaceAssignmentNode(AssignmentNode):
             if c_op == "//":
                 c_op = "/"
             elif c_op == "**":
-                if self.lhs.type.is_int and self.rhs.type.is_int:
-                    error(self.pos, "** with two C int types is ambiguous")
-                else:
-                    error(self.pos, "No C inplace power operator")
+                error(self.pos, "No C inplace power operator")
+            elif self.lhs.type.is_complex and not code.globalstate.directives['c99_complex']:
+                error(self.pos, "Inplace operators not implemented for complex types.")
+                
             # have to do assignment directly to avoid side-effects
             if isinstance(self.lhs, ExprNodes.IndexNode) and self.lhs.is_buffer_access:
                 self.lhs.generate_buffer_setitem_code(self.rhs, code, c_op)
index 4c1493c02080a35a751629de9d6c77b43deaa8e7..52e1331126fc746f5a2ed2abf6e695ca8ba81269 100644 (file)
@@ -65,8 +65,7 @@ option_defaults = {
     'cdivision_warnings': False,
     'always_allow_keywords': False,
     'wraparound' : True,
-    'c99_complex' : False,
-    'a': 4,
+    'c99_complex' : False, # Don't use macro wrappers for complex arith, not sure what to name this...
 }
 
 # Override types possibilities above, if needed
index da9aeeb0e5876cc6dc18249f8b5bf12997bcbe29..3cf7395cdfde03f9f5765a688b075b9e2b0cad18 100644 (file)
@@ -441,7 +441,10 @@ class CType(PyrexType):
     exception_value = None
     exception_check = 1
 
-    def create_convert_utility_code(self, env):
+    def create_to_py_utility_code(self, env):
+        return True
+        
+    def create_from_py_utility_code(self, env):
         return True
         
     def error_condition(self, result_code):
@@ -702,6 +705,7 @@ class CFloatType(CNumericType):
     def assignable_from_resolved_type(self, src_type):
         return (src_type.is_numeric and not src_type.is_complex) or src_type is error_type
 
+
 class CComplexType(CNumericType):
     
     is_complex = 1
@@ -710,7 +714,7 @@ class CComplexType(CNumericType):
     def __init__(self, real_type):
         self.real_type = real_type
         CNumericType.__init__(self, real_type.rank + 0.5, real_type.signed)
-        self.from_py_function = "__pyx_PyObject_As_" + self.specalization_name()
+        self.binops = {}
     
     def __cmp__(self, other):
         if isinstance(self, CComplexType) and isinstance(other, CComplexType):
@@ -718,51 +722,140 @@ class CComplexType(CNumericType):
         else:
             return 1
     
+    def __hash__(self):
+        return ~hash(self.real_type)
+    
     def sign_and_name(self):
-        return self.real_type.sign_and_name() + " _Complex"
+        return Naming.type_prefix + self.real_type.specalization_name() + "_complex"
 
     def assignable_from_resolved_type(self, src_type):
         return (src_type.is_complex and self.real_type.assignable_from_resolved_type(src_type.real_type)
                     or src_type.is_numeric and self.real_type.assignable_from_resolved_type(src_type) 
                     or src_type is error_type)
 
-    def create_convert_utility_code(self, env):
-        self.real_type.create_convert_utility_code(env)
-        env.use_utility_code(complex_generic_utility_code)
+    def create_declaration_utility_code(self, env):
+        if not hasattr(self, 'from_parts'):
+            self.from_parts = "%s_from_parts" % self.specalization_name()
+            env.use_utility_code(complex_generic_utility_code)
+            env.use_utility_code(
+                complex_arithmatic_utility_code.specialize(self, 
+                            math_h_modifier = self.real_type.math_h_modifier,
+                            real_type = self.real_type.declaration_code('')))
+        return True
+
+    def create_from_py_utility_code(self, env):
+        self.real_type.create_from_py_utility_code(env)
         env.use_utility_code(
-            complex_utility_code.specialize(self, 
-                        real_type=self.real_type.declaration_code(''),
-                        type_convert=self.real_type.from_py_function))
+            complex_conversion_utility_code.specialize(self, 
+                        math_h_modifier = self.real_type.math_h_modifier,
+                        real_type = self.real_type.declaration_code(''),
+                        type_convert = self.real_type.from_py_function))
+        self.from_py_function = "__pyx_PyObject_As_" + self.specalization_name()
         return True
+    
+    def binop(self, op):
+        try:
+            return self.binops[op]
+        except KeyError:
+            if op in "+-*/":
+                from ExprNodes import compile_time_binary_operators
+                op_name = compile_time_binary_operators[op].__name__
+                self.binops[op] = func_name = "%s_%s" % (self.specalization_name(), op_name)
+                return func_name
+            else:
+                error("Binary '%s' not supported in for %s" % (op, self))
+                return "<error>"
 
-complex_utility_code = UtilityCode(
+complex_generic_utility_code = UtilityCode(
+proto="""
+#if __PYX_USE_C99_COMPLEX
+    #define __Pyx_REAL_PART(z) __real__(z)
+    #define __Pyx_IMAG_PART(z) __imag__(z)
+#else
+    #define __Pyx_REAL_PART(z) ((z).real)
+    #define __Pyx_IMAG_PART(z) ((z).imag)
+#endif
+
+#define __pyx_PyObject_from_complex(z) PyComplex_FromDoubles((double)__Pyx_REAL_PART(z), (double)__Pyx_IMAG_PART(z))
+""")
+
+complex_conversion_utility_code = UtilityCode(
 proto="""
-static INLINE %(type)s __pyx_%(type_name)s_from_parts(%(real_type)s real, %(real_type)s imag); /* proto */
 static %(type)s __pyx_PyObject_As_%(type_name)s(PyObject* o); /* proto */
 """, 
 impl="""
-static INLINE %(type)s __pyx_%(type_name)s_from_parts(%(real_type)s real, %(real_type)s imag) {
-    %(type)s z;
-    __real__(z) = real;
-    __imag__(z) = imag;
-    return z;
-}
-
 static %(type)s __pyx_PyObject_As_%(type_name)s(PyObject* o) {
     if (PyComplex_Check(o)) {
-        return __pyx_%(type_name)s_from_parts(
+        return %(type_name)s_from_parts(
             (%(real_type)s)((PyComplexObject *)o)->cval.real,
             (%(real_type)s)((PyComplexObject *)o)->cval.imag);
     }
     else {
-        return __pyx_%(type_name)s_from_parts(%(type_convert)s(o), 0);
+        return %(type_name)s_from_parts(%(type_convert)s(o), 0);
     }
 }
 """)
 
-complex_generic_utility_code = UtilityCode(
+complex_arithmatic_utility_code = UtilityCode(
 proto="""
-#define __pyx_PyObject_from_complex(z) PyComplex_FromDoubles((double)__real__(z), (double)__imag__(z))
+#if __PYX_USE_C99_COMPLEX
+
+    typedef %(real_type)s _Complex %(type_name)s;
+    #define %(type_name)s_from_parts(x, y) ((x) + (y)*(%(type)s)_Complex_I)
+
+    #define %(type_name)s_is_zero(a) ((a) == 0)
+    #define %(type_name)s_add(a, b) ((a)+(b))
+    #define %(type_name)s_sub(a, b) ((a)-(b))
+    #define %(type_name)s_mul(a, b) ((a)*(b))
+    #define %(type_name)s_div(a, b) ((a)/(b))
+    #define %(type_name)s_neg(a) (-(a))
+
+#else
+
+    typedef struct { %(real_type)s real, imag; } %(type_name)s;
+    #define %(type_name)s_from_parts(x, y) ((%(type_name)s){(%(real_type)s)x, (%(real_type)s)y})
+
+    static INLINE int %(type_name)s_is_zero(%(type)s a) {
+       return (a.real == 0) & (a.imag == 0);
+    }
+
+    static INLINE %(type)s %(type_name)s_add(%(type)s a, %(type)s b) {
+        %(type)s z;
+        z.real = a.real + b.real;
+        z.imag = a.imag + b.imag;
+        return z;
+    }
+
+    static INLINE %(type)s %(type_name)s_sub(%(type)s a, %(type)s b) {
+        %(type)s z;
+        z.real = a.real - b.real;
+        z.imag = a.imag - b.imag;
+        return z;
+    }
+
+    static INLINE %(type)s %(type_name)s_mul(%(type)s a, %(type)s b) {
+        %(type)s z;
+        z.real = a.real * b.real - a.imag * b.imag;
+        z.imag = a.real * b.imag + a.imag * b.real;
+        return z;
+    }
+
+    static INLINE %(type)s %(type_name)s_div(%(type)s a, %(type)s b) {
+        %(type)s z;
+        %(real_type)s denom = b.real*b.real + b.imag*b.imag;
+        z.real = (a.real * b.real + a.imag * b.imag) / denom;
+        z.imag = (a.imag * b.real - a.real * b.imag) / denom;
+        return z;
+    }
+
+    static INLINE %(type)s %(type_name)s_neg(%(type)s a) {
+        %(type)s z;
+        z.real = -a.real;
+        z.imag = -a.imag;
+        return z;
+    }
+
+#endif
 """)
 
 
@@ -1120,7 +1213,7 @@ class CStructOrUnionType(CType):
         self._convert_code = None
         self.packed = packed
         
-    def create_convert_utility_code(self, env):
+    def create_to_py_utility_code(self, env):
         if env.outer_scope is None:
             return False
         if self._convert_code is None:
@@ -1133,7 +1226,7 @@ class CStructOrUnionType(CType):
             code.putln("PyObject* member;")
             code.putln("res = PyDict_New(); if (res == NULL) return NULL;")
             for member in self.scope.var_entries:
-                if member.type.to_py_function and member.type.create_convert_utility_code(env):
+                if member.type.to_py_function and member.type.create_to_py_utility_code(env):
                     interned_name = env.get_string_const(member.name, identifier=True)
                     env.add_py_string(interned_name)
                     code.putln("member = %s(s.%s); if (member == NULL) goto bad;" % (
@@ -1303,7 +1396,10 @@ class ErrorType(PyrexType):
     to_py_function = "dummy"
     from_py_function = "dummy"
     
-    def create_convert_utility_code(self, env):
+    def create_to_py_utility_code(self, env):
+        return True
+    
+    def create_from_py_utility_code(self, env):
         return True
     
     def declaration_code(self, entity_code, 
@@ -1362,6 +1458,8 @@ c_float_type =       CFloatType(7, "T_FLOAT", math_h_modifier='f')
 c_double_type =      CFloatType(8, "T_DOUBLE")
 c_longdouble_type =  CFloatType(9, math_h_modifier='l')
 
+c_double_complex_type = CComplexType(c_double_type)
+
 c_null_ptr_type =     CNullPtrType(c_void_type)
 c_char_array_type =   CCharArrayType(None)
 c_char_ptr_type =     CCharPtrType()
@@ -1457,7 +1555,10 @@ def widest_numeric_type(type1, type2):
     if type1 == type2:
         return type1
     if type1.is_complex:
-        return CComplexType(widest_numeric_type(type1.real_type, type2))
+        if type2.is_complex:
+            return CComplexType(widest_numeric_type(type1.real_type, type2.real_type))
+        else:
+            return CComplexType(widest_numeric_type(type1.real_type, type2))
     elif type2.is_complex:
         return CComplexType(widest_numeric_type(type1, type2.real_type))
     if type1.is_enum and type2.is_enum: