1 # Copyright (C) 2010 W. Trevor King <wking@drexel.edu>
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation, either version 3 of the License, or
6 # (at your option) any later version.
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11 # GNU General Public License for more details.
13 # You should have received a copy of the GNU General Public License
14 # along with this program. If not, see <http://www.gnu.org/licenses/>.
16 # The author may be contacted at <wking@drexel.edu> on the Internet, or
17 # write to Trevor King, Drudge's University, Physics Dept., 3141 Chestnut St.,
18 # Philadelphia PA 19104, USA.
20 """Functions for running external commands on other hosts.
22 mpi4py_ is a Python wrapper around MPI.
24 .. _mpi4py: http://mpi4py.scipy.org/
26 The MPIManager data flow is a bit complicated, so I've made a
29 [original intracommunicator] [intercom] [spawned intracommunicator]
31 ,->--(spawn-thread)-->--+-->------------->-(worker-0)-.
32 (main-thread) | `--->-(worker-1)-|
33 `-<-(receive-thread)-<--+--<-----. `->-(worker-2)-|
34 | `-------<-----------'
38 ============== ============== ===============================
39 Source Target Connection
40 ============== ============== ===============================
41 main-thread spawn-thread spawn_queue
42 spawn-thread worker-* SpawnThread.comm(SPAWN_TAG)
43 worker-* receive-thread ReceiveThread.comm(RECEIVE_TAG)
44 receive-thread main-thread receive_queue
45 ============== ============== ===============================
47 There is also a `free_queue` running from `receive-thread` to
48 `spawn-thread` to mark job completion so `spawn-thread` knows which
49 nodes are free (and therefore ready to receive new jobs).
53 from Queue import Queue, Empty
55 from threading import Thread
58 from mpi4py import MPI
60 _DISABLING_ERROR = None
61 if MPI.COMM_WORLD.Get_rank() == 0:
64 _SKIP = ' # doctest: +SKIP'
65 except ImportError, _DISABLING_ERROR:
67 _SKIP = ' # doctest: +SKIP'
71 from .thread import CLOSE_MESSAGE, ThreadManager
74 CLOSE_MESSAGE = "close"
79 def MPI_worker_death():
82 if MPI.COMM_WORLD.Get_rank() != 0:
87 raise _DISABLING_ERROR
88 rank = MPI.COMM_WORLD.Get_rank()
90 'process %d should have been killed by an MPI_worker_death() call'
94 class WorkerProcess (object):
96 self.comm = MPI.Comm.Get_parent() # intercommunicator
97 self.rank = self.comm.Get_rank() # *intracom* rank?
99 self.name = 'worker-%d' % self.rank
100 log().debug('%s started' % self.name)
104 # only one worker needs to disconnect from the intercommunicator.
105 self.comm.Disconnect()
110 msg = self.comm.recv(source=self.manager, tag=SPAWN_TAG, status=s)
111 if msg == CLOSE_MESSAGE:
112 log().debug('%s closing' % self.name)
114 assert isinstance(msg, Job), msg
115 log().debug('%s running job %s' % (self.name, msg))
117 self.comm.send(msg, dest=self.manager, tag=RECEIVE_TAG)
119 # forward close message to receive-thread
120 self.comm.send(CLOSE_MESSAGE, dest=self.manager, tag=RECEIVE_TAG)
123 class ManagerThread (Thread):
124 def __init__(self, job_queue, free_queue, comm, rank, size,
126 super(ManagerThread, self).__init__(*args, **kwargs)
127 self.job_queue = job_queue
128 self.free_queue = free_queue
132 self.name = self.getName() # work around Pythons < 2.6
133 log().debug('%s starting' % self.name)
136 class SpawnThread (ManagerThread):
138 for i in range(self.size):
140 self.comm.send(CLOSE_MESSAGE, dest=i, tag=SPAWN_TAG)
142 while len(free) < self.size:
143 free.append(self.free_queue.get())
144 # close recieve-thread via worker-0
145 self.comm.send(CLOSE_MESSAGE, dest=0, tag=SPAWN_TAG)
149 msg = self.job_queue.get()
150 if msg == CLOSE_MESSAGE:
151 log().debug('%s closing' % self.name)
154 assert isinstance(msg, Job), msg
155 rank = self.free_queue.get()
156 log().debug('%s sending job %s to %d' % (self.name, msg, rank))
157 self.comm.send(msg, dest=rank, tag=SPAWN_TAG)
160 class ReceiveThread (ManagerThread):
164 msg = self.comm.recv(
165 source=MPI.ANY_SOURCE, tag=RECEIVE_TAG, status=s)
166 if msg == CLOSE_MESSAGE:
167 log().debug('%s closing' % self.name)
168 self.comm.Disconnect()
170 rank = s.Get_source()
171 self.free_queue.put(rank)
172 log().debug('%s got job %s from %d' % (self.name, msg, rank))
173 assert isinstance(msg, Job), msg
174 self.job_queue.put(msg)
177 class MPIManager (ThreadManager):
178 __doc__ = """Manage asynchronous `Job` execution via :mod:`pbs`.
180 >>> from math import sqrt
181 >>> m = MPIManager()%(skip)s
183 >>> for i in range(10):
184 ... group_A.append(m.async_invoke(Job(target=sqrt, args=[i])))%(skip)s
186 >>> for i in range(10):
187 ... group_B.append(m.async_invoke(Job(target=sqrt, args=[i],
188 ... blocks_on=[j.id for j in group_A])))%(skip)s
189 >>> jobs = m.wait(ids=[j.id for j in group_A[5:8]])%(skip)s
190 >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
191 [<Job 5>, <Job 6>, <Job 7>]
192 >>> jobs = m.wait()%(skip)s
193 >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
194 ... # doctest: +NORMALIZE_WHITESPACE
195 [<Job 0>, <Job 1>, <Job 2>, <Job 3>, <Job 4>, <Job 8>, <Job 9>, <Job 10>,
196 <Job 11>, <Job 12>, <Job 13>, <Job 14>, <Job 15>, <Job 16>, <Job 17>,
198 >>> m.teardown()%(skip)s
199 """ % {'skip': _SKIP}
201 def __init__(self, worker_pool=None):
203 super(MPIManager, self).__init__(worker_pool)
205 def _spawn_workers(self, worker_pool):
206 spawn_script = ';'.join([
207 'from %s import WorkerProcess' % __name__,
208 'w = WorkerProcess()',
212 if worker_pool is None:
213 worker_pool = int(os.environ.get('WORKER_POOL',
214 MPI.COMM_WORLD.Get_size()))
215 comm = MPI.COMM_SELF.Spawn( # locks with mpich2 if no mpd running
216 sys.executable, args=['-c', spawn_script], maxprocs=worker_pool)
217 rank = comm.Get_rank()
218 assert rank == 0, rank
219 # `comm` connects `COMM_WORLD` with the spawned group
220 # (intercommunicator), so `comm.Get_size() == 1` regardless of
221 # `worker_pool`. We want to know the size of the worker pool,
225 for worker_rank in range(size):
226 free_queue.put(worker_rank)
229 for worker in [SpawnThread(self._spawn_queue, free_queue,
231 name='spawn-thread'),
232 ReceiveThread(self._receive_queue, free_queue,
234 name='receive-thread'),
236 log().debug('start %s' % worker.name)
238 self._workers.append(worker)