38d58ca5a577a09c794e11d03b6a4446d03d5c55
[pyrisk.git] / pyrisk / graphics.py
1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """Various PyRisk class -> graphic renderers.
18
19 Creates SVG files by hand.  See the TemplateLibrary class for a
20 description of map specification format.
21 """
22
23 import operator
24 import os
25 import os.path
26
27 from .base import NameMixin, ID_CmpMixin
28
29
30 class Template (NameMixin):
31     """Setup regions for a particular world.
32     """
33     def __init__(self, name, regions, continent_colors={}):
34         NameMixin.__init__(self, name)
35         self.regions = regions
36         self.continent_colors = continent_colors
37
38 class TemplateLibrary (object):
39     """Create Templates on demand from a directory of template data.
40
41     TODO: explain template data format.
42     """
43     def __init__(self, template_dir='share/templates/'):
44         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
45     def get(self, name):
46         region_pointlists,route_pointlists,continent_colors = \
47             self._get_data(name)
48         regions = self._generate_regions(region_pointlists, route_pointlists)
49         return Template('template', regions, continent_colors)
50     def _get_data(self, name):
51         dirname = os.path.join(self.template_dir, name.lower())
52         try:
53             files = os.listdir(dirname)
54         except IOError:
55             return None
56         region_pointlists = {}
57         route_pointlists = {}
58         for filename in files:
59             path = os.path.join(dirname, filename)
60             name,extension = filename.rsplit('.', 1)
61             if extension == 'reg':
62                 region_pointlists[name] = self._read_pointlist(path)
63             elif extension in ['rt', 'vrt']:
64                 route_pointlists[name] = (self._read_pointlist(path),
65                                           extension == 'vrt')
66             elif extension == 'ccl':
67                 continent_colors = self._read_colors(path)
68         return (region_pointlists, route_pointlists, continent_colors)
69     def _read_pointlist(self, filename):
70         pointlist = []
71         for line in open(filename, 'r'):
72             line = line.strip()
73             if len(line) == 0:
74                 pointlist.append(None)
75                 continue
76             fields = line.split('\t')
77             x = int(fields[0])
78             y = -int(fields[1])
79             if len(fields) == 3:
80                 label = fields[2].strip()
81             else:
82                 label = None
83             pointlist.append((x,y,label))
84         return pointlist
85     def _read_colors(self, filename):
86         colors = {}
87         for line in open(filename, 'r'):
88             line = line.strip()
89             if len(line) == 0:
90                 continue
91             fields = line.split('\t')
92             name,color = fields
93             colors[name] = color
94         return colors
95     def _generate_regions(self, region_pointlists, route_pointlists):
96         regions = []
97         all_boundaries = []
98         for name,pointlist in region_pointlists.items():
99             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
100                 all_boundaries, pointlist)
101             regions.append(Region(name, boundaries, head_to_tail))
102             r = regions[-1]
103         for name,v_pointlist in route_pointlists.items():
104             pointlist,virtual = v_pointlist
105             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
106                 all_boundaries, pointlist)
107             assert len(boundaries) == 1, boundaries
108             route = boundaries[0]
109             route.virtual = virtual
110             for terminal in [route[0], route[-1]]:
111                 for r in regions:
112                     for point in r.outline:
113                         if hasattr(point, 'name') \
114                                 and point.name == terminal.name:
115                             r.routes.append(route)
116                             r.route_head_to_tail.append(
117                                 terminal == route[0])
118                             route.regions.append(r)
119         for r in regions:
120             r.locate_routes()
121         match_counts = [b.match_count for b in all_boundaries]
122         assert min(match_counts) in [0, 1], set(match_counts)
123         assert max(match_counts) == 1, set(match_counts)
124         return regions
125     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
126         boundaries = []
127         head_to_tail = []
128         b_points = []
129         for i,point in enumerate(pointlist):
130             if point == None:
131                 boundary,reverse = self._analyze(b_points)
132                 boundary = self._insert_boundary(all_boundaries, boundary)
133                 boundaries.append(boundary)
134                 head_to_tail.append(not reverse)
135                 b_points = []
136                 continue
137             b_points.append(point)
138         if len(b_points) > 0:
139             boundary,reverse = self._analyze(b_points)
140             boundary = self._insert_boundary(all_boundaries, boundary)
141             boundaries.append(boundary)
142             head_to_tail.append(not reverse)
143         return boundaries, head_to_tail
144     def _analyze(self, boundary_points):
145         start = self._vbp(boundary_points[0])
146         stop = self._vbp(boundary_points[-1])
147         if stop < start:
148             reverse = True
149         points = [self._vbp(b) for b in boundary_points]
150         reverse = start > stop
151         if reverse == True:
152             points.reverse()
153             start,stop = (stop, start)
154         boundary = Boundary([p-start for p in points])
155         for bp,p in zip(boundary, points):
156             bp.name = p.name # preserve point names
157         boundary.name = '(%s) -> (%s)' % (start.name, stop.name)
158         boundary.real_pos = start
159         return (boundary, reverse)
160     def _vbp(self, boundary_point):
161         v = Vector((boundary_point[0], boundary_point[1]))
162         v.name = boundary_point[2]
163         return v
164     def _insert_boundary(self, all_boundaries, new):
165         if new in all_boundaries:
166             return new
167         for b in all_boundaries:
168             if len(b) == len(new) and b.real_pos == new.real_pos:
169                 match = True
170                 for bp,np in zip(b, new):
171                     if bp != np:
172                         match = False
173                         break
174                 if match == True:
175                     b.match_count += 1
176                     return b
177         all_boundaries.append(new)
178         new.match_count = 0
179         return new
180
181 TEMPLATE_LIBRARY = TemplateLibrary()
182
183
184 class Vector (tuple):
185     """Simple vector addition and subtraction.
186
187     >>> v = Vector
188     >>> a = v((0, 0))
189     >>> b = v((1, 1))
190     >>> c = v((2, 3))
191     >>> a+b
192     (1, 1)
193     >>> a+b+c
194     (3, 4)
195     >>> b-c
196     (-1, -2)
197     >>> -c
198     (-2, -3)
199     >>> a < b
200     True
201     >>> c > b
202     True
203     """
204     def _set_name(self, new, other=None):
205         if hasattr(self, 'name'):
206             if self.name == None:
207                 if hasattr(other, 'name'):
208                     new.name = other.name
209                     return
210             new.name = self.name
211         elif hasattr(other, 'name'):
212             new.name = other.name
213     def __neg__(self):
214         new = self.__class__(map(operator.neg, self))
215         self._set_name(new)
216         return new
217     def __add__(self, other):
218         if len(self) != len(other):
219             raise ValueError('length missmatch %s, %s' % (self, other))
220         new = self.__class__(map(operator.add, self, other))
221         self._set_name(new, other)
222         return new
223     def __sub__(self, other):
224         if len(self) != len(other):
225             raise ValueError('length missmatch %s, %s' % (self, other))
226         new = self.__class__(map(operator.sub, self, other))
227         self._set_name(new, other)
228         return new
229     def __mul__(self, other):
230         if len(self) != len(other):
231             raise ValueError('length missmatch %s, %s' % (self, other))
232         new = self.__class__(map(operator.mul, self, other))
233         self._set_name(new, other)
234         return new
235
236 def nameless(vector):
237     """Return a nameless version of a given Vector.
238
239     Useful for ensuring the result of a sum / etc. has the name of the
240     *other* vector.
241     """
242     return Vector(vector)
243
244 class Boundary (ID_CmpMixin, list):
245     """Contains a list of points along the boundary.
246
247     All positions are relative to the location of the first point,
248     which should therefore always be (0,0).
249     """
250     def __init__(self, points):
251         list.__init__(self)
252         ID_CmpMixin.__init__(self)
253         for p in points:
254             self.append(Vector(p))
255         self.regions = []
256         assert self[0] == (0,0), self
257         self.x_min = min([p[0] for p in self])
258         self.x_max = max([p[0] for p in self])
259         self.y_min = min([p[1] for p in self])
260         self.y_max = max([p[1] for p in self])
261
262 class Route (ID_CmpMixin):
263     """Connect non-adjacent Regions.
264     """
265     def __init__(self, boundary):
266         ID_CmpMixin.__init__(self)
267         self.boundary = boundary
268
269 class Region (NameMixin, ID_CmpMixin, list):
270     """Contains a list of boundaries and a label.
271
272     Regions can be Territories, sections of ocean, etc.
273
274     >>> r = Region('Earth',
275     ...            [Boundary([(0,0), (0,1)]),
276     ...             Boundary([(0,0), (1,0)]),
277     ...             Boundary([(0,0), (0,1)]),
278     ...             Boundary([(0,0), (1,0)])],
279     ...            [True, True, False, False],
280     ...            (0.5, 0.5))
281     >>> r.outline
282     [(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]
283     """
284     def __init__(self, name, boundaries, head_to_tail, routes=None,
285                  route_head_to_tail=None, label_offset=(0,0)):
286         NameMixin.__init__(self, name)
287         list.__init__(self, boundaries)
288         ID_CmpMixin.__init__(self)
289         for boundary in self:
290             boundary.regions.append(self)
291         self.head_to_tail = head_to_tail
292         self.routes = routes
293         self.route_head_to_tail = route_head_to_tail
294         if routes == None:
295             assert route_head_to_tail == None
296             self.routes = []
297             self.route_head_to_tail = []
298         self.route_starts = [] # set by .locate_routes
299         self.label_offset = Vector(label_offset)
300         self.generate_outline() # sets .outline, .starts
301         self.x_min = min([b.x_min+s[0] for b,s in zip(self, self.starts)])
302         self.x_max = max([b.x_max+s[0] for b,s in zip(self, self.starts)])
303         self.y_min = min([b.y_min+s[1] for b,s in zip(self, self.starts)])
304         self.y_max = max([b.y_max+s[1] for b,s in zip(self, self.starts)])
305     def generate_outline(self):
306         """Return a list of boundary points surrounding the region.
307
308         The main issue here is determining the proper border
309         orientation for a CCW outline, which we do via a user-supplied
310         list head_to_tail.
311         """
312         self.starts = []
313         points = [Vector((0,0))]
314         for boundary,htt in zip(self, self.head_to_tail):
315             pos = points[-1]
316             if htt == True:
317                 assert boundary[0] == (0,0), boundary
318                 new = boundary[1:]
319             else:
320                 pos -= nameless(boundary[-1])
321                 new = reversed(boundary[:-1])
322             pos = nameless(pos)
323             self.starts.append(pos)
324             for p in new:
325                 points.append(pos+p)
326         assert points[-1] == points[0], '%s: %s' % (self, points)
327         self.outline = points
328     def locate_routes(self):
329         self.route_starts = []
330         for route,htt in zip(self.routes, self.route_head_to_tail):
331             if htt:
332                 anchor = route[0]
333             else:
334                 anchor = route[-1]
335             for point in self.outline:
336                 if hasattr(point, 'name') and point.name == anchor.name:
337                     self.route_starts.append(point-nameless(anchor))
338                     break
339
340 class WorldRenderer (object):
341     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
342         self.template_lib = template_lib
343         if self.template_lib == None:
344             self.template_lib = TEMPLATE_LIBRARY
345         self.buf = buf
346         self.line_width = line_width
347         self.line_color = 'black'
348         self.dpcm = dpcm
349     def render(self, world):
350         template = self.template_lib.get(world.name)
351         if template == None:
352             template = self._auto_template(world)
353         return self.render_template(world, template)
354     def render_template(self, world, template):
355         region_pos,width,height = self._locate(template)
356         lines = [
357             '<?xml version="1.0" standalone="no"?>',
358             '<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"',
359             '  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">',
360             '<svg width="%.1fcm" height="%.1fcm" viewBox="0 0 %d %d"'
361             % (float(width)/self.dpcm, float(height)/self.dpcm,
362                width, height),
363             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
364             '<desc>PyRisk world: %s</desc>' % template,
365             ]
366         drawn_rts = {}
367         for r in template.regions:
368             t = self._matching_territory(world, r)
369             c_col = template.continent_colors[t.continent.name]
370             lines.extend([
371                     '<!-- %s -->' % r,
372                     '<polygon fill="%s" stroke="%s" stroke-width="%d"'
373                     % (c_col, self.line_color, self.line_width),
374                     '         points="%s" />'
375                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
376                                            *(1,-1) # svg y value increases down
377                                            +(0,height)) # shift back into bbox
378                                 for p in r.outline[:-1]])
379                     ])
380             for rt,rt_start in zip(r.routes, r.route_starts):
381                 if id(rt) in drawn_rts or rt.virtual == True:
382                     continue
383                 drawn_rts[id(rt)] = rt
384                 lines.extend([
385                         '<polyline stroke="%s" stroke-width="%d"'
386                         % (self.line_color, self.line_width),
387                         '         points="%s" />'
388                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
389                                            *(1,-1) # svg y value increases down
390                                            +(0,height)) # shift back into bbox
391                                     for p in rt])
392                         ])
393         lines.extend([
394                 '<circle fill="black" cx="0" cy="0" r="20" />',
395                 '<circle fill="green" cx="%d" cy="%d" r="20" />'
396                  % (width, height)
397                 ])
398         lines.extend(['</svg>', ''])
399         return '\n'.join(lines)
400     def _locate(self, template):
401         region_pos = {} # {id: absolute position, ...}
402         boundary_pos = {} # {id: absolute position, ...}
403         route_pos = {} # {id: absolute position, ...}
404         b1 = template.regions[0][0]
405         boundary_pos[id(b1)] = Vector((0,0)) # fix the first boundary point
406         stack = [r for r in b1.regions]
407         while len(stack) > 0:
408             r = stack.pop()
409             if id(r) in region_pos:
410                 continue # skip duplicate entries
411             r_start = None
412             for b,rel_b_start in zip(r, r.starts):
413                 if id(b) in boundary_pos:
414                     b_start = boundary_pos[id(b)]
415                     r_start = b_start - rel_b_start
416                     break # found an anchor
417             if r_start == None:
418                 for rt,rel_rt_start in zip(r.routes, r.route_starts):
419                     if id(rt) in route_pos:
420                         rt_start = route_pos[id(rt)]
421                         r_start = rt_start - rel_rt_start
422                         break # found an anchor
423             region_pos[id(r)] = r_start
424             for b,rel_b_start in zip(r, r.starts):
425                 if id(b) not in boundary_pos:
426                     boundary_pos[id(b)] = r_start + rel_b_start
427                     for r2 in b.regions:
428                         stack.append(r2)
429             for rt,rt_start in zip(r.routes, r.route_starts):
430                 if id(rt) not in route_pos:
431                     route_pos[id(rt)] = r_start + rt_start
432                     for r2 in rt.regions:
433                         stack.append(r2)
434         for r in template.regions:
435             if id(r) not in region_pos:
436                 raise KeyError(r.name)
437         x_min = min([r.x_min + region_pos[id(r)][0]
438                      for r in template.regions]) - self.buf
439         x_max = max([r.x_max + region_pos[id(r)][0]
440                      for r in template.regions]) + self.buf
441         y_min = min([r.y_min + region_pos[id(r)][1]
442                      for r in template.regions]) - self.buf
443         y_max = max([r.y_max + region_pos[id(r)][1]
444                      for r in template.regions]) + self.buf
445         for key,value in region_pos.items():
446             region_pos[key] = value - Vector((x_min, y_min))
447         return (region_pos, x_max-x_min, y_max-y_min)
448     def _matching_territory(self, world, region):
449         t = None
450         try:
451             t = world.territory_by_name(region.name)
452         except KeyError:
453             for rt in region.routes:
454                 if not rt.virtual:
455                     continue
456                 for r in rt.regions:
457                     try:
458                         t = world.territory_by_name(r.name)
459                     except KeyError:
460                         pass
461         assert t != None, 'No territory in %s associated with region %s' \
462             % (world, region)
463         return t
464     def _auto_template(self, world):
465         raise NotImplementedError
466
467 def test():
468     import doctest, sys
469     failures,tests = doctest.testmod(sys.modules[__name__])
470     return failures
471
472 def render_earth():
473     from .base import generate_earth
474     r = WorldRenderer()
475     print r.render(generate_earth())
476     #f = open('world.svg', 'w')
477     #f.write(r.render(generate_earth()))
478     #f.close()