Fixed bug in active_player calculation + other cleanups.
[pyrisk.git] / risk / base.py
index b63e6c2540e14a0b8a3bb4b034adea2b0a73bfed..e1ea6224a8b23529cb2a8b57d18f6c4eb1c68d72 100644 (file)
@@ -11,11 +11,21 @@ VERSION='0.1'
 class PlayerError (Exception):
     pass
 
-class ID_CmpMixin (object):
+class NameMixin (object):
+    """Simple mixin for pretty-printing named objects.
+    """
+    def __init__(self, name):
+        self.name = name
     def __str__(self):
         return self.name
     def __repr__(self):
         return self.__str__()
+
+
+class ID_CmpMixin (object):
+    """Simple mixin to ensure the fancier comparisons are all based on
+    __cmp__().
+    """
     def __cmp__(self, other):
         return cmp(id(self), id(other))
     def __eq__(self, other):
@@ -23,90 +33,111 @@ class ID_CmpMixin (object):
     def __ne__(self, other):
         return self.__cmp__(other) != 0
 
-class Territory (list, ID_CmpMixin):
+class Territory (NameMixin, ID_CmpMixin, list):
+    """An occupiable territory.
+
+    Contains a list of neighboring territories.
+    """
     def __init__(self, name, short_name=None, type=-1,
                  link_names=[], continent=None, player=None):
-        list.__init__(self)
+        NameMixin.__init__(self, name)
         ID_CmpMixin.__init__(self)
-        self.name = name
+        list.__init__(self)
         self.short_name = short_name
         if short_name == None:
             self.short_name = name
-        self._card_type = type
-        self._link_names = list(link_names)
-        self.continent = continent
-        self.player = player
-        self.card = None
-        self.armies = 0
+        self._card_type = type     # for Deck construction
+        self._link_names = list(link_names) # used by World._resolve_link_names
+        self.continent = continent # used by World.production
+        self.player = player       # who owns this territory
+        self.armies = 0            # number of occupying armies
     def __str__(self):
         if self.short_name == self.name:
             return self.name
         return '%s (%s)' % (self.name, self.short_name)
-    def __repr__(self):
-        return self.__str__()
     def borders(self, other):
         for t in self:
             if id(t) == id(other):
                 return True
         return False
 
-class Continent (list, ID_CmpMixin):
+class Continent (NameMixin, ID_CmpMixin, list):
+    """A group of Territories.
+
+    Stores the army-production bonus if it's owned by a single player.
+    """
     def __init__(self, name, production, territories=[]):
-        list.__init__(self, territories)
+        NameMixin.__init__(self, name)
         ID_CmpMixin.__init__(self)
-        self.name = name
+        list.__init__(self, territories)
         self.production = production
     def append(self, territory):
+        """Add a new territory (setting the territory's .continent
+        attribute).
+        """
         list.append(self, territory)
         territory.continent = self
     def territory_by_name(self, name):
+        """Find a Territory instance by name (long or short, case
+        insensitive).
+        """
         for t in self:
             if name.lower() in [t.short_name.lower(), t.name.lower()]:
-                #assert self.contains_territory(t), t
                 return t
         raise KeyError(name)
-    def contains_territory(self, territory):
-        for t in self:
-            if t == territory:
-                return True
-        return False
     def single_player(self):
+        """Is the continent owned by a single player?
+        """
         p = self[0].player
         for territory in self:
             if territory.player != p:
                 return False
         return True
 
-class World (list, ID_CmpMixin):
+class World (NameMixin, ID_CmpMixin, list):
+    """Store the world map and current world state.
+
+    Holds list of Continents.  Also controls territory-based army
+    production (via production).
+    """
     def __init__(self, name, continents=[]):
-        list.__init__(self, continents)
+        NameMixin.__init__(self, name)
         ID_CmpMixin.__init__(self)
-        self.name = name
+        list.__init__(self, continents)
         self.initial_armies = { # num_players:num_armies
             2: 40, 3:35, 4:30, 5:25, 6:20
                 }
     def territories(self):
