Complex powers.
authorRobert Bradshaw <robertwb@math.washington.edu>
Thu, 23 Sep 2010 08:34:58 +0000 (01:34 -0700)
committerRobert Bradshaw <robertwb@math.washington.edu>
Thu, 23 Sep 2010 08:34:58 +0000 (01:34 -0700)
Cython/Compiler/ExprNodes.py
Cython/Compiler/PyrexTypes.py
tests/run/complex_numbers_T305.pyx

index 073d7ddc57de97ae7c4d30283422f5e13d47a3f9..8946ec0c725b72eb10135fa9e620815c251c19a8 100755 (executable)
@@ -5655,8 +5655,13 @@ class PowNode(NumBinopNode):
     def analyse_c_operation(self, env):
         NumBinopNode.analyse_c_operation(self, env)
         if self.type.is_complex:
-            error(self.pos, "complex powers not yet supported")
-            self.pow_func = "<error>"
+            if self.type.real_type.is_float:
+                self.operand1 = self.operand1.coerce_to(self.type, env)
+                self.operand2 = self.operand2.coerce_to(self.type, env)
+                self.pow_func = "__Pyx_c_pow" + self.type.real_type.math_h_modifier
+            else:
+                error(self.pos, "complex int powers not supported")
+                self.pow_func = "<error>"
         elif self.type.is_float:
             self.pow_func = "pow" + self.type.math_h_modifier
         else:
index 993d784af3fe185a145be5fda459ee3adb2ec9a8..3509bd85b94775edbfc68c852f126b91daa050cb 100755 (executable)
@@ -1095,7 +1095,8 @@ class CComplexType(CNumericType):
                 utility_code.specialize(
                     self, 
                     real_type = self.real_type.declaration_code(''),
-                    m = self.funcsuffix))
+                    m = self.funcsuffix,
+                    is_float = self.real_type.is_float))
         return True
 
     def create_to_py_utility_code(self, env):
@@ -1112,7 +1113,8 @@ class CComplexType(CNumericType):
                 utility_code.specialize(
                     self, 
                     real_type = self.real_type.declaration_code(''),
-                    m = self.funcsuffix))
+                    m = self.funcsuffix,
+                    is_float = self.real_type.is_float))
         self.from_py_function = "__Pyx_PyComplex_As_" + self.specialization_name()
         return True
     
@@ -1271,11 +1273,17 @@ proto="""
   #ifdef __cplusplus
     #define __Pyx_c_is_zero%(m)s(z) ((z)==(%(real_type)s)0)
     #define __Pyx_c_conj%(m)s(z)    (::std::conj(z))
-    /*#define __Pyx_c_abs%(m)s(z)     (::std::abs(z))*/
+    #if %(is_float)s
+        #define __Pyx_c_abs%(m)s(z)     (::std::abs(z))
+        #define __Pyx_c_pow%(m)s(a, b)  (::std::pow(a, b))
+    #endif
   #else
     #define __Pyx_c_is_zero%(m)s(z) ((z)==0)
     #define __Pyx_c_conj%(m)s(z)    (conj%(m)s(z))
-    /*#define __Pyx_c_abs%(m)s(z)     (cabs%(m)s(z))*/
+    #if %(is_float)s
+        #define __Pyx_c_abs%(m)s(z)     (cabs%(m)s(z))
+        #define __Pyx_c_pow%(m)s(a, b)  (cpow%(m)s(a, b))
+    #endif
  #endif
 #else
     static CYTHON_INLINE int __Pyx_c_eq%(m)s(%(type)s, %(type)s);
@@ -1286,7 +1294,10 @@ proto="""
     static CYTHON_INLINE %(type)s __Pyx_c_neg%(m)s(%(type)s);
     static CYTHON_INLINE int __Pyx_c_is_zero%(m)s(%(type)s);
     static CYTHON_INLINE %(type)s __Pyx_c_conj%(m)s(%(type)s);
-    /*static CYTHON_INLINE %(real_type)s __Pyx_c_abs%(m)s(%(type)s);*/
+    #if %(is_float)s
+        static CYTHON_INLINE %(real_type)s __Pyx_c_abs%(m)s(%(type)s);
+        static CYTHON_INLINE %(type)s __Pyx_c_pow%(m)s(%(type)s, %(type)s);
+    #endif
 #endif
 """,
 impl="""
