f171e0d24b8078c04bd7cc1f99479e63e8fe6e4f
[pyrisk.git] / pyrisk / graphics.py
1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """Various PyRisk class -> graphic renderers.
18
19 Creates SVG files by hand.  See the TemplateLibrary class for a
20 description of map specification format.
21 """
22
23 import operator
24 import os
25 import os.path
26
27 from .base import ID_CmpMixin
28
29
30 class Template (object):
31     """Setup regions for a particular world.
32     """
33     def __init__(self, name, regions):
34         self.name = name
35         self.regions = regions
36
37 class TemplateLibrary (object):
38     """Create Templates on demand from a directory of template data.
39
40     TODO: explain template data format.
41     """
42     def __init__(self, template_dir='share/templates/'):
43         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
44     def get(self, name):
45         region_pointlists,route_pointlists = self._get_pointlists(name)
46         template = self._generate_template(region_pointlists, route_pointlists)
47         return template
48     def _get_pointlists(self, name):
49         dirname = os.path.join(self.template_dir, name.lower())
50         try:
51             files = os.listdir(dirname)
52         except IOError:
53             return None
54         region_pointlists = {}
55         route_pointlists = {}
56         for filename in files:
57             path = os.path.join(dirname, filename)
58             name,extension = filename.rsplit('.', 1)
59             if extension == 'reg':
60                 region_pointlists[name] = self._read_file(path)
61             elif extension in ['rt', 'vrt']:
62                 route_pointlists[name] = (self._read_file(path),
63                                           extension == 'vrt')
64         return (region_pointlists, route_pointlists)
65     def _read_file(self, filename):
66         pointlist = []
67         for line in open(filename, 'r'):
68             line = line.strip()
69             if len(line) == 0:
70                 pointlist.append(None)
71                 continue
72             fields = line.split('\t')
73             x = int(fields[0])
74             y = -int(fields[1])
75             if len(fields) == 3:
76                 label = fields[2].strip()
77             else:
78                 label = None
79             pointlist.append((x,y,label))
80         return pointlist
81     def _generate_template(self, region_pointlists, route_pointlists):
82         regions = []
83         all_boundaries = []
84         for name,pointlist in region_pointlists.items():
85             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
86                 all_boundaries, pointlist)
87             regions.append(Region(name, boundaries, head_to_tail))
88             r = regions[-1]
89         for name,v_pointlist in route_pointlists.items():
90             pointlist,virtual = v_pointlist
91             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
92                 all_boundaries, pointlist)
93             assert len(boundaries) == 1, boundaries
94             route = boundaries[0]
95             route.virtual = virtual
96             for terminal in [route[0], route[-1]]:
97                 for r in regions:
98                     for point in r.outline:
99                         if hasattr(point, 'name') \
100                                 and point.name == terminal.name:
101                             r.routes.append(route)
102                             r.route_head_to_tail.append(
103                                 terminal == route[0])
104                             route.regions.append(r)
105         for r in regions:
106             r.locate_routes()
107         match_counts = [b.match_count for b in all_boundaries]
108         assert min(match_counts) in [0, 1], set(match_counts)
109         assert max(match_counts) == 1, set(match_counts)
110         return Template('template', regions)
111     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
112         boundaries = []
113         head_to_tail = []
114         b_points = []
115         for i,point in enumerate(pointlist):
116             if point == None:
117                 boundary,reverse = self._analyze(b_points)
118                 boundary = self._insert_boundary(all_boundaries, boundary)
119                 boundaries.append(boundary)
120                 head_to_tail.append(not reverse)
121                 b_points = []
122                 continue
123             b_points.append(point)
124         if len(b_points) > 0:
125             boundary,reverse = self._analyze(b_points)
126             boundary = self._insert_boundary(all_boundaries, boundary)
127             boundaries.append(boundary)
128             head_to_tail.append(not reverse)
129         return boundaries, head_to_tail
130     def _analyze(self, boundary_points):
131         start = self._vbp(boundary_points[0])
132         stop = self._vbp(boundary_points[-1])
133         if stop < start:
134             reverse = True
135         points = [self._vbp(b) for b in boundary_points]
136         reverse = start > stop
137         if reverse == True:
138             points.reverse()
139             start,stop = (stop, start)
140         boundary = Boundary([p-start for p in points])
141         for bp,p in zip(boundary, points):
142             bp.name = p.name # preserve point names
143         boundary.name = '(%s) -> (%s)' % (start.name, stop.name)
144         boundary.real_pos = start
145         return (boundary, reverse)
146     def _vbp(self, boundary_point):
147         v = Vector((boundary_point[0], boundary_point[1]))
148         v.name = boundary_point[2]
149         return v
150     def _insert_boundary(self, all_boundaries, new):
151         if new in all_boundaries:
152             return new
153         for b in all_boundaries:
154             if len(b) == len(new) and b.real_pos == new.real_pos:
155                 match = True
156                 for bp,np in zip(b, new):
157                     if bp != np:
158                         match = False
159                         break
160                 if match == True:
161                     b.match_count += 1
162                     return b
163         all_boundaries.append(new)
164         new.match_count = 0
165         return new
166
167 TEMPLATE_LIBRARY = TemplateLibrary()
168
169
170 class Vector (tuple):
171     """Simple vector addition and subtraction.
172
173     >>> v = Vector
174     >>> a = v((0, 0))
175     >>> b = v((1, 1))
176     >>> c = v((2, 3))
177     >>> a+b
178     (1, 1)
179     >>> a+b+c
180     (3, 4)
181     >>> b-c
182     (-1, -2)
183     >>> -c
184     (-2, -3)
185     >>> a < b
186     True
187     >>> c > b
188     True
189     """
190     def _set_name(self, new, other=None):
191         if hasattr(self, 'name'):
192             if self.name == None:
193                 if hasattr(other, 'name'):
194                     new.name = other.name
195                     return
196             new.name = self.name
197         elif hasattr(other, 'name'):
198             new.name = other.name
199     def __neg__(self):
200         new = self.__class__(map(operator.neg, self))
201         self._set_name(new)
202         return new
203     def __add__(self, other):
204         if len(self) != len(other):
205             raise ValueError('length missmatch %s, %s' % (self, other))
206         new = self.__class__(map(operator.add, self, other))
207         self._set_name(new, other)
208         return new
209     def __sub__(self, other):
210         if len(self) != len(other):
211             raise ValueError('length missmatch %s, %s' % (self, other))
212         new = self.__class__(map(operator.sub, self, other))
213         self._set_name(new, other)
214         return new
215     def __mul__(self, other):
216         if len(self) != len(other):
217             raise ValueError('length missmatch %s, %s' % (self, other))
218         new = self.__class__(map(operator.mul, self, other))
219         self._set_name(new, other)
220         return new
221
222 def nameless(vector):
223     """Return a nameless version of a given Vector.
224
225     Useful for ensuring the result of a sum / etc. has the name of the
226     *other* vector.
227     """
228     return Vector(vector)
229
230 class Boundary (ID_CmpMixin, list):
231     """Contains a list of points along the boundary.
232
233     All positions are relative to the location of the first point,
234     which should therefore always be (0,0).
235     """
236     def __init__(self, points):
237         list.__init__(self)
238         ID_CmpMixin.__init__(self)
239         for p in points:
240             self.append(Vector(p))
241         self.regions = []
242         assert self[0] == (0,0), self
243         self.x_min = min([p[0] for p in self])
244         self.x_max = max([p[0] for p in self])
245         self.y_min = min([p[1] for p in self])
246         self.y_max = max([p[1] for p in self])
247
248 class Region (ID_CmpMixin, list):
249     """Contains a list of boundaries and a label.
250
251     Regions can be Territories, sections of ocean, etc.
252
253     >>> r = Region('Earth',
254     ...            [Boundary([(0,0), (0,1)]),
255     ...             Boundary([(0,0), (1,0)]),
256     ...             Boundary([(0,0), (0,1)]),
257     ...             Boundary([(0,0), (1,0)])],
258     ...            [True, True, False, False],
259     ...            (0.5, 0.5))
260     >>> r.outline
261     [(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]
262     """
263     def __init__(self, name, boundaries, head_to_tail, routes=None,
264                  route_head_to_tail=None, label_offset=(0,0)):
265         list.__init__(self, boundaries)
266         ID_CmpMixin.__init__(self)
267         for boundary in self:
268             boundary.regions.append(self)
269         self.head_to_tail = head_to_tail
270         self.name = name
271         self.routes = routes
272         self.route_head_to_tail = route_head_to_tail
273         if routes == None:
274             assert route_head_to_tail == None
275             self.routes = []
276             self.route_head_to_tail = []
277         self.route_starts = [] # set by .locate_routes
278         self.label_offset = Vector(label_offset)
279         self.generate_outline() # sets .outline, .starts
280         self.x_min = min([b.x_min+s[0] for b,s in zip(self, self.starts)])
281         self.x_max = max([b.x_max+s[0] for b,s in zip(self, self.starts)])
282         self.y_min = min([b.y_min+s[1] for b,s in zip(self, self.starts)])
283         self.y_max = max([b.y_max+s[1] for b,s in zip(self, self.starts)])
284     def generate_outline(self):
285         """Return a list of boundary points surrounding the region.
286
287         The main issue here is determining the proper border
288         orientation for a CCW outline, which we do via a user-supplied
289         list head_to_tail.
290         """
291         self.starts = []
292         points = [Vector((0,0))]
293         for boundary,htt in zip(self, self.head_to_tail):
294             pos = points[-1]
295             if htt == True:
296                 assert boundary[0] == (0,0), boundary
297                 new = boundary[1:]
298             else:
299                 pos -= nameless(boundary[-1])
300                 new = reversed(boundary[:-1])
301             pos = nameless(pos)
302             self.starts.append(pos)
303             for p in new:
304                 points.append(pos+p)
305         assert points[-1] == points[0], '%s: %s' % (self.name, points)
306         self.outline = points
307     def locate_routes(self):
308         self.route_starts = []
309         for route,htt in zip(self.routes, self.route_head_to_tail):
310             if htt:
311                 anchor = route[0]
312             else:
313                 anchor = route[-1]
314             for point in self.outline:
315                 if hasattr(point, 'name') and point.name == anchor.name:
316                     self.route_starts.append(point-nameless(anchor))
317                     break
318
319 class Route (ID_CmpMixin):
320     """Connect non-adjacent Regions.
321     """
322     def __init__(self, boundary):
323         ID_CmpMixin.__init__(self)
324         self.boundary = boundary
325
326
327 class WorldRenderer (object):
328     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
329         self.template_lib = template_lib
330         if self.template_lib == None:
331             self.template_lib = TEMPLATE_LIBRARY
332         self.buf = buf
333         self.line_width = line_width
334         self.line_color = 'black'
335         self.dpcm = dpcm
336     def render(self, world):
337         template = self.template_lib.get(world.name)
338         if template == None:
339             template = self._auto_template(world)
340         return self.render_template(world, template)
341     def render_template(self, world, template):
342         region_pos,width,height = self._locate(template)
343         lines = [
344             '<?xml version="1.0" standalone="no"?>',
345             '<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"',
346             '  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">',
347             '<svg width="%.1fcm" height="%.1fcm" viewBox="0 0 %d %d"'
348             % (float(width)/self.dpcm, float(height)/self.dpcm,
349                width, height),
350             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
351             '<desc>PyRisk world: %s</desc>' % template.name,
352             ]
353         drawn_rts = {}
354         for r in template.regions:
355             lines.extend([
356                     '<!-- %s -->' % r.name,
357                     '<polygon fill="red" stroke="%s" stroke-width="%d"'
358                     % (self.line_color, self.line_width),
359                     '         points="%s" />'
360                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
361                                            *(1,-1) # svg y value increases down
362                                            +(0,height)) # shift back into bbox
363                                 for p in r.outline[:-1]])
364                     ])
365             for rt,rt_start in zip(r.routes, r.route_starts):
366                 if id(rt) in drawn_rts or rt.virtual == True:
367                     continue
368                 drawn_rts[id(rt)] = rt
369                 lines.extend([
370                         '<polyline stroke="%s" stroke-width="%d"'
371                         % (self.line_color, self.line_width),
372                         '         points="%s" />'
373                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
374                                            *(1,-1) # svg y value increases down
375                                            +(0,height)) # shift back into bbox
376                                     for p in rt])
377                         ])
378         lines.extend([
379                 '<circle fill="black" cx="0" cy="0" r="20" />',
380                 '<circle fill="green" cx="%d" cy="%d" r="20" />'
381                  % (width, height)
382                 ])
383         lines.extend(['</svg>', ''])
384         return '\n'.join(lines)
385     def _locate(self, template):
386         region_pos = {} # {id: absolute position, ...}
387         boundary_pos = {} # {id: absolute position, ...}
388         route_pos = {} # {id: absolute position, ...}
389         b1 = template.regions[0][0]
390         boundary_pos[id(b1)] = Vector((0,0)) # fix the first boundary point
391         stack = [r for r in b1.regions]
392         while len(stack) > 0:
393             r = stack.pop()
394             if id(r) in region_pos:
395                 continue # skip duplicate entries
396             r_start = None
397             for b,rel_b_start in zip(r, r.starts):
398                 if id(b) in boundary_pos:
399                     b_start = boundary_pos[id(b)]
400                     r_start = b_start - rel_b_start
401                     break # found an anchor
402             if r_start == None:
403                 for rt,rel_rt_start in zip(r.routes, r.route_starts):
404                     if id(rt) in route_pos:
405                         rt_start = route_pos[id(rt)]
406                         r_start = rt_start - rel_rt_start
407                         break # found an anchor
408             region_pos[id(r)] = r_start
409             for b,rel_b_start in zip(r, r.starts):
410                 if id(b) not in boundary_pos:
411                     boundary_pos[id(b)] = r_start + rel_b_start
412                     for r2 in b.regions:
413                         stack.append(r2)
414             for rt,rt_start in zip(r.routes, r.route_starts):
415                 if id(rt) not in route_pos:
416                     route_pos[id(rt)] = r_start + rt_start
417                     for r2 in rt.regions:
418                         stack.append(r2)
419         for r in template.regions:
420             if id(r) not in region_pos:
421                 raise KeyError(r.name)
422         x_min = min([r.x_min + region_pos[id(r)][0]
423                      for r in template.regions]) - self.buf
424         x_max = max([r.x_max + region_pos[id(r)][0]
425                      for r in template.regions]) + self.buf
426         y_min = min([r.y_min + region_pos[id(r)][1]
427                      for r in template.regions]) - self.buf
428         y_max = max([r.y_max + region_pos[id(r)][1]
429                      for r in template.regions]) + self.buf
430         for key,value in region_pos.items():
431             region_pos[key] = value - Vector((x_min, y_min))
432         return (region_pos, x_max-x_min, y_max-y_min)
433     def _auto_template(self, world):
434         raise NotImplementedError
435
436 def test():
437     import doctest, sys
438     failures,tests = doctest.testmod(sys.modules[__name__])
439     return failures
440
441 def render_earth():
442     from .base import generate_earth
443     r = WorldRenderer()
444     print r.render(generate_earth())
445     #f = open('world.svg', 'w')
446     #f.write(r.render(generate_earth()))
447     #f.close()