+        """Iterate through all the territories in the world.
+        """
         for continent in self:
             for territory in continent:
                 yield territory
     def territory_by_name(self, name):
+        """Find a Territory instance by name (long or short, case
+        insensitive).
+        """
         for continent in self:
             try:
                 return continent.territory_by_name(name)
             except KeyError:
                 pass
         raise KeyError(name)
-    def contains_territory(self, territory):
-        for continent in self:
-            if continent.contains_territory(territory):
-                return True
-        return False
     def continent_by_name(self, name):
+        """Find a Continent instance by name (case insensitive).
+        """
         for continent in self:
             if continent.name.lower() == name.lower():
                 return continent
         raise KeyError(name)
     def _resolve_link_names(self):
+        """Initialize Territory links.
+
+        The Territory class doesn't actually link to neighbors after
+        initialization, but one of each linked pair has the others
+        name in _link_names.  This method goes through the territories,
+        looks up the referenced link target, and joins the pair.
+        """
         self._check_short_names()
         for territory in self.territories():
             for name in territory._link_names:
@@ -116,14 +147,19 @@ class World (list, ID_CmpMixin):
                 if not other.borders(territory):
                     other.append(territory)
     def _check_short_names(self):
+        """Ensure there are no short_name collisions.
+        """
         ts = {}
         for t in self.territories():
-            if t.short_name not in ts:
-                ts[t.short_name] = t
+            if t.short_name.lower() not in ts:
+                ts[t.short_name.lower()] = t
             else:
                 raise ValueError('%s shared by %s and %s'
-                                 % (t.short_name, ts[t.short_name], t))
+                    % (t.short_name.lower(), ts[t.short_name.lower()], t))
     def production(self, player):
+        """Calculate the number of armies a player should earn based
+        on territory occupation.
+        """
         ts = list(player.territories(self))
         production = max(3, len(ts) / 3)
         continents = set([t.continent.name for t in ts])
@@ -133,11 +169,18 @@ class World (list, ID_CmpMixin):
                 production += c.production
         return (production, {})
     def place_territory_production(self, territory_production):
+        """Place armies based on {territory_name: num_armies, ...}.
+        """
         for territory_name,production in territory_production.items():
             t = self.territory_by_name(territory_name)
             t.armies += production
 
 class Card (ID_CmpMixin):
+    """Represent a territory card (or wild)
+
+    Nothing exciting going on here, just a class for pretty-printing
+    card names.
+    """
     def __init__(self, deck, type_, territory=None):
         ID_CmpMixin.__init__(self)
         self.deck = deck
@@ -153,14 +196,23 @@ class Card (ID_CmpMixin):
         return self.__str__()
 
 class Deck (list):
-    def __init__(self, territories=[]):
+    """All the cards yet to be handed out in a given game.
+
+    Controls the type branding (via type_names) and army production
+    values for scoring sets (via production_value).
+    """
+    def __init__(self, territories=[], num_wilds=2,
+                 type_names=['Wild', 'Infantry', 'Cavalry', 'Artillery']):
         list.__init__(self, [Card(self, t._card_type, t) for t in territories])
-        random.shuffle(self)
-        self.type_names = ['Wild', 'Infantry', 'Cavalry', 'Artillery', 'Wild']
-        for i in range(2):
+        self.type_names = type_names
+        for i in range(num_wilds):
             self.append(Card(self, 0))
         self._production_sequence = [4, 6, 8, 10, 12, 15]
         self._production_index = 0
+    def shuffle(self):
+        """Shuffle the remaining cards in the deck.
+        """
+        random.shuffle(self)
     def production_value(self, index):
         """
         >>> d = Deck()
@@ -182,8 +234,7 @@ class Deck (list):
         ...                  Card(d, 1, Territory('b'))])
         Traceback (most recent call last):
           ...
-        PlayerError: You must play cards in groups of 3, not 2
-        ([<Card a Infantry>, <Card b Infantry>])
+        PlayerError: [<Card a Infantry>, <Card b Infantry>] is not a scoring set
         >>> d.production(a, [Card(d, 1, Territory('a', player=a)),
         ...                  Card(d, 1, Territory('b', player=b)),
         ...                  Card(d, 1, Territory('c'))])
@@ -198,11 +249,8 @@ class Deck (list):
         """
         if cards == None:
             return (0, {})
