Distinguish between line (boder) and route color
[pyrisk.git] / pyrisk / graphics.py
index 4a758cb89c131df2ee6fc229bef76c83d6331221..c407c29547598bd92ebed27124a8c80f0f83cc13 100644 (file)
@@ -24,28 +24,30 @@ import operator
 import os
 import os.path
 
-from .base import ID_CmpMixin
+from .base import NameMixin, ID_CmpMixin
 
 
-class Template (object):
+class Template (NameMixin):
     """Setup regions for a particular world.
     """
-    def __init__(self, name, regions):
-        self.name = name
+    def __init__(self, name, regions, continent_colors={}):
+        NameMixin.__init__(self, name)
         self.regions = regions
+        self.continent_colors = continent_colors
 
 class TemplateLibrary (object):
     """Create Templates on demand from a directory of template data.
 
     TODO: explain template data format.
     """
-    def __init__(self, template_dir='/usr/share/pyrisk/templates/'):
+    def __init__(self, template_dir='share/templates/'):
         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
     def get(self, name):
-        region_pointlists,route_pointlists = self._get_pointlists(name)
-        template = self._generate_template(region_pointlists, route_pointlists)
-        return template
-    def _get_pointlists(self, name):
+        region_pointlists,route_pointlists,continent_colors = \
+            self._get_data(name)
+        regions = self._generate_regions(region_pointlists, route_pointlists)
+        return Template('template', regions, continent_colors)
+    def _get_data(self, name):
         dirname = os.path.join(self.template_dir, name.lower())
         try:
             files = os.listdir(dirname)
@@ -57,11 +59,14 @@ class TemplateLibrary (object):
             path = os.path.join(dirname, filename)
             name,extension = filename.rsplit('.', 1)
             if extension == 'reg':
-                region_pointlists[name] = self._read_file(path)
-            elif extension == 'rt':
-                route_pointlists[name] = self._read_file(path)
-        return (region_pointlists, route_pointlists)
-    def _read_file(self, filename):
+                region_pointlists[name] = self._read_pointlist(path)
+            elif extension in ['rt', 'vrt']:
+                route_pointlists[name] = (self._read_pointlist(path),
+                                          extension == 'vrt')
+            elif extension == 'ccl':
+                continent_colors = self._read_colors(path)
+        return (region_pointlists, route_pointlists, continent_colors)
+    def _read_pointlist(self, filename):
         pointlist = []
         for line in open(filename, 'r'):
             line = line.strip()
@@ -77,7 +82,17 @@ class TemplateLibrary (object):
                 label = None
             pointlist.append((x,y,label))
         return pointlist
-    def _generate_template(self, region_pointlists, route_pointlists):
+    def _read_colors(self, filename):
+        colors = {}
+        for line in open(filename, 'r'):
+            line = line.strip()
+            if len(line) == 0:
+                continue
+            fields = line.split('\t')
+            name,color = fields
+            colors[name] = color
+        return colors
+    def _generate_regions(self, region_pointlists, route_pointlists):
         regions = []
         all_boundaries = []
         for name,pointlist in region_pointlists.items():
@@ -85,11 +100,13 @@ class TemplateLibrary (object):
                 all_boundaries, pointlist)
             regions.append(Region(name, boundaries, head_to_tail))
             r = regions[-1]
-        for name,pointlist in route_pointlists.items():
+        for name,v_pointlist in route_pointlists.items():
+            pointlist,virtual = v_pointlist
             boundaries,head_to_tail = self._pointlist_to_array_of_boundaries(
                 all_boundaries, pointlist)
             assert len(boundaries) == 1, boundaries
             route = boundaries[0]
+            route.virtual = virtual
             for terminal in [route[0], route[-1]]:
                 for r in regions:
                     for point in r.outline:
@@ -104,7 +121,7 @@ class TemplateLibrary (object):
         match_counts = [b.match_count for b in all_boundaries]
         assert min(match_counts) in [0, 1], set(match_counts)
         assert max(match_counts) == 1, set(match_counts)
-        return Template('template', regions)
+        return regions
     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
         boundaries = []
         head_to_tail = []
@@ -242,7 +259,14 @@ class Boundary (ID_CmpMixin, list):
         self.y_min = min([p[1] for p in self])
         self.y_max = max([p[1] for p in self])
 
