Add pysawsim.manager.mpi with an mpi4py-based manager.
[sawsim.git] / pysawsim / manager / mpi.py
1 # Copyright (C) 2010  W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation, either version 3 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License
14 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 #
16 # The author may be contacted at <wking@drexel.edu> on the Internet, or
17 # write to Trevor King, Drudge's University, Physics Dept., 3141 Chestnut St.,
18 # Philadelphia PA 19104, USA.
19
20 """Functions for running external commands on other hosts.
21
22 mpi4py_ is a Python wrapper around MPI.
23
24 .. _mpi4py: http://mpi4py.scipy.org/http://mpi4py.scipy.org/vb>
25
26 The MPIManager data flow is a bit complicated, so I've made a
27 diagram::
28
29     [original intracommunicator]    [intercom]   [spawned intracommunicator]
30                                         |
31                 ,->--(spawn-thread)-->--+-->------------->-(worker-0)-.
32     (main-thread)                       |            `--->-(worker-1)-|
33                 `-<-(receive-thread)-<--+--<-----.     `->-(worker-2)-|
34                                         |         `-------<-----------'
35
36 The connections are:
37
38 ==============  ==============  ===============================
39 Source          Target          Connection
40 ==============  ==============  ===============================
41 main-thread     spawn-thread    spawn_queue
42 spawn-thread    worker-*        SpawnThread.comm(SPAWN_TAG)
43 worker-*        receive-thread  ReceiveThread.comm(RECEIVE_TAG)
44 receive-thread  main-thread     receive_queue
45 ==============  ==============  ===============================
46
47 There is also a `free_queue` running from `receive-thread` to
48 `spawn-thread` to mark job completion so `spawn-thread` knows which
49 nodes are free (and therefore ready to receive new jobs).
50 """
51
52 from Queue import Queue, Empty
53 import sys
54 from threading import Thread
55
56 try:
57     from mpi4py import MPI
58     if MPI.COMM_WORLD.Get_rank() == 0:
59         _SKIP = ''
60     else:
61         _SKIP = '  # doctest: +SKIP'
62 except ImportError, MPI_error:
63     MPI = None
64     _SKIP = '  # doctest: +SKIP'
65
66 from .. import log
67 from . import Job
68 from .thread import CLOSE_MESSAGE, ThreadManager
69
70
71 CLOSE_MESSAGE = "close"
72 SPAWN_TAG = 100
73 RECEIVE_TAG = 101
74
75
76 def MPI_worker_death():
77     if MPI == None:
78         return
79     if MPI.COMM_WORLD.Get_rank() != 0:
80         sys.exit(0)
81
82 def _manager_check():
83     assert MPI != None, MPI_error
84     rank = MPI.COMM_WORLD.Get_rank()
85     assert rank == 0, (
86         'process %d should have been killed by an MPI_worker_death() call'
87         % rank)
88
89
90 class WorkerProcess (object):
91     def __init__(self):
92         self.comm = MPI.Comm.Get_parent()  # intercommunicator
93         self.rank = self.comm.Get_rank()   # *intracom* rank?
94         self.manager = 0
95         self.name = 'worker-%d' % self.rank
96         log().debug('%s started' % self.name)
97
98     def teardown(self):
99         if self.rank == 0:
100             # only one worker needs to disconnect from the intercommunicator.
101             self.comm.Disconnect()
102
103     def run(self):
104         s = MPI.Status()
105         while True:
106             msg = self.comm.recv(source=self.manager, tag=SPAWN_TAG, status=s)
107             if msg == CLOSE_MESSAGE:
108                 log().debug('%s closing' % self.name)
109                 break
110             assert isinstance(msg, Job), msg
111             log().debug('%s running job %s' % (self.name, msg))
112             msg.run()
113             self.comm.send(msg, dest=self.manager, tag=RECEIVE_TAG)
114         if self.rank == 0:
115             # forward close message to receive-thread
116             self.comm.send(CLOSE_MESSAGE, dest=self.manager, tag=RECEIVE_TAG)
117
118
119 class ManagerThread (Thread):
120     def __init__(self, job_queue, free_queue, comm, rank, size,
121                  *args, **kwargs):
122         super(ManagerThread, self).__init__(*args, **kwargs)
123         self.job_queue = job_queue
124         self.free_queue = free_queue
125         self.comm = comm
126         self.rank = rank
127         self.size = size
128         self.name = self.getName()  # work around Pythons < 2.6
129         log().debug('%s starting' % self.name)
130
131
132 class SpawnThread (ManagerThread):
133     def teardown(self):
134         for i in range(self.size):
135             if i != 0:
136                 self.comm.send(CLOSE_MESSAGE, dest=i, tag=SPAWN_TAG)
137         free = []
138         while len(free) < self.size:
139             free.append(self.free_queue.get())
140         # close recieve-thread via worker-0
141         self.comm.send(CLOSE_MESSAGE, dest=0, tag=SPAWN_TAG)
142
143     def run(self):
144         while True:
145             msg = self.job_queue.get()
146             if msg == CLOSE_MESSAGE:
147                 log().debug('%s closing' % self.name)
148                 self.teardown()
149                 break
150             assert isinstance(msg, Job), msg
151             rank = self.free_queue.get()
152             log().debug('%s sending job %s to %d' % (self.name, msg, rank))
153             self.comm.send(msg, dest=rank, tag=SPAWN_TAG)
154
155
156 class ReceiveThread (ManagerThread):
157     def run(self):
158         s = MPI.Status()
159         while True:
160             msg = self.comm.recv(
161                 source=MPI.ANY_SOURCE, tag=RECEIVE_TAG, status=s)
162             if msg == CLOSE_MESSAGE:
163                 log().debug('%s closing' % self.name)
164                 self.comm.Disconnect()
165                 break
166             rank = s.Get_source()
167             self.free_queue.put(rank)
168             log().debug('%s got job %s from %d' % (self.name, msg, rank))
169             assert isinstance(msg, Job), msg
170             self.job_queue.put(msg)
171
172
173 class MPIManager (ThreadManager):
174     __doc__ = """Manage asynchronous `Job` execution via :mod:`pbs`.
175
176     >>> from math import sqrt
177     >>> m = MPIManager()%(skip)s
178     >>> group_A = []
179     >>> for i in range(10):
180     ...     group_A.append(m.async_invoke(Job(target=sqrt, args=[i])))%(skip)s
181     >>> group_B = []
182     >>> for i in range(10):
183     ...     group_B.append(m.async_invoke(Job(target=sqrt, args=[i],
184     ...                 blocks_on=[j.id for j in group_A])))%(skip)s
185     >>> jobs = m.wait(ids=[j.id for j in group_A[5:8]])%(skip)s
186     >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
187     [<Job 5>, <Job 6>, <Job 7>]
188     >>> jobs = m.wait()%(skip)s
189     >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
190     ... # doctest: +NORMALIZE_WHITESPACE
191     [<Job 0>, <Job 1>, <Job 2>, <Job 3>, <Job 4>, <Job 8>, <Job 9>, <Job 10>,
192      <Job 11>, <Job 12>, <Job 13>, <Job 14>, <Job 15>, <Job 16>, <Job 17>,
193      <Job 18>, <Job 19>]
194     >>> m.teardown()%(skip)s
195     """ % {'skip': _SKIP}
196
197     def __init__(self, worker_pool=None):
198         _manager_check()
199         super(MPIManager, self).__init__(worker_pool)
200
201     def _spawn_workers(self, worker_pool):
202         spawn_script = ';'.join([
203                 'from %s import WorkerProcess' % __name__,
204                 'w = WorkerProcess()',
205                 'w.run()',
206                 'w.teardown()',
207                 ])
208         if worker_pool == None:
209             worker_pool = MPI.COMM_WORLD.Get_size()
210         comm = MPI.COMM_SELF.Spawn(
211             sys.executable, args=['-c', spawn_script], maxprocs=worker_pool)
212         rank = comm.Get_rank()
213         assert rank == 0, rank
214         # `comm` connects `COMM_WORLD` with the spawned group
215         # (intercommunicator), so `comm.Get_size() == 1` regardless of
216         # `worker_pool`.  We want to know the size of the worker pool,
217         # so we just use:
218         size = worker_pool
219         free_queue = Queue()
220         for worker_rank in range(size):
221             free_queue.put(worker_rank)
222
223         self._workers = []
224         for worker in [SpawnThread(self._spawn_queue, free_queue,
225                                    comm, rank, size,
226                                    name='spawn-thread'),
227                        ReceiveThread(self._receive_queue, free_queue,
228                                      comm, rank, size,
229                                      name='receive-thread'),
230                        ]:
231             log().debug('start %s' % worker.name)
232             worker.start()
233             self._workers.append(worker)