2dade8bd2624ac70ac8986d347442e6a83ce867c
[sawsim.git] / pysawsim / parameter_scan.py
1 # Copyright (C) 2009-2010  W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation, either version 3 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License
14 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 #
16 # The author may be contacted at <wking@drexel.edu> on the Internet, or
17 # write to Trevor King, Drexel University, Physics Dept., 3141 Chestnut St.,
18 # Philadelphia PA 19104, USA.
19
20 """Experiment vs. simulation comparison and scanning.
21 """
22
23 from os import getpid  # for rss()
24 import os.path
25 import pickle
26 from StringIO import StringIO
27
28 import matplotlib
29 matplotlib.use('Agg')  # select backend that doesn't require X Windows
30 import numpy
31 import pylab
32
33 from . import log
34 from . import PYSAWSIM_LOG_LEVEL_MSG as _PYSAWSIM_LOG_LEVEL_MSG
35 from .histogram import Histogram
36 from .sawsim_histogram import sawsim_histogram
37 from .sawsim import SawsimRunner
38
39
40 _multiprocess_can_split_ = True
41 """Allow nosetests to split tests between processes.
42 """
43
44 FIGURE = pylab.figure()  # avoid memory problems.
45 """`pylab` keeps internal references to all created figures, so share
46 a single instance.
47 """
48 EXAMPLE_HISTOGRAM_FILE_CONTENTS = """# Velocity histograms
49 # Other general comments...
50
51 #HISTOGRAM: -v 6e-7
52 #Force (N)\tUnfolding events
53 1.4e-10\t1
54 1.5e-10\t0
55 1.6e-10\t4
56 1.7e-10\t6
57 1.8e-10\t8
58 1.9e-10\t20
59 2e-10\t28
60 2.1e-10\t38
61 2.2e-10\t72
62 2.3e-10\t110
63 2.4e-10\t155
64 2.5e-10\t247
65 2.6e-10\t395
66 2.7e-10\t451
67 2.8e-10\t430
68 2.9e-10\t300
69 3e-10\t116
70 3.1e-10\t18
71 3.2e-10\t1
72
73 #HISTOGRAM: -v 8e-7
74 #Force (N)\tUnfolding events
75 8e-11\t1
76 9e-11\t0
77 1e-10\t0
78 1.1e-10\t1
79 1.2e-10\t0
80 1.3e-10\t0
81 1.4e-10\t0
82 1.5e-10\t3
83 1.6e-10\t3
84 1.7e-10\t4
85 1.8e-10\t4
86 1.9e-10\t13
87 2e-10\t29
88 2.1e-10\t39
89 2.2e-10\t60
90 2.3e-10\t102
91 2.4e-10\t154
92 2.5e-10\t262
93 2.6e-10\t402
94 2.7e-10\t497
95 2.8e-10\t541
96 2.9e-10\t555
97 3e-10\t325
98 3.1e-10\t142
99 3.2e-10\t50
100 3.3e-10\t13
101
102 #HISTOGRAM: -v 1e-6
103 #Force (N)\tUnfolding events
104 1.5e-10\t2
105 1.6e-10\t3
106 1.7e-10\t7
107 1.8e-10\t8
108 1.9e-10\t7
109 2e-10\t25
110 2.1e-10\t30
111 2.2e-10\t58
112 2.3e-10\t76
113 2.4e-10\t159
114 2.5e-10\t216
115 2.6e-10\t313
116 2.7e-10\t451
117 2.8e-10\t568
118 2.9e-10\t533
119 3e-10\t416
120 3.1e-10\t222
121 3.2e-10\t80
122 3.3e-10\t24
123 3.4e-10\t2
124 """
125
126
127 MEM_DEBUG = False
128
129
130
131 def rss():
132     """
133     For debugging memory usage.
134
135     resident set size, the non-swapped physical memory that a task has
136     used (in kilo-bytes).
137     """
138     call = "ps -o rss= -p %d" % getpid()
139     status,stdout,stderr = invoke(call)
140     return int(stdout)
141
142
143 class HistogramMatcher (object):
144     """Compare experimental histograms to simulated data.
145
146     The main entry points are `fit()` and `plot()`.
147
148     The input `histogram_stream` should contain a series of
149     experimental histograms with '#HISTOGRAM: <params>` lines starting
150     each histogram.  `<params>` lists the `sawsim` parameters that are
151     unique to that experiment.
152
153     >>> from .manager.thread import ThreadManager
154     >>> histogram_stream = StringIO(EXAMPLE_HISTOGRAM_FILE_CONTENTS)
155     >>> param_format_string = (
156     ...     '-s cantilever,hooke,0.05 -N1 '
157     ...     '-s folded,null -N8 '
158     ...     '-s "unfolded,wlc,{0.39e-9,28e-9}" '
159     ...     '-k "folded,unfolded,bell,{%g,%g}" -q folded')
160     >>> m = ThreadManager()
161     >>> sr = SawsimRunner(manager=m)
162     >>> hm = HistogramMatcher(histogram_stream, param_format_string, sr, N=3)
163     >>> hm.plot([[1e-5,1e-3,3],[0.1e-9,1e-9,3]], logx=True, logy=False)
164     >>> m.teardown()
165     """
166     def __init__(self, histogram_stream, param_format_string,
167                  sawsim_runner, N=400, residual_type='jensen-shannon',
168                  plot=False):
169         self.experiment_histograms = self._read_force_histograms(
170             histogram_stream)
171         self.param_format_string = param_format_string
172         self.sawsim_runner = sawsim_runner
173         self.N = N
174         self.residual_type = residual_type
175         self._plot = plot
176
177     def _read_force_histograms(self, stream):
178         """
179         File format:
180
181           # comment and blank lines ignored
182           #HISTOGRAM: <histogram-specific params>
183           <pysawsim.histogram.Histogram-compatible histogram>
184           #HISTOGRAM: <other histogram-specific params>
185           <another pysawsim.histogram.Histogram-compatible histogram>
186           ...
187
188         >>> import sys
189         >>> stream = StringIO(EXAMPLE_HISTOGRAM_FILE_CONTENTS)
190         >>> hm = HistogramMatcher(StringIO(), None, None, None)
191         >>> histograms = hm._read_force_histograms(stream)
192         >>> sorted(histograms.iterkeys())
193         ['-v 1e-6', '-v 6e-7', '-v 8e-7']
194         >>> histograms['-v 1e-6'].to_stream(sys.stdout)
195         ... # doctest: +NORMALIZE_WHITESPACE, +REPORT_UDIFF
196         #Force (N)\tUnfolding events
197         1.5e-10\t2
198         1.6e-10\t3
199         1.7e-10\t7
200         1.8e-10\t8
201         1.9e-10\t7
202         2e-10\t25
203         2.1e-10\t30
204         2.2e-10\t58
205         2.3e-10\t76
206         2.4e-10\t159
207         2.5e-10\t216
208         2.6e-10\t313
209         2.7e-10\t451
210         2.8e-10\t568
211         2.9e-10\t533
212         3e-10\t416
213         3.1e-10\t222
214         3.2e-10\t80
215         3.3e-10\t24
216         3.4e-10\t2
217         """
218         token = '#HISTOGRAM:'
219         hist_blocks = {None: []}
220         params = None
221         for line in stream.readlines():
222             line = line.strip()
223             if line.startswith(token):
224                 params = line[len(token):].strip()
225                 assert params not in hist_blocks, params
226                 hist_blocks[params] = []
227             else:
228                 hist_blocks[params].append(line)
229
230         histograms = {}
231         for params,block in hist_blocks.iteritems():
232             if params == None:
233                 continue
234             h = Histogram()
235             h.from_stream(StringIO('\n'.join(block)))
236             histograms[params] = h
237         return histograms
238
239     def param_string(self, params, hist_params):
240         """Generate a string of options to pass to `sawsim`.
241         """
242         return '%s %s' % (
243             self.param_format_string % tuple(params), hist_params)
244
245     def residual(self, params):
246         residual = 0
247         for hist_params,experiment_hist in self.experiment_histograms.iteritems():
248             sawsim_hist = sawsim_histogram(
249                 sawsim_runner=self.sawsim_runner,
250                 param_string=self.param_string(params, hist_params),
251                 N=self.N, bin_edges=experiment_hist.bin_edges)
252             r = experiment_hist.residual(sawsim_hist, type=self.residual_type)
253             residual += r
254             if self._plot == True:
255                 title = ", ".join(["%g" % p for p in params]+[hist_params])
256                 filename = "residual-%s-%g.png" % (
257                     title.replace(', ', '_').replace(' ', '_'), r)
258                 self._plot_residual_comparison(
259                     experiment_hist, sawsim_hist, residual=r,
260                     title=title, filename=filename)
261         log().debug('residual %s: %g' % (params, residual))
262         return residual
263
264     def plot(self, param_ranges, logx=False, logy=False, contour=False,
265              csv=None):
266         if csv:
267             csv.write(','.join(('param 1', 'param 2', 'fit quality')) + '\n')
268         xranges = param_ranges[0]
269         yranges = param_ranges[1]
270         if logx == False:
271             x = numpy.linspace(*xranges)
272         else:
273             m,M,n = xranges
274             x = numpy.exp(numpy.linspace(numpy.log(m), numpy.log(M), n))
275         if logy == False:
276             y = numpy.linspace(*yranges)
277         else:
278             m,M,n = yranges
279             y = numpy.exp(numpy.linspace(numpy.log(m), numpy.log(M), n))
280         X, Y = pylab.meshgrid(x,y)
281         C = numpy.zeros((len(y)-1, len(x)-1))
282         for i,xi in enumerate(x[:-1]):
283             for j,yj in enumerate(y[:-1]):
284                 log().info('point %d %d (%d of %d)'
285                            % (i, j, i*(len(y)-1) + j, (len(x)-1)*(len(y)-1)))
286                 params = (xi,yj)
287                 r = self.residual(params)
288                 if csv:
289                     csv.write(','.join([str(v) for v in (xi,yj,r)]) + '\n')
290                 C[j,i] = numpy.log(r) # better resolution in valleys
291                 if MEM_DEBUG == True:
292                     log().debug('RSS: %d KB' % rss())
293         C = numpy.nan_to_num(C) # NaN -> 0
294         fid = file("histogram_matcher-XYC.pkl", "wb")
295         pickle.dump([X,Y,C], fid)
296         fid.close()
297         # read in with
298         # import pickle
299         # [X,Y,C] = pickle.load(file("histogram_matcher-XYC.pkl", "rb"))
300         # ...
301         FIGURE.clear()
302         axes = FIGURE.add_subplot(111)
303         if logx == True:
304             axes.set_xscale('log')
305         if logy == True:
306             axes.set_yscale('log')
307         if contour == True:
308             p = axes.contour(X[:-1,:-1], Y[:-1,:-1], C)
309             # [:-1,:-1] to strip dummy last row & column from X&Y.
310         else: # pseudocolor plot
311             p = axes.pcolor(X, Y, C)
312             axes.autoscale_view(tight=True)
313         FIGURE.colorbar(p)
314         FIGURE.savefig("figure.png")
315
316     def _plot_residual_comparison(self, experiment_hist, theory_hist,
317                                   residual, title, filename):
318         FIGURE.clear()
319         p = pylab.plot(experiment_hist.bin_edges[:-1],
320                        experiment_hist.probabilities, 'r-',
321                        theory_hist.bin_edges[:-1],
322                        theory_hist.probabilities, 'b-')
323         pylab.title(title)
324         FIGURE.savefig(filename)
325
326
327 def parse_param_ranges_string(string):
328     """Parse parameter range stings.
329
330     '[Amin,Amax,Asteps],[Bmin,Bmax,Bsteps],...'
331       ->
332     [[Amin,Amax,Asteps],[Bmin,Bmax,Bsteps],...]
333
334     >>> parse_param_ranges_string('[1,2,3],[4,5,6]')
335     [[1.0, 2.0, 3.0], [4.0, 5.0, 6.0]]
336     >>> parse_param_ranges_string('[1,2,3]')
337     [[1.0, 2.0, 3.0]]
338     """
339     ranges = []
340     for range_string in string.split("],["):
341         range_number_strings = range_string.strip("[]").split(",")
342         ranges.append([float(x) for x in range_number_strings])
343     return ranges
344
345
346 def main(argv=None):
347     """
348     >>> import tempfile
349     >>> f = tempfile.NamedTemporaryFile()
350     >>> f.write(EXAMPLE_HISTOGRAM_FILE_CONTENTS)
351     >>> f.flush()
352     >>> main(['-r', '[1e-5,1e-3,3],[0.1e-9,1e-9,3]',
353     ...       '-N', '2',
354     ...       f.name])
355     >>> f.close()
356     """
357     from optparse import OptionParser
358     import sys
359
360     if argv == None:
361         argv = sys.argv[1:]
362
363     sr = SawsimRunner()
364
365     usage = '%prog [options] histogram_file'
366     epilog = '\n'.join([
367             'Compare simulated results against experimental values over a',
368             'range of parameters.  Generates a plot of fit quality over',
369             'the parameter space.  The histogram file should look something',
370             'like:',
371             '',
372             EXAMPLE_HISTOGRAM_FILE_CONTENTS,
373             ''
374             '`#HISTOGRAM: <params>` lines start each histogram.  `params`',
375             'lists the `sawsim` parameters that are unique to that',
376             'experiment.',
377             '',
378             'Each histogram line is of the format:',
379             '',
380             '<bin_edge><whitespace><count>',
381             '',
382             '`<bin_edge>` should mark the left-hand side of the bin, and',
383             'all bins should be of equal width (so we know where the last',
384             'one ends).',
385             _PYSAWSIM_LOG_LEVEL_MSG,
386             ])
387     parser = OptionParser(usage, epilog=epilog)
388     parser.format_epilog = lambda formatter: epilog+'\n'
389     for option in sr.optparse_options:
390         if option.dest == 'param_string':
391             continue
392         parser.add_option(option)
393     parser.add_option('-f','--param-format', dest='param_format',
394                       metavar='FORMAT',
395                       help='Convert params to sawsim options (%default).',
396                       default=('-s cantilever,hooke,0.05 -N1 -s folded,null -N8 -s "unfolded,wlc,{0.39e-9,28e-9}" -k "folded,unfolded,bell,{%g,%g}" -q folded'))
397     parser.add_option('-p','--initial-params', dest='initial_params',
398                       metavar='PARAMS',
399                       help='Initial params for fitting (%default).',
400                       default='3.3e-4,0.25e-9')
401     parser.add_option('-r','--param-range', dest='param_range',
402                       metavar='PARAMS',
403                       help='Param range for plotting (%default).',
404                       default='[1e-5,1e-3,20],[0.1e-9,1e-9,20]')
405     parser.add_option('--logx', dest='logx',
406                       help='Use a log scale for the x range.',
407                       default=False, action='store_true')
408     parser.add_option('--logy', dest='logy',
409                       help='Use a log scale for the y range.',
410                       default=False, action='store_true')
411     parser.add_option('-R','--residual', dest='residual',
412                       metavar='STRING',
413                       help='Residual type (from %s; default: %%default).'
414                       % ', '.join(Histogram().types()),
415                       default='jensen-shannon')
416     parser.add_option('-P','--plot-residuals', dest='plot_residuals',
417                       help='Generate residual difference plots for each point in the plot range.',
418                       default=False, action='store_true')
419     parser.add_option('-c','--contour-plot', dest='contour_plot',
420                       help='Select contour plot (vs. the default pseudocolor plot).',
421                       default=False, action='store_true')
422     parser.add_option('--csv', dest='csv', metavar='FILE',
423                       help='Save fit qualities to a comma-separated value file FILE.'),
424
425     options,args = parser.parse_args(argv)
426
427     initial_params = [float(p) for p in options.initial_params.split(",")]
428     param_ranges = parse_param_ranges_string(options.param_range)
429     histogram_file = args[0]
430     csv = None
431     sr_call_params = sr.initialize_from_options(options)
432
433     try:
434         hm = HistogramMatcher(
435             file(histogram_file, 'r'),
436             param_format_string=options.param_format,
437             sawsim_runner=sr, residual_type=options.residual,
438             plot=options.plot_residuals, **sr_call_params)
439         #hm.fit(initial_params)
440         if options.csv:
441             csv = open(options.csv, 'w')
442         hm.plot(param_ranges, logx=options.logx, logy=options.logy,
443                 contour=options.contour_plot, csv=csv)
444     finally:
445         sr.teardown()
446         if csv:
447             csv.close()