+++ /dev/null
-#!/usr/bin/python
-#
-# Copyright (C) 2010 W. Trevor King
-
-from scipy.optimize import leastsq
-from math import sqrt
-
-def _param_char(index):
- """
- >>> _param_char(0)
- 'A'
- >>> _param_char(1)
- 'B'
- >>> _param_char(25)
- 'Z'
- """
- if index < 0 or index > 25 or index != int(index):
- raise ValueError(index)
- return chr(ord('A')+index)
-
-def parse_modules(module_strings):
- """
- >>> m = parse_modules(['math:sin,cos', 'sys'])
- >>> import math, sys
- >>> if m == {'sin':math.sin, 'cos':math.cos, 'sys':sys}:
- ... print True
- ... else:
- ... print m
- True
- """
- ms = []
- for string in module_strings:
- fields = string.split(':')
- assert len(fields) in [1, 2], '%d fields in %s: %s' % (len(fields), string, fields)
- modname = fields[0]
- objects = []
- if len(fields) == 2:
- objects = fields[1].split(',')
- ms.append((modname, objects))
-
- _locals = {}
- for modname,objects in ms:
- if len(objects) == 0:
- _locals[modname] = __import__(modname)
- else:
- _temp = __import__(modname, globals(), locals(), objects)
- for objname in objects:
- _locals[objname] = getattr(_temp, objname)
- return _locals
-
-def parse_function(function_string):
- """
- >>> f = parse_function('A*x+B')
- >>> f(0, [1,2])
- 2.0
- >>> f(-1, [1,2])
- 1.0
- """
- def f(x, params, _locals={}):
- _locals['x'] = x
- for i,p in enumerate(params):
- _locals[_param_char(i)] = float(p)
- return eval(function_string, globals(), _locals)
- return f
-
-def read_data(datalines, xcol=0, ycol=1, wcol=-1):
- """
- >>> read_data('1\\t2\\t3\\n4\\t5\\t6\\n'.splitlines())
- ([1.0, 4.0], [2.0, 5.0], [1, 1])
- >>> read_data('1\\t2\\t3\\n4\\t5\\t6\\n'.splitlines(), wcol=2)
- ([1.0, 4.0], [2.0, 5.0], [3.0, 6.0])
- """
- xs = []
- ys = []
- ws = []
- for line in datalines:
- if len(line) == 0 or line.startswith('#'):
- continue # blank or comment line
- row = [float(x) for x in line.split('\t')]
- xs.append(row[xcol])
- ys.append(row[ycol])
- if wcol >= 0:
- ws.append(row[wcol])
- if wcol < 0:
- ws = [1]*len(xs)
- return (xs, ys, ws)
-
-def fit(data, fn, initial_ps, _locals={}, verbose=False):
- xs,ys,ws = data
- sqws = [sqrt(w) for w in ws]
- def residuals(params):
- return [sqw*(fn(x, params, _locals)-y)
- for x,y,sqw in zip(xs,ys,sqws)]
- p,cov,info,mesg,ier = leastsq(residuals, initial_ps,
- full_output=True, maxfev=10000)
- if verbose == True:
- print "Fitted params:",p
- print "Covariance mx:",cov
- #print "Info:", info # too much information :p
- print "mesg:", mesg
- if ier == 1 :
- print "Solution converged"
- else :
- print "Solution did not converge"
- return p
-
-
-def test():
- import doctest
- failures,tests = doctest.testmod()
- return failures
-
-if __name__ == '__main__':
- import optparse
- import sys
-
- p = optparse.OptionParser(usage='%prog [options] DATAFILE',
- epilog="""
-DATAFILE should consist of TAB delimited ASCII values, with each line
-specifying a point. Lines beginning with a hash character (#) are
-ignored.
-""".strip())
- p.add_option('-f', '--model-function', dest='function', default='A*x+B',
- metavar='STRING',
- help='Define the model function used for fitting. Use capital letters as parameters for the function (%default)')
- p.add_option('-p', '--initial-params', dest='params', default='1,1',
- metavar='FLOAT[,...]',
- help='Set initial model parameters for fitting (%default)')
- p.add_option('-m', '--modules', dest='modules', default=[],
- action='append', metavar='MODULE[:OBJ[,...]]',
- help='Add a module necessary for your model function.')
- p.add_option('-x', '--x-column', dest='xcol', default=0, type='int',
- help='Select the column used for x values (%default)')
- p.add_option('-y', '--y-column', dest='ycol', default=1, type='int',
- help='Select the column used for y values (%default)')
- p.add_option('-w', '--weight-column', dest='wcol', default=-1, type='int',
- help='Select the column used for weighting points. Set to -1 to use uniform weighting (%default)')
- p.add_option('-v', '--verbose', dest='verbose', default=False,
- action='store_true', help='Print technical fitting details.')
- p.add_option('-t', '--test', dest='test', default=False,
- action='store_true', help='Run internal tests and exit')
-
- options,args = p.parse_args()
-
- if options.test == True:
- n = test()
- sys.exit(min(127, n))
-
- assert len(args) == 1, len(args)
- datafile = args[0]
-
- mod_includes = parse_modules(options.modules)
- fn = parse_function(options.function)
- data = read_data(file(datafile, 'r'),
- options.xcol, options.ycol, options.wcol)
- initial_ps = [float(x) for x in options.params.split(',')]
-
- assert len(initial_ps) < 26, ('More parameters than letters! (%d)'
- % len(params))
- params = fit(data, fn, initial_ps, mod_includes, verbose=options.verbose)
- assert len(params) == len(initial_ps)
-
- for i,p in enumerate(params):
- print '%s: %g' % (_param_char(i), p)