c407c29547598bd92ebed27124a8c80f0f83cc13
[pyrisk.git] / pyrisk / graphics.py
1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """Various PyRisk class -> graphic renderers.
18
19 Creates SVG files by hand.  See the TemplateLibrary class for a
20 description of map specification format.
21 """
22
23 import operator
24 import os
25 import os.path
26
27 from .base import NameMixin, ID_CmpMixin
28
29
30 class Template (NameMixin):
31     """Setup regions for a particular world.
32     """
33     def __init__(self, name, regions, continent_colors={}):
34         NameMixin.__init__(self, name)
35         self.regions = regions
36         self.continent_colors = continent_colors
37
38 class TemplateLibrary (object):
39     """Create Templates on demand from a directory of template data.
40
41     TODO: explain template data format.
42     """
43     def __init__(self, template_dir='share/templates/'):
44         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
45     def get(self, name):
46         region_pointlists,route_pointlists,continent_colors = \
47             self._get_data(name)
48         regions = self._generate_regions(region_pointlists, route_pointlists)
49         return Template('template', regions, continent_colors)
50     def _get_data(self, name):
51         dirname = os.path.join(self.template_dir, name.lower())
52         try:
53             files = os.listdir(dirname)
54         except IOError:
55             return None
56         region_pointlists = {}
57         route_pointlists = {}
58         for filename in files:
59             path = os.path.join(dirname, filename)
60             name,extension = filename.rsplit('.', 1)
61             if extension == 'reg':
62                 region_pointlists[name] = self._read_pointlist(path)
63             elif extension in ['rt', 'vrt']:
64                 route_pointlists[name] = (self._read_pointlist(path),
65                                           extension == 'vrt')
66             elif extension == 'ccl':
67                 continent_colors = self._read_colors(path)
68         return (region_pointlists, route_pointlists, continent_colors)
69     def _read_pointlist(self, filename):
70         pointlist = []
71         for line in open(filename, 'r'):
72             line = line.strip()
73             if len(line) == 0:
74                 pointlist.append(None)
75                 continue
76             fields = line.split('\t')
77             x = int(fields[0])
78             y = -int(fields[1])
79             if len(fields) == 3:
80                 label = fields[2].strip()
81             else:
82                 label = None
83             pointlist.append((x,y,label))
84         return pointlist
85     def _read_colors(self, filename):
86         colors = {}
87         for line in open(filename, 'r'):
88             line = line.strip()
89             if len(line) == 0:
90                 continue
91             fields = line.split('\t')
92             name,color = fields
93             colors[name] = color
94         return colors
95     def _generate_regions(self, region_pointlists, route_pointlists):
96         regions = []
97         all_boundaries = []
98         for name,pointlist in region_pointlists.items():
99             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
100                 all_boundaries, pointlist)
101             regions.append(Region(name, boundaries, head_to_tail))
102             r = regions[-1]
103         for name,v_pointlist in route_pointlists.items():
104             pointlist,virtual = v_pointlist
105             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
106                 all_boundaries, pointlist)
107             assert len(boundaries) == 1, boundaries
108             route = boundaries[0]
109             route.virtual = virtual
110             for terminal in [route[0], route[-1]]:
111                 for r in regions:
112                     for point in r.outline:
113                         if hasattr(point, 'name') \
114                                 and point.name == terminal.name:
115                             r.routes.append(route)
116                             r.route_head_to_tail.append(
117                                 terminal == route[0])
118                             route.regions.append(r)
119         for r in regions:
120             r.locate_routes()
121         match_counts = [b.match_count for b in all_boundaries]
122         assert min(match_counts) in [0, 1], set(match_counts)
123         assert max(match_counts) == 1, set(match_counts)
124         return regions
125     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
126         boundaries = []
127         head_to_tail = []
128         b_points = []
129         for i,point in enumerate(pointlist):
130             if point == None:
131                 boundary,reverse = self._analyze(b_points)
132                 boundary = self._insert_boundary(all_boundaries, boundary)
133                 boundaries.append(boundary)
134                 head_to_tail.append(not reverse)
135                 b_points = []
136                 continue
137             b_points.append(point)
138         if len(b_points) > 0:
139             boundary,reverse = self._analyze(b_points)
140             boundary = self._insert_boundary(all_boundaries, boundary)
141             boundaries.append(boundary)
142             head_to_tail.append(not reverse)
143         return boundaries, head_to_tail
144     def _analyze(self, boundary_points):
145         start = self._vbp(boundary_points[0])
146         stop = self._vbp(boundary_points[-1])
147         if stop < start:
148             reverse = True
149         points = [self._vbp(b) for b in boundary_points]
150         reverse = start > stop
151         if reverse == True:
152             points.reverse()
153             start,stop = (stop, start)
154         boundary = Boundary([p-start for p in points])
155         for bp,p in zip(boundary, points):
156             bp.name = p.name # preserve point names
157         boundary.name = '(%s) -> (%s)' % (start.name, stop.name)
158         boundary.real_pos = start
159         return (boundary, reverse)
160     def _vbp(self, boundary_point):
161         v = Vector((boundary_point[0], boundary_point[1]))
162         v.name = boundary_point[2]
163         return v
164     def _insert_boundary(self, all_boundaries, new):
165         if new in all_boundaries:
166             return new
167         for b in all_boundaries:
168             if len(b) == len(new) and b.real_pos == new.real_pos:
169                 match = True
170                 for bp,np in zip(b, new):
171                     if bp != np:
172                         match = False
173                         break
174                 if match == True:
175                     b.match_count += 1
176                     return b
177         all_boundaries.append(new)
178         new.match_count = 0
179         return new
180
181 TEMPLATE_LIBRARY = TemplateLibrary()
182
183
184 class Vector (tuple):
185     """Simple vector addition and subtraction.
186
187     >>> v = Vector
188     >>> a = v((0, 0))
189     >>> b = v((1, 1))
190     >>> c = v((2, 3))
191     >>> a+b
192     (1, 1)
193     >>> a+b+c
194     (3, 4)
195     >>> b-c
196     (-1, -2)
197     >>> -c
198     (-2, -3)
199     >>> a < b
200     True
201     >>> c > b
202     True
203     """
204     def _set_name(self, new, other=None):
205         if hasattr(self, 'name'):
206             if self.name == None:
207                 if hasattr(other, 'name'):
208                     new.name = other.name
209                     return
210             new.name = self.name
211         elif hasattr(other, 'name'):
212             new.name = other.name
213     def __neg__(self):
214         new = self.__class__(map(operator.neg, self))
215         self._set_name(new)
216         return new
217     def __add__(self, other):
218         if len(self) != len(other):
219             raise ValueError('length missmatch %s, %s' % (self, other))
220         new = self.__class__(map(operator.add, self, other))
221         self._set_name(new, other)
222         return new
223     def __sub__(self, other):
224         if len(self) != len(other):
225             raise ValueError('length missmatch %s, %s' % (self, other))
226         new = self.__class__(map(operator.sub, self, other))
227         self._set_name(new, other)
228         return new
229     def __mul__(self, other):
230         if len(self) != len(other):
231             raise ValueError('length missmatch %s, %s' % (self, other))
232         new = self.__class__(map(operator.mul, self, other))
233         self._set_name(new, other)
234         return new
235
236 def nameless(vector):
237     """Return a nameless version of a given Vector.
238
239     Useful for ensuring the result of a sum / etc. has the name of the
240     *other* vector.
241     """
242     return Vector(vector)
243
244 class Boundary (ID_CmpMixin, list):
245     """Contains a list of points along the boundary.
246
247     All positions are relative to the location of the first point,
248     which should therefore always be (0,0).
249     """
250     def __init__(self, points):
251         list.__init__(self)
252         ID_CmpMixin.__init__(self)
253         for p in points:
254             self.append(Vector(p))
255         self.regions = []
256         assert self[0] == (0,0), self
257         self.x_min = min([p[0] for p in self])
258         self.x_max = max([p[0] for p in self])
259         self.y_min = min([p[1] for p in self])
260         self.y_max = max([p[1] for p in self])
261
262 class Route (ID_CmpMixin):
263     """Connect non-adjacent Regions.
264     """
265     def __init__(self, boundary):
266         ID_CmpMixin.__init__(self)
267         self.boundary = boundary
268
269 class Region (NameMixin, ID_CmpMixin, list):
270     """Contains a list of boundaries and a label.
271
272     Regions can be Territories, sections of ocean, etc.
273
274     >>> r = Region('Earth',
275     ...            [Boundary([(0,0), (0,1)]),
276     ...             Boundary([(0,0), (1,0)]),
277     ...             Boundary([(0,0), (0,1)]),
278     ...             Boundary([(0,0), (1,0)])],
279     ...            [True, True, False, False],
280     ...            (0.5, 0.5))
281     >>> r.outline
282     [(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]
283     """
284     def __init__(self, name, boundaries, head_to_tail, routes=None,
285                  route_head_to_tail=None, label_offset=(0,0)):
286         NameMixin.__init__(self, name)
287         list.__init__(self, boundaries)
288         ID_CmpMixin.__init__(self)
289         for boundary in self:
290             boundary.regions.append(self)
291         self.head_to_tail = head_to_tail
292         self.routes = routes
293         self.route_head_to_tail = route_head_to_tail
294         if routes == None:
295             assert route_head_to_tail == None
296             self.routes = []
297             self.route_head_to_tail = []
298         self.route_starts = [] # set by .locate_routes
299         self.label_offset = Vector(label_offset)
300         self.generate_outline() # sets .outline, .starts
301         self.x_min = min([b.x_min+s[0] for b,s in zip(self, self.starts)])
302         self.x_max = max([b.x_max+s[0] for b,s in zip(self, self.starts)])
303         self.y_min = min([b.y_min+s[1] for b,s in zip(self, self.starts)])
304         self.y_max = max([b.y_max+s[1] for b,s in zip(self, self.starts)])
305     def generate_outline(self):
306         """Return a list of boundary points surrounding the region.
307
308         The main issue here is determining the proper border
309         orientation for a CCW outline, which we do via a user-supplied
310         list head_to_tail.
311         """
312         self.starts = []
313         points = [Vector((0,0))]
314         for boundary,htt in zip(self, self.head_to_tail):
315             pos = points[-1]
316             if htt == True:
317                 assert boundary[0] == (0,0), boundary
318                 new = boundary[1:]
319             else:
320                 pos -= nameless(boundary[-1])
321                 new = reversed(boundary[:-1])
322             pos = nameless(pos)
323             self.starts.append(pos)
324             for p in new:
325                 points.append(pos+p)
326         assert points[-1] == points[0], '%s: %s' % (self, points)
327         self.outline = points
328     def locate_routes(self):
329         self.route_starts = []
330         for route,htt in zip(self.routes, self.route_head_to_tail):
331             if htt:
332                 anchor = route[0]
333             else:
334                 anchor = route[-1]
335             for point in self.outline:
336                 if hasattr(point, 'name') and point.name == anchor.name:
337                     self.route_starts.append(point-nameless(anchor))
338                     break
339
340 class WorldRenderer (object):
341     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
342         self.template_lib = template_lib
343         if self.template_lib == None:
344             self.template_lib = TEMPLATE_LIBRARY
345         self.buf = buf
346         self.line_width = line_width
347         self.line_color = 'grey'
348         self.route_color = 'black'
349         self.dpcm = dpcm
350     def render(self, world):
351         template = self.template_lib.get(world.name)
352         if template == None:
353             template = self._auto_template(world)
354         return self.render_template(world, template)
355     def render_template(self, world, template):
356         region_pos,width,height = self._locate(template)
357         lines = [
358             '<?xml version="1.0" standalone="no"?>',
359             '<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"',
360             '  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">',
361             '<svg width="%.1fcm" height="%.1fcm" viewBox="0 0 %d %d"'
362             % (float(width)/self.dpcm, float(height)/self.dpcm,
363                width, height),
364             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
365             '<desc>PyRisk world: %s</desc>' % template,
366             ]
367         drawn_rts = {}
368         for r in template.regions:
369             t = self._matching_territory(world, r)
370             c_col = template.continent_colors[t.continent.name]
371             lines.extend([
372                     '<!-- %s -->' % r,
373                     '<polygon fill="%s" stroke="%s" stroke-width="%d"'
374                     % (c_col, self.line_color, self.line_width),
375                     '         points="%s" />'
376                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
377                                            *(1,-1) # svg y value increases down
378                                            +(0,height)) # shift back into bbox
379                                 for p in r.outline[:-1]])
380                     ])
381             for rt,rt_start in zip(r.routes, r.route_starts):
382                 if id(rt) in drawn_rts or rt.virtual == True:
383                     continue
384                 drawn_rts[id(rt)] = rt
385                 lines.extend([
386                         '<polyline stroke="%s" stroke-width="%d"'
387                         % (self.route_color, self.line_width),
388                         '         points="%s" />'
389                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
390                                            *(1,-1) # svg y value increases down
391                                            +(0,height)) # shift back into bbox
392                                     for p in rt])
393                         ])
394         lines.extend([
395                 '<circle fill="black" cx="0" cy="0" r="20" />',
396                 '<circle fill="green" cx="%d" cy="%d" r="20" />'
397                  % (width, height)
398                 ])
399         lines.extend(['</svg>', ''])
400         return '\n'.join(lines)
401     def _locate(self, template):
402         region_pos = {} # {id: absolute position, ...}
403         boundary_pos = {} # {id: absolute position, ...}
404         route_pos = {} # {id: absolute position, ...}
405         b1 = template.regions[0][0]
406         boundary_pos[id(b1)] = Vector((0,0)) # fix the first boundary point
407         stack = [r for r in b1.regions]
408         while len(stack) > 0:
409             r = stack.pop()
410             if id(r) in region_pos:
411                 continue # skip duplicate entries
412             r_start = None
413             for b,rel_b_start in zip(r, r.starts):
414                 if id(b) in boundary_pos:
415                     b_start = boundary_pos[id(b)]
416                     r_start = b_start - rel_b_start
417                     break # found an anchor
418             if r_start == None:
419                 for rt,rel_rt_start in zip(r.routes, r.route_starts):
420                     if id(rt) in route_pos:
421                         rt_start = route_pos[id(rt)]
422                         r_start = rt_start - rel_rt_start
423                         break # found an anchor
424             region_pos[id(r)] = r_start
425             for b,rel_b_start in zip(r, r.starts):
426                 if id(b) not in boundary_pos:
427                     boundary_pos[id(b)] = r_start + rel_b_start
428                     for r2 in b.regions:
429                         stack.append(r2)
430             for rt,rt_start in zip(r.routes, r.route_starts):
431                 if id(rt) not in route_pos:
432                     route_pos[id(rt)] = r_start + rt_start
433                     for r2 in rt.regions:
434                         stack.append(r2)
435         for r in template.regions:
436             if id(r) not in region_pos:
437                 raise KeyError(r.name)
438         x_min = min([r.x_min + region_pos[id(r)][0]
439                      for r in template.regions]) - self.buf
440         x_max = max([r.x_max + region_pos[id(r)][0]
441                      for r in template.regions]) + self.buf
442         y_min = min([r.y_min + region_pos[id(r)][1]
443                      for r in template.regions]) - self.buf
444         y_max = max([r.y_max + region_pos[id(r)][1]
445                      for r in template.regions]) + self.buf
446         for key,value in region_pos.items():
447             region_pos[key] = value - Vector((x_min, y_min))
448         return (region_pos, x_max-x_min, y_max-y_min)
449     def _matching_territory(self, world, region):
450         t = None
451         try:
452             t = world.territory_by_name(region.name)
453         except KeyError:
454             for rt in region.routes:
455                 if not rt.virtual:
456                     continue
457                 for r in rt.regions:
458                     try:
459                         t = world.territory_by_name(r.name)
460                     except KeyError:
461                         pass
462         assert t != None, 'No territory in %s associated with region %s' \
463             % (world, region)
464         return t
465     def _auto_template(self, world):
466         raise NotImplementedError
467
468 def test():
469     import doctest, sys
470     failures,tests = doctest.testmod(sys.modules[__name__])
471     return failures
472
473 def render_earth():
474     from .base import generate_earth
475     r = WorldRenderer()
476     print r.render(generate_earth())
477     #f = open('world.svg', 'w')
478     #f.write(r.render(generate_earth()))
479     #f.close()