@@ -1335,15 +1346,60 @@ impl="""
         z.imag = -a.imag;
         return z;
     }
-/*
-    static CYTHON_INLINE %(real_type)s __Pyx_c_abs%(m)s(%(type)s z) {
-#if HAVE_HYPOT
-        return hypot%(m)s(z.real, z.imag);
-#else
-        return sqrt%(m)s(z.real*z.real + z.imag*z.imag);
-#endif
-    }
-*/
+    #if %(is_float)s
+        static CYTHON_INLINE %(real_type)s __Pyx_c_abs%(m)s(%(type)s z) {
+          #if HAVE_HYPOT
+            return hypot%(m)s(z.real, z.imag);
+          #else
+            return sqrt%(m)s(z.real*z.real + z.imag*z.imag);
+          #endif
+        }
+        static CYTHON_INLINE %(type)s __Pyx_c_pow%(m)s(%(type)s a, %(type)s b) {
+            %(type)s z;
+            %(real_type)s r, lnr, theta, z_r, z_theta;
+            if (b.imag == 0 && b.real == (int)b.real) {
+                if (b.real < 0) {
+                    %(real_type)s denom = a.real * a.real + a.imag * a.imag;
+                    a.real = a.real / denom;
+                    a.imag = -a.imag / denom;
+                    b.real = -b.real;
+                }
+                switch ((int)b.real) {
+                    case 0:
+                        z.real = 1;
+                        z.imag = 0;
+                        return z;
+                    case 1:
+                        return a;
+                    case 2:
+                        z = __Pyx_c_prod%(m)s(a, a);
+                        return __Pyx_c_prod%(m)s(a, a);
+                    case 3:
+                        z = __Pyx_c_prod%(m)s(a, a);
+                        return __Pyx_c_prod%(m)s(z, a);
+                    case 4:
+                        z = __Pyx_c_prod%(m)s(a, a);
+                        return __Pyx_c_prod%(m)s(z, z);
+                }
+            }
+            if (a.imag == 0) {
+                if (a.real == 0) {
+                    return a;
+                }
+                r = a.real;
+                theta = 0;
+            } else {
+                r = __Pyx_c_abs%(m)s(a);
+                theta = atan2%(m)s(a.imag, a.real);
+            }
+            lnr = log%(m)s(r);
+            z_r = exp%(m)s(lnr * b.real - theta * b.imag);
+            z_theta = theta * b.real + lnr * b.imag;
+            z.real = z_r * cos%(m)s(z_theta);
+            z.imag = z_r * sin%(m)s(z_theta);
+            return z;
+        }
+    #endif
 #endif
 """)
 
index 0df26f42788c23eec042a41bf895cdee9a7411f0..6343ceb88e5032ee441050db81a3d24a3941851c 100644 (file)
@@ -23,6 +23,39 @@ def test_arithmetic(double complex z, double complex w):
     """
     return +z, -z+0, z+w, z-w, z*w, z/w
 
+def test_pow(double complex z, double complex w, tol=None):
+    """
+    Various implementations produce slightly different results...
+    
+    >>> a = complex(3, 1)
+    >>> test_pow(a, 1)
+    (3+1j)
+    >>> test_pow(a, 2, 1e-15)
+    True
+    >>> test_pow(a, a, 1e-15)
+    True
+    >>> test_pow(complex(0.5, -.25), complex(3, 4), 1e-15)
+    True
+    """
+    if tol is None:
+        return z**w
+    else:
+        return abs(z**w / <object>z ** <object>w - 1) < tol
+
+def test_int_pow(double complex z, int n, tol=None):
+    """
+    >>> [test_int_pow(complex(0, 1), k, 1e-15) for k in range(-4, 5)]
+    [True, True, True, True, True, True, True, True, True]
+    >>> [test_int_pow(complex(0, 2), k, 1e-15) for k in range(-4, 5)]
+    [True, True, True, True, True, True, True, True, True]
+    >>> [test_int_pow(complex(2, 0.5), k, 1e-15) for k in range(0, 10)]
+    [True, True, True, True, True, True, True, True, True, True]
+    """
+    if tol is None:
+        return z**n + <object>0 # add zero to normalize zero sign
+    else:
+        return abs(z**n / <object>z ** <object>n - 1) < tol
+
 @cython.cdivision(False)
 def test_div_by_zero(double complex z):
     """