Adapt to handle scipy.optimize.leastsq for scipy >= 0.8.0.
authorW. Trevor King <wking@drexel.edu>
Thu, 28 Apr 2011 11:49:26 +0000 (07:49 -0400)
committerW. Trevor King <wking@drexel.edu>
Thu, 28 Apr 2011 11:49:26 +0000 (07:49 -0400)
For single-parameter fitting, leastsq() used to return the fitted
parameter as a float.  Now it returns it as a length-one array.  With
this commit, Hooke should work with either case.

hooke/plugin/polymer_fit.py
hooke/util/fit.py

index 7105e570841a1fca6c3c073778c71e03fb46436c..4447ca76e21571aac8f5623554e7f26bfc55493d 100644 (file)
@@ -1,6 +1,6 @@
 # -*- coding: utf-8 -*-
 #
-# Copyright (C) 2008-2010 Alberto Gomez-Casado
+# Copyright (C) 2008-2011 Alberto Gomez-Casado
 #                         Fabrizio Benedetti
 #                         Massimo Sandal <devicerandom@gmail.com>
 #                         W. Trevor King <wking@drexel.edu>
@@ -244,7 +244,7 @@ class FJC (ModelFitter):
 
     >>> info['Kuhn length (m)'] = 2*a
     >>> model = FJC(d_data, info=info, rescale=True)
-    >>> L = model.fit(outqueue=outqueue)
+    >>> L, = model.fit(outqueue=outqueue)
     >>> fit_info = outqueue.get(block=False)
     >>> print L  # doctest: +ELLIPSIS
     3.199...e-08
@@ -340,11 +340,9 @@ class FJC (ModelFitter):
 
     def fit(self, *args, **kwargs):
         params = super(FJC, self).fit(*args, **kwargs)
-        if is_iterable(params):
-            params[0] = self.L(params[0])  # convert Lp -> L
+        params[0] = self.L(params[0])  # convert Lp -> L
+        if len(params) > 1:
             params[1] = abs(params[1])  # take the absolute value of `a`
-        else:  # params is a float
-            params = self.L(params)  # convert Lp -> L
         return params
 
     def guess_initial_params(self, outqueue=None):
@@ -507,7 +505,7 @@ class FJC_PEG (ModelFitter):
 
     >>> info['Kuhn length (m)'] = 2*kwargs['a']
     >>> model = FJC_PEG(d_data, info=info, rescale=True)
-    >>> N = model.fit(outqueue=outqueue)
+    >>> N, = model.fit(outqueue=outqueue)
     >>> fit_info = outqueue.get(block=False)
     >>> print N  # doctest: +ELLIPSIS
     96.931...
@@ -613,11 +611,9 @@ class FJC_PEG (ModelFitter):
 
     def fit(self, *args, **kwargs):
         params = super(FJC_PEG, self).fit(*args, **kwargs)
-        if is_iterable(params):
-            params[0] = self.L(params[0])  # convert Nr -> N
+        params[0] = self.L(params[0])  # convert Nr -> N
+        if len(params) > 1:
             params[1] = abs(params[1])  # take the absolute value of `a`
-        else:  # params is a float
-            params = self.L(params)  # convert Nr -> N
         return params
 
     def guess_initial_params(self, outqueue=None):
@@ -726,7 +722,7 @@ class WLC (ModelFitter):
 
     >>> info['persistence length (m)'] = 2*p
     >>> model = WLC(d_data, info=info, rescale=True)
-    >>> L = model.fit(outqueue=outqueue)
+    >>> L, = model.fit(outqueue=outqueue)
     >>> fit_info = outqueue.get(block=False)
     >>> print L  # doctest: +ELLIPSIS
     3.318...e-08
@@ -822,11 +818,9 @@ class WLC (ModelFitter):
 
     def fit(self, *args, **kwargs):
         params = super(WLC, self).fit(*args, **kwargs)
-        if is_iterable(params):
-            params[0] = self.L(params[0])  # convert Lp -> L
+        params[0] = self.L(params[0])  # convert Lp -> L
+        if len(params) > 1:
             params[1] = abs(params[1])  # take the absolute value of `p`
-        else:  # params is a float
-            params = self.L(params)  # convert Lp -> L
         return params
 
     def guess_initial_params(self, outqueue=None):
index 0a97135aeba2a31cc3b76965ddbeca5bf563daf8..6267341849f450e757ba1d64193898e3b5c4903c 100644 (file)
@@ -1,4 +1,4 @@
-# Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
+# Copyright (C) 2010-2011 W. Trevor King <wking@drexel.edu>
 #
 # This file is part of Hooke.
 #
 """
 
 from numpy import arange, ndarray
+from scipy import __version__ as _scipy_version
 from scipy.optimize import leastsq
-import scipy.optimize
+
+_strings = _scipy_version.split('.')
+# Don't convert third string to an integer in case of (for example) '0.7.2rc3'
+_SCIPY_VERSION = (int(_strings[0]), int(_strings[1]), _strings[2])
+del _strings
 
 
 class PoorFit (ValueError):
@@ -128,6 +133,25 @@ class ModelFitter (object):
     7.000
     >>> print '%.3f' % offset
     -32.890
+
+    Test single-parameter models:
+
+    >>> class SingleParameterModel (LinearModel):
+    ...     '''Simple linear model.
+    ...     '''
+    ...     def model(self, params):
+    ...         return super(SingleParameterModel, self).model([params[0], 0.])
+    ...     def guess_initial_params(self, outqueue=None):
+    ...         return super(SingleParameterModel, self
+    ...             ).guess_initial_params(outqueue)[:1]
+    ...     def guess_scale(self, params, outqueue=None):
+    ...         return super(SingleParameterModel, self
+    ...             ).guess_scale([params[0], 0.], outqueue)[:1]
+    >>> data = 20*numpy.sin(arange(1000)) + 7.*arange(1000)
+    >>> m = SingleParameterModel(data)
+    >>> slope, = m.fit(outqueue=outqueue)
+    >>> print '%.3f' % slope
+    7.000
     """
     def __init__(self, *args, **kwargs):
         self.set_data(*args, **kwargs)
@@ -233,13 +257,13 @@ class ModelFitter (object):
         params,cov,info,mesg,ier = leastsq(
             func=self.residual, x0=active_params, full_output=True,
             diag=scale, **kwargs)
+        if len(initial_params) == 1 and _SCIPY_VERSION < (0, 8, '0'):
+            # params is a float for scipy < 0.8.0.  Convert to list.
+            params = [params]
         if self._rescale == True:
             active_params = params
-            if len(initial_params) == 1:  # params is a float
-                params = params * self._param_scale_factors[0]
-            else:
-                params = [p*s for p,s in zip(params,
-                                             self._param_scale_factors)]
+            params = [p*s for p,s in zip(params,
+                                         self._param_scale_factors)]
         else:
             active_params = params
         if outqueue != None: