Add preliminary graphics library with world generation
[pyrisk.git] / pyrisk / graphics.py
1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software; you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation; either version 2 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License along
14 # with this program; if not, write to the Free Software Foundation, Inc.,
15 # 51 Franklin Street, Fifth Floor, Boston, MA 02110-1301 USA.
16
17 """Various PyRisk class -> graphic renderers.
18
19 Creates SVG files by hand.  See the TemplateLibrary class for a
20 description of map specification format.
21 """
22
23 import operator
24 import os
25 import os.path
26
27 from .base import ID_CmpMixin
28
29
30 class Template (object):
31     """Setup regions for a particular world.
32     """
33     def __init__(self, name, regions):
34         self.name = name
35         self.regions = regions
36
37 class TemplateLibrary (object):
38     """Create Templates on demand from a directory of template data.
39
40     TODO: explain template data format.
41     """
42     def __init__(self, template_dir='/usr/share/pyrisk/templates/'):
43         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
44     def get(self, name):
45         region_pointlists,route_pointlists = self._get_pointlists(name)
46         template = self._generate_template(region_pointlists, route_pointlists)
47         return template
48     def _get_pointlists(self, name):
49         dirname = os.path.join(self.template_dir, name.lower())
50         try:
51             files = os.listdir(dirname)
52         except IOError:
53             return None
54         region_pointlists = {}
55         route_pointlists = {}
56         for filename in files:
57             path = os.path.join(dirname, filename)
58             name,extension = filename.rsplit('.', 1)
59             if extension == 'reg':
60                 region_pointlists[name] = self._read_file(path)
61             elif extension == 'rt':
62                 route_pointlists[name] = self._read_file(path)
63         return (region_pointlists, route_pointlists)
64     def _read_file(self, filename):
65         pointlist = []
66         for line in open(filename, 'r'):
67             line = line.strip()
68             if len(line) == 0:
69                 pointlist.append(None)
70                 continue
71             fields = line.split('\t')
72             x = int(fields[0])
73             y = -int(fields[1])
74             if len(fields) == 3:
75                 label = fields[2].strip()
76             else:
77                 label = None
78             pointlist.append((x,y,label))
79         return pointlist
80     def _generate_template(self, region_pointlists, route_pointlists):
81         regions = []
82         all_boundaries = []
83         for name,pointlist in region_pointlists.items():
84             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
85                 all_boundaries, pointlist)
86             regions.append(Region(name, boundaries, head_to_tail))
87             r = regions[-1]
88         for name,pointlist in route_pointlists.items():
89             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
90                 all_boundaries, pointlist)
91             assert len(boundaries) == 1, boundaries
92             route = boundaries[0]
93             for terminal in [route[0], route[-1]]:
94                 for r in regions:
95                     for point in r.outline:
96                         if hasattr(point, 'name') \
97                                 and point.name == terminal.name:
98                             r.routes.append(route)
99                             r.route_head_to_tail.append(
100                                 terminal == route[0])
101                             route.regions.append(r)
102         for r in regions:
103             r.locate_routes()
104         match_counts = [b.match_count for b in all_boundaries]
105         assert min(match_counts) in [0, 1], set(match_counts)
106         assert max(match_counts) == 1, set(match_counts)
107         return Template('template', regions)
108     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
109         boundaries = []
110         head_to_tail = []
111         b_points = []
112         for i,point in enumerate(pointlist):
113             if point == None:
114                 boundary,reverse = self._analyze(b_points)
115                 boundary = self._insert_boundary(all_boundaries, boundary)
116                 boundaries.append(boundary)
117                 head_to_tail.append(not reverse)
118                 b_points = []
119                 continue
120             b_points.append(point)
121         if len(b_points) > 0:
122             boundary,reverse = self._analyze(b_points)
123             boundary = self._insert_boundary(all_boundaries, boundary)
124             boundaries.append(boundary)
125             head_to_tail.append(not reverse)
126         return boundaries, head_to_tail
127     def _analyze(self, boundary_points):
128         start = self._vbp(boundary_points[0])
129         stop = self._vbp(boundary_points[-1])
130         if stop < start:
131             reverse = True
132         points = [self._vbp(b) for b in boundary_points]
133         reverse = start > stop
134         if reverse == True:
135             points.reverse()
136             start,stop = (stop, start)
137         boundary = Boundary([p-start for p in points])
138         for bp,p in zip(boundary, points):
139             bp.name = p.name # preserve point names
140         boundary.name = '(%s) -> (%s)' % (start.name, stop.name)
141         boundary.real_pos = start
142         return (boundary, reverse)
143     def _vbp(self, boundary_point):
144         v = Vector((boundary_point[0], boundary_point[1]))
145         v.name = boundary_point[2]
146         return v
147     def _insert_boundary(self, all_boundaries, new):
148         if new in all_boundaries:
149             return new
150         for b in all_boundaries:
151             if len(b) == len(new) and b.real_pos == new.real_pos:
152                 match = True
153                 for bp,np in zip(b, new):
154                     if bp != np:
155                         match = False
156                         break
157                 if match == True:
158                     b.match_count += 1
159                     return b
160         all_boundaries.append(new)
161         new.match_count = 0
162         return new
163
164 TEMPLATE_LIBRARY = TemplateLibrary()
165
166
167 class Vector (tuple):
168     """Simple vector addition and subtraction.
169
170     >>> v = Vector
171     >>> a = v((0, 0))
172     >>> b = v((1, 1))
173     >>> c = v((2, 3))
174     >>> a+b
175     (1, 1)
176     >>> a+b+c
177     (3, 4)
178     >>> b-c
179     (-1, -2)
180     >>> -c
181     (-2, -3)
182     >>> a < b
183     True
184     >>> c > b
185     True
186     """
187     def _set_name(self, new, other=None):
188         if hasattr(self, 'name'):
189             if self.name == None:
190                 if hasattr(other, 'name'):
191                     new.name = other.name
192                     return
193             new.name = self.name
194         elif hasattr(other, 'name'):
195             new.name = other.name
196     def __neg__(self):
197         new = self.__class__(map(operator.neg, self))
198         self._set_name(new)
199         return new
200     def __add__(self, other):
201         if len(self) != len(other):
202             raise ValueError('length missmatch %s, %s' % (self, other))
203         new = self.__class__(map(operator.add, self, other))
204         self._set_name(new, other)
205         return new
206     def __sub__(self, other):
207         if len(self) != len(other):
208             raise ValueError('length missmatch %s, %s' % (self, other))
209         new = self.__class__(map(operator.sub, self, other))
210         self._set_name(new, other)
211         return new
212     def __mul__(self, other):
213         if len(self) != len(other):
214             raise ValueError('length missmatch %s, %s' % (self, other))
215         new = self.__class__(map(operator.mul, self, other))
216         self._set_name(new, other)
217         return new
218
219 def nameless(vector):
220     """Return a nameless version of a given Vector.
221
222     Useful for ensuring the result of a sum / etc. has the name of the
223     *other* vector.
224     """
225     return Vector(vector)
226
227 class Boundary (ID_CmpMixin, list):
228     """Contains a list of points along the boundary.
229
230     All positions are relative to the location of the first point,
231     which should therefore always be (0,0).
232     """
233     def __init__(self, points):
234         list.__init__(self)
235         ID_CmpMixin.__init__(self)
236         for p in points:
237             self.append(Vector(p))
238         self.regions = []
239         assert self[0] == (0,0), self
240         self.x_min = min([p[0] for p in self])
241         self.x_max = max([p[0] for p in self])
242         self.y_min = min([p[1] for p in self])
243         self.y_max = max([p[1] for p in self])
244
245 class Region (ID_CmpMixin, list):
246     """Contains a list of boundaries and a label.
247
248     Regions can be Territories, sections of ocean, etc.
249
250     >>> r = Region('Earth',
251     ...            [Boundary([(0,0), (0,1)]),
252     ...             Boundary([(0,0), (1,0)]),
253     ...             Boundary([(0,0), (0,1)]),
254     ...             Boundary([(0,0), (1,0)])],
255     ...            [True, True, False, False],
256     ...            (0.5, 0.5))
257     >>> r.outline
258     [(0, 0), (0, 1), (1, 1), (1, 0), (0, 0)]
259     """
260     def __init__(self, name, boundaries, head_to_tail, routes=None,
261                  route_head_to_tail=None, label_offset=(0,0)):
262         list.__init__(self, boundaries)
263         ID_CmpMixin.__init__(self)
264         for boundary in self:
265             boundary.regions.append(self)
266         self.head_to_tail = head_to_tail
267         self.name = name
268         self.routes = routes
269         self.route_head_to_tail = route_head_to_tail
270         if routes == None:
271             assert route_head_to_tail == None
272             self.routes = []
273             self.route_head_to_tail = []
274         self.route_starts = [] # set by .locate_routes
275         self.label_offset = Vector(label_offset)
276         self.generate_outline() # sets .outline, .starts
277         self.x_min = min([b.x_min+s[0] for b,s in zip(self, self.starts)])
278         self.x_max = max([b.x_max+s[0] for b,s in zip(self, self.starts)])
279         self.y_min = min([b.y_min+s[1] for b,s in zip(self, self.starts)])
280         self.y_max = max([b.y_max+s[1] for b,s in zip(self, self.starts)])
281     def generate_outline(self):
282         """Return a list of boundary points surrounding the region.
283
284         The main issue here is determining the proper border
285         orientation for a CCW outline, which we do via a user-supplied
286         list head_to_tail.
287         """
288         self.starts = []
289         points = [Vector((0,0))]
290         for boundary,htt in zip(self, self.head_to_tail):
291             pos = points[-1]
292             if htt == True:
293                 assert boundary[0] == (0,0), boundary
294                 new = boundary[1:]
295             else:
296                 pos -= nameless(boundary[-1])
297                 new = reversed(boundary[:-1])
298             pos = nameless(pos)
299             self.starts.append(pos)
300             for p in new:
301                 points.append(pos+p)
302         assert points[-1] == points[0], points
303         self.outline = points
304     def locate_routes(self):
305         self.route_starts = []
306         for route,htt in zip(self.routes, self.route_head_to_tail):
307             if htt:
308                 anchor = route[0]
309             else:
310                 anchor = route[-1]
311             for point in self.outline:
312                 if hasattr(point, 'name') and point.name == anchor.name:
313                     self.route_starts.append(point-nameless(anchor))
314                     break
315
316 class Route (ID_CmpMixin):
317     """Connect non-adjacent Regions.
318     """
319     def __init__(self, boundary):
320         ID_CmpMixin.__init__(self)
321         self.boundary = boundary
322
323
324 class WorldRenderer (object):
325     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
326         self.template_lib = template_lib
327         if self.template_lib == None:
328             self.template_lib = TEMPLATE_LIBRARY
329         self.buf = buf
330         self.line_width = line_width
331         self.dpcm = dpcm
332     def render(self, world):
333         template = self.template_lib.get(world.name)
334         if template == None:
335             template = self._auto_template(world)
336         return self.render_template(template)
337     def render_template(self, template):
338         region_pos,width,height = self._locate(template)
339         lines = [
340             '<?xml version="1.0" standalone="no"?>',
341             '<!DOCTYPE svg PUBLIC "-//W3C//DTD SVG 1.1//EN"',
342             '  "http://www.w3.org/Graphics/SVG/1.1/DTD/svg11.dtd">',
343             '<svg width="%.1fcm" height="%.1fcm" viewBox="0 0 %d %d"'
344             % (float(width)/self.dpcm, float(height)/self.dpcm,
345                width, height),
346             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
347             '<desc>PyRisk world: %s</desc>' % template.name,
348             ]
349         drawn_rts = {}
350         for r in template.regions:
351             lines.extend([
352                     '<!-- %s -->' % r.name,
353                     '<polygon fill="red" stroke="blue" stroke-width="%d"'
354                     % self.line_width,
355                     '         points="%s" />'
356                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
357                                            *(1,-1) # svg y value increases down
358                                            +(0,height)) # shift back into bbox
359                                 for p in r.outline[:-1]])
360                     ])
361             for rt,rt_start in zip(r.routes, r.route_starts):
362                 if id(rt) in drawn_rts:
363                     continue
364                 drawn_rts[id(rt)] = rt
365                 lines.extend([
366                         '<polyline stroke="black" stroke-width="%d"'
367                         % self.line_width,
368                         '         points="%s" />'
369                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
370                                            *(1,-1) # svg y value increases down
371                                            +(0,height)) # shift back into bbox
372                                     for p in rt])
373                         ])
374         lines.extend([
375                 '<circle fill="black" cx="0" cy="0" r="20" />',
376                 '<circle fill="green" cx="%d" cy="%d" r="20" />'
377                  % (width, height)
378                 ])
379         lines.extend(['</svg>', ''])
380         return '\n'.join(lines)
381     def _locate(self, template):
382         region_pos = {} # {id: absolute position, ...}
383         boundary_pos = {} # {id: absolute position, ...}
384         route_pos = {} # {id: absolute position, ...}
385         b1 = template.regions[0][0]
386         boundary_pos[id(b1)] = Vector((0,0)) # fix the first boundary point
387         stack = [r for r in b1.regions]
388         while len(stack) > 0:
389             r = stack.pop()
390             if id(r) in region_pos:
391                 continue # skip duplicate entries
392             r_start = None
393             for b,rel_b_start in zip(r, r.starts):
394                 if id(b) in boundary_pos:
395                     b_start = boundary_pos[id(b)]
396                     r_start = b_start - rel_b_start
397                     break # found an anchor
398             if r_start == None:
399                 for rt,rel_rt_start in zip(r.routes, r.route_starts):
400                     if id(rt) in route_pos:
401                         rt_start = route_pos[id(rt)]
402                         r_start = rt_start - rel_rt_start
403                         break # found an anchor
404             region_pos[id(r)] = r_start
405             for b,rel_b_start in zip(r, r.starts):
406                 if id(b) not in boundary_pos:
407                     boundary_pos[id(b)] = r_start + rel_b_start
408                     for r2 in b.regions:
409                         stack.append(r2)
410             for rt,rt_start in zip(r.routes, r.route_starts):
411                 if id(rt) not in route_pos:
412                     route_pos[id(rt)] = r_start + rt_start
413                     for r2 in rt.regions:
414                         stack.append(r2)
415         for r in template.regions:
416             if id(r) not in region_pos:
417                 raise KeyError(r.name)
418         x_min = min([r.x_min + region_pos[id(r)][0]
419                      for r in template.regions]) - self.buf
420         x_max = max([r.x_max + region_pos[id(r)][0]
421                      for r in template.regions]) + self.buf
422         y_min = min([r.y_min + region_pos[id(r)][1]
423                      for r in template.regions]) - self.buf
424         y_max = max([r.y_max + region_pos[id(r)][1]
425                      for r in template.regions]) + self.buf
426         for key,value in region_pos.items():
427             region_pos[key] = value - Vector((x_min, y_min))
428         return (region_pos, x_max-x_min, y_max-y_min)
429     def _auto_template(self, world):
430         raise NotImplementedError
431
432 def test():
433     import doctest, sys
434     failures,tests = doctest.testmod(sys.modules[__name__])
435     return failures
436
437 def render_earth():
438     from .base import generate_earth
439     r = WorldRenderer()
440     print r.render(generate_earth())
441     #f = open('world.svg', 'w')
442     #f.write(r.render(generate_earth()))
443     #f.close()