-        if len(cards) != 3:
-            raise PlayerError('You must play cards in groups of 3, not %d\n(%s)'
-                              % (len(cards), cards))
         h = Hand(cards)
-        if h.set() or h.run():
+        if h.scores():
             p = self.production_value(self._production_index)
             self._production_index += 1
             territory_production = {}
@@ -210,42 +258,137 @@ class Deck (list):
                 if c.territory != None and c.territory.player == player:
                     territory_production[c.territory.name] = 1
             return (p, territory_production)
-        raise PlayerError('%s is neither a set nor a run' % cards)
+        raise PlayerError('%s is not a scoring set' % h)
 
 class Hand (list):
+    """Represent a hand of cards.
+
+    This is the place to override the set of allowed scoring
+    combinations.  You should override one of
+
+    * set
+    * run
+    * scores
+
+    Adding additional scoring methods as needed (e.g. flush).
+    """
     def __init__(self, cards=[]):
         list.__init__(self, cards)
     def set(self):
+        if len(self) != 3:
+            return False
         s = sorted(set([card.type for card in self]))
         if len(s) == 1 \
                 or (len(s) == 2 and s[0] == 0):
             return True
         return False
     def run(self):
+        if len(self) != 3:
+            return False
         if len(set([card.type for card in self])) == 3:
             return True
         return False
+    def scores(self):
+        """The hand is any valid scoring combination.
+        """
+        return self.set() or self.run()
+    def subhands(self, lengths=None):
+        """Return all possible subhands.
+
+        Lengths can either be a list of allowed subhand lengths or
+        None.  If None, all possible subhand lengths are allowed.
+
+        >>> d = Deck()
+        >>> h = Hand([Card(d, 1, Territory('a')),
+        ...           Card(d, 1, Territory('b')),
+        ...           Card(d, 1, Territory('c')),
+        ...           Card(d, 1, Territory('d'))])
+        >>> for hand in h.subhands():
+        ...     print hand
+        [<Card a Infantry>]
+        [<Card b Infantry>]
+        [<Card c Infantry>]
+        [<Card d Infantry>]
+        [<Card a Infantry>, <Card b Infantry>]
+        [<Card a Infantry>, <Card c Infantry>]
+        [<Card a Infantry>, <Card d Infantry>]
+        [<Card b Infantry>, <Card c Infantry>]
+        [<Card b Infantry>, <Card d Infantry>]
+        [<Card c Infantry>, <Card d Infantry>]
+        [<Card a Infantry>, <Card b Infantry>, <Card c Infantry>]
+        [<Card a Infantry>, <Card b Infantry>, <Card d Infantry>]
+        [<Card a Infantry>, <Card c Infantry>, <Card d Infantry>]
+        [<Card b Infantry>, <Card c Infantry>, <Card d Infantry>]
+        [<Card a Infantry>, <Card b Infantry>, <Card c Infantry>, <Card d Infantry>]
+        """
+        for i in range(len(self)):
+            i += 1 # check all sub-hands of length i
+            if lengths != None and i not in lengths:
+                continue # don't check this length
+            indices = range(i)
+            stop = range(len(self)-i, len(self))
+            while indices != stop:
+                yield Hand([self[i] for i in indices])
+                indices = self._increment(indices, stop)
+            yield Hand([self[i] for i in indices])
+    def _increment(self, indices, stop):
+        """
+        >>> d = Deck()
+        >>> h = Hand([Card(d, 1, Territory('a'))])
+        >>> h._increment([0, 1, 2], [2, 3, 4])
+        [0, 1, 3]
+        >>> h._increment([0, 1, 3], [2, 3, 4])
+        [0, 1, 4]
+        >>> h._increment([0, 1, 4], [2, 3, 4])
+        [0, 2, 3]
+        """
+        moveable = [i for i,m in zip(indices, stop) if i < m]
+        assert len(moveable) > 0, 'At stop? indices: %s, stop: %s' % (indices, stop)
+        key = indices.index(moveable[-1])
+        new = indices[key] + 1
+        for i in range(key, len(indices)):
+            indices[i] = new + i-key
+        return indices
     def possible(self):
