efbfbf: upgrade to Bugs Everywhere Directory v1.5
[sawsim.git] / pysawsim / manager / mpi.py
1 # Copyright (C) 2010  W. Trevor King <wking@drexel.edu>
2 #
3 # This program is free software: you can redistribute it and/or modify
4 # it under the terms of the GNU General Public License as published by
5 # the Free Software Foundation, either version 3 of the License, or
6 # (at your option) any later version.
7 #
8 # This program is distributed in the hope that it will be useful,
9 # but WITHOUT ANY WARRANTY; without even the implied warranty of
10 # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
11 # GNU General Public License for more details.
12 #
13 # You should have received a copy of the GNU General Public License
14 # along with this program.  If not, see <http://www.gnu.org/licenses/>.
15 #
16 # The author may be contacted at <wking@drexel.edu> on the Internet, or
17 # write to Trevor King, Drudge's University, Physics Dept., 3141 Chestnut St.,
18 # Philadelphia PA 19104, USA.
19
20 """Functions for running external commands on other hosts.
21
22 mpi4py_ is a Python wrapper around MPI.
23
24 .. _mpi4py: http://mpi4py.scipy.org/
25
26 The MPIManager data flow is a bit complicated, so I've made a
27 diagram::
28
29     [original intracommunicator]    [intercom]   [spawned intracommunicator]
30                                         |
31                 ,->--(spawn-thread)-->--+-->------------->-(worker-0)-.
32     (main-thread)                       |            `--->-(worker-1)-|
33                 `-<-(receive-thread)-<--+--<-----.     `->-(worker-2)-|
34                                         |         `-------<-----------'
35
36 The connections are:
37
38 ==============  ==============  ===============================
39 Source          Target          Connection
40 ==============  ==============  ===============================
41 main-thread     spawn-thread    spawn_queue
42 spawn-thread    worker-*        SpawnThread.comm(SPAWN_TAG)
43 worker-*        receive-thread  ReceiveThread.comm(RECEIVE_TAG)
44 receive-thread  main-thread     receive_queue
45 ==============  ==============  ===============================
46
47 There is also a `free_queue` running from `receive-thread` to
48 `spawn-thread` to mark job completion so `spawn-thread` knows which
49 nodes are free (and therefore ready to receive new jobs).
50 """
51
52 import os
53 from Queue import Queue, Empty
54 import sys
55 from threading import Thread
56
57 try:
58     from mpi4py import MPI
59     _ENABLED = True
60     _DISABLING_ERROR = None
61     if MPI.COMM_WORLD.Get_rank() == 0:
62         _SKIP = ''
63     else:
64         _SKIP = '  # doctest: +SKIP'
65 except ImportError, _DISABLING_ERROR:
66     _ENABLED = False
67     _SKIP = '  # doctest: +SKIP'
68
69 from .. import log
70 from . import Job
71 from .thread import CLOSE_MESSAGE, ThreadManager
72
73
74 CLOSE_MESSAGE = "close"
75 SPAWN_TAG = 100
76 RECEIVE_TAG = 101
77
78
79 def MPI_worker_death():
80     if _ENABLED != True:
81         return
82     if MPI.COMM_WORLD.Get_rank() != 0:
83         sys.exit(0)
84
85 def _manager_check():
86     if _ENABLED == False:
87         raise _DISABLING_ERROR
88     rank = MPI.COMM_WORLD.Get_rank()
89     assert rank == 0, (
90         'process %d should have been killed by an MPI_worker_death() call'
91         % rank)
92
93
94 class WorkerProcess (object):
95     def __init__(self):
96         self.comm = MPI.Comm.Get_parent()  # intercommunicator
97         self.rank = self.comm.Get_rank()   # *intracom* rank?
98         self.manager = 0
99         self.name = 'worker-%d' % self.rank
100         log().debug('%s started' % self.name)
101
102     def teardown(self):
103         if self.rank == 0:
104             # only one worker needs to disconnect from the intercommunicator.
105             self.comm.Disconnect()
106
107     def run(self):
108         s = MPI.Status()
109         while True:
110             msg = self.comm.recv(source=self.manager, tag=SPAWN_TAG, status=s)
111             if msg == CLOSE_MESSAGE:
112                 log().debug('%s closing' % self.name)
113                 break
114             assert isinstance(msg, Job), msg
115             log().debug('%s running job %s' % (self.name, msg))
116             msg.run()
117             self.comm.send(msg, dest=self.manager, tag=RECEIVE_TAG)
118         if self.rank == 0:
119             # forward close message to receive-thread
120             self.comm.send(CLOSE_MESSAGE, dest=self.manager, tag=RECEIVE_TAG)
121
122
123 class ManagerThread (Thread):
124     def __init__(self, job_queue, free_queue, comm, rank, size,
125                  *args, **kwargs):
126         super(ManagerThread, self).__init__(*args, **kwargs)
127         self.job_queue = job_queue
128         self.free_queue = free_queue
129         self.comm = comm
130         self.rank = rank
131         self.size = size
132         self.name = self.getName()  # work around Pythons < 2.6
133         log().debug('%s starting' % self.name)
134
135
136 class SpawnThread (ManagerThread):
137     def teardown(self):
138         for i in range(self.size):
139             if i != 0:
140                 self.comm.send(CLOSE_MESSAGE, dest=i, tag=SPAWN_TAG)
141         free = []
142         while len(free) < self.size:
143             free.append(self.free_queue.get())
144         # close recieve-thread via worker-0
145         self.comm.send(CLOSE_MESSAGE, dest=0, tag=SPAWN_TAG)
146
147     def run(self):
148         while True:
149             msg = self.job_queue.get()
150             if msg == CLOSE_MESSAGE:
151                 log().debug('%s closing' % self.name)
152                 self.teardown()
153                 break
154             assert isinstance(msg, Job), msg
155             rank = self.free_queue.get()
156             log().debug('%s sending job %s to %d' % (self.name, msg, rank))
157             self.comm.send(msg, dest=rank, tag=SPAWN_TAG)
158
159
160 class ReceiveThread (ManagerThread):
161     def run(self):
162         s = MPI.Status()
163         while True:
164             msg = self.comm.recv(
165                 source=MPI.ANY_SOURCE, tag=RECEIVE_TAG, status=s)
166             if msg == CLOSE_MESSAGE:
167                 log().debug('%s closing' % self.name)
168                 self.comm.Disconnect()
169                 break
170             rank = s.Get_source()
171             self.free_queue.put(rank)
172             log().debug('%s got job %s from %d' % (self.name, msg, rank))
173             assert isinstance(msg, Job), msg
174             self.job_queue.put(msg)
175
176
177 class MPIManager (ThreadManager):
178     __doc__ = """Manage asynchronous `Job` execution via :mod:`pbs`.
179
180     >>> from math import sqrt
181     >>> m = MPIManager()%(skip)s
182     >>> group_A = []
183     >>> for i in range(10):
184     ...     group_A.append(m.async_invoke(Job(target=sqrt, args=[i])))%(skip)s
185     >>> group_B = []
186     >>> for i in range(10):
187     ...     group_B.append(m.async_invoke(Job(target=sqrt, args=[i],
188     ...                 blocks_on=[j.id for j in group_A])))%(skip)s
189     >>> jobs = m.wait(ids=[j.id for j in group_A[5:8]])%(skip)s
190     >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
191     [<Job 5>, <Job 6>, <Job 7>]
192     >>> jobs = m.wait()%(skip)s
193     >>> print sorted(jobs.values(), key=lambda j: j.id)%(skip)s
194     ... # doctest: +NORMALIZE_WHITESPACE
195     [<Job 0>, <Job 1>, <Job 2>, <Job 3>, <Job 4>, <Job 8>, <Job 9>, <Job 10>,
196      <Job 11>, <Job 12>, <Job 13>, <Job 14>, <Job 15>, <Job 16>, <Job 17>,
197      <Job 18>, <Job 19>]
198     >>> m.teardown()%(skip)s
199     """ % {'skip': _SKIP}
200
201     def __init__(self, worker_pool=None):
202         _manager_check()
203         super(MPIManager, self).__init__(worker_pool)
204
205     def _spawn_workers(self, worker_pool):
206         spawn_script = ';'.join([
207                 'from %s import WorkerProcess' % __name__,
208                 'w = WorkerProcess()',
209                 'w.run()',
210                 'w.teardown()',
211                 ])
212         if worker_pool is None:
213             worker_pool = int(os.environ.get('WORKER_POOL',
214                                              MPI.COMM_WORLD.Get_size()))
215         comm = MPI.COMM_SELF.Spawn(  # locks with mpich2 if no mpd running
216             sys.executable, args=['-c', spawn_script], maxprocs=worker_pool)
217         rank = comm.Get_rank()
218         assert rank == 0, rank
219         # `comm` connects `COMM_WORLD` with the spawned group
220         # (intercommunicator), so `comm.Get_size() == 1` regardless of
221         # `worker_pool`.  We want to know the size of the worker pool,
222         # so we just use:
223         size = worker_pool
224         free_queue = Queue()
225         for worker_rank in range(size):
226             free_queue.put(worker_rank)
227
228         self._workers = []
229         for worker in [SpawnThread(self._spawn_queue, free_queue,
230                                    comm, rank, size,
231                                    name='spawn-thread'),
232                        ReceiveThread(self._receive_queue, free_queue,
233                                      comm, rank, size,
234                                      name='receive-thread'),
235                        ]:
236             log().debug('start %s' % worker.name)
237             worker.start()
238             self._workers.append(worker)