-class Region (ID_CmpMixin, list):
+class Route (ID_CmpMixin):
+    """Connect non-adjacent Regions.
+    """
+    def __init__(self, boundary):
+        ID_CmpMixin.__init__(self)
+        self.boundary = boundary
+
+class Region (NameMixin, ID_CmpMixin, list):
     """Contains a list of boundaries and a label.
 
     Regions can be Territories, sections of ocean, etc.
@@ -259,12 +283,12 @@ class Region (ID_CmpMixin, list):
     """
     def __init__(self, name, boundaries, head_to_tail, routes=None,
                  route_head_to_tail=None, label_offset=(0,0)):
+        NameMixin.__init__(self, name)
         list.__init__(self, boundaries)
         ID_CmpMixin.__init__(self)
         for boundary in self:
             boundary.regions.append(self)
         self.head_to_tail = head_to_tail
-        self.name = name
         self.routes = routes
         self.route_head_to_tail = route_head_to_tail
         if routes == None:
@@ -299,7 +323,7 @@ class Region (ID_CmpMixin, list):
             self.starts.append(pos)
             for p in new:
                 points.append(pos+p)
-        assert points[-1] == points[0], points
+        assert points[-1] == points[0], '%s: %s' % (self, points)
         self.outline = points
     def locate_routes(self):
         self.route_starts = []
@@ -313,14 +337,6 @@ class Region (ID_CmpMixin, list):
                     self.route_starts.append(point-nameless(anchor))
                     break
 
-class Route (ID_CmpMixin):
-    """Connect non-adjacent Regions.
-    """
-    def __init__(self, boundary):
-        ID_CmpMixin.__init__(self)
-        self.boundary = boundary
-
-
 class WorldRenderer (object):
     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
         self.template_lib = template_lib
@@ -328,13 +344,15 @@ class WorldRenderer (object):
             self.template_lib = TEMPLATE_LIBRARY
         self.buf = buf
         self.line_width = line_width
+        self.line_color = 'grey'
+        self.route_color = 'black'
         self.dpcm = dpcm
     def render(self, world):
         template = self.template_lib.get(world.name)
         if template == None:
             template = self._auto_template(world)
-        return self.render_template(template)
-    def render_template(self, template):
+        return self.render_template(world, template)
+    def render_template(self, world, template):
         region_pos,width,height = self._locate(template)
         lines = [
             '<?xml version="1.0" standalone="no"?>',
@@ -344,14 +362,16 @@ class WorldRenderer (object):
             % (float(width)/self.dpcm, float(height)/self.dpcm,
                width, height),
             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
-            '<desc>PyRisk world: %s</desc>' % template.name,
+            '<desc>PyRisk world: %s</desc>' % template,
             ]
         drawn_rts = {}
         for r in template.regions:
+            t = self._matching_territory(world, r)
+            c_col = template.continent_colors[t.continent.name]
             lines.extend([
-                    '<!-- %s -->' % r.name,
-                    '<polygon fill="red" stroke="blue" stroke-width="%d"'
-                    % self.line_width,
+                    '<!-- %s -->' % r,
+                    '<polygon fill="%s" stroke="%s" stroke-width="%d"'
+                    % (c_col, self.line_color, self.line_width),
                     '         points="%s" />'
                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
                                            *(1,-1) # svg y value increases down
@@ -359,12 +379,12 @@ class WorldRenderer (object):
                                 for p in r.outline[:-1]])
                     ])
             for rt,rt_start in zip(r.routes, r.route_starts):
-                if id(rt) in drawn_rts:
+                if id(rt) in drawn_rts or rt.virtual == True:
                     continue
                 drawn_rts[id(rt)] = rt
                 lines.extend([
-                        '<polyline stroke="black" stroke-width="%d"'
-                        % self.line_width,
+                        '<polyline stroke="%s" stroke-width="%d"'
+                        % (self.route_color, self.line_width),
                         '         points="%s" />'
                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
                                            *(1,-1) # svg y value increases down
@@ -426,6 +446,22 @@ class WorldRenderer (object):
         for key,value in region_pos.items():
             region_pos[key] = value - Vector((x_min, y_min))
         return (region_pos, x_max-x_min, y_max-y_min)
+    def _matching_territory(self, world, region):
+        t = None
+        try:
+            t = world.territory_by_name(region.name)
+        except KeyError:
+            for rt in region.routes:
+                if not rt.virtual:
+                    continue
+                for r in rt.regions:
+                    try:
+                        t = world.territory_by_name(r.name)
+                    except KeyError:
+                        pass
+        assert t != None, 'No territory in %s associated with region %s' \
+            % (world, region)
+        return t
     def _auto_template(self, world):
         raise NotImplementedError