-        if len(self) >= 3:
-            for i,c1 in enumerate(self[:-2]):
-                for j,c2 in enumerate(self[i+1:-1]):
-                    for c3 in self[i+j+2:]:
-                        h = Hand([c1, c2, c3])
-                        if h.set() or h.run():
-                            yield h
-
-class Player (ID_CmpMixin):
+        """Return a list of all possible scoring subhands.
+        """
+        for h in self.subhands():
+            if h.scores():
+                yield h
+
+class Player (NameMixin, ID_CmpMixin):
+    """Represent a risk player.
+
+    This class implements a very basic AI player.  Subclasses should
+    consider overriding the "action-required" methods:
+
+    * select_territory
+    * play_cards
+    * place_armies
+    * attack_and_fortify
+    * support_attack
+
+    And the "report" methods:
+    
+    * report
+    * draw
+    """
     def __init__(self, name):
-        self.name = name
+        NameMixin.__init__(self, name)
         ID_CmpMixin.__init__(self)
         self.alive = True
         self.hand = Hand()
         self._message_index = 0
     def territories(self, world):
+        """Iterate through all territories owned by this player.
+        """
         for t in world.territories():
             if t.player == self:
                 yield t
     def border_territories(self, world):
+        """Iterate through all territories owned by this player which
+        border another player's territories.
+        """
         for t in self.territories(world):
             for neighbor in t:
                 if neighbor.player != self:
@@ -256,32 +399,50 @@ class Player (ID_CmpMixin):
 
         These events mark the end of contact and require no change in
         player status or response, so they get a special command
-        seperate from the usual phase_* family.  The phase_* commands
-        in Player subclasses can notify the player (possibly by
-        calling report internally) if they feel so inclined.
+        seperate from the usual action family.  The action commands in
+        Player subclasses can notify the player (possibly by calling
+        report internally) if they feel so inclined.
+        
+        See also
+        --------
+        draw - another notification-only method
         """
         print 'Reporting for %s:\n  %s' \
             % (self, '\n  '.join(log[self._message_index:]))
         self._message_index = len(log)
-    def phase_select_territory(self, world, log):
+    def draw(self, world, log, cards=[]):
+        """Only called if you earned a new card (or cards).
+
+        See also
+        --------
+        report - another notification-only method
+        """
+        pass
+    def select_territory(self, world, log):
         """Return the selected territory
         """
         free_territories = [t for t in world.territories() if t.player == None]
         return random.sample(free_territories, 1)[0]
-    def phase_play_cards(self, world, log, play_required=True):
+    def play_cards(self, world, log, play_required=True):
+        """Decide whether or not to turn in a set of cards.
+
+        Return a list of cards to turn in or None.  If play_required
+        is True, you *must* play.
+        """
         if play_required == True:
             return random.sample(list(self.hand.possible()), 1)[0]
-    def phase_place_armies(self, world, log, remaining=1, this_round=1):
+    def place_armies(self, world, log, remaining=1, this_round=1):
         """Both during setup and before each turn.
 
         Return {territory_name: num_armies, ...}
         """
         t = random.sample(list(self.border_territories(world)), 1)[0]
         return {t.name: this_round}
-    def phase_attack(self, world, log):
+    def attack_and_fortify(self, world, log, mode='attack'):
         """Return list of (source, target, armies) tuples.  Place None
         in the list to end this phase.
         """
+        assert mode != 'fortify', mode
         possible_attacks = []
         for t in self.border_territories(world):
             if t.armies <= 3: #1: # be more conservative, only attack with 3 dice
@@ -290,20 +451,24 @@ class Player (ID_CmpMixin):
             for tg in targets:
                 possible_attacks.append((t.name, tg.name, min(3, t.armies-1)))
         if len(possible_attacks) == 0:
-            return [None]
+            return [None, None] # stop attack phase, then stop fortification phase
         return random.sample(possible_attacks, 1) # + [None]
-    def phase_support_attack(self, world, log, source, target):
-        return source.armies-1
-    def phase_fortify(self, world, log):
-        """Return list of (source, target, armies) tuples.  Place None
-        in the list to end this phase.
+    def support_attack(self, world, log, source, target):
+        """Follow up on a conquest by moving additional armies.
         """
