Add pysawsim.manager.mpi with an mpi4py-based manager.
authorW. Trevor King <wking@drexel.edu>
Wed, 20 Oct 2010 17:34:48 +0000 (13:34 -0400)
committerW. Trevor King <wking@drexel.edu>
Wed, 20 Oct 2010 17:34:48 +0000 (13:34 -0400)
Test with:
  mpdboot -1 -n 1 -f <(hostname)
  mpiexec -n 5 nosetests --with-doctest --doctest-tests pysawsim/manager/mpi.py
  mpdallexit

I still need to find a way to turn off the doctests when mpi4py is
installed, but the tests are not being run from an `mpiexec`ed
environment.

pysawsim/manager/mpi.py [new file with mode: 0644]

diff --git a/pysawsim/manager/mpi.py b/pysawsim/manager/mpi.py
new file mode 100644 (file)
index 0000000..d99f415
--- /dev/null
@@ -0,0 +1,233 @@
+# Copyright (C) 2010  W. Trevor King <wking@drexel.edu>
+#
+# This program is free software: you can redistribute it and/or modify
+# it under the terms of the GNU General Public License as published by
+# the Free Software Foundation, either version 3 of the License, or
+# (at your option) any later version.
+#
+# This program is distributed in the hope that it will be useful,
+# but WITHOUT ANY WARRANTY; without even the implied warranty of
+# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
+# GNU General Public License for more details.
+#
+# You should have received a copy of the GNU General Public License
+# along with this program.  If not, see <http://www.gnu.org/licenses/>.
+#
+# The author may be contacted at <wking@drexel.edu> on the Internet, or
+# write to Trevor King, Drudge's University, Physics Dept., 3141 Chestnut St.,
+# Philadelphia PA 19104, USA.
+
+"""Functions for running external commands on other hosts.
+
+mpi4py_ is a Python wrapper around MPI.
+
+.. _mpi4py: http://mpi4py.scipy.org/http://mpi4py.scipy.org/vb>
+
+The MPIManager data flow is a bit complicated, so I've made a
+diagram::
+
+    [original intracommunicator]    [intercom]   [spawned intracommunicator]
+                                        |
+                ,->--(spawn-thread)-->--+-->------------->-(worker-0)-.
+    (main-thread)                       |            `--->-(worker-1)-|
+                `-<-(receive-thread)-<--+--<-----.     `->-(worker-2)-|
+                                        |         `-------<-----------'
+
+The connections are:
+
+==============  ==============  ===============================
+Source          Target          Connection
+==============  ==============  ===============================
+main-thread     spawn-thread    spawn_queue
+spawn-thread    worker-*        SpawnThread.comm(SPAWN_TAG)
+worker-*        receive-thread  ReceiveThread.comm(RECEIVE_TAG)
+receive-thread  main-thread     receive_queue
+==============  ==============  ===============================
+
+There is also a `free_queue` running from `receive-thread` to
+`spawn-thread` to mark job completion so `spawn-thread` knows which
+nodes are free (and therefore ready to receive new jobs).
+"""
+
+from Queue import Queue, Empty
+import sys
+from threading import Thread
+
+try:
+    from mpi4py import MPI
+    if MPI.COMM_WORLD.Get_rank() == 0:
+        _SKIP = ''
+    else:
+        _SKIP = '  # doctest: +SKIP'
+except ImportError, MPI_error:
+    MPI = None
+    _SKIP = '  # doctest: +SKIP'
+
+from .. import log
+from . import Job
+from .thread import CLOSE_MESSAGE, ThreadManager
+
+
+CLOSE_MESSAGE = "close"
+SPAWN_TAG = 100
+RECEIVE_TAG = 101
+
+
+def MPI_worker_death():
+    if MPI == None:
+        return
+    if MPI.COMM_WORLD.Get_rank() != 0:
+        sys.exit(0)
+
+def _manager_check():
+    assert MPI != None, MPI_error
+    rank = MPI.COMM_WORLD.Get_rank()
+    assert rank == 0, (
+        'process %d should have been killed by an MPI_worker_death() call'
+        % rank)
+
+
+class WorkerProcess (object):
+    def __init__(self):
+        self.comm = MPI.Comm.Get_parent()  # intercommunicator
+        self.rank = self.comm.Get_rank()   # *intracom* rank?
+        self.manager = 0
+        self.name = 'worker-%d' % self.rank
+        log().debug('%s started' % self.name)
+
+    def teardown(self):
+        if self.rank == 0:
+            # only one worker needs to disconnect from the intercommunicator.
+            self.comm.Disconnect()
+
+    def run(self):
+        s = MPI.Status()
+        while True:
+            msg = self.comm.recv(source=self.manager, tag=SPAWN_TAG, status=s)
+            if msg == CLOSE_MESSAGE:
+                log().debug('%s closing' % self.name)
+                break
+            assert isinstance(msg, Job), msg
+            log().debug('%s running job %s' % (self.name, msg))
+            msg.run()
+            self.comm.send(msg, dest=self.manager, tag=RECEIVE_TAG)
+        if self.rank == 0:
+            # forward close message to receive-thread
+            self.comm.send(CLOSE_MESSAGE, dest=self.manager, tag=RECEIVE_TAG)
+
+
+class ManagerThread (Thread):
+    def __init__(self, job_queue, free_queue, comm, rank, size,
+                 *args, **kwargs):
+        super(ManagerThread, self).__init__(*args, **kwargs)
+        self.job_queue = job_queue
+        self.free_queue = free_queue
+        self.comm = comm
+        self.rank = rank
+        self.size = size
+        self.name = self.getName()  # work around Pythons < 2.6
+        log().debug('%s starting' % self.name)
+
+
+class SpawnThread (ManagerThread):
+    def teardown(self):
+        for i in range(self.size):
+            if i != 0:
+                self.comm.send(CLOSE_MESSAGE, dest=i, tag=SPAWN_TAG)
+        free = []
+        while len(free) < self.size:
+            free.append(self.free_queue.get())
+        # close recieve-thread via worker-0
+        self.comm.send(CLOSE_MESSAGE, dest=0, tag=SPAWN_TAG)
+
+    def run(self):
+        while True:
+            msg = self.job_queue.get()
+            if msg == CLOSE_MESSAGE:
+                log().debug('%s closing' % self.name)
+                self.teardown()
+                break
+            assert isinstance(msg, Job), msg
+            rank = self.free_queue.get()
+            log().debug('%s sending job %s to %d' % (self.name, msg, rank))
+            self.comm.send(msg, dest=rank, tag=SPAWN_TAG)
+
+
+class ReceiveThread (ManagerThread):
+    def run(self):
+        s = MPI.Status()
+        while True:
+            msg = self.comm.recv(
+                source=MPI.ANY_SOURCE, tag=RECEIVE_TAG, status=s)
+            if msg == CLOSE_MESSAGE:
+                log().debug('%s closing' % self.name)
+                self.comm.Disconnect()
+                break
+            rank = s.Get_source()
+            self.free_queue.put(rank)
+            log().debug('%s got job %s from %d' % (self.name, msg, rank))
+            assert isinstance(msg, Job), msg
+            self.job_queue.put(msg)
+
+
+class MPIManager (ThreadManager):
+    __doc__ = """Manage asynchronous `Job` execution via :mod:`pbs`.
+
+    >>> from math import sqrt
+    >>> m = MPIManager()%(skip)s
+    >>> group_A = []
+    >>> for i in range(10):
+    ...     group_A.append(m.async_invoke(Job(target=sqrt, args=[i])))%(skip)s
+    >>> group_B = []
+    >>> for i in range(10):
+    ...     group_B.append(m.async_invoke(Job(target=sqrt, args=[i],
+    ...                 blocks_on=[j.id for j in group_A])))%(skip)s
+    >>> jobs = m.wait(ids=[j.id for j in group_A[5:8]])%(skip)s
+    >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
+    [<Job 5>, <Job 6>, <Job 7>]
+    >>> jobs = m.wait()%(skip)s
+    >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
+    ... # doctest: +NORMALIZE_WHITESPACE
+    [<Job 0>, <Job 1>, <Job 2>, <Job 3>, <Job 4>, <Job 8>, <Job 9>, <Job 10>,
+     <Job 11>, <Job 12>, <Job 13>, <Job 14>, <Job 15>, <Job 16>, <Job 17>,
+     <Job 18>, <Job 19>]
+    >>> m.teardown()%(skip)s
+    """ % {'skip': _SKIP}
+
+    def __init__(self, worker_pool=None):
+        _manager_check()
+        super(MPIManager, self).__init__(worker_pool)
+
+    def _spawn_workers(self, worker_pool):
+        spawn_script = ';'.join([
+                'from %s import WorkerProcess' % __name__,
+                'w = WorkerProcess()',
+                'w.run()',
+                'w.teardown()',
+                ])
+        if worker_pool == None:
+            worker_pool = MPI.COMM_WORLD.Get_size()
+        comm = MPI.COMM_SELF.Spawn(
+            sys.executable, args=['-c', spawn_script], maxprocs=worker_pool)
+        rank = comm.Get_rank()
+        assert rank == 0, rank
+        # `comm` connects `COMM_WORLD` with the spawned group
+        # (intercommunicator), so `comm.Get_size() == 1` regardless of
+        # `worker_pool`.  We want to know the size of the worker pool,
+        # so we just use:
+        size = worker_pool
+        free_queue = Queue()
+        for worker_rank in range(size):
+            free_queue.put(worker_rank)
+
+        self._workers = []
+        for worker in [SpawnThread(self._spawn_queue, free_queue,
+                                   comm, rank, size,
+                                   name='spawn-thread'),
+                       ReceiveThread(self._receive_queue, free_queue,
+                                     comm, rank, size,
+                                     name='receive-thread'),
+                       ]:
+            log().debug('start %s' % worker.name)
+            worker.start()
+            self._workers.append(worker)