Correct mpi4py URL.
[sawsim.git] / pysawsim / manager / mpi.py
1 # Copyright (C) 2010  W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation, either version 3 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License
14 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 #
16 # The author may be contacted at <wking@drexel.edu> on the Internet, or
17 # write to Trevor King, Drudge's University, Physics Dept., 3141 Chestnut St.,
18 # Philadelphia PA 19104, USA.
19
20 """Functions for running external commands on other hosts.
21
22 mpi4py_ is a Python wrapper around MPI.
23
24 .. _mpi4py: http://mpi4py.scipy.org/
25
26 The MPIManager data flow is a bit complicated, so I've made a
27 diagram::
28
29     [original intracommunicator]    [intercom]   [spawned intracommunicator]
30                                         |
31                 ,->--(spawn-thread)-->--+-->------------->-(worker-0)-.
32     (main-thread)                       |            `--->-(worker-1)-|
33                 `-<-(receive-thread)-<--+--<-----.     `->-(worker-2)-|
34                                         |         `-------<-----------'
35
36 The connections are:
37
38 ==============  ==============  ===============================
39 Source          Target          Connection
40 ==============  ==============  ===============================
41 main-thread     spawn-thread    spawn_queue
42 spawn-thread    worker-*        SpawnThread.comm(SPAWN_TAG)
43 worker-*        receive-thread  ReceiveThread.comm(RECEIVE_TAG)
44 receive-thread  main-thread     receive_queue
45 ==============  ==============  ===============================
46
47 There is also a `free_queue` running from `receive-thread` to
48 `spawn-thread` to mark job completion so `spawn-thread` knows which
49 nodes are free (and therefore ready to receive new jobs).
50 """
51
52 from Queue import Queue, Empty
53 import sys
54 from threading import Thread
55
56 try:
57     from mpi4py import MPI
58     _ENABLED = True
59     _DISABLING_ERROR = None
60     if MPI.COMM_WORLD.Get_rank() == 0:
61         _SKIP = ''
62     else:
63         _SKIP = '  # doctest: +SKIP'
64 except ImportError, _DISABLING_ERROR:
65     _ENABLED = False
66     _SKIP = '  # doctest: +SKIP'
67
68 from .. import log
69 from . import Job
70 from .thread import CLOSE_MESSAGE, ThreadManager
71
72
73 CLOSE_MESSAGE = "close"
74 SPAWN_TAG = 100
75 RECEIVE_TAG = 101
76
77
78 def MPI_worker_death():
79     if _ENABLED != True:
80         return
81     if MPI.COMM_WORLD.Get_rank() != 0:
82         sys.exit(0)
83
84 def _manager_check():
85     if _ENABLED == False:
86         raise _DISABLING_ERROR
87     rank = MPI.COMM_WORLD.Get_rank()
88     assert rank == 0, (
89         'process %d should have been killed by an MPI_worker_death() call'
90         % rank)
91
92
93 class WorkerProcess (object):
94     def __init__(self):
95         self.comm = MPI.Comm.Get_parent()  # intercommunicator
96         self.rank = self.comm.Get_rank()   # *intracom* rank?
97         self.manager = 0
98         self.name = 'worker-%d' % self.rank
99         log().debug('%s started' % self.name)
100
101     def teardown(self):
102         if self.rank == 0:
103             # only one worker needs to disconnect from the intercommunicator.
104             self.comm.Disconnect()
105
106     def run(self):
107         s = MPI.Status()
108         while True:
109             msg = self.comm.recv(source=self.manager, tag=SPAWN_TAG, status=s)
110             if msg == CLOSE_MESSAGE:
111                 log().debug('%s closing' % self.name)
112                 break
113             assert isinstance(msg, Job), msg
114             log().debug('%s running job %s' % (self.name, msg))
115             msg.run()
116             self.comm.send(msg, dest=self.manager, tag=RECEIVE_TAG)
117         if self.rank == 0:
118             # forward close message to receive-thread
119             self.comm.send(CLOSE_MESSAGE, dest=self.manager, tag=RECEIVE_TAG)
120
121
122 class ManagerThread (Thread):
123     def __init__(self, job_queue, free_queue, comm, rank, size,
124                  *args, **kwargs):
125         super(ManagerThread, self).__init__(*args, **kwargs)
126         self.job_queue = job_queue
127         self.free_queue = free_queue
128         self.comm = comm
129         self.rank = rank
130         self.size = size
131         self.name = self.getName()  # work around Pythons < 2.6
132         log().debug('%s starting' % self.name)
133
134
135 class SpawnThread (ManagerThread):
136     def teardown(self):
137         for i in range(self.size):
138             if i != 0:
139                 self.comm.send(CLOSE_MESSAGE, dest=i, tag=SPAWN_TAG)
140         free = []
141         while len(free) < self.size:
142             free.append(self.free_queue.get())
143         # close recieve-thread via worker-0
144         self.comm.send(CLOSE_MESSAGE, dest=0, tag=SPAWN_TAG)
145
146     def run(self):
147         while True:
148             msg = self.job_queue.get()
149             if msg == CLOSE_MESSAGE:
150                 log().debug('%s closing' % self.name)
151                 self.teardown()
152                 break
153             assert isinstance(msg, Job), msg
154             rank = self.free_queue.get()
155             log().debug('%s sending job %s to %d' % (self.name, msg, rank))
156             self.comm.send(msg, dest=rank, tag=SPAWN_TAG)
157
158
159 class ReceiveThread (ManagerThread):
160     def run(self):
161         s = MPI.Status()
162         while True:
163             msg = self.comm.recv(
164                 source=MPI.ANY_SOURCE, tag=RECEIVE_TAG, status=s)
165             if msg == CLOSE_MESSAGE:
166                 log().debug('%s closing' % self.name)
167                 self.comm.Disconnect()
168                 break
169             rank = s.Get_source()
170             self.free_queue.put(rank)
171             log().debug('%s got job %s from %d' % (self.name, msg, rank))
172             assert isinstance(msg, Job), msg
173             self.job_queue.put(msg)
174
175
176 class MPIManager (ThreadManager):
177     __doc__ = """Manage asynchronous `Job` execution via :mod:`pbs`.
178
179     >>> from math import sqrt
180     >>> m = MPIManager()%(skip)s
181     >>> group_A = []
182     >>> for i in range(10):
183     ...     group_A.append(m.async_invoke(Job(target=sqrt, args=[i])))%(skip)s
184     >>> group_B = []
185     >>> for i in range(10):
186     ...     group_B.append(m.async_invoke(Job(target=sqrt, args=[i],
187     ...                 blocks_on=[j.id for j in group_A])))%(skip)s
188     >>> jobs = m.wait(ids=[j.id for j in group_A[5:8]])%(skip)s
189     >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
190     [<Job 5>, <Job 6>, <Job 7>]
191     >>> jobs = m.wait()%(skip)s
192     >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
193     ... # doctest: +NORMALIZE_WHITESPACE
194     [<Job 0>, <Job 1>, <Job 2>, <Job 3>, <Job 4>, <Job 8>, <Job 9>, <Job 10>,
195      <Job 11>, <Job 12>, <Job 13>, <Job 14>, <Job 15>, <Job 16>, <Job 17>,
196      <Job 18>, <Job 19>]
197     >>> m.teardown()%(skip)s
198     """ % {'skip': _SKIP}
199
200     def __init__(self, worker_pool=None):
201         _manager_check()
202         super(MPIManager, self).__init__(worker_pool)
203
204     def _spawn_workers(self, worker_pool):
205         spawn_script = ';'.join([
206                 'from %s import WorkerProcess' % __name__,
207                 'w = WorkerProcess()',
208                 'w.run()',
209                 'w.teardown()',
210                 ])
211         if worker_pool == None:
212             worker_pool = MPI.COMM_WORLD.Get_size()
213         comm = MPI.COMM_SELF.Spawn(  # locks with mpich2 if no mpd running
214             sys.executable, args=['-c', spawn_script], maxprocs=worker_pool)
215         rank = comm.Get_rank()
216         assert rank == 0, rank
217         # `comm` connects `COMM_WORLD` with the spawned group
218         # (intercommunicator), so `comm.Get_size() == 1` regardless of
219         # `worker_pool`.  We want to know the size of the worker pool,
220         # so we just use:
221         size = worker_pool
222         free_queue = Queue()
223         for worker_rank in range(size):
224             free_queue.put(worker_rank)
225
226         self._workers = []
227         for worker in [SpawnThread(self._spawn_queue, free_queue,
228                                    comm, rank, size,
229                                    name='spawn-thread'),
230                        ReceiveThread(self._receive_queue, free_queue,
231                                      comm, rank, size,
232                                      name='receive-thread'),
233                        ]:
234             log().debug('start %s' % worker.name)
235             worker.start()
236             self._workers.append(worker)