-        return [None]
-    def phase_draw(self, world, log, cards=[]):
-        """Only called if you earned a new card (or cards)"""
-        self.hand.extend(cards)
+        return source.armies-1
 
 class Engine (ID_CmpMixin):
+    """Drive the game.
+
+    Basic usage will be along the lines of
+
+    >>> world = generate_earth()
+    >>> players = [Player('Alice'), Player('Bob'), Player('Charlie')]
+    >>> e = Engine(world, players)
+    >>> e.run() # doctest: +ELLIPSIS
+    ...
+    """
     def __init__(self, world, players, deck_class=Deck, logger_class=Logger):
         ID_CmpMixin.__init__(self)
         self.world = world
@@ -315,41 +480,61 @@ class Engine (ID_CmpMixin):
     def __repr__(self):
         return self.__str__()
     def run(self):
+        """The main entry point.
+        """
         self.setup()
         self.play()
-        self.log('Game over.')
-        for p in self.players:
-            p.report(self.world, self.log)
+        self.game_over()
     def setup(self):
+        """Setup phase.  Pick territories, place initial armies, and
+        deal initial hands.
+        """
         for p in self.players:
             p.alive = True
         random.shuffle(self.players)
+        deck.shuffle()
         self.select_territories()
         self.place_initial_armies()
-        self.deal()
+        for p in self.players:
+            self.deal(p, 3)
     def play(self):
+        """Main gameplay phase.  Take turns until only one Player survives.
+        """
         turn = 0
         active_player = 0
         living = len(self.living_players())
         while living > 1:
             self.play_turn(self.players[active_player])
             living = len(self.living_players())
-            active_player = (active_player + 1) % living
+            active_player = (active_player + 1) % len(self.players)
+            if living > 1:
+                while self.players[active_player].alive == False:
+                    active_player = (active_player + 1) % len(self.players)
             turn += 1
+    def game_over(self):
+        """The end of the game.
+
+        Currently just a notification hook.
+        """
+        self.log('Game over.')
+        for p in self.players:
+            p.report(self.world, self.log)
     def play_turn(self, player):
+        """Work through the phases of player's turn.
+        """
         self.log("%s's turn (territory score: %s)"
                  % (player, [(p,len(list(p.territories(self.world))))
                              for p in self.players]))
         self.play_cards_and_place_armies(player)
-        captures = self.attack_phase(player)
+        captures = self.attack_and_fortify(player)
         if captures > 0 and len(self.deck) > 0 and len(self.living_players()) > 1:
-            player.phase_draw(self.world, self.log, [self.deck.pop()])
+            self.deal(player, 1)
     def select_territories(self):
         for t in self.world.territories():
             t.player = None
         for i in range(len(list(self.world.territories()))):
             p = self.players[i % len(self.players)]
-            t = p.phase_select_territory(self.world, self.log)
+            t = p.select_territory(self.world, self.log)
             if t.player != None:
                 raise PlayerError('Cannot select %s owned by %s'
                                   % (t, t.player))
@@ -371,7 +556,7 @@ class Engine (ID_CmpMixin):
                 self.player_place_armies(p, remaining, 1)
             remaining -= 1
     def player_place_armies(self, player, remaining=1, this_round=1):
