32cc46e867c90b89286d0b388158b509544e8de0
[pyrisk.git] / pyrisk / graphics.py
1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """Various PyRisk class -> graphic renderers.
18
19 Creates SVG files by hand.  See the TemplateLibrary class for a
20 description of map specification format.
21 """
22
23 import math
24 import operator
25 import os
26 import os.path
27
28 from .base import NameMixin, ID_CmpMixin
29
30
31 class Template (NameMixin):
32     """Setup regions for a particular world.
33     """
34     def __init__(self, name, regions, continent_colors={},
35                  line_colors={}, player_colors=[]):
36         NameMixin.__init__(self, name)
37         self.regions = regions
38         self.continent_colors = continent_colors
39         self.line_colors = line_colors
40         self.player_colors = player_colors
41
42 class TemplateLibrary (object):
43     """Create Templates on demand from a directory of template data.
44
45     TODO: explain template data format.
46     """
47     def __init__(self, template_dir='share/templates/'):
48         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
49     def get(self, name):
50         region_pointlists,route_pointlists,continent_colors,line_colors, \
51             player_colors = \
52             self._get_data(name)
53         regions = self._generate_regions(region_pointlists, route_pointlists)
54         return Template(name, regions, continent_colors, line_colors, 
55                         player_colors)
56     def _get_data(self, name):
57         dirname = os.path.join(self.template_dir, name.lower())
58         try:
59             files = os.listdir(dirname)
60         except IOError:
61             return None
62         region_pointlists = {}
63         route_pointlists = {}
64         for filename in files:
65             path = os.path.join(dirname, filename)
66             name,extension = filename.rsplit('.', 1)
67             if extension == 'reg':
68                 region_pointlists[name] = self._read_pointlist(path)
69             elif extension in ['rt', 'vrt']:
70                 route_pointlists[name] = (self._read_pointlist(path),
71                                           extension == 'vrt')
72             elif extension == 'col':
73                 c = self._read_colors(path)
74                 if name == 'continent':
75                     continent_colors = c
76                 elif name == 'line':
77                     line_colors = c
78                 else:
79                     assert name == 'player', name
80                     player_colors = []
81                     for k,v in sorted(c.items()):
82                         player_colors.append(v)
83         return (region_pointlists, route_pointlists,
84                 continent_colors, line_colors, player_colors)
85     def _read_pointlist(self, filename):
86         pointlist = []
87         for line in open(filename, 'r'):
88             line = line.strip()
89             if len(line) == 0:
90                 pointlist.append(None)
91                 continue
92             fields = line.split('\t')
93             x = int(fields[0])
94             y = -int(fields[1])
95             if len(fields) == 3:
96                 label = fields[2].strip()
97             else:
98                 label = None
99             pointlist.append((x,y,label))
100         return pointlist
101     def _read_colors(self, filename):
102         colors = {}
103         for line in open(filename, 'r'):
104             line = line.strip()
105             if len(line) == 0:
106                 continue
107             fields = line.split('\t')
108             name,color = [x.strip() for x in fields]
109             if color == '-':
110                 color = None
111             colors[name] = color
112         return colors
113     def _generate_regions(self, region_pointlists, route_pointlists):
114         regions = []
115         all_boundaries = []
116         for name,pointlist in region_pointlists.items():
117             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
118                 all_boundaries, pointlist)
119             regions.append(Region(name, boundaries, head_to_tail))
120             r = regions[-1]
121         for name,v_pointlist in route_pointlists.items():
122             pointlist,virtual = v_pointlist
123             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
124                 all_boundaries, pointlist)
125             assert len(boundaries) == 1, boundaries
126             route = boundaries[0]
127             route.virtual = virtual
128             for terminal in [route[0], route[-1]]:
129                 for r in regions:
130                     for point in r.outline:
131                         if hasattr(point, 'name') \
132                                 and point.name == terminal.name:
133                             r.routes.append(route)
134                             r.route_head_to_tail.append(
135                                 terminal == route[0])
136                             route.regions.append(r)
137         for r in regions:
138             r.locate_routes()
139         match_counts = [b.match_count for b in all_boundaries]
140         assert min(match_counts) in [0, 1], set(match_counts)
141         assert max(match_counts) == 1, set(match_counts)
142         return regions
143     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
144         boundaries = []
145         head_to_tail = []
146         b_points = []
147         for i,point in enumerate(pointlist):
148             if point == None:
149                 boundary,reverse = self._analyze(b_points)
150                 boundary = self._insert_boundary(all_boundaries, boundary)
151                 boundaries.append(boundary)
152                 head_to_tail.append(not reverse)
153                 b_points = []
154                 continue
155             b_points.append(point)
156         if len(b_points) > 0:
157             boundary,reverse = self._analyze(b_points)
158             boundary = self._insert_boundary(all_boundaries, boundary)
159             boundaries.append(boundary)
160             head_to_tail.append(not reverse)
161         return boundaries, head_to_tail
162     def _analyze(self, boundary_points):
163         start = self._vbp(boundary_points[0])
164         stop = self._vbp(boundary_points[-1])
165         if stop < start:
166             reverse = True
167         points = [self._vbp(b) for b in boundary_points]
168         reverse = start > stop
169         if reverse == True:
170             points.reverse()
171             start,stop = (stop, start)
172         boundary = Boundary([p-start for p in points])
173         for bp,p in zip(boundary, points):
174             bp.name = p.name # preserve point names
175         boundary.name = '(%s) -> (%s)' % (start.name, stop.name)
176         boundary.real_pos = start
177         return (boundary, reverse)
178     def _vbp(self, boundary_point):
179         v = Vector((boundary_point[0], boundary_point[1]))
180         v.name = boundary_point[2]
181         return v
182     def _insert_boundary(self, all_boundaries, new):
183         if new in all_boundaries:
184             return new
185         for b in all_boundaries:
186             if len(b) == len(new) and b.real_pos == new.real_pos:
187                 match = True
188                 for bp,np in zip(b, new):
189                     if bp != np:
190                         match = False
191                         break
192                 if match == True:
193                     b.match_count += 1
194                     return b
195         all_boundaries.append(new)
196         new.match_count = 0
197         return new
198
199 TEMPLATE_LIBRARY = TemplateLibrary()
200
201
202 class Vector (tuple):
203     """Simple vector addition and subtraction.
204
205     >>> v = Vector
206     >>> a = v((0, 0))
207     >>> b = v((1, 1))
208     >>> c = v((2, 3))
209     >>> a+b
210     (1, 1)
211     >>> a+b+c
212     (3, 4)
213     >>> b-c
214     (-1, -2)
215     >>> -c
216     (-2, -3)
217     >>> a < b
218     True
219     >>> c > b
220     True
221     """
222     def _set_name(self, new, other=None):
223         if hasattr(self, 'name'):
224             if self.name == None:
225                 if hasattr(other, 'name'):
226                     new.name = other.name
227                     return
228             new.name = self.name
229         elif hasattr(other, 'name'):
230             new.name = other.name
231     def __neg__(self):
232         new = self.__class__(map(operator.neg, self))
233         self._set_name(new)
234         return new
235     def __add__(self, other):
236         if len(self) != len(other):
237             raise ValueError('length missmatch %s, %s' % (self, other))
238         new = self.__class__(map(operator.add, self, other))
239         self._set_name(new, other)
240         return new
241     def __sub__(self, other):
242         if len(self) != len(other):
243             raise ValueError('length missmatch %s, %s' % (self, other))
244         new = self.__class__(map(operator.sub, self, other))
245         self._set_name(new, other)
246         return new
247     def __mul__(self, other):
248         if len(self) != len(other):
249             raise ValueError('length missmatch %s, %s' % (self, other))
250         new = self.__class__(map(operator.mul, self, other))
251         self._set_name(new, other)
252         return new
253
254 def nameless(vector):
255     """Return a nameless version of a given Vector.
256
257     Useful for ensuring the result of a sum / etc. has the name of the
258     *other* vector.
259     """
260     return Vector(vector)
261
262 class Boundary (ID_CmpMixin, list):
263     """Contains a list of points along the boundary.
264
265     All positions are relative to the location of the first point,
266     which should therefore always be (0,0).
267     """
268     def __init__(self, points):
269         list.__init__(self)
270         ID_CmpMixin.__init__(self)
271         for p in points:
272             self.append(Vector(p))
273         self.regions = []
274         assert self[0] == (0,0), self
275         self.x_min = min([p[0] for p in self])
276         self.x_max = max([p[0] for p in self])
277         self.y_min = min([p[1] for p in self])
278         self.y_max = max([p[1] for p in self])
279
280 class Route (ID_CmpMixin):
281     """Connect non-adjacent Regions.
282     """
283     def __init__(self, boundary):
284         ID_CmpMixin.__init__(self)
285         self.boundary = boundary
286
287 class Region (NameMixin, ID_CmpMixin, list):
288     """Contains a list of boundaries and a label.
289
290     Regions can be Territories, sections of ocean, etc.
291
292     >>> r = Region('Earth',
293     ...            [Boundary([(0,0), (0,1)]),
294     ...             Boundary([(0,0), (1,0)]),
295     ...             Boundary([(0,0), (0,1)]),
296     ...             Boundary([(0,0), (1,0)])],
297     ...            [True, True, False, False],
298     ...            (0.5, 0.5))
299     >>> r.outline
300     [(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]
301     """
302     def __init__(self, name, boundaries, head_to_tail, routes=None,
303                  route_head_to_tail=None, label_offset=(0,0)):
304         NameMixin.__init__(self, name)
305         list.__init__(self, boundaries)
306         ID_CmpMixin.__init__(self)
307         for boundary in self:
308             boundary.regions.append(self)
309         self.head_to_tail = head_to_tail
310         self.routes = routes
311         self.route_head_to_tail = route_head_to_tail
312         if routes == None:
313             assert route_head_to_tail == None
314             self.routes = []
315             self.route_head_to_tail = []
316         self.route_starts = [] # set by .locate_routes
317         self.label_offset = Vector(label_offset)
318         self.generate_outline() # sets .outline, .starts
319         self.x_min = min([b.x_min+s[0] for b,s in zip(self, self.starts)])
320         self.x_max = max([b.x_max+s[0] for b,s in zip(self, self.starts)])
321         self.y_min = min([b.y_min+s[1] for b,s in zip(self, self.starts)])
322         self.y_max = max([b.y_max+s[1] for b,s in zip(self, self.starts)])
323     def generate_outline(self):
324         """Return a list of boundary points surrounding the region.
325
326         The main issue here is determining the proper border
327         orientation for a CCW outline, which we do via a user-supplied
328         list head_to_tail.
329         """
330         self.starts = []
331         points = [Vector((0,0))]
332         for boundary,htt in zip(self, self.head_to_tail):
333             pos = points[-1]
334             if htt == True:
335                 assert boundary[0] == (0,0), boundary
336                 new = boundary[1:]
337             else:
338                 pos -= nameless(boundary[-1])
339                 new = reversed(boundary[:-1])
340             pos = nameless(pos)
341             self.starts.append(pos)
342             for p in new:
343                 points.append(pos+p)
344         assert points[-1] == points[0], '%s: %s' % (self, points)
345         self.outline = points
346     def locate_routes(self):
347         self.route_starts = []
348         for route,htt in zip(self.routes, self.route_head_to_tail):
349             if htt:
350                 anchor = route[0]
351             else:
352                 anchor = route[-1]
353             for point in self.outline:
354                 if hasattr(point, 'name') and point.name == anchor.name:
355                     self.route_starts.append(point-nameless(anchor))
356                     break
357
358 class WorldRenderer (object):
359     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
360         self.template_lib = template_lib
361         if self.template_lib == None:
362             self.template_lib = TEMPLATE_LIBRARY
363         self.buf = buf
364         self.line_width = line_width
365         self.dpcm = dpcm
366         self.army_scale = 3
367     def render(self, world, players):
368         template = self.template_lib.get(world.name)
369         if template == None:
370             template = self._auto_template(world)
371         return self.render_template(world, players, template)
372     def render_template(self, world, players, template):
373         region_pos,width,height = self._locate(template)
374         lines = [
375             '<?xml version="1.0" standalone="no"?>',
376             '<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"',
377             '  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">',
378             '<svg width="%.1fcm" height="%.1fcm" viewBox="0 0 %d %d"'
379             % (float(width)/self.dpcm, float(height)/self.dpcm,
380                width, height),
381             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
382             '<title>%s</title>' % template,
383             '<desc>A PyRisk world snapshot</desc>',
384            ]
385         terr_regions = {}
386         drawn_rts = {}
387         for r in template.regions:
388             t = self._matching_territory(world, r)
389             if t.name in terr_regions:
390                 terr_regions[t.name].append(r)
391             else:
392                 terr_regions[t.name] = [r]
393             c_col = template.continent_colors[t.continent.name]
394             if template.line_colors['border'] == None:
395                 b_col_attr = ''
396             else:
397                 b_col_attr = 'stroke="%s" stroke-width="%d"' \
398                     % (template.line_colors['border'], self.line_width)
399             lines.extend([
400                     '<polygon title="%s / %s / %s"'
401                     % (t, t.player, t.armies),                    
402                     '         fill="%s" %s' % (c_col, b_col_attr),
403                     '         points="%s" />'
404                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
405                                            *(1,-1) # svg y value increases down
406                                            +(0,height)) # shift back into bbox
407                                 for p in r.outline[:-1]])
408                     ])
409             for rt,rt_start in zip(r.routes, r.route_starts):
410                 if id(rt) in drawn_rts:
411                     continue
412                 drawn_rts[id(rt)] = rt
413                 if rt.virtual == True:
414                     color = template.line_colors['virtual route']
415                 else:
416                     color = template.line_colors['route']
417                 if color == None:
418                     continue
419                 lines.extend([
420                         '<polyline stroke="%s" stroke-width="%d"'
421                         % (color, self.line_width),
422                         '         points="%s" />'
423                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
424                                            *(1,-1) # svg y value increases down
425                                            +(0,height)) # shift back into bbox
426                                     for p in rt])
427                         ])
428         for t in world.territories():
429             regions = terr_regions[t.name]
430             center = self._territory_center(region_pos, regions)
431             radius = self.army_scale*math.sqrt(t.armies)
432             color = template.player_colors[players.index(t.player)]
433             if color == None:
434                 continue
435             lines.extend([
436                     '<circle title="%s / %s / %s" fill="%s"'
437                     % (t, t.player, t.armies, color),
438                     '         cx="%d" cy="%d" r="%.1f" />'
439                     % (center[0], center[1]*-1+height, radius),
440                     ])
441         lines.extend(['</svg>', ''])
442         return '\n'.join(lines)
443     def _locate(self, template):
444         region_pos = {} # {id: absolute position, ...}
445         boundary_pos = {} # {id: absolute position, ...}
446         route_pos = {} # {id: absolute position, ...}
447         b1 = template.regions[0][0]
448         boundary_pos[id(b1)] = Vector((0,0)) # fix the first boundary point
449         stack = [r for r in b1.regions]
450         while len(stack) > 0:
451             r = stack.pop()
452             if id(r) in region_pos:
453                 continue # skip duplicate entries
454             r_start = None
455             for b,rel_b_start in zip(r, r.starts):
456                 if id(b) in boundary_pos:
457                     b_start = boundary_pos[id(b)]
458                     r_start = b_start - rel_b_start
459                     break # found an anchor
460             if r_start == None:
461                 for rt,rel_rt_start in zip(r.routes, r.route_starts):
462                     if id(rt) in route_pos:
463                         rt_start = route_pos[id(rt)]
464                         r_start = rt_start - rel_rt_start
465                         break # found an anchor
466             region_pos[id(r)] = r_start
467             for b,rel_b_start in zip(r, r.starts):
468                 if id(b) not in boundary_pos:
469                     boundary_pos[id(b)] = r_start + rel_b_start
470                     for r2 in b.regions:
471                         stack.append(r2)
472             for rt,rt_start in zip(r.routes, r.route_starts):
473                 if id(rt) not in route_pos:
474                     route_pos[id(rt)] = r_start + rt_start
475                     for r2 in rt.regions:
476                         stack.append(r2)
477         for r in template.regions:
478             if id(r) not in region_pos:
479                 raise KeyError(r.name)
480         x_min = min([r.x_min + region_pos[id(r)][0]
481                      for r in template.regions]) - self.buf
482         x_max = max([r.x_max + region_pos[id(r)][0]
483                      for r in template.regions]) + self.buf
484         y_min = min([r.y_min + region_pos[id(r)][1]
485                      for r in template.regions]) - self.buf
486         y_max = max([r.y_max + region_pos[id(r)][1]
487                      for r in template.regions]) + self.buf
488         for key,value in region_pos.items():
489             region_pos[key] = value - Vector((x_min, y_min))
490         return (region_pos, x_max-x_min, y_max-y_min)
491     def _matching_territory(self, world, region):
492         t = None
493         try:
494             t = world.territory_by_name(region.name)
495         except KeyError:
496             for rt in region.routes:
497                 if not rt.virtual:
498                     continue
499                 for r in rt.regions:
500                     try:
501                         t = world.territory_by_name(r.name)
502                     except KeyError:
503                         pass
504         assert t != None, 'No territory in %s associated with region %s' \
505             % (world, region)
506         return t
507     def _territory_center(self, region_pos, regions):
508         """Return the center of mass of a territory composed of regions.
509
510         Note: currently not CM, just averages outline points.
511         """
512         points = []
513         for r in regions:
514             for p in r.outline:
515                 points.append(p + region_pos[id(r)])
516         average = Vector((int(sum([p[0] for p in points])/len(points)),
517                           int(sum([p[1] for p in points])/len(points))))
518         return average
519     def _auto_template(self, world):
520         raise NotImplementedError
521
522 def test():
523     import doctest, sys
524     failures,tests = doctest.testmod(sys.modules[__name__])
525     return failures
526
527 def render_earth():
528     from .base import generate_earth,Player,Engine
529     players = [Player('Alice'), Player('Bob'), Player('Charlie'),
530                Player('Eve'), Player('Mallory'), Player('Zoe')]
531     world = generate_earth()
532     e = Engine(world, players)
533     e.setup()
534     r = WorldRenderer()
535     print r.render(e.world, players)
536     #f = open('world.svg', 'w')
537     #f.write(r.render(generate_earth()))
538     #f.close()