Added template control over line colors (and presence).
[pyrisk.git] / pyrisk / graphics.py
1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """Various PyRisk class -> graphic renderers.
18
19 Creates SVG files by hand.  See the TemplateLibrary class for a
20 description of map specification format.
21 """
22
23 import operator
24 import os
25 import os.path
26
27 from .base import NameMixin, ID_CmpMixin
28
29
30 class Template (NameMixin):
31     """Setup regions for a particular world.
32     """
33     def __init__(self, name, regions, continent_colors={},
34                  line_colors={}):
35         NameMixin.__init__(self, name)
36         self.regions = regions
37         self.continent_colors = continent_colors
38         self.line_colors = line_colors
39
40 class TemplateLibrary (object):
41     """Create Templates on demand from a directory of template data.
42
43     TODO: explain template data format.
44     """
45     def __init__(self, template_dir='share/templates/'):
46         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
47     def get(self, name):
48         region_pointlists,route_pointlists,continent_colors,line_colors = \
49             self._get_data(name)
50         regions = self._generate_regions(region_pointlists, route_pointlists)
51         return Template('template', regions, continent_colors, line_colors)
52     def _get_data(self, name):
53         dirname = os.path.join(self.template_dir, name.lower())
54         try:
55             files = os.listdir(dirname)
56         except IOError:
57             return None
58         region_pointlists = {}
59         route_pointlists = {}
60         for filename in files:
61             path = os.path.join(dirname, filename)
62             name,extension = filename.rsplit('.', 1)
63             if extension == 'reg':
64                 region_pointlists[name] = self._read_pointlist(path)
65             elif extension in ['rt', 'vrt']:
66                 route_pointlists[name] = (self._read_pointlist(path),
67                                           extension == 'vrt')
68             elif extension == 'col':
69                 c = self._read_colors(path)
70                 if name == 'continent':
71                     continent_colors = c
72                 else:
73                     assert name == 'line', name
74                     line_colors = c
75         return (region_pointlists, route_pointlists,
76                 continent_colors, line_colors)
77     def _read_pointlist(self, filename):
78         pointlist = []
79         for line in open(filename, 'r'):
80             line = line.strip()
81             if len(line) == 0:
82                 pointlist.append(None)
83                 continue
84             fields = line.split('\t')
85             x = int(fields[0])
86             y = -int(fields[1])
87             if len(fields) == 3:
88                 label = fields[2].strip()
89             else:
90                 label = None
91             pointlist.append((x,y,label))
92         return pointlist
93     def _read_colors(self, filename):
94         colors = {}
95         for line in open(filename, 'r'):
96             line = line.strip()
97             if len(line) == 0:
98                 continue
99             fields = line.split('\t')
100             name,color = [x.strip() for x in fields]
101             if color == '-':
102                 color = None
103             colors[name] = color
104         return colors
105     def _generate_regions(self, region_pointlists, route_pointlists):
106         regions = []
107         all_boundaries = []
108         for name,pointlist in region_pointlists.items():
109             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
110                 all_boundaries, pointlist)
111             regions.append(Region(name, boundaries, head_to_tail))
112             r = regions[-1]
113         for name,v_pointlist in route_pointlists.items():
114             pointlist,virtual = v_pointlist
115             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
116                 all_boundaries, pointlist)
117             assert len(boundaries) == 1, boundaries
118             route = boundaries[0]
119             route.virtual = virtual
120             for terminal in [route[0], route[-1]]:
121                 for r in regions:
122                     for point in r.outline:
123                         if hasattr(point, 'name') \
124                                 and point.name == terminal.name:
125                             r.routes.append(route)
126                             r.route_head_to_tail.append(
127                                 terminal == route[0])
128                             route.regions.append(r)
129         for r in regions:
130             r.locate_routes()
131         match_counts = [b.match_count for b in all_boundaries]
132         assert min(match_counts) in [0, 1], set(match_counts)
133         assert max(match_counts) == 1, set(match_counts)
134         return regions
135     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
136         boundaries = []
137         head_to_tail = []
138         b_points = []
139         for i,point in enumerate(pointlist):
140             if point == None:
141                 boundary,reverse = self._analyze(b_points)
142                 boundary = self._insert_boundary(all_boundaries, boundary)
143                 boundaries.append(boundary)
144                 head_to_tail.append(not reverse)
145                 b_points = []
146                 continue
147             b_points.append(point)
148         if len(b_points) > 0:
149             boundary,reverse = self._analyze(b_points)
150             boundary = self._insert_boundary(all_boundaries, boundary)
151             boundaries.append(boundary)
152             head_to_tail.append(not reverse)
153         return boundaries, head_to_tail
154     def _analyze(self, boundary_points):
155         start = self._vbp(boundary_points[0])
156         stop = self._vbp(boundary_points[-1])
157         if stop < start:
158             reverse = True
159         points = [self._vbp(b) for b in boundary_points]
160         reverse = start > stop
161         if reverse == True:
162             points.reverse()
163             start,stop = (stop, start)
164         boundary = Boundary([p-start for p in points])
165         for bp,p in zip(boundary, points):
166             bp.name = p.name # preserve point names
167         boundary.name = '(%s) -> (%s)' % (start.name, stop.name)
168         boundary.real_pos = start
169         return (boundary, reverse)
170     def _vbp(self, boundary_point):
171         v = Vector((boundary_point[0], boundary_point[1]))
172         v.name = boundary_point[2]
173         return v
174     def _insert_boundary(self, all_boundaries, new):
175         if new in all_boundaries:
176             return new
177         for b in all_boundaries:
178             if len(b) == len(new) and b.real_pos == new.real_pos:
179                 match = True
180                 for bp,np in zip(b, new):
181                     if bp != np:
182                         match = False
183                         break
184                 if match == True:
185                     b.match_count += 1
186                     return b
187         all_boundaries.append(new)
188         new.match_count = 0
189         return new
190
191 TEMPLATE_LIBRARY = TemplateLibrary()
192
193
194 class Vector (tuple):
195     """Simple vector addition and subtraction.
196
197     >>> v = Vector
198     >>> a = v((0, 0))
199     >>> b = v((1, 1))
200     >>> c = v((2, 3))
201     >>> a+b
202     (1, 1)
203     >>> a+b+c
204     (3, 4)
205     >>> b-c
206     (-1, -2)
207     >>> -c
208     (-2, -3)
209     >>> a < b
210     True
211     >>> c > b
212     True
213     """
214     def _set_name(self, new, other=None):
215         if hasattr(self, 'name'):
216             if self.name == None:
217                 if hasattr(other, 'name'):
218                     new.name = other.name
219                     return
220             new.name = self.name
221         elif hasattr(other, 'name'):
222             new.name = other.name
223     def __neg__(self):
224         new = self.__class__(map(operator.neg, self))
225         self._set_name(new)
226         return new
227     def __add__(self, other):
228         if len(self) != len(other):
229             raise ValueError('length missmatch %s, %s' % (self, other))
230         new = self.__class__(map(operator.add, self, other))
231         self._set_name(new, other)
232         return new
233     def __sub__(self, other):
234         if len(self) != len(other):
235             raise ValueError('length missmatch %s, %s' % (self, other))
236         new = self.__class__(map(operator.sub, self, other))
237         self._set_name(new, other)
238         return new
239     def __mul__(self, other):
240         if len(self) != len(other):
241             raise ValueError('length missmatch %s, %s' % (self, other))
242         new = self.__class__(map(operator.mul, self, other))
243         self._set_name(new, other)
244         return new
245
246 def nameless(vector):
247     """Return a nameless version of a given Vector.
248
249     Useful for ensuring the result of a sum / etc. has the name of the
250     *other* vector.
251     """
252     return Vector(vector)
253
254 class Boundary (ID_CmpMixin, list):
255     """Contains a list of points along the boundary.
256
257     All positions are relative to the location of the first point,
258     which should therefore always be (0,0).
259     """
260     def __init__(self, points):
261         list.__init__(self)
262         ID_CmpMixin.__init__(self)
263         for p in points:
264             self.append(Vector(p))
265         self.regions = []
266         assert self[0] == (0,0), self
267         self.x_min = min([p[0] for p in self])
268         self.x_max = max([p[0] for p in self])
269         self.y_min = min([p[1] for p in self])
270         self.y_max = max([p[1] for p in self])
271
272 class Route (ID_CmpMixin):
273     """Connect non-adjacent Regions.
274     """
275     def __init__(self, boundary):
276         ID_CmpMixin.__init__(self)
277         self.boundary = boundary
278
279 class Region (NameMixin, ID_CmpMixin, list):
280     """Contains a list of boundaries and a label.
281
282     Regions can be Territories, sections of ocean, etc.
283
284     >>> r = Region('Earth',
285     ...            [Boundary([(0,0), (0,1)]),
286     ...             Boundary([(0,0), (1,0)]),
287     ...             Boundary([(0,0), (0,1)]),
288     ...             Boundary([(0,0), (1,0)])],
289     ...            [True, True, False, False],
290     ...            (0.5, 0.5))
291     >>> r.outline
292     [(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]
293     """
294     def __init__(self, name, boundaries, head_to_tail, routes=None,
295                  route_head_to_tail=None, label_offset=(0,0)):
296         NameMixin.__init__(self, name)
297         list.__init__(self, boundaries)
298         ID_CmpMixin.__init__(self)
299         for boundary in self:
300             boundary.regions.append(self)
301         self.head_to_tail = head_to_tail
302         self.routes = routes
303         self.route_head_to_tail = route_head_to_tail
304         if routes == None:
305             assert route_head_to_tail == None
306             self.routes = []
307             self.route_head_to_tail = []
308         self.route_starts = [] # set by .locate_routes
309         self.label_offset = Vector(label_offset)
310         self.generate_outline() # sets .outline, .starts
311         self.x_min = min([b.x_min+s[0] for b,s in zip(self, self.starts)])
312         self.x_max = max([b.x_max+s[0] for b,s in zip(self, self.starts)])
313         self.y_min = min([b.y_min+s[1] for b,s in zip(self, self.starts)])
314         self.y_max = max([b.y_max+s[1] for b,s in zip(self, self.starts)])
315     def generate_outline(self):
316         """Return a list of boundary points surrounding the region.
317
318         The main issue here is determining the proper border
319         orientation for a CCW outline, which we do via a user-supplied
320         list head_to_tail.
321         """
322         self.starts = []
323         points = [Vector((0,0))]
324         for boundary,htt in zip(self, self.head_to_tail):
325             pos = points[-1]
326             if htt == True:
327                 assert boundary[0] == (0,0), boundary
328                 new = boundary[1:]
329             else:
330                 pos -= nameless(boundary[-1])
331                 new = reversed(boundary[:-1])
332             pos = nameless(pos)
333             self.starts.append(pos)
334             for p in new:
335                 points.append(pos+p)
336         assert points[-1] == points[0], '%s: %s' % (self, points)
337         self.outline = points
338     def locate_routes(self):
339         self.route_starts = []
340         for route,htt in zip(self.routes, self.route_head_to_tail):
341             if htt:
342                 anchor = route[0]
343             else:
344                 anchor = route[-1]
345             for point in self.outline:
346                 if hasattr(point, 'name') and point.name == anchor.name:
347                     self.route_starts.append(point-nameless(anchor))
348                     break
349
350 class WorldRenderer (object):
351     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
352         self.template_lib = template_lib
353         if self.template_lib == None:
354             self.template_lib = TEMPLATE_LIBRARY
355         self.buf = buf
356         self.line_width = line_width
357         self.dpcm = dpcm
358     def render(self, world):
359         template = self.template_lib.get(world.name)
360         if template == None:
361             template = self._auto_template(world)
362         return self.render_template(world, template)
363     def render_template(self, world, template):
364         region_pos,width,height = self._locate(template)
365         lines = [
366             '<?xml version="1.0" standalone="no"?>',
367             '<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"',
368             '  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">',
369             '<svg width="%.1fcm" height="%.1fcm" viewBox="0 0 %d %d"'
370             % (float(width)/self.dpcm, float(height)/self.dpcm,
371                width, height),
372             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
373             '<desc>PyRisk world: %s</desc>' % template,
374             ]
375         drawn_rts = {}
376         for r in template.regions:
377             t = self._matching_territory(world, r)
378             c_col = template.continent_colors[t.continent.name]
379             if template.line_colors['border'] == None:
380                 b_col_attr = ''
381             else:
382                 b_col_attr = 'stroke="%s" stroke-width="%d"' \
383                     % (template.line_colors['border'], self.line_width)
384             lines.extend([
385                     '<!-- %s -->' % r,
386                     '<polygon fill="%s" %s' % (c_col, b_col_attr),
387                     '         points="%s" />'
388                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
389                                            *(1,-1) # svg y value increases down
390                                            +(0,height)) # shift back into bbox
391                                 for p in r.outline[:-1]])
392                     ])
393             for rt,rt_start in zip(r.routes, r.route_starts):
394                 if id(rt) in drawn_rts:
395                     continue
396                 drawn_rts[id(rt)] = rt
397                 if rt.virtual == True:
398                     color = template.line_colors['virtual route']
399                 else:
400                     color = template.line_colors['route']
401                 if color == None:
402                     continue
403                 lines.extend([
404                         '<polyline stroke="%s" stroke-width="%d"'
405                         % (color, self.line_width),
406                         '         points="%s" />'
407                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
408                                            *(1,-1) # svg y value increases down
409                                            +(0,height)) # shift back into bbox
410                                     for p in rt])
411                         ])
412         lines.extend([
413                 '<circle fill="black" cx="0" cy="0" r="20" />',
414                 '<circle fill="green" cx="%d" cy="%d" r="20" />'
415                  % (width, height)
416                 ])
417         lines.extend(['</svg>', ''])
418         return '\n'.join(lines)
419     def _locate(self, template):
420         region_pos = {} # {id: absolute position, ...}
421         boundary_pos = {} # {id: absolute position, ...}
422         route_pos = {} # {id: absolute position, ...}
423         b1 = template.regions[0][0]
424         boundary_pos[id(b1)] = Vector((0,0)) # fix the first boundary point
425         stack = [r for r in b1.regions]
426         while len(stack) > 0:
427             r = stack.pop()
428             if id(r) in region_pos:
429                 continue # skip duplicate entries
430             r_start = None
431             for b,rel_b_start in zip(r, r.starts):
432                 if id(b) in boundary_pos:
433                     b_start = boundary_pos[id(b)]
434                     r_start = b_start - rel_b_start
435                     break # found an anchor
436             if r_start == None:
437                 for rt,rel_rt_start in zip(r.routes, r.route_starts):
438                     if id(rt) in route_pos:
439                         rt_start = route_pos[id(rt)]
440                         r_start = rt_start - rel_rt_start
441                         break # found an anchor
442             region_pos[id(r)] = r_start
443             for b,rel_b_start in zip(r, r.starts):
444                 if id(b) not in boundary_pos:
445                     boundary_pos[id(b)] = r_start + rel_b_start
446                     for r2 in b.regions:
447                         stack.append(r2)
448             for rt,rt_start in zip(r.routes, r.route_starts):
449                 if id(rt) not in route_pos:
450                     route_pos[id(rt)] = r_start + rt_start
451                     for r2 in rt.regions:
452                         stack.append(r2)
453         for r in template.regions:
454             if id(r) not in region_pos:
455                 raise KeyError(r.name)
456         x_min = min([r.x_min + region_pos[id(r)][0]
457                      for r in template.regions]) - self.buf
458         x_max = max([r.x_max + region_pos[id(r)][0]
459                      for r in template.regions]) + self.buf
460         y_min = min([r.y_min + region_pos[id(r)][1]
461                      for r in template.regions]) - self.buf
462         y_max = max([r.y_max + region_pos[id(r)][1]
463                      for r in template.regions]) + self.buf
464         for key,value in region_pos.items():
465             region_pos[key] = value - Vector((x_min, y_min))
466         return (region_pos, x_max-x_min, y_max-y_min)
467     def _matching_territory(self, world, region):
468         t = None
469         try:
470             t = world.territory_by_name(region.name)
471         except KeyError:
472             for rt in region.routes:
473                 if not rt.virtual:
474                     continue
475                 for r in rt.regions:
476                     try:
477                         t = world.territory_by_name(r.name)
478                     except KeyError:
479                         pass
480         assert t != None, 'No territory in %s associated with region %s' \
481             % (world, region)
482         return t
483     def _auto_template(self, world):
484         raise NotImplementedError
485
486 def test():
487     import doctest, sys
488     failures,tests = doctest.testmod(sys.modules[__name__])
489     return failures
490
491 def render_earth():
492     from .base import generate_earth
493     r = WorldRenderer()
494     print r.render(generate_earth())
495     #f = open('world.svg', 'w')
496     #f.write(r.render(generate_earth()))
497     #f.close()