Added player colors
[pyrisk.git] / pyrisk / graphics.py
index f171e0d24b8078c04bd7cc1f99479e63e8fe6e4f..32cc46e867c90b89286d0b388158b509544e8de0 100644 (file)
@@ -20,19 +20,24 @@ Creates SVG files by hand.  See the TemplateLibrary class for a
 description of map specification format.
 """
 
+import math
 import operator
 import os
 import os.path
 
-from .base import ID_CmpMixin
+from .base import NameMixin, ID_CmpMixin
 
 
-class Template (object):
+class Template (NameMixin):
     """Setup regions for a particular world.
     """
-    def __init__(self, name, regions):
-        self.name = name
+    def __init__(self, name, regions, continent_colors={},
+                 line_colors={}, player_colors=[]):
+        NameMixin.__init__(self, name)
         self.regions = regions
+        self.continent_colors = continent_colors
+        self.line_colors = line_colors
+        self.player_colors = player_colors
 
 class TemplateLibrary (object):
     """Create Templates on demand from a directory of template data.
@@ -42,10 +47,13 @@ class TemplateLibrary (object):
     def __init__(self, template_dir='share/templates/'):
         self.template_dir = os.path.abspath(os.path.expanduser(template_dir))
     def get(self, name):
-        region_pointlists,route_pointlists = self._get_pointlists(name)
-        template = self._generate_template(region_pointlists, route_pointlists)
-        return template
-    def _get_pointlists(self, name):
+        region_pointlists,route_pointlists,continent_colors,line_colors, \
+            player_colors = \
+            self._get_data(name)
+        regions = self._generate_regions(region_pointlists, route_pointlists)
+        return Template(name, regions, continent_colors, line_colors, 
+                        player_colors)
+    def _get_data(self, name):
         dirname = os.path.join(self.template_dir, name.lower())
         try:
             files = os.listdir(dirname)
@@ -57,12 +65,24 @@ class TemplateLibrary (object):
             path = os.path.join(dirname, filename)
             name,extension = filename.rsplit('.', 1)
             if extension == 'reg':
-                region_pointlists[name] = self._read_file(path)
+                region_pointlists[name] = self._read_pointlist(path)
             elif extension in ['rt', 'vrt']:
-                route_pointlists[name] = (self._read_file(path),
+                route_pointlists[name] = (self._read_pointlist(path),
                                           extension == 'vrt')
-        return (region_pointlists, route_pointlists)
-    def _read_file(self, filename):
+            elif extension == 'col':
+                c = self._read_colors(path)
+                if name == 'continent':
+                    continent_colors = c
+                elif name == 'line':
+                    line_colors = c
+                else:
+                    assert name == 'player', name
+                    player_colors = []
+                    for k,v in sorted(c.items()):
+                        player_colors.append(v)
+        return (region_pointlists, route_pointlists,
+                continent_colors, line_colors, player_colors)
+    def _read_pointlist(self, filename):
         pointlist = []
         for line in open(filename, 'r'):
             line = line.strip()
@@ -78,7 +98,19 @@ class TemplateLibrary (object):
                 label = None
             pointlist.append((x,y,label))
         return pointlist
-    def _generate_template(self, region_pointlists, route_pointlists):
+    def _read_colors(self, filename):
+        colors = {}
+        for line in open(filename, 'r'):
+            line = line.strip()
+            if len(line) == 0:
+                continue
+            fields = line.split('\t')
+            name,color = [x.strip() for x in fields]
+            if color == '-':
+                color = None
+            colors[name] = color
+        return colors
+    def _generate_regions(self, region_pointlists, route_pointlists):
         regions = []
         all_boundaries = []
         for name,pointlist in region_pointlists.items():
@@ -107,7 +139,7 @@ class TemplateLibrary (object):
         match_counts = [b.match_count for b in all_boundaries]
         assert min(match_counts) in [0, 1], set(match_counts)
         assert max(match_counts) == 1, set(match_counts)
-        return Template('template', regions)
+        return regions
     def _pointlist_to_array_of_boundaries(self, all_boundaries, pointlist):
         boundaries = []
         head_to_tail = []
@@ -245,7 +277,14 @@ class Boundary (ID_CmpMixin, list):
         self.y_min = min([p[1] for p in self])
         self.y_max = max([p[1] for p in self])
 
-class Region (ID_CmpMixin, list):
+class Route (ID_CmpMixin):
+    """Connect non-adjacent Regions.
+    """
+    def __init__(self, boundary):
+        ID_CmpMixin.__init__(self)
+        self.boundary = boundary
+
+class Region (NameMixin, ID_CmpMixin, list):
     """Contains a list of boundaries and a label.
 
     Regions can be Territories, sections of ocean, etc.
@@ -262,12 +301,12 @@ class Region (ID_CmpMixin, list):
     """
     def __init__(self, name, boundaries, head_to_tail, routes=None,
                  route_head_to_tail=None, label_offset=(0,0)):
+        NameMixin.__init__(self, name)
         list.__init__(self, boundaries)
         ID_CmpMixin.__init__(self)
         for boundary in self:
             boundary.regions.append(self)
         self.head_to_tail = head_to_tail
-        self.name = name
         self.routes = routes
         self.route_head_to_tail = route_head_to_tail
         if routes == None:
@@ -302,7 +341,7 @@ class Region (ID_CmpMixin, list):
             self.starts.append(pos)
             for p in new:
                 points.append(pos+p)
-        assert points[-1] == points[0], '%s: %s' % (self.name, points)
+        assert points[-1] == points[0], '%s: %s' % (self, points)
         self.outline = points
     def locate_routes(self):
         self.route_starts = []
@@ -316,14 +355,6 @@ class Region (ID_CmpMixin, list):
                     self.route_starts.append(point-nameless(anchor))
                     break
 
-class Route (ID_CmpMixin):
-    """Connect non-adjacent Regions.
-    """
-    def __init__(self, boundary):
-        ID_CmpMixin.__init__(self)
-        self.boundary = boundary
-
-
 class WorldRenderer (object):
     def __init__(self, template_lib=None, line_width=2, buf=10, dpcm=60):
         self.template_lib = template_lib
@@ -331,14 +362,14 @@ class WorldRenderer (object):
             self.template_lib = TEMPLATE_LIBRARY
         self.buf = buf
         self.line_width = line_width
-        self.line_color = 'black'
         self.dpcm = dpcm
-    def render(self, world):
+        self.army_scale = 3
+    def render(self, world, players):
         template = self.template_lib.get(world.name)
         if template == None:
             template = self._auto_template(world)
-        return self.render_template(world, template)
-    def render_template(self, world, template):
+        return self.render_template(world, players, template)
+    def render_template(self, world, players, template):
         region_pos,width,height = self._locate(template)
         lines = [
             '<?xml version="1.0" standalone="no"?>',
@@ -348,14 +379,27 @@ class WorldRenderer (object):
             % (float(width)/self.dpcm, float(height)/self.dpcm,
                width, height),
             '     xmlns="http://www.w3.org/2000/svg" version="1.1">',
-            '<desc>PyRisk world: %s</desc>' % template.name,
-            ]
+            '<title>%s</title>' % template,
+            '<desc>A PyRisk world snapshot</desc>',
+           ]
+        terr_regions = {}
         drawn_rts = {}
         for r in template.regions:
+            t = self._matching_territory(world, r)
+            if t.name in terr_regions:
+                terr_regions[t.name].append(r)
+            else:
+                terr_regions[t.name] = [r]
+            c_col = template.continent_colors[t.continent.name]
+            if template.line_colors['border'] == None:
+                b_col_attr = ''
+            else:
+                b_col_attr = 'stroke="%s" stroke-width="%d"' \
+                    % (template.line_colors['border'], self.line_width)
             lines.extend([
-                    '<!-- %s -->' % r.name,
-                    '<polygon fill="red" stroke="%s" stroke-width="%d"'
-                    % (self.line_color, self.line_width),
+                    '<polygon title="%s / %s / %s"'
+                    % (t, t.player, t.armies),                    
+                    '         fill="%s" %s' % (c_col, b_col_attr),
                     '         points="%s" />'
                     % ' '.join(['%d,%d' % ((region_pos[id(r)]+p)
                                            *(1,-1) # svg y value increases down
@@ -363,23 +407,37 @@ class WorldRenderer (object):
                                 for p in r.outline[:-1]])
                     ])
             for rt,rt_start in zip(r.routes, r.route_starts):
-                if id(rt) in drawn_rts or rt.virtual == True:
+                if id(rt) in drawn_rts:
                     continue
                 drawn_rts[id(rt)] = rt
+                if rt.virtual == True:
+                    color = template.line_colors['virtual route']
+                else:
+                    color = template.line_colors['route']
+                if color == None:
+                    continue
                 lines.extend([
                         '<polyline stroke="%s" stroke-width="%d"'
-                        % (self.line_color, self.line_width),
+                        % (color, self.line_width),
                         '         points="%s" />'
                         % ' '.join(['%d,%d' % ((region_pos[id(r)]+rt_start+p)
                                            *(1,-1) # svg y value increases down
                                            +(0,height)) # shift back into bbox
                                     for p in rt])
                         ])
-        lines.extend([
-                '<circle fill="black" cx="0" cy="0" r="20" />',
-                '<circle fill="green" cx="%d" cy="%d" r="20" />'
-                 % (width, height)
-                ])
+        for t in world.territories():
+            regions = terr_regions[t.name]
+            center = self._territory_center(region_pos, regions)
+            radius = self.army_scale*math.sqrt(t.armies)
+            color = template.player_colors[players.index(t.player)]
+            if color == None:
+                continue
+            lines.extend([
+                    '<circle title="%s / %s / %s" fill="%s"'
+                    % (t, t.player, t.armies, color),
+                    '         cx="%d" cy="%d" r="%.1f" />'
+                    % (center[0], center[1]*-1+height, radius),
+                    ])
         lines.extend(['</svg>', ''])
         return '\n'.join(lines)
     def _locate(self, template):
@@ -430,6 +488,34 @@ class WorldRenderer (object):
         for key,value in region_pos.items():
             region_pos[key] = value - Vector((x_min, y_min))
         return (region_pos, x_max-x_min, y_max-y_min)
+    def _matching_territory(self, world, region):
+        t = None
+        try:
+            t = world.territory_by_name(region.name)
+        except KeyError:
+            for rt in region.routes:
+                if not rt.virtual:
+                    continue
+                for r in rt.regions:
+                    try:
+                        t = world.territory_by_name(r.name)
+                    except KeyError:
+                        pass
+        assert t != None, 'No territory in %s associated with region %s' \
+            % (world, region)
+        return t
+    def _territory_center(self, region_pos, regions):
+        """Return the center of mass of a territory composed of regions.
+
+        Note: currently not CM, just averages outline points.
+        """
+        points = []
+        for r in regions:
+            for p in r.outline:
+                points.append(p + region_pos[id(r)])
+        average = Vector((int(sum([p[0] for p in points])/len(points)),
+                          int(sum([p[1] for p in points])/len(points))))
+        return average
     def _auto_template(self, world):
         raise NotImplementedError
 
@@ -439,9 +525,14 @@ def test():
     return failures
 
 def render_earth():
-    from .base import generate_earth
+    from .base import generate_earth,Player,Engine
+    players = [Player('Alice'), Player('Bob'), Player('Charlie'),
+               Player('Eve'), Player('Mallory'), Player('Zoe')]
+    world = generate_earth()
+    e = Engine(world, players)
+    e.setup()
     r = WorldRenderer()
-    print r.render(generate_earth())
+    print r.render(e.world, players)
     #f = open('world.svg', 'w')
     #f.write(r.render(generate_earth()))
     #f.close()