dd772a0476c21d4d6cd8786666101956d27c1a46
[parallel_computing.git] / src / plot_image / plot_image.py
1 #!/usr/bin/env python
2
3 """Generate an image from ASCII data.
4
5 Step 1: make this file executable::
6
7     chmod +x plot_image.py
8
9 Step 2: pipe data into python script::
10
11     ./gen_data | ./plot_image.py -s nx,ny -c nc -t 'image title'
12
13 Data should be one ASCII float per line, in the following order::
14
15     z[x1,y1]
16     z[x2,y1]
17     ...
18     z[xN,y1]
19     z[x1,y2]
20     ...
21     z[xN,y2]
22     ...
23     z[xN,yN]
24
25 where `x` increases from `x1` to `xN` and `y` increases from `y1`
26 through `yN`.
27
28 You can use the `--xyz` option to read an alternative data format,
29 which is::
30
31     x1 y1 z[x1,y1]
32     x2 y2 z[x1,y1]
33     ...
34
35 If you use this method, the ordering of lines in the data file is
36 irrelevant, and `Nx` and `Ny` are extracted from the data file itself.
37 However, you still need to use a rectangular grid (i.e. for every
38 `xi`, you need to have entries for every `yi`).
39
40 For usage info, run::
41
42     ./plot_image.py --help
43
44 When run with interactive output (i.e. no `-o ...` option), the
45 interactive figure is displayed using `pylab.show`, which means that
46 you'll have to kill `plot_image.py` using ^C or similar [1]_.
47
48 For other ideas, see the Matplotlib website [2]_.
49
50 .. [1] http://matplotlib.sourceforge.net/faq/howto_faq.html#use-show
51 .. [2] http://matplotlib.sourceforge.net/
52 """
53
54 import optparse
55 import sys
56
57 import matplotlib
58 import matplotlib.image
59 import numpy
60
61 # Depending on your Matplotlib configuration, you may need to adjust
62 # your backend.  Do this before importing pylab or matplotlib.backends.
63 #matplotlib.use('Agg')     # select backend that doesn't require X Windows
64 #matplotlib.use('GTKAgg')  # select backend that supports pylab.show()
65
66 import pylab
67
68
69 _DOC = __doc__
70
71
72 def read_data_1d(stream, nx, ny):
73     """Read in data, one entry per line.
74
75     >>> from StringIO import StringIO
76     >>> s = StringIO('\\n'.join(map(str, range(10)))+'\\n')
77     >>> X,Y,Z = read_data_1d(s, 5, 2)
78     >>> X
79     array([[0, 1, 2, 3, 4, 5],
80            [0, 1, 2, 3, 4, 5],
81            [0, 1, 2, 3, 4, 5]])
82     >>> Y
83     array([[0, 0, 0, 0, 0, 0],
84            [1, 1, 1, 1, 1, 1],
85            [2, 2, 2, 2, 2, 2]])
86     >>> Z
87     array([[ 0.,  1.,  2.,  3.,  4.],
88            [ 5.,  6.,  7.,  8.,  9.]])
89     """
90     X,Y = pylab.meshgrid(range(nx+1), range(ny+1))
91     Z = numpy.loadtxt(stream)
92     assert Z.size == nx*ny, 'Z.size = %d != %d = %dx%d' % (
93         Z.size, nx*ny, nx, ny)
94     Z = Z.reshape([x-1 for x in X.shape])    
95     return (X,Y,Z)
96
97 def read_data_3d(stream):
98     """Read in data, one `(x, y, z)` tuple per line.
99
100     >>> from StringIO import StringIO
101     >>> lines = []
102     >>> for x in range(5):
103     ...     for y in range(2):
104     ...         lines.append('\t'.join(map(str, [x, y, x+y*5])))
105     >>> s = StringIO('\\n'.join(lines)+'\\n')
106     >>> X,Y,Z = read_data_3d(s)
107     >>> X
108     array([[ 0.,  1.,  2.,  3.,  4.,  5.],
109            [ 0.,  1.,  2.,  3.,  4.,  5.],
110            [ 0.,  1.,  2.,  3.,  4.,  5.]])
111     >>> Y
112     array([[ 0.,  0.,  0.,  0.,  0.,  0.],
113            [ 1.,  1.,  1.,  1.,  1.,  1.],
114            [ 2.,  2.,  2.,  2.,  2.,  2.]])
115     >>> Z
116     array([[ 0.,  1.,  2.,  3.,  4.],
117            [ 5.,  6.,  7.,  8.,  9.]])
118     """
119     XYZ = numpy.loadtxt(stream)
120     assert len(XYZ.shape) == 2 and XYZ.shape[1] == 3, XYZ.shape
121     Xs = numpy.array(sorted(set(XYZ[:,0])))
122     Ys = numpy.array(sorted(set(XYZ[:,1])))
123     Z = numpy.ndarray((len(Ys), len(Xs)), dtype=float)
124     xyz = {}  # dict of z values keyed by (x,y) tuples
125     for i in range(XYZ.shape[0]):
126         xyz[(XYZ[i,0], XYZ[i,1])] = XYZ[i,2]
127     for i,x in enumerate(Xs):
128         for j,y in enumerate(Ys):
129             Z[j,i] = xyz[x,y]
130     # add dummy row/column for pcolor
131     dx = Xs[-1] - Xs[-2]
132     dy = Ys[-1] - Ys[-2]
133     Xs = numpy.append(Xs, Xs[-1] + dx)
134     Ys = numpy.append(Ys, Ys[-1] + dy)
135     X,Y = pylab.meshgrid(Xs, Ys)
136     return (X,Y,Z)
137
138 def plot(X, Y, Z, full=False, title=None, contours=None, interpolation=None,
139          cmap=None):
140     """Plot Z over the mesh X, Y.
141
142     >>> X, Y = pylab.meshgrid(range(6), range(2))
143     >>> Z = X[:-1,:-1]**2 + Y[:-1,:-1]
144     >>> plot(X, Y, Z)  # doctest: +ELLIPSIS
145     <matplotlib.figure.Figure object at 0x...>
146     """
147     X_min = X[0,0]
148     X_max = X[-1,-1]
149     Y_min = Y[0,0]
150     Y_max = Y[-1,-1]
151
152     fig = pylab.figure()
153     if full:
154         axes = fig.add_axes([0, 0, 1, 1])
155     else:
156         axes = fig.add_subplot(1, 1, 1)
157         if title:
158             axes.set_title(title)
159     axes.set_axis_off()
160
161     if contours:
162         cset = axes.contour(X[:-1,:-1], Y[:-1,:-1], Z, contours, cmap=cmap)
163         # [:-1,:-1] to strip dummy last row & column from X&Y.
164         axes.clabel(cset, inline=1, fmt='%1.1f', fontsize=10)
165     else:
166         # pcolor() is much slower than imshow.
167         #plot = axes.pcolor(X, Y, Z, cmap=cmap, edgecolors='none')
168         #axes.autoscale_view(tight=True)
169         plot = axes.imshow(Z, aspect='auto', interpolation=interpolation,
170                            origin='lower', cmap=cmap,
171                            extent=(X_min, X_max, Y_min, Y_max))
172         if not full:
173             fig.colorbar(plot)
174     return fig
175
176
177 def get_possible_interpolations():
178     try:  # Matplotlib v1.0.1
179         return sorted(matplotlib.image.AxesImage._interpd.keys())
180     except AttributeError:
181         try:  # Matplotlib v0.91.2
182             return sorted(matplotlib.image.AxesImage(None)._interpd.keys())
183         except AttributeError:
184             # give up ;)
185             pass
186     return ['nearest']
187
188 def test():
189     import doctest
190     results = doctest.testmod()
191     return results.failed
192
193
194 def main(argv=None):
195     """Read in data and plot it.
196
197     >>> from tempfile import NamedTemporaryFile
198     >>> i = NamedTemporaryFile(prefix='tmp-input', suffix='.dat')
199     >>> i.write('\\n'.join([str(x) for x in range(10)])+'\\n')
200     >>> i.flush()
201     >>> o = NamedTemporaryFile(prefix='tmp-output', suffix='.png')
202     >>> main(['-i', i.name, '-s', '5,2', '-o', o.name, '-m', 'binary'])
203     Plot_image
204     Title:             Some like it hot
205     Image size:        5 2
206     False color
207     X range:           0 4
208     X range:           0 1
209     Z range:           0.0 9.0
210     >>> img = o.read()
211     >>> img.startswith('\\x89PNG')
212     True
213     >>> i.close()
214     >>> o.close()
215     """
216     if argv == None:
217         argv = sys.argv[1:]
218
219     usage = '%prog [options]'
220     epilog = _DOC
221     p = optparse.OptionParser(usage=usage, epilog=epilog)
222     p.format_epilog = lambda formatter: epilog+'\n'
223
224     p.add_option('-i', '--input', dest='input', metavar='FILE',
225                  help='If set, read data from FILE rather than stdin.')
226     p.add_option('-o', '--output', dest='output', metavar='FILE',
227                  help=('If set, save the figure to FILE rather than '
228                        'displaying it immediately'))
229     p.add_option('-s', '--size', dest='size', default='%d,%d' % (16, 16),
230                  help='Data size (columns,rows; default: %default)')
231     p.add_option('-3', '--xyz', dest='xyz', default=False, action='store_true',
232                  help=('If set, read (x,y,z) tuples from the input data rather'
233                        'then reading `z` and calculating `x` and `y` from '
234                        '`--size`.'))
235     p.add_option('-c', '--contours', dest='contours', type='int',
236                  help=('Number of contour lines (if not set, draw false color '
237                        'instead of contour lines; default: %default)'))
238     p.add_option('-f', '--full-figure', dest='full', action='store_true',
239                  help=('Set axes to fill the figure (i.e. no title or color '
240                        'bar'))
241     p.add_option('-t', '--title', dest='title', default='Some like it hot',
242                  help='Title (%default)')
243     p.add_option('--test', dest='test', action='store_true',
244                  help='Run internal tests and exit.')
245     interpolations = get_possible_interpolations()
246     p.add_option('--interpolation', dest='interpolation', default='nearest',
247                  help=('Interpolation scheme (for false color images) from %s '
248                        '(%%default)') % ', '.join(interpolations))
249     maps=[m for m in pylab.cm.datad if not m.endswith("_r")]
250     maps.sort()
251     p.add_option('-m', '--color-map', dest='cmap', default='jet',
252                  help='Select color map from %s (%%default)' % ', '.join(maps))
253
254     options,args = p.parse_args(argv)
255
256     if options.test:
257         sys.exit(test())
258
259     nx,ny = [int(x) for x in options.size.split(',')]
260     try:
261         cmap = getattr(pylab.cm, options.cmap)
262     except AttributeError:
263         raise Exception('no color map named %s in %s'
264                         % (options.cmap, ', '.join(maps)))
265
266     print 'Plot_image'
267     print 'Title:            ', options.title
268     if not options.xyz:
269         print 'Image size:       ', nx, ny
270     if options.contours:
271         print '# countour lines: ', options.contours
272     else:
273         print 'False color'
274     if options.input:
275         fin = open(options.input, 'r')
276     else:
277         fin = sys.stdin
278
279     if options.xyz:
280         X,Y,Z = read_data_3d(fin)
281     else:
282         X,Y,Z = read_data_1d(fin, nx, ny)
283
284     if options.input:
285         fin.close()
286
287     Z_min = numpy.min(Z.flat)
288     Z_max = numpy.max(Z.flat)
289     print 'X range:          ', X[0,0], X[0,-2]
290     print 'X range:          ', Y[0,0], Y[-2,0]
291     print 'Z range:          ', Z_min, Z_max
292
293     fig = plot(X, Y, Z, full=options.full, title=options.title,
294                contours=options.contours, interpolation=options.interpolation,
295                cmap=cmap)
296
297     if options.output:
298         fig.savefig(options.output)
299     else:
300         pylab.show()
301
302
303 if __name__ == '__main__':
304     main()