-        placements = player.phase_place_armies(self.world, self.log, remaining, this_round)
+        placements = player.place_armies(self.world, self.log, remaining, this_round)
         if sum(placements.values()) != this_round:
             raise PlayerError('Placing more than %d armies' % this_round)
         for ter_name,armies in placements.items():
@@ -383,19 +568,19 @@ class Engine (ID_CmpMixin):
                 raise PlayerError('Placing a negative number of armies (%d) in %s'
                                   % (armies, t))
         self.log('%s places %s' % (player, placements))
-        for ter_name,armies in placements.items():
-            t = self.world.territory_by_name(ter_name)
+        for terr_name,armies in placements.items():
+            t = self.world.territory_by_name(terr_name)
             t.armies += armies
-    def deal(self):
-        for p in self.players:
-            cards = []
-            for i in range(3):
-                cards.append(self.deck.pop())
-            p.phase_draw(self.world, self.log, cards)
-        self.log('Initial hands dealt')
+    def deal(self, player, number):
+        cards = []
+        for i in range(number):
+            cards.append(self.deck.pop())
+        player.hand.extend(cards)
+        player.draw(self.world, self.log, cards)
+        self.log('%s dealt %d cards' % (player, number))
     def play_cards_and_place_armies(self, player, additional_armies=0):
         cards_required = len(player.hand) >= 5
-        cards = player.phase_play_cards(
+        cards = player.play_cards(
             self.world, self.log, play_required=cards_required)
         if cards_required == True and cards == None:
             raise PlayerError('You have %d >= 5 cards in your hand, you must play'
@@ -419,22 +604,32 @@ class Engine (ID_CmpMixin):
             self.log('%s was required to place %s' % (player, w_terr_prod))
         armies = w_prod + c_prod
         self.player_place_armies(player, armies, armies)
-    def attack_phase(self, player):
+    def attack_and_fortify(self, player):
         captures = 0
+        mode = 'attack'
         while True:
-            attacks = player.phase_attack(self.world, self.log)
-            for attack in attacks:
-                if attack == None:
-                    return captures
-                source_name,target_name,armies = attack
+            actions = player.attack_and_fortify(self.world, self.log, mode)
+            for action in actions:
+                if action == None:
+                    if mode == 'attack':
+                        mode = 'fortify'
+                        continue
+                    else:
+                        assert mode == 'fortify', mode
+                        return captures
+                source_name,target_name,armies = action
                 source = self.world.territory_by_name(source_name)
                 target = self.world.territory_by_name(target_name)
-                tplayer = target.player
-                capture = self.attack(source, target, armies)
-                if capture == True:
-                    captures += 1
-                    if len(list(tplayer.territories(self.world))) == 0:
-                        self.player_killed(tplayer, killer=player)
+                if mode == 'attack':
+                    tplayer = target.player
+                    capture = self.attack(source, target, armies)
+                    if capture == True:
+                        captures += 1
+                        if len(list(tplayer.territories(self.world))) == 0:
+                            self.player_killed(tplayer, killer=player)
+                else:
+                    assert mode == 'fortify', mode
+                    self.fortify(source, target, armies)
     def attack(self, source, target, armies):
         if source.player == target.player:
             raise PlayerError('%s attacking %s, but you own both.'
@@ -474,7 +669,7 @@ class Engine (ID_CmpMixin):
         source.armies -= remaining_attackers
         target.armies += remaining_attackers
         target.player = source.player
-        support = source.player.phase_support_attack(self.world, self.log, source, target)
+        support = source.player.support_attack(self.world, self.log, source, target)
         if support < 0 or support >= source.armies:
             raise PlayerError('Cannot support from %s to %s with %d armies, only %d available'
                               % (source, target, support, source.armies-1))
@@ -561,8 +756,8 @@ def generate_earth():
     return w
 
 def test():
-    import doctest
-    failures,tests = doctest.testmod()
+    import doctest, sys
+    failures,tests = doctest.testmod(sys.modules[__name__])
     return failures
 